In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy.data import Field, TabularDataset, BucketIterator

from dataset.mtgcards import RuleText
from utils.preprocess import fields_for_rule_text

import random
import math
import time
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SRC, TRG = fields_for_rule_text()
fields = {'src': ('src', SRC), 'trg': ('trg', TRG)}

train_data, valid_data, test_data = RuleText.splits(fields=fields, version='v2.2')

In [30]:
SRC.build_vocab(train_data, min_freq = 4)
TRG.build_vocab(train_data, min_freq = 4)
print(f"Unique tokens in source (en) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (zh) vocabulary: {len(TRG.vocab)}")

for x in random.sample(list(train_data), 3):
    print(x.src, x.trg)

Unique tokens in source (en) vocabulary: 1294
Unique tokens in target (zh) vocabulary: 1949
['whenever', 'a', 'creature', 'you', 'control', 'attacks', 'alone', ',', 'that', 'creature', 'gets', '+', '1', '/', '+', '1', 'until', 'end', 'of', 'turn', '.'] ['每当', '一个', '由', '你', '操控', '的', '生物', '单独', '攻击', '时', '，', '该', '生物', '得', '+', '1', '/', '+', '1', '直到', '回合', '结束', '。']
['it', 'becomes', 'a', 'creature', 'again', 'if', 'it', "'s", 'not', 'attached', 'to', 'a', 'creature', '.'] ['如果', '它', '未', '结', '附于', '生物', '上', '，', '就', '会', '再度', '成为', '生物', '。']
['choose', 'one', '—'] ['选择', '一', '项', '～']


In [31]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE = 128

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    sort_within_batch = True,
    sort_key = lambda x: len(x.src),
    device = device)

tmp = next(iter(train_iterator))
print(tmp)

cpu

[torchtext.legacy.data.batch.Batch of size 128]
	[.src]:('[torch.LongTensor of size 11x128]', '[torch.LongTensor of size 128]')
	[.trg]:[torch.LongTensor of size 28x128]


In [32]:
from models.model4.definition import Encoder, Attention, Decoder, Seq2Seq
from models.model4.train import init_weights, train, evaluate
from utils import count_parameters, train_loop

In [33]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]

attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, SRC_PAD_IDX, device).to(device)

model.apply(init_weights)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 10,758,045 trainable parameters


In [8]:
optimizer = optim.Adam(model.parameters())
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

train_loop(model, optimizer, criterion, train, evaluate,
           train_iterator, valid_iterator, 
           save_path='result/', file_name='model4-rule-v2.2.pt', load_before_train=True)

model will be saved to result/model4-rule-v2.2.pt


  2%|▏         | 6/291 [00:36<28:36,  6.02s/it]


KeyboardInterrupt: 

In [34]:
from utils.translate import Translator
from models.model4.definition import beam_search
model.load_state_dict(torch.load('result/model4-rule-v2.2.pt', map_location=torch.device(device)))
T = Translator(SRC, TRG, model, device, beam_search)

In [35]:
data = 'Whenever <1> becomes attached to a creature, for as long as <1> remains attached to it, you may have that creature become a copy of another target creature you control.'
data = 'target creature gets - 1 / - 1 until end of turn .'
ret, prob = T.translate(data, max_len=100)
print(*ret[:3], sep='\n')

['目标', '生物', '得', '-', '1', '/', '-', '1', '直到', '回合', '结束', '。', '<eos>']
['目标', '生物', '得', '-', '1', '-', '1', '直到', '回合', '结束', '。', '<eos>']
['目标', '生物', '得', '-', '1', '1', '-', '1', '直到', '回合', '结束', '。', '<eos>']


In [36]:
from utils import show_samples
long_data = [x for x in test_data.examples if len(x.src) > 20]
print(f'Number of samples: {len(long_data)}')
show_samples(long_data, T, n=3, beam_size=3)

Number of samples: 355
src: [whenever you cast an instant or sorcery spell , < 1 > gets + 2 / + 0 until end of turn . ] trg = [每当你施放瞬间或法术咒语时，<1>得+2/+0直到回合结束。]
每当你施放瞬间或法术咒语时，<1>得+2/+0直到回合结束。<eos> 	[probability: 0.68745]
每当你使用瞬间或法术咒语时，<1>得+2/+0直到回合结束。<eos> 	[probability: 0.17417]
每当你施放瞬间或咒语时，<1>得+2/+0直到回合结束。<eos> 	[probability: 0.00862]

src: [when < 0 > enters the battlefield , exile target creature an opponent controls until < 0 > leaves the battlefield . ] trg = [当<0>进战场时，放逐目标由对手操控的生物，直到<0>离开战场为止。]
当<0>进战场时，放逐目标由对手操控的生物，直到<0>离开战场为止。<eos> 	[probability: 0.81881]
当<0>进战场时，放逐目标由对手操控的生物，直到0>离开战场为止。<eos> 	[probability: 0.02603]
当<0>进战场时，放逐目标由对手操控的生物，令<0>离开战场为止。<eos> 	[probability: 0.01527]

src: [at the beginning of combat on your turn , create a < 9 > that 's a copy of target non creature < 1 > you control , except its name is mishra 's warform and it 's a 4 / 4 construct < 1 > creature in addition to its other types . ] trg = [在你回合的战斗开始时，派出一个<9>，其为目标由你操控之非生物<1>的复制品，但名称是米斯拉的战形械，为4/4组构体<1>

In [37]:
from dataset.mtgcards import TestSets
from utils import calculate_bleu
from torchtext.legacy.data import Field
from models.card_name_detector.definition import TrainedDetector
from utils.translate import sentencize, CardTranslator

fields = {'src-rule': ('src', Field(tokenize=lambda x: x.split(' '))), 'trg-rule': ('trg', Field())}
test_data = TestSets.load(fields)

D = TrainedDetector()

path: d:\Desktop\mtg-cards-translation\models\card_name_detector


In [51]:
def sentencize(text: str):
    ignore = {' ', '(', ')', '\n'}
    while len(text) and text[0] in ignore:
        text = text[1:]
    if len(text) == 0:
        return []
    
    r = 0
    delims = {'.', '\n', '('}
    ignore = False
    while r < len(text):
        if text[r] == '\"':
            ignore = not ignore
        if not ignore and text[r] in delims:
            break
        r += 1
    
    if r < len(text) and text[r] == '.':
        return [text[:r + 1]] + sentencize(text[r + 1:])
    return [text[:r]] + sentencize(text[r:])
def preprocess(x:str):
    x = D.annotate(x).removeprefix(' ')
    print(f'[after preprocess]:{x}')
    return x
def postprocess(x:str):
    return x.replace('<', '').replace('>', '')

import re
class CTHelper:
    def __init__(self, name_detector, dictionary={}) -> None:
        self.D = name_detector
        self.dictionary = dictionary
    
    def preprocess(self, x:str):
        self.tag2str = {}
        x = D.annotate(x).removeprefix(' ') # x become lowercase after go through detector
        m = re.search('<[^0-9>]+>', x)
        id = 0
        while m:
            l, r = m.span()
            tag = '<' + str(id) + '>'
            self.tag2str[tag] = x[l:r]
            x = x[:l] + tag + x[r:]
            id += 1
            m = re.search('<[^0-9>]+>', x)

        for s in self.dictionary.keys():
            m = re.search(s, x)
            if m:
                tag = '<' + str(id) + '>'
                self.tag2str[tag] = s
                x = x.replace(s, tag)
                id += 1

        print(f'[  after preprocess]:{x}')
        return x

    def postprocess(self, x:str):
        print(f'[before postprocess]:{x}')
        for tag, s in self.tag2str.items():
            x = x.replace(tag, self.dictionary[s] if s in self.dictionary else s)
        return x

dic = {}
dic = {'oil':'烁油', 'rebel':'反抗军','compleated':'完化','multicolored':'多色','toxic':'下毒'}
helper = CTHelper(D, dic)
CT = CardTranslator(sentencize, T, preprocess=lambda x: helper.preprocess(x), postprocess=lambda x:helper.postprocess(x))

example = random.sample(list(test_data), 1)[0]
example = list(test_data)[237]
print(vars(example))
ret=CT.translate(' '.join(example.src))
print(ret+'\n')
for example in random.sample(list(test_data), 3):
    print(vars(example))
    ret = CT.translate(' '.join(example.src))
    print(ret + '\n')

{'src': ['Gain', 'control', 'of', 'target', 'creature', 'with', 'mana', 'value', 'X', 'or', 'less.', 'If', 'X', 'is', '5', 'or', 'more,', 'create', 'a', 'token', "that's", 'a', 'copy', 'of', 'that', 'creature.'], 'trg': ['获得目标法术力值等于或小于X的生物之操控权。如果X等于或大于5，则派出一个衍生物，此衍生物为该生物的复制品。']}
[after preprocess]:gain control of target creature with mana value x or less .
[before postprocess]:获得目标总法术力费用等于或小于x的生物之操控权。
[after preprocess]:if x is 5 or more , create a token that 's a copy of that creature .
[before postprocess]:如果x等于或大于5，则将一个衍生物放进战场，此衍生物为该生物的复制品。
获得目标总法术力费用等于或小于x的生物之操控权。 如果x等于或大于5，则将一个衍生物放进战场，此衍生物为该生物的复制品。

{'src': ['Toxic', '1', '(Players', 'dealt', 'combat', 'damage', 'by', 'this', 'creature', 'also', 'get', 'a', 'poison', 'counter.)\nOther', 'Rats', 'you', 'control', 'have', 'toxic', '1.\nWhen', 'Karumonix', 'enters', 'the', 'battlefield,', 'look', 'at', 'the', 'top', 'five', 'cards', 'of', 'your', 'library.', 'You', 'may', 'reveal', 'any', 'number', 'of', 'Rat', 'cards', 'from', 'amon

In [42]:
from utils import calculate_testset_bleu
calculate_testset_bleu(list(test_data)[:100], CT)

100%|██████████| 100/100 [00:37<00:00,  2.66it/s]


0.6877907101188122