-
Notifications
You must be signed in to change notification settings - Fork 9.5k
/
generate.py
86 lines (74 loc) · 3.42 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
###############################################################################
# Language Modeling on Wikitext-2
#
# This file generates new sentences sampled from the language model.
#
###############################################################################
import argparse
import torch
import data
parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 Language Model')
# Model parameters.
parser.add_argument('--data', type=str, default='./data/wikitext-2',
help='location of the data corpus')
parser.add_argument('--checkpoint', type=str, default='./model.pt',
help='model checkpoint to use')
parser.add_argument('--outf', type=str, default='generated.txt',
help='output file for generated text')
parser.add_argument('--words', type=int, default='1000',
help='number of words to generate')
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
parser.add_argument('--cuda', action='store_true',
help='use CUDA')
parser.add_argument('--mps', action='store_true', default=False,
help='enables macOS GPU training')
parser.add_argument('--temperature', type=float, default=1.0,
help='temperature - higher will increase diversity')
parser.add_argument('--log-interval', type=int, default=100,
help='reporting interval')
args = parser.parse_args()
# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
if torch.cuda.is_available():
if not args.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda.")
if torch.backends.mps.is_available():
if not args.mps:
print("WARNING: You have mps device, to enable macOS GPU run with --mps.")
use_mps = args.mps and torch.backends.mps.is_available()
if args.cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")
if args.temperature < 1e-3:
parser.error("--temperature has to be greater or equal 1e-3.")
with open(args.checkpoint, 'rb') as f:
model = torch.load(f, map_location=device)
model.eval()
corpus = data.Corpus(args.data)
ntokens = len(corpus.dictionary)
is_transformer_model = hasattr(model, 'model_type') and model.model_type == 'Transformer'
if not is_transformer_model:
hidden = model.init_hidden(1)
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
with open(args.outf, 'w') as outf:
with torch.no_grad(): # no tracking history
for i in range(args.words):
if is_transformer_model:
output = model(input, False)
word_weights = output[-1].squeeze().div(args.temperature).exp().cpu()
word_idx = torch.multinomial(word_weights, 1)[0]
word_tensor = torch.Tensor([[word_idx]]).long().to(device)
input = torch.cat([input, word_tensor], 0)
else:
output, hidden = model(input, hidden)
word_weights = output.squeeze().div(args.temperature).exp().cpu()
word_idx = torch.multinomial(word_weights, 1)[0]
input.fill_(word_idx)
word = corpus.dictionary.idx2word[word_idx]
outf.write(word + ('\n' if i % 20 == 19 else ' '))
if i % args.log_interval == 0:
print('| Generated {}/{} words'.format(i, args.words))