In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy.data import Field, TabularDataset, BucketIterator

from dataset.mtgcards import RuleText
from utils.preprocess import fields_for_rule_text

import random
import math
import time
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SRC, TRG = fields_for_rule_text()
fields = {'src': ('src', SRC), 'trg': ('trg', TRG)}

train_data, valid_data, test_data = RuleText.splits(fields=fields, version='v2')

In [3]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)
print(f"Unique tokens in source (en) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (zh) vocabulary: {len(TRG.vocab)}")

Unique tokens in source (en) vocabulary: 1488
Unique tokens in target (zh) vocabulary: 2313


In [4]:
from utils.preprocess import tokenize_en, tokenize_zh

print(tokenize_zh('假设中间夹<xbp, the strongest of all>着<xbp, the strongest of all>一个名字'))
print(tokenize_en('there <xbp, the strongest of all> is a card name <xbp, the strongest of all>.'))

['假设', '中间', '夹', '<', 'x', 'b', 'p', ',', ' ', 't', 'h', 'e', ' ', 's', 't', 'r', 'o', 'n', 'g', 'e', 's', 't', ' ', 'o', 'f', ' ', 'a', 'l', 'l', '>', '着', '<', 'x', 'b', 'p', ',', ' ', 't', 'h', 'e', ' ', 's', 't', 'r', 'o', 'n', 'g', 'e', 's', 't', ' ', 'o', 'f', ' ', 'a', 'l', 'l', '>', '一个', '名字']
['there', '<', 'x', 'b', 'p', ',', ' ', 't', 'h', 'e', ' ', 's', 't', 'r', 'o', 'n', 'g', 'e', 's', 't', ' ', 'o', 'f', ' ', 'a', 'l', 'l', '>', 'is', 'a', 'card', 'name', '<', 'x', 'b', 'p', ',', ' ', 't', 'h', 'e', ' ', 's', 't', 'r', 'o', 'n', 'g', 'e', 's', 't', ' ', 'o', 'f', ' ', 'a', 'l', 'l', '>', '.']


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE = 128

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    sort_within_batch = True,
    sort_key = lambda x: len(x.src),
    device = device)

tmp = next(iter(train_iterator))
print(tmp)

cpu

[torchtext.legacy.data.batch.Batch of size 128]
	[.src]:('[torch.LongTensor of size 15x128]', '[torch.LongTensor of size 128]')
	[.trg]:[torch.LongTensor of size 33x128]


In [6]:
from models.model4.definition import Encoder, Attention, Decoder, Seq2Seq
from models.model4.train import init_weights, train, evaluate
from utils import count_parameters, train_loop

In [7]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]

attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, SRC_PAD_IDX, device).to(device)

model.apply(init_weights)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 11,553,545 trainable parameters


In [None]:
optimizer = optim.Adam(model.parameters())
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

train_loop(model, optimizer, criterion, train, evaluate,
           train_iterator, valid_iterator, 
           save_path='result/', file_name='model4-rule-v2.pt', load_before_train=True)

In [23]:
from utils.translate import Translator
from models.model4.definition import beam_search
model.load_state_dict(torch.load('model4-rule-v2.pt', map_location=torch.device(device)))
T = Translator(SRC, TRG, model, device, beam_search)

In [26]:
data = 'Whenever <1> becomes attached to a creature, for as long as <1> remains attached to it, you may have that creature become a copy of another target creature you control.'
ret, prob = T.translate(data, max_len=100)
print(*ret[:3], sep='\n')

['每当', '<', 'm', '>', '上', '生物', '上', '的', '生物', '上', '时', '，', '只要', '其', '上', '，', '它', '，', '此', '武具', '的', '生物', '上', '，', '该', '生物', '，', '，', '你', '可以', '令', '该', '生物', '成为', '另', '一个', '目标', '由', '你', '操控', '的', '生物', '之', '复制品', '。', '<eos>']
['每当', '<', 'm', '>', '上', '生物', '上', '的', '生物', '上', '时', '，', '只要', '其', '上', '，', '它', '，', '此', '武具', '的', '生物', '上', '，', '该', '生物', '，', '，', '你', '可以', '令', '该', '生物', '成为', '另', '一个', '目标', '由', '你', '操控', '的', '生物', '成为', '另', '。', '<eos>']
['每当', '<', 'm', '>', '上', '生物', '上', '的', '生物', '上', '时', '，', '只要', '其', '上', '，', '它', '，', '此', '武具', '的', '生物', '上', '，', '该', '生物', '，', '，', '你', '可以', '让', '该', '生物', '成为', '另', '一个', '目标', '由', '你', '操控', '的', '生物', '成为', '另', '。', '<eos>']


In [38]:
from utils import show_samples
long_data = [x for x in test_data.examples if len(x.src) > 20]
print(f'Number of samples: {len(long_data)}')
show_samples(long_data, T, n=3, beam_size=3)

Number of samples: 352
src: [spells your opponents cast that target < i c e f a l l   r e g e n t > cost { 2 } more to cast . ] trg = [由对手施放且以<icefall regent>为目标的咒语增加{2}来施放。]
对手施放之以<icefall regent>为目标的咒语增加{2}来施放。<eos> 	[probability: 0.16734]
对手施放之以<icefall regent>为目标的咒语增加{2}来使用。<eos> 	[probability: 0.01427]
对手施放之以<icefall regent>为目标为目标的咒语增加{2}来施放。<eos> 	[probability: 0.00839]

src: [each clashing player reveals the top card of their library , then puts that card on the top or bottom . ] trg = [参与比点的牌手各展示其牌库顶牌，然后将该牌置于牌库顶部或底部。]
参与比点的牌手各展示其牌库顶牌，然后将该牌置于底部底部底部底部。<eos> 	[probability: 0.30383]
参与比点的牌手各展示其牌库顶牌，然后将该牌置于底部底部底部底部底部。<eos> 	[probability: 0.11974]
参与比点的牌手各展示其牌库顶牌，然后将该牌置于底部底部底部底部。底部。<eos> 	[probability: 0.09725]

src: [as this creature enters the battlefield , an opponent of your choice may put two + 1 / + 1 counters on it . ] trg = [于此生物进战场时，选择一位对手，他可以在其上放置两个+1/+1指示物。]
于此生物进战场时，选择一位对手，他可以在其上放置两个+1/+1指示物。<eos> 	[probability: 0.62496]
于此生物进时，选择一位对手，他可以在其上放置两个+1/+1指示物。<eos> 	[probabilit

In [60]:
from dataset.mtgcards import TestSets
from utils import calculate_bleu
from torchtext.legacy.data import Field

fields = {'src-rule': ('src', Field(tokenize=lambda x: x.split(' '))), 'trg-rule': ('trg', Field())}
test_data = TestSets.load(fields)

from models.card_name_detector.definition import TrainedDetector
D = TrainedDetector()

path: d:\ddw\school\大三下\语音信息处理技术\期末作业\code\mtg-cards-translation\models\card_name_detector


In [65]:
def sentencize(text: str):
    ignore = {' ', '(', ')', '\n'}
    while len(text) and text[0] in ignore:
        text = text[1:]
    if len(text) == 0:
        return []
    
    r = 0
    delims = {'.', '\n', '('}
    while r < len(text) and text[r] not in delims:
        r += 1
    
    if r < len(text) and text[r] == '.':
        return [text[:r + 1]] + sentencize(text[r + 1:])
    return [text[:r]] + sentencize(text[r:])

class CardTranslator:
    def __init__(self, sentencize, sent_translator, preprocess=None, postprocess=None) -> None:
        self.sentencize = sentencize
        self.sent_translator = sent_translator
        self.preprocess = preprocess
        self.postprocess = postprocess
    
    def translate(self, text: str)->str:
        sents = self.sentencize(text)
        result = []
        for sent in sents:
            if self.preprocess:
                sent = self.preprocess(sent)
            sent, _ = self.sent_translator.translate(sent)
            sent = ''.join(sent[0][:-1])
            if self.postprocess:
                sent = self.postprocess(sent)
            result.append(sent)
        return ' '.join(result)

In [66]:
CT = CardTranslator(sentencize, T, preprocess=lambda x: D.annotate(x))

example = random.sample(list(test_data), 1)[0]
print(vars(example))
CT.translate(' '.join(example.src))

{'src': ['You', 'gain', 'X', 'life.', 'Create', 'X', '1/1', 'colorless', 'Phyrexian', 'Mite', 'artifact', 'creature', 'tokens', 'with', 'toxic', '1', 'and', '"This', 'creature', "can't", 'block."', 'If', 'X', 'is', '5', 'or', 'more,', 'destroy', 'all', 'other', 'creatures.', '(Players', 'dealt', 'combat', 'damage', 'by', 'a', 'creature', 'with', 'toxic', '1', 'also', 'get', 'a', 'poison', 'counter.)'], 'trg': ['你获得X点生命。派出X个1/1无色非瑞人／虫械衍生神器生物，且具有下毒1与「此生物不能进行阻挡。」如果X等于或大于5，则消灭所有其他生物。（受到具下毒1生物之战斗伤害的牌手还会得到一个中毒指示物。）']}
['you', 'gain', 'x', 'life', '.']
['create', 'x', '1', '/', '1', 'colorless', 'phyrexian', 'mite', 'artifact', 'creature', 'tokens', 'with', 'toxic', '1', 'and', '"', 'this', 'creature', 'ca', "n't", 'block', '.']
['"', 'if', 'x', 'is', '5', 'or', 'more', ',', 'destroy', 'all', 'other', 'creatures', '.']
['players', 'dealt', 'combat', 'damage', 'by', 'a', 'creature', 'with', 'toxic', '1', 'also', 'get', 'a', 'poison', 'counter', '.']


'你你获得x点生命。 将派出x个1/1无色秘耳衍生神器生物放进战场，且具有下毒1且具有"此生物不能进行阻挡。 赋予 xx等于5或更多，则消灭所有其他生物。 当此牌手受过伤害的生物还会触发时上面得到一个中毒指示物。'