Skip to content

Commit

Permalink
Gradient Checkpointing - LM1B Sampled Softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed May 3, 2019
1 parent a93b420 commit 941bbbf
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 2 deletions.
2 changes: 2 additions & 0 deletions examples/lm1b/argument.py
Expand Up @@ -49,4 +49,6 @@ def add_transformer_args(parser):
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', default=True, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--chkpt-grad', default=True, action='store_true',
help='checkpoint gradients to allow for training with larger models and sequences')
return parser
5 changes: 3 additions & 2 deletions examples/lm1b/transformer.py
Expand Up @@ -259,8 +259,9 @@ def __init__(self, args, embed_tokens, left_pad=False):
TransformerDecoderLayer(args)
for i in range(args.decoder_layers)
])
self.chkpt_grad = args.chkpt_grad

def forward(self, prev_output_tokens, encoder_out, incremental_state=None, chkpt_grad=False, **kwargs):
def forward(self, prev_output_tokens, encoder_out, incremental_state=None, **kwargs):
# embed positions
positions = self.embed_positions(
prev_output_tokens,
Expand Down Expand Up @@ -288,7 +289,7 @@ def custom_forward(*inputs):
return x_
return custom_forward

if self.training and chkpt_grad:
if self.training and self.chkpt_grad:
l = 0
num_layers = len(self.layers)
chunk_length = math.ceil(math.sqrt(num_layers))
Expand Down
1 change: 1 addition & 0 deletions examples/lm1b/transformer_main.py
Expand Up @@ -91,6 +91,7 @@

print("Sampled Softmax:", nsampled, "Batch Size:", args.batch_size, "Initial LR:", args.lr)
#optimizer = Adam(net.parameters(), args.lr, betas=(0.0, 0.999))
#optimizer = optim.RMSprop(net.parameters(), args.lr)
optimizer = RMSprop(net.parameters(), args.lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=train_corpus.batch_num*args.epochs, eta_min=1e-8)
net, optimizer = amp.initialize(net, optimizer, opt_level="O1")
Expand Down
2 changes: 2 additions & 0 deletions examples/word_language_model/main.py
Expand Up @@ -15,6 +15,7 @@
from sgd import SGD # Count-Sketch Momentum optimizer
from adagrad import Adagrad # Count-Sketch Adagrad optimizer
from adam import Adam # Count-Sketch Adam optimizer
from rmsprop import RMSprop # Count-Sketch RMSProp optimizer
#from adam_base import Adam # Baseline Adam optimizer supports sparse gradients
#from adafactor import Adam # Low-Rank Approximation Adam optimzer

Expand Down Expand Up @@ -106,6 +107,7 @@ def batchify(data, bsz):
optimizer = SGD(model.parameters(), args.lr, momentum=0.9, nesterov=True)
#optimizer = Adagrad(model.parameters(), args.lr)
#optimizer = Adam(model.parameters(), betas=(0.9, 0.999))
#optimizer = RMSprop(model.parameters())

###############################################################################
# Training code
Expand Down

0 comments on commit 941bbbf

Please sign in to comment.