Skip to content

Commit

Permalink
Export Word Language Model to ONNX (#348)
Browse files Browse the repository at this point in the history
* Export Word Language Model to ONNX

* Make ONNX export optional
  • Loading branch information
houseroad authored and soumith committed May 4, 2018
1 parent 0604520 commit eee5ca3
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion word_language_model/main.py
Expand Up @@ -2,8 +2,10 @@
import argparse
import time
import math
import os
import torch
import torch.nn as nn
import torch.onnx

import data
import model
Expand Down Expand Up @@ -39,8 +41,10 @@
help='use CUDA')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
help='report interval')
parser.add_argument('--save', type=str, default='model.pt',
parser.add_argument('--save', type=str, default='model.pt',
help='path to save the final model')
parser.add_argument('--onnx-export', type=str, default='',
help='path to export the final model in onnx format')
args = parser.parse_args()

# Set the random seed manually for reproducibility.
Expand Down Expand Up @@ -171,6 +175,16 @@ def train():
total_loss = 0
start_time = time.time()


def export_onnx(path, batch_size, seq_len):
print('The model is also exported in ONNX format at {}'.
format(os.path.realpath(args.onnx_export)))
model.eval()
dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(-1, batch_size).to(device)
hidden = model.init_hidden(batch_size)
torch.onnx.export(model, (dummy_input, hidden), path)


# Loop over epochs.
lr = args.lr
best_val_loss = None
Expand Down Expand Up @@ -211,3 +225,7 @@ def train():
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)

if len(args.onnx_export) > 0:
# Export the model in ONNX format.
export_onnx(args.onnx_export, batch_size=1, seq_len=args.bptt)

0 comments on commit eee5ca3

Please sign in to comment.