diff --git a/snli/train.py b/snli/train.py index d80c790bf7..60ec0e75ad 100644 --- a/snli/train.py +++ b/snli/train.py @@ -38,6 +38,8 @@ config.n_embed = len(inputs.vocab) config.d_out = len(answers.vocab) config.n_cells = config.n_layers + +# double the number of cells for bidirectional networks if config.birnn: config.n_cells *= 2 @@ -66,14 +68,27 @@ train_iter.init_epoch() n_correct, n_total = 0, 0 for batch_idx, batch in enumerate(train_iter): + + # switch model to training mode, clear gradient accumulators model.train(); opt.zero_grad() + iterations += 1 + + # forward pass answer = model(batch) + + # calculate accuracy of predictions in the current batch n_correct += (torch.max(answer, 1)[1].view(batch.label.size()).data == batch.label.data).sum() n_total += batch.batch_size train_acc = 100. * n_correct/n_total + + # calculate loss of the network output with respect to training labels loss = criterion(answer, batch.label) + + # backpropagate and update optimizer learning rate loss.backward(); opt.step() + + # checkpoint model periodically if iterations % args.save_every == 0: snapshot_prefix = os.path.join(args.save_path, 'snapshot') snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, loss.data[0], iterations) @@ -81,26 +96,43 @@ for f in glob.glob(snapshot_prefix + '*'): if f != snapshot_path: os.remove(f) + + # evaluate performance on validation set periodically if iterations % args.dev_every == 0: + + # switch model to evaluation mode model.eval(); dev_iter.init_epoch() + + # calculate accuracy on validation set n_dev_correct, dev_loss = 0, 0 for dev_batch_idx, dev_batch in enumerate(dev_iter): answer = model(dev_batch) n_dev_correct += (torch.max(answer, 1)[1].view(dev_batch.label.size()).data == dev_batch.label.data).sum() dev_loss = criterion(answer, dev_batch.label) dev_acc = 100. * n_dev_correct / len(dev) + print(dev_log_template.format(time.time()-start, epoch, iterations, 1+batch_idx, len(train_iter), 100. * (1+batch_idx) / len(train_iter), loss.data[0], dev_loss.data[0], train_acc, dev_acc)) + + # update best valiation set accuracy if dev_acc > best_dev_acc: + + # found a model with better validation set accuracy + best_dev_acc = dev_acc snapshot_prefix = os.path.join(args.save_path, 'best_snapshot') snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}__iter_{}_model.pt'.format(dev_acc, dev_loss.data[0], iterations) + + # save model, delete previous 'best_snapshot' files torch.save(model, snapshot_path) for f in glob.glob(snapshot_prefix + '*'): if f != snapshot_path: os.remove(f) + elif iterations % args.log_every == 0: + + # print progress message print(log_template.format(time.time()-start, epoch, iterations, 1+batch_idx, len(train_iter), 100. * (1+batch_idx) / len(train_iter), loss.data[0], ' '*8, n_correct/n_total*100, ' '*12))