In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy.data import Field, TabularDataset, BucketIterator

from dataset.mtgcards import RuleText
from utils.preprocess import fields_for_rule_text

import random
import math
import time
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SRC, TRG = fields_for_rule_text()
fields = {'src': ('src', SRC), 'trg': ('trg', TRG)}

train_data, valid_data, test_data = RuleText.splits(fields=fields)

In [3]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)
print(f"Unique tokens in source (en) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (zh) vocabulary: {len(TRG.vocab)}")

Unique tokens in source (en) vocabulary: 4991
Unique tokens in target (zh) vocabulary: 2357


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE = 32

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    sort_within_batch = True,
    sort_key = lambda x: len(x.src),
    device = device)

tmp = next(iter(train_iterator))
print(tmp)

cuda

[torchtext.legacy.data.batch.Batch of size 32]
	[.src]:('[torch.cuda.LongTensor of size 23x32 (GPU 0)]', '[torch.cuda.LongTensor of size 32 (GPU 0)]')
	[.trg]:[torch.cuda.LongTensor of size 31x32 (GPU 0)]


In [5]:
from models.model4.definition import Encoder, Attention, Decoder, Seq2Seq
from models.model4.train import init_weights, train, evaluate
from utils import count_parameters, train_loop

In [6]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]

attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, SRC_PAD_IDX, device).to(device)

model.apply(init_weights)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 12,540,469 trainable parameters


In [7]:
optimizer = optim.Adam(model.parameters())
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

train_loop(model, optimizer, criterion, train, evaluate,
           train_iterator, valid_iterator, 
           save_path='result/', file_name='tut4-model.pt', load_before_train=True)

model will be saved to result/tut4-model.pt
load model parameters from result/tut4-model.pt
Val. Loss: 1.269 |  Val. PPL:   3.559


100%|██████████| 1219/1219 [01:26<00:00, 14.14it/s]


Epoch: 01 | Time: 1m 27s
	Train Loss: 0.426 | Train PPL:   1.530
	 Val. Loss: 1.522 |  Val. PPL:   4.580


 22%|██▏       | 266/1219 [00:18<01:07, 14.07it/s]


KeyboardInterrupt: 

In [8]:
from utils.translate import Translator
from models.model4.definition import beam_search

In [9]:
model.load_state_dict(torch.load('result/tut4-model.pt', map_location=torch.device(device)))
T = Translator(SRC, TRG, model, device, beam_search)

In [10]:
ret = T.translate('Unlicensed Hearse’s power and toughness are each equal to the number of cards exiled with it.')
print(*ret[0][:3], sep='\n')

['<cn>', '的', '力量', '和', '防御力', '各', '等', '同于', '以', '它', '放逐', '的', '牌', '数量', '。', '<eos>']
['<cn>', '的', '力量', '和', '防御力', '各', '等', '同于', '它', '放逐', '之', '牌', '的', '数量', '。', '<eos>']
['<cn>', '的', '力量', '和', '防御力', '各', '等', '同于', '以', '它', '放逐', '之', '牌', '的', '数量', '。', '<eos>']


In [11]:
from utils import show_samples
long_data = [x for x in test_data.examples if len(x.src) > 30]
print(f'Number of samples: {len(long_data)}')
show_samples(long_data, T, n=3, beam_size=3)

Number of samples: 45
src: [whenever you cast an enchantment spell , if you do n't control a creature named keimi , create keimi , a legendary 3 / 3 black and green frog creature token with " whenever you cast an enchantment spell , each opponent loses 1 life and you gain 1 life . " ] trg = [每当你施放结界咒语时，若你未操控名称为哇魅的生物，则派出传奇衍生生物哇魅，其为3/3，黑绿双色的蛙，且具有"每当你施放结界咒语时，每位对手各失去1点生命且你获得1点生命。"]
每当你施放结界咒语时，若你未操控名称为哇魅的生物哇魅，则派出传奇衍生生物哇魅，，传奇为3/3黑色，黑绿双色，且具有"每当你施放神器，且每位对手各 	[probability: 0.00000]
每当你施放结界咒语时，若你未操控名称为哇魅的生物哇魅，则派出传奇衍生生物哇魅，，传奇为3/3黑色，黑绿双色，且具有"每当你施放一个神器，且每位对手 	[probability: 0.00000]
每当你施放结界咒语时，若你未操控名称为哇魅的生物哇魅，则派出传奇衍生生物哇魅，，传奇为3/3黑色，黑绿双色，且具有"每当你施放神器生物，且每位对手 	[probability: 0.00000]

src: [{ w } { u } { b } { r } { g } , { t } : put three + 1 / + 1 counters on each myr you control . ] trg = [{w}{u}{b}{r}{g}，{t}：在每个由你操控的秘耳上各放置三个+1/+1指示物。]
{w}{u}{b}{r}{g}，{t}：在每个由你操控的秘耳上各放置三个+1/+1指示物。<eos> 	[probability: 0.13671]
{w}{u}{b}{r}，{t}：在每个由你操控的秘耳上各放置三个+1/+1指示物。<eos> 	[probability: 0.09985]
{w}{u}{b}{g}{g}，{t}：在

In [12]:
from utils import calculate_bleu

bleu = calculate_bleu(long_data, lambda x: T.translate(x, beam_size=3)[0][0])
print(bleu*100)

100%|██████████| 45/45 [00:06<00:00,  7.18it/s]


74.36702387304524
