In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy.data import Field, TabularDataset, BucketIterator

from dataset.mtgcards import RuleText
from utils.preprocess import fields_for_rule_text

import random
import math
import time
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SRC, TRG = fields_for_rule_text()
fields = {'src': ('src', SRC), 'trg': ('trg', TRG)}

train_data, valid_data, test_data = RuleText.splits(fields=fields)

In [3]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)
print(f"Unique tokens in source (en) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (zh) vocabulary: {len(TRG.vocab)}")

Unique tokens in source (en) vocabulary: 4991
Unique tokens in target (zh) vocabulary: 2357


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE = 32

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    sort_within_batch = True,
    sort_key = lambda x: len(x.src),
    device = device)

tmp = next(iter(train_iterator))
print(tmp)

cpu

[torchtext.legacy.data.batch.Batch of size 32]
	[.src]:('[torch.LongTensor of size 3x32]', '[torch.LongTensor of size 32]')
	[.trg]:[torch.LongTensor of size 7x32]


In [5]:
from models.model4.definition import Encoder, Attention, Decoder, Seq2Seq
from models.model4.train import init_weights, train, evaluate
from utils import count_parameters, train_loop

In [6]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]

attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, SRC_PAD_IDX, device).to(device)

model.apply(init_weights)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 12,540,469 trainable parameters


In [None]:
optimizer = optim.Adam(model.parameters())
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

train_loop(model, optimizer, criterion, train, evaluate,
           train_iterator, valid_iterator, 
           save_path='result/', file_name='tut4-model.pt', load_before_train=True)

In [7]:
from utils.translate import Translator
from models.model4.definition import beam_search

In [8]:
model.load_state_dict(torch.load('result/tut4-model.pt', map_location=torch.device(device)))
T = Translator(SRC, TRG, model, device, beam_search)

In [53]:
data = '{ 1 } { b } { b } : if evolved sleeper is a phyrexian , put a + 1 / + 1 counter on it , then you draw a card and you lose 1 life .'
ret, prob, att = T.translate_with_att(data)
print(*ret[:3], sep='\n')
SRC.preprocess(data)

['{', '<unk>', '1', '}', '{', 'b', '}', '{', 'b', '}', '，', '{', 't', '}', '：', '如果', '<cn>', '是', '非瑞人', '，', '则', '在', '其', '上', '放置', '一个', '+', '1', '/', '+', '1', '指示物', '，', '然后', '你', '抓', '一', '张', '牌且', '失去', '1', '点', '生命', '。', '<eos>']
['{', '<unk>', '1', '}', '{', 'b', '}', '{', 'b', '}', '，', '{', 't', '}', '：', '如果', '<cn>', '是', '非瑞人', '，', '则', '在', '其', '上', '放置', '一个', '+', '1', '/', '+', '1', '指示物', '，', '然后', '你', '抓', '一', '张', '牌且', '你', '失去', '1', '点', '生命', '。', '<eos>']
['{', '<unk>', '1', '}', '{', 'b', '}', '{', 'b', '}', '，', '{', 't', '}', '，', '如果', '<cn>', '是', '非瑞人', '，', '则', '在', '其', '上', '放置', '一个', '+', '1', '/', '+', '1', '指示物', '，', '然后', '你', '抓', '一', '张', '牌且', '你', '失去', '1', '点', '生命', '。', '<eos>']


['{',
 ' ',
 '1',
 ' ',
 '}',
 '{',
 ' ',
 'b',
 ' ',
 '}',
 '{',
 ' ',
 'b',
 ' ',
 '}',
 ':',
 'if',
 'evolved',
 'sleeper',
 'is',
 'a',
 'phyrexian',
 ',',
 'put',
 'a',
 '+',
 '1',
 '/',
 '+',
 '1',
 'counter',
 'on',
 'it',
 ',',
 'then',
 'you',
 'draw',
 'a',
 'card',
 'and',
 'you',
 'lose',
 '1',
 'life',
 '.']

In [52]:
from utils import show_samples
long_data = [x for x in test_data.examples if len(x.src) > 30]
print(f'Number of samples: {len(long_data)}')
show_samples(long_data, T, n=3, beam_size=3)

Number of samples: 45
src: [when exuberant fuseling enters the battlefield and whenever another creature or artifact you control is put into a graveyard from the battlefield , put an oil counter on exuberant fuseling . ] trg = [当<cn>进战场和每当另一个由你操控的生物或神器从战场进入坟墓场时，在<cn>上放置一个烁油指示物。]
当<cn>进战场和另一个由你操控的生物或神器从战场进入坟墓场时，在<cn>上放置一个烁油指示物。<eos> 	[probability: 0.00359]
当<cn>进战场时，每当另一个由你操控的其他或神器从战场进入坟墓场时，在<cn>上放置一个烁油指示物。<eos> 	[probability: 0.00037]
当<cn>进战场和另一个由你操控的生物或神器从战场进入坟墓场，在<cn>上放置一个烁油指示物。<eos> 	[probability: 0.00015]

src: [as long as the top card of your library is a creature card , creatures you control that share a color with that card get + 1 / + 1 . ] trg = [只要你的牌库顶牌是生物牌，由你操控，且与该生物牌有共通颜色的生物便得+1/+1。]
只要你的牌库顶牌是生物牌，由你操控且具共通颜色的生物便得+1/+1。<eos> 	[probability: 0.00004]
只要你的牌库顶牌是生物牌，由你操控且具共通的共通颜色的生物便得+1/+1。<eos> 	[probability: 0.00002]
只要你的牌库顶牌是生物牌，由你操控且具共通颜色的生物的生物便得+1/+1。<eos> 	[probability: 0.00000]

src: [{ g } , { t } , sacrifice magus of the order and another green creature : search your li

In [12]:
from utils import calculate_bleu

bleu = calculate_bleu(long_data, lambda x: T.translate(x, beam_size=3)[0][0])
print(bleu*100)

100%|██████████| 45/45 [00:06<00:00,  7.18it/s]


74.36702387304524
