In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy.data import Field, TabularDataset, BucketIterator

from dataset.mtgcards import RuleText
from utils.preprocess import fields_for_rule_text

import random
import math
import time
import os

In [7]:
SRC, TRG = fields_for_rule_text()
fields = {'src': ('src', SRC), 'trg': ('trg', TRG)}

train_data, valid_data, test_data = RuleText.splits(fields=fields, version='v2')

In [8]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)
print(f"Unique tokens in source (en) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (zh) vocabulary: {len(TRG.vocab)}")

Unique tokens in source (en) vocabulary: 1488
Unique tokens in target (zh) vocabulary: 2313


In [9]:
from utils.preprocess import tokenize_en, tokenize_zh

print(tokenize_zh('假设中间夹<xbp, the strongest of all>着<xbp, the strongest of all>一个名字'))
print(tokenize_en('there <xbp, the strongest of all> is a card name <xbp, the strongest of all>.'))

['假设', '中间', '夹', '<', 'x', 'b', 'p', ',', ' ', 't', 'h', 'e', ' ', 's', 't', 'r', 'o', 'n', 'g', 'e', 's', 't', ' ', 'o', 'f', ' ', 'a', 'l', 'l', '>', '着', '<', 'x', 'b', 'p', ',', ' ', 't', 'h', 'e', ' ', 's', 't', 'r', 'o', 'n', 'g', 'e', 's', 't', ' ', 'o', 'f', ' ', 'a', 'l', 'l', '>', '一个', '名字']
['there', '<', 'x', 'b', 'p', ',', ' ', 't', 'h', 'e', ' ', 's', 't', 'r', 'o', 'n', 'g', 'e', 's', 't', ' ', 'o', 'f', ' ', 'a', 'l', 'l', '>', 'is', 'a', 'card', 'name', '<', 'x', 'b', 'p', ',', ' ', 't', 'h', 'e', ' ', 's', 't', 'r', 'o', 'n', 'g', 'e', 's', 't', ' ', 'o', 'f', ' ', 'a', 'l', 'l', '>', '.']


In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE = 128

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    sort_within_batch = True,
    sort_key = lambda x: len(x.src),
    device = device)

tmp = next(iter(train_iterator))
print(tmp)

cpu

[torchtext.legacy.data.batch.Batch of size 128]
	[.src]:('[torch.LongTensor of size 22x128]', '[torch.LongTensor of size 128]')
	[.trg]:[torch.LongTensor of size 39x128]


In [11]:
from models.model4.definition import Encoder, Attention, Decoder, Seq2Seq
from models.model4.train import init_weights, train, evaluate
from utils import count_parameters, train_loop

In [12]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]

attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, SRC_PAD_IDX, device).to(device)

model.apply(init_weights)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 11,553,545 trainable parameters


In [None]:
optimizer = optim.Adam(model.parameters())
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

train_loop(model, optimizer, criterion, train, evaluate,
           train_iterator, valid_iterator, 
           save_path='result/', file_name='model4-rule-v1.pt', load_before_train=True)

In [13]:
from utils.translate import Translator
from models.model4.definition import beam_search
model.load_state_dict(torch.load('model4-rule-v1.pt', map_location=torch.device(device)))
T = Translator(SRC, TRG, model, device, beam_search)

In [36]:
data = 'When ever <prototype> enlist a creature, draw a card.'
ret, prob, att = T.translate_with_att(data, max_len=100)
print(*ret[:3], sep='\n')

['当', '<', 'p', 'r', 'o', 't', 'o', 't', 'y', 'p', 'e', '>', '征列', '生物', '时', '，', '抓', '一', '张', '生物', '，', '抓', '一', '张牌', '。', '<eos>']
['当', '<', '<', 'r', 'r', 'o', 't', 'o', 't', 'y', 'p', 'e', '>', '征列', '生物', '时', '，', '抓', '一', '张', '生物', '，', '抓', '一', '张牌', '。', '<eos>']
['当', '<', 'p', 'r', 'o', 't', 'o', 't', 'y', 'p', 'e', '>', '征列', '生物', '时', '，', '抓', '一', '张', '生物', '，', '抓', '抓', '一', '张牌', '。', '<eos>']


In [38]:
from utils import show_samples
long_data = [x for x in test_data.examples if len(x.src) > 20]
print(f'Number of samples: {len(long_data)}')
show_samples(long_data, T, n=3, beam_size=3)

Number of samples: 352
src: [spells your opponents cast that target < i c e f a l l   r e g e n t > cost { 2 } more to cast . ] trg = [由对手施放且以<icefall regent>为目标的咒语增加{2}来施放。]
对手施放之以<icefall regent>为目标的咒语增加{2}来施放。<eos> 	[probability: 0.16734]
对手施放之以<icefall regent>为目标的咒语增加{2}来使用。<eos> 	[probability: 0.01427]
对手施放之以<icefall regent>为目标为目标的咒语增加{2}来施放。<eos> 	[probability: 0.00839]

src: [each clashing player reveals the top card of their library , then puts that card on the top or bottom . ] trg = [参与比点的牌手各展示其牌库顶牌，然后将该牌置于牌库顶部或底部。]
参与比点的牌手各展示其牌库顶牌，然后将该牌置于底部底部底部底部。<eos> 	[probability: 0.30383]
参与比点的牌手各展示其牌库顶牌，然后将该牌置于底部底部底部底部底部。<eos> 	[probability: 0.11974]
参与比点的牌手各展示其牌库顶牌，然后将该牌置于底部底部底部底部。底部。<eos> 	[probability: 0.09725]

src: [as this creature enters the battlefield , an opponent of your choice may put two + 1 / + 1 counters on it . ] trg = [于此生物进战场时，选择一位对手，他可以在其上放置两个+1/+1指示物。]
于此生物进战场时，选择一位对手，他可以在其上放置两个+1/+1指示物。<eos> 	[probability: 0.62496]
于此生物进时，选择一位对手，他可以在其上放置两个+1/+1指示物。<eos> 	[probabilit