diff --git a/word_language_model/main.py b/word_language_model/main.py index 412c5a20a7..a15083bf42 100644 --- a/word_language_model/main.py +++ b/word_language_model/main.py @@ -2,8 +2,10 @@ import argparse import time import math +import os import torch import torch.nn as nn +import torch.onnx import data import model @@ -39,8 +41,10 @@ help='use CUDA') parser.add_argument('--log-interval', type=int, default=200, metavar='N', help='report interval') -parser.add_argument('--save', type=str, default='model.pt', +parser.add_argument('--save', type=str, default='model.pt', help='path to save the final model') +parser.add_argument('--onnx-export', type=str, default='', + help='path to export the final model in onnx format') args = parser.parse_args() # Set the random seed manually for reproducibility. @@ -171,6 +175,16 @@ def train(): total_loss = 0 start_time = time.time() + +def export_onnx(path, batch_size, seq_len): + print('The model is also exported in ONNX format at {}'. + format(os.path.realpath(args.onnx_export))) + model.eval() + dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(-1, batch_size).to(device) + hidden = model.init_hidden(batch_size) + torch.onnx.export(model, (dummy_input, hidden), path) + + # Loop over epochs. lr = args.lr best_val_loss = None @@ -211,3 +225,7 @@ def train(): print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( test_loss, math.exp(test_loss))) print('=' * 89) + +if len(args.onnx_export) > 0: + # Export the model in ONNX format. + export_onnx(args.onnx_export, batch_size=1, seq_len=args.bptt)