-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
98 lines (70 loc) · 3.12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import spacy
import torch
import logging
from torchtext.data import bleu_score
from tqdm import tqdm
logger = logging.getLogger(__name__)
def lines_from_file_path(path, strip=True):
"""
:param strip:
:param path: string or PosixPath to file
:return: generator object where next is a single \n rstripped line from file
"""
logger.info(f'processing lines from file: {path}')
# need to specify newline explicitly for some files to be split on new lines properly
with open(path, encoding='utf-8', mode='r', newline='\n') as file_handle:
for line in file_handle:
if strip:
yield line.rstrip('\n')
else:
yield line
def save_checkpoint(state, filename="model_checkpoint.pth.tar"):
logger.info(f'Saving checkpoint to {filename}')
torch.save(state, filename)
def load_checkpoint(checkpoint, model, optimizer):
logger.info(f'Loading checkpoint')
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
def translate_sentence(model, sentence, src_field, trg_field, device, src_lang='de', max_length=150):
# Load german tokenizer
spacy_de = spacy.load(src_lang)
# Create tokens using spacy and everything in lower case (which is what our vocab is)
if isinstance(sentence, str):
tokens = [token.text.lower() for token in spacy_de(sentence)]
else:
tokens = [token.lower() for token in sentence]
# Add <SOS> and <EOS> in beginning and end respectively
tokens.insert(0, src_field.init_token)
tokens.append(src_field.eos_token)
# Go through each german token and convert to an index
text_to_indices = [src_field.vocab.stoi[token] for token in tokens]
# Convert to Tensor
# Tensor shape needs to be (batch size, seq len)
src_tensor = torch.LongTensor(text_to_indices).unsqueeze(0).to(device)
src_mask = model.make_src_mask(src_tensor)
with torch.no_grad():
enc_src = model.encoder(src_tensor, src_mask)
trg_token_indices = [trg_field.vocab.stoi[trg_field.init_token]]
for i in range(max_length):
trg_tensor = torch.LongTensor(trg_token_indices).unsqueeze(0).to(device)
trg_mask = model.make_trg_mask(trg_tensor)
with torch.no_grad():
output = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
predicted_token_index = output.argmax(2)[:, -1].item()
trg_token_indices.append(predicted_token_index)
if predicted_token_index == trg_field.vocab.stoi[trg_field.eos_token]:
break
translated_sentence = [trg_field.vocab.itos[idx] for idx in trg_token_indices]
# remove start token
return translated_sentence[1:]
def calc_bleu_score(data, model, src_field, trg_field, device):
targets = []
outputs = []
for example in tqdm(data):
src = vars(example)['src']
trg = vars(example)['trg']
prediction = translate_sentence(model, src, src_field, trg_field, device)
prediction = prediction[:-1] # remove <eos> token
targets.append([trg])
outputs.append(prediction)
return bleu_score(outputs, targets)