Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions snli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
config.n_embed = len(inputs.vocab)
config.d_out = len(answers.vocab)
config.n_cells = config.n_layers

# double the number of cells for bidirectional networks
if config.birnn:
config.n_cells *= 2

Expand Down Expand Up @@ -66,41 +68,71 @@
train_iter.init_epoch()
n_correct, n_total = 0, 0
for batch_idx, batch in enumerate(train_iter):

# switch model to training mode, clear gradient accumulators
model.train(); opt.zero_grad()

iterations += 1

# forward pass
answer = model(batch)

# calculate accuracy of predictions in the current batch
n_correct += (torch.max(answer, 1)[1].view(batch.label.size()).data == batch.label.data).sum()
n_total += batch.batch_size
train_acc = 100. * n_correct/n_total

# calculate loss of the network output with respect to training labels
loss = criterion(answer, batch.label)

# backpropagate and update optimizer learning rate
loss.backward(); opt.step()

# checkpoint model periodically
if iterations % args.save_every == 0:
snapshot_prefix = os.path.join(args.save_path, 'snapshot')
snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, loss.data[0], iterations)
torch.save(model, snapshot_path)
for f in glob.glob(snapshot_prefix + '*'):
if f != snapshot_path:
os.remove(f)

# evaluate performance on validation set periodically
if iterations % args.dev_every == 0:

# switch model to evaluation mode
model.eval(); dev_iter.init_epoch()

# calculate accuracy on validation set
n_dev_correct, dev_loss = 0, 0
for dev_batch_idx, dev_batch in enumerate(dev_iter):
answer = model(dev_batch)
n_dev_correct += (torch.max(answer, 1)[1].view(dev_batch.label.size()).data == dev_batch.label.data).sum()
dev_loss = criterion(answer, dev_batch.label)
dev_acc = 100. * n_dev_correct / len(dev)

print(dev_log_template.format(time.time()-start,
epoch, iterations, 1+batch_idx, len(train_iter),
100. * (1+batch_idx) / len(train_iter), loss.data[0], dev_loss.data[0], train_acc, dev_acc))

# update best valiation set accuracy
if dev_acc > best_dev_acc:

# found a model with better validation set accuracy

best_dev_acc = dev_acc
snapshot_prefix = os.path.join(args.save_path, 'best_snapshot')
snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}__iter_{}_model.pt'.format(dev_acc, dev_loss.data[0], iterations)

# save model, delete previous 'best_snapshot' files
torch.save(model, snapshot_path)
for f in glob.glob(snapshot_prefix + '*'):
if f != snapshot_path:
os.remove(f)

elif iterations % args.log_every == 0:

# print progress message
print(log_template.format(time.time()-start,
epoch, iterations, 1+batch_idx, len(train_iter),
100. * (1+batch_idx) / len(train_iter), loss.data[0], ' '*8, n_correct/n_total*100, ' '*12))
Expand Down