-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
80 lines (62 loc) · 3.12 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
import matplotlib.pyplot as plt
import argparse
import torch.optim as optim
from random import shuffle
import pickle
from models import *
from utilities import *
PAD_TOKEN = 0
SOS_TOKEN = 1
EOS_TOKEN = 2
def load_dictionary(directory):
with open(directory + 'dictionary.pkl', 'rb') as f:
return pickle.load(f)
def main():
parser = argparse.ArgumentParser(description='Hyperparameters for training Transformer')
parser.add_argument('--model_directory', type=str, default='saved_models/', help='directory where models will be saved')
#default hyperparameters
parser.add_argument('--MAX_LENGTH', type=int, default=20, help='max length of a sentence')
parser.add_argument('--epochs_trained', type=int, default=10, help='number of epochs trained in train.py')
parser.add_argument('--hidden_size', type=int, default=512, help='size of hidden layer for encoder and decoder')
parser.add_argument('--encoder_n_layers', type=int, default=2, help='number of encoder gru layers')
parser.add_argument('--decoder_n_layers', type=int, default=2, help='number of decoder gru layers')
parser.add_argument('--encoder_dropout', type=float, default=0.1, help='dropout in encoder ff')
parser.add_argument('--decoder_dropout', type=float, default=0.1, help='dropout in decoder ff')
parser.add_argument('--device', type=str, default='cpu', help='device to run computations')
parser.add_argument('-gui', action='store_true')
args = parser.parse_args()
model_directory = args.model_directory
#hyperparameters
MAX_LENGTH = args.MAX_LENGTH
epochs_trained = args.epochs_trained
hidden_size = args.hidden_size
encoder_n_layers = args.encoder_n_layers
decoder_n_layers = args.decoder_n_layers
encoder_dropout = args.encoder_dropout
decoder_dropout = args.decoder_dropout
device = args.device
if not torch.cuda.is_available():
device = 'cpu'
dictionary = load_dictionary(model_directory)
#load embedder weights
embedding = nn.Embedding(dictionary.n_count, hidden_size)
embedding.load_state_dict(torch.load(model_directory + 'embedding_' + str(epochs_trained - 1) + '.pt'))
#load encoder weights
encoder = Encoder(hidden_size, embedding, encoder_n_layers, encoder_dropout).to(device)
encoder.load_state_dict(torch.load(model_directory + 'encoder_' + str(epochs_trained - 1) + '.pt'))
#load decoder weights
decoder = Decoder(hidden_size, embedding, dictionary.n_count, decoder_n_layers, decoder_dropout).to(device)
decoder.load_state_dict(torch.load(model_directory + 'decoder_' + str(epochs_trained - 1) + '.pt'))
model = Seq2Seq(encoder, decoder, device)
if args.gui:
evaluateGuiInput(encoder, decoder, model, dictionary, device, MAX_LENGTH)
else:
print('To exit chatbot, enter \'q\' or \'quit\'.\n')
evaluateInput(encoder, decoder, model, dictionary, device, MAX_LENGTH)
if __name__ == "__main__":
main()