### Loading data

In [57]:
from data.idiom_data.idiom_translation import TranslationDataset

with open('data/idiom_data/total_idioms.txt') as f:
    idioms = f.read()
    idiomatic_sentences = idioms.split("\n")
    
with open('data/idiom_data/total_translated_idioms.txt') as f:
    translated = f.read()
    plain_sentences = translated.split("\n")


print(len(idiomatic_sentences))
print(len(plain_sentences))
idiomatic_sentences = idiomatic_sentences[0:-1]
plain_sentences = plain_sentences[0:-1]
print(len(idiomatic_sentences))
print(len(plain_sentences))

101
101
100
100


### Vocab and Tokenizer Initialization

In [58]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

# Helper function to tokenize and build vocabulary
def yield_tokens(data_iter, tokenizer):
    for text in data_iter:
        yield tokenizer(text)


tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
vocab = build_vocab_from_iterator(yield_tokens(idiomatic_sentences + plain_sentences, tokenizer), specials=["<unk>", "<pad>", "<sos>", "<eos>"])
vocab.set_default_index(vocab["<unk>"])


In [59]:
import torch
BATCH_SIZE = 10

dataset = TranslationDataset(idiomatic_sentences, plain_sentences, vocab, tokenizer)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

### Model Initialization and Training

In [60]:
from models.idiom_model import Seq2Seq, Encoder, Decoder, Attention
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm

# Parameters
INPUT_DIM = len(vocab)
OUTPUT_DIM = len(vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
CLIP = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Init attention mask, encoder, and decoder
attn = Attention(HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT, attn)

# Init model
idiom_model = Seq2Seq(enc, dec, device, vocab, tokenizer).to(device)

# Define optimizer and Criterion
optimizer = optim.Adam(idiom_model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=vocab["<pad>"])

In [61]:
# Train_loop (details train_model in models.idiom_model)
losses = []
for epoch in range(1001):
    # print("start of epoch", epoch, "========================================")
    train_loss = idiom_model.train_model(dataloader=dataloader,optimizer=optimizer,criterion=criterion,clip=CLIP,vocab=vocab)
    losses.append(train_loss)
    if epoch %10 == 0:
        print(f'Epoch: {epoch+1:02}, Train Loss: {train_loss:.3f}')
        pass
    if epoch >= 300:
        print("current loss", train_loss)
        tmp_chk = True
        for i in range(5):
            if losses[-i] >= 0.001:
                tmp_chk = False
        if tmp_chk == True:
            print("Early Stopping at epoch", epoch)
            break


Epoch: 01, Train Loss: 5.470
Epoch: 11, Train Loss: 0.764
Epoch: 21, Train Loss: 0.066
Epoch: 31, Train Loss: 0.013
Epoch: 41, Train Loss: 0.005
Epoch: 51, Train Loss: 0.002
Epoch: 61, Train Loss: 0.022
Epoch: 71, Train Loss: 0.004
Epoch: 81, Train Loss: 0.002
Epoch: 91, Train Loss: 0.001
Epoch: 101, Train Loss: 0.001
Epoch: 111, Train Loss: 0.001
Epoch: 121, Train Loss: 0.001
Epoch: 131, Train Loss: 0.001
Epoch: 141, Train Loss: 0.001
Epoch: 151, Train Loss: 0.000
Epoch: 161, Train Loss: 0.216
Epoch: 171, Train Loss: 0.009
Epoch: 181, Train Loss: 0.057
Epoch: 191, Train Loss: 0.013
Epoch: 201, Train Loss: 0.001
Epoch: 211, Train Loss: 0.000
Epoch: 221, Train Loss: 0.000
Epoch: 231, Train Loss: 0.001
Epoch: 241, Train Loss: 0.000
Epoch: 251, Train Loss: 0.000
Epoch: 261, Train Loss: 0.000
Epoch: 271, Train Loss: 0.000
Epoch: 281, Train Loss: 0.000
Epoch: 291, Train Loss: 0.000
Epoch: 301, Train Loss: 0.000
current loss 9.443262315471657e-05
current loss 0.00013130527622706722
current l

### Saving Model

In [62]:
save_file_path = ('models_checkpoint/idiom_model.pth')  
torch.save(idiom_model.state_dict(), save_file_path)

### Additional Testing section

In [63]:
# Init a new model
new_model = Seq2Seq(enc,dec,device,vocab,tokenizer)
new_model.load_state_dict(torch.load(save_file_path))

<All keys matched successfully>

In [64]:
sentence = "in a really bad shape"
generated_sentence = new_model.sample(sentence)
print("Generated Sentence:", generated_sentence)

Generated Sentence: <unk> in poor condition


In [65]:
for i in range(0,100):
    sentence = idiomatic_sentences[i]
    generated_sentence = idiom_model.sample(sentence)
    print(i, "Generated Sentence:", generated_sentence)

0 Generated Sentence: <unk> as a precaution
1 Generated Sentence: <unk> something pitiful or disappointing to see
2 Generated Sentence: <unk> general guideline
3 Generated Sentence: <unk> seize the weather
4 Generated Sentence: <unk> youthful times
5 Generated Sentence: <unk> unofficially
6 Generated Sentence: <unk>
7 Generated Sentence: <unk> lot of money
8 Generated Sentence: <unk> hottest days of summer
9 Generated Sentence: <unk> inexperienced
10 Generated Sentence: <unk> deserved outcome
11 Generated Sentence: <unk> expensive
12 Generated Sentence: <unk> forget it
13 Generated Sentence: <unk> physical buildings
14 Generated Sentence: <unk> narrow escape
15 Generated Sentence: <unk> a welcome sight
16 Generated Sentence: <unk> obvious conflict
17 Generated Sentence: <unk> small amount of money small amount of money small amount of money small amount of money small amount of money small amount of money small amount of money small amount of money small amount of money small amount of