In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy.data import Field, TabularDataset, BucketIterator

from dataset.mtgcards import RuleText
from utils.preprocess import fields_for_rule_text

import random
import math
import time
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SRC, TRG = fields_for_rule_text()
fields = {'src': ('src', SRC), 'trg': ('trg', TRG)}

train_data, valid_data, test_data = RuleText.splits(fields=fields)

In [3]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)
print(f"Unique tokens in source (en) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (zh) vocabulary: {len(TRG.vocab)}")

Unique tokens in source (en) vocabulary: 4991
Unique tokens in target (zh) vocabulary: 2357


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE = 32

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    sort_within_batch = True,
    sort_key = lambda x: len(x.src),
    device = device)

tmp = next(iter(train_iterator))
print(tmp)

cpu

[torchtext.legacy.data.batch.Batch of size 32]
	[.src]:('[torch.LongTensor of size 14x32]', '[torch.LongTensor of size 32]')
	[.trg]:[torch.LongTensor of size 20x32]


In [5]:
from models.model4.definition import Encoder, Attention, Decoder, Seq2Seq
from models.model4.train import init_weights, train, evaluate
from utils import count_parameters, train_loop

In [6]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]

attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, SRC_PAD_IDX, device).to(device)

model.apply(init_weights)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 12,540,469 trainable parameters


In [None]:
optimizer = optim.Adam(model.parameters())
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

train_loop(model, optimizer, criterion, train, evaluate,
           train_iterator, valid_iterator, 
           save_path='result/', file_name='tut4-model.pt', load_before_train=True)

In [7]:
from utils.translate import Translator
from models.model4.definition import beam_search

In [8]:
model.load_state_dict(torch.load('result/tut4-model.pt', map_location=torch.device(device)))
T = Translator(SRC, TRG, model, device, beam_search)

In [9]:
ret = T.translate('Unlicensed Hearse’s power and toughness are each equal to the number of cards exiled with it.')
print(*ret[0][:3], sep='\n')

['<cn>', '的', '力量', '和', '防御力', '各', '等', '同于', '以', '它', '放逐', '的', '牌', '数量', '。', '<eos>']
['<cn>', '的', '力量', '和', '防御力', '各', '等', '同于', '它', '放逐', '之', '牌', '的', '数量', '。', '<eos>']
['<cn>', '的', '力量', '和', '防御力', '各', '等', '同于', '以', '它', '放逐', '之', '牌', '的', '数量', '。', '<eos>']


In [17]:
from utils import show_samples
long_data = [x for x in test_data.examples if len(x.src) > 30]
print(f'Number of samples: {len(long_data)}')
show_samples(long_data, T, n=3, beam_size=3)

Number of samples: 45
src: [whenever basri 's lieutenant or another creature you control dies , if it had a + 1 / + 1 counter on it , create a 2 / 2 white knight creature token with vigilance . ] trg = [每当<cn>或另一个由你操控的生物死去时，若其上有+1/+1指示物，则派出一个2/2白色，具警戒异能的骑士衍生生物。]
每当<cn>或另一个由你操控的生物死去时，若其上有+1/+1指示物，则派出一个2/2白色，具警戒异能的骑士衍生生物。<eos> 	[probability: 0.11363]
每当<cn>或另一个由你操控的生物死去时，若其上有+1/+1指示物，则派出一个2/2白色，具警戒异能的吸血鬼衍生生物。<eos> 	[probability: 0.01553]
每当<cn>或另一个由你操控的生物死去时，若其上有+1/+1指示物，则派出一个2/2白色，具警戒异能的2/2白色，具警戒异能的骑士衍生生物。<eos> 	[probability: 0.00253]

src: [whenever ghost of ramirez depietro deals combat damage to a player , choose up to one target card in a graveyard that was discarded or put there from a library this turn . ] trg = [每当<cn>对任一牌手造成战斗伤害时，选择至多一张目标在坟墓场中的牌，且须为本回合中弃掉或从牌库进入该处者。]
每当<cn>对任一牌手造成战斗伤害时，选择至多一张目标在坟墓场中的坟墓场，且在本回合中有一张牌。<eos> 	[probability: 0.00000]
每当<cn>对任一牌手造成战斗伤害时，选择至多一张目标在坟墓场中的坟墓场，且在本回合中有一张牌在你手上。<eos> 	[probability: 0.00000]
每当<cn>对任一牌手造成战斗伤害时，选择至多一张目标在坟墓场中的坟墓场，且在本回合中有一张牌在你手上。。<eo

In [12]:
from utils import calculate_bleu

bleu = calculate_bleu(long_data, lambda x: T.translate(x, beam_size=3)[0][0])
print(bleu*100)

100%|██████████| 45/45 [00:06<00:00,  7.18it/s]


74.36702387304524
