In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy.data import Field, TabularDataset, BucketIterator

from dataset.mtgcards import RuleText
from utils.preprocess import fields_for_rule_text

import random
from tqdm import tqdm

import models.model6 as model6

In [11]:
SRC, TRG = fields_for_rule_text(include_lengths=False, batch_first=True)
fields = {'src': ('src', SRC), 'trg': ('trg', TRG)}

train_data, valid_data, test_data = RuleText.splits(fields=fields, version='v2.2')

In [12]:
SRC.build_vocab(train_data, min_freq = 4)
TRG.build_vocab(train_data, min_freq = 4)
print(f"Unique tokens in source (en) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (zh) vocabulary: {len(TRG.vocab)}")

for x in random.sample(list(train_data), 1):
    print(x.src, x.trg)

Unique tokens in source (en) vocabulary: 1294
Unique tokens in target (zh) vocabulary: 1949
['flying', ',', 'vigilance'] ['飞行', '，', '警戒']


In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE = 128

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    sort_within_batch = True,
    sort_key = lambda x: len(x.src),
    device = device)

print(next(iter(train_iterator)))

cpu

[torchtext.legacy.data.batch.Batch of size 128]
	[.src]:[torch.LongTensor of size 128x14]
	[.trg]:[torch.LongTensor of size 128x23]


In [27]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

model = model6.create_model(INPUT_DIM, OUTPUT_DIM, SRC_PAD_IDX, TRG_PAD_IDX, device,
                            config='wide')

Load parameters from d:\Desktop\mtg-cards-translation\models\model6/configs/wide.json
Parameters: {'HID_DIM': 512, 'ENC_LAYERS': 3, 'DEC_LAYERS': 3, 'ENC_HEADS': 8, 'DEC_HEADS': 8, 'ENC_PF_DIM': 1024, 'DEC_PF_DIM': 1024, 'ENC_DROPOUT': 0.1, 'DEC_DROPOUT': 0.1}


In [28]:
from utils import count_parameters, train_loop
from models.model6.train import initialize_weights, train, evaluate
model.apply(initialize_weights)
count_parameters(model)

18534301

In [7]:
LEARNING_RATE = 0.0005
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)
train_loop(model, optimizer, criterion, train, evaluate,
           train_iterator, valid_iterator, 
           save_path='result/', file_name='model6-rule-v2.2.pt', load_before_train=True)

model will be saved to result/model6-rule-v2.2.pt


100%|██████████| 454/454 [00:19<00:00, 23.19it/s]


Epoch: 01 | Time: 0m 19s
	Train Loss: 2.062 | Train PPL:   7.861
	 Val. Loss: 0.686 |  Val. PPL:   1.985


100%|██████████| 454/454 [00:19<00:00, 23.53it/s]


Epoch: 02 | Time: 0m 19s
	Train Loss: 0.566 | Train PPL:   1.761
	 Val. Loss: 0.335 |  Val. PPL:   1.398


100%|██████████| 454/454 [00:19<00:00, 23.38it/s]


Epoch: 03 | Time: 0m 19s
	Train Loss: 0.341 | Train PPL:   1.406
	 Val. Loss: 0.256 |  Val. PPL:   1.292


100%|██████████| 454/454 [00:19<00:00, 23.70it/s]


Epoch: 04 | Time: 0m 19s
	Train Loss: 0.256 | Train PPL:   1.292
	 Val. Loss: 0.206 |  Val. PPL:   1.229


100%|██████████| 454/454 [00:19<00:00, 23.52it/s]


Epoch: 05 | Time: 0m 19s
	Train Loss: 0.208 | Train PPL:   1.231
	 Val. Loss: 0.174 |  Val. PPL:   1.190


100%|██████████| 454/454 [00:19<00:00, 23.37it/s]


Epoch: 06 | Time: 0m 19s
	Train Loss: 0.178 | Train PPL:   1.195
	 Val. Loss: 0.154 |  Val. PPL:   1.167


100%|██████████| 454/454 [00:19<00:00, 23.67it/s]


Epoch: 07 | Time: 0m 19s
	Train Loss: 0.158 | Train PPL:   1.171
	 Val. Loss: 0.146 |  Val. PPL:   1.157


100%|██████████| 454/454 [00:19<00:00, 23.33it/s]


Epoch: 08 | Time: 0m 19s
	Train Loss: 0.142 | Train PPL:   1.152
	 Val. Loss: 0.137 |  Val. PPL:   1.147


100%|██████████| 454/454 [00:19<00:00, 23.30it/s]


Epoch: 09 | Time: 0m 19s
	Train Loss: 0.128 | Train PPL:   1.137
	 Val. Loss: 0.136 |  Val. PPL:   1.145


100%|██████████| 454/454 [00:19<00:00, 23.22it/s]


Epoch: 10 | Time: 0m 19s
	Train Loss: 0.120 | Train PPL:   1.128
	 Val. Loss: 0.128 |  Val. PPL:   1.136


In [29]:
from utils.translate import Translator
from models.model6.definition import beam_search
model.load_state_dict(torch.load('result/model6-rule-v2.2.pt', map_location=torch.device(device)))
T = Translator(SRC, TRG, model, device, beam_search)

In [30]:
data = 'Whenever <1> becomes attached to a creature, for as long as <1> remains attached to it, you may have that creature become a copy of another target creature you control.'
data = 'target creature gets - 1 / - 1 until end of turn .'
ret, prob = T.translate(data, max_len=100)
print(*ret[:3], sep='\n')

['目标', '生物', '得', '-', '1', '/', '-', '1', '直到', '回合', '结束', '。', '<eos>']
['目标', '生物', '得', '-', '1', '1', '/', '-', '1', '直到', '回合', '结束', '。', '<eos>']
['直到', '回合', '结束', '，', '目标', '生物', '得', '-', '1', '/', '-', '1', '。', '<eos>']


In [31]:
from dataset.mtgcards import TestSets
from utils import calculate_bleu
from torchtext.legacy.data import Field
from models.card_name_detector.definition import TrainedDetector
from utils.translate import sentencize, CardTranslator, CTHelper

fields = {'src-rule': ('src', Field(tokenize=lambda x: x.split(' '))), 'trg-rule': ('trg', Field())}
test_data = TestSets.load(fields)

D = TrainedDetector()

path: d:\Desktop\mtg-cards-translation\models\card_name_detector


In [35]:
dic = {}
dic = {'oil':'烁油', 'rebel':'反抗军'}
helper = CTHelper(D, dic)
silent = False
CT = CardTranslator(sentencize, T, 
                    preprocess=lambda x: helper.preprocess(x, silent), 
                    postprocess=lambda x: helper.postprocess(x, silent))

example = list(test_data)[13]
example = list(test_data)[8]
# ret = CT.translate(' '.join(example.src))
# print(ret)
for example in random.sample(list(test_data), 3):
    print(vars(example))
    ret = CT.translate(' '.join(example.src))
    print(ret)

{'src': ['As', 'long', 'as', 'Mirran', 'Safehouse', 'is', 'on', 'the', 'battlefield,', 'it', 'has', 'all', 'activated', 'abilities', 'of', 'all', 'land', 'cards', 'in', 'all', 'graveyards.'], 'trg': ['只要秘罗避难屋在战场上，它便具有所有坟墓场中所有地牌的所有起动式异能。']}
[after preprocess]:as long as <0> is on the battlefield , it has all activated abilities of all land cards in all graveyards .
[before postprocess]:只要<0>在战场上，它便具有所有坟墓场中每张生物牌的所有起动式异能。
只要<mirran safehouse>在战场上，它便具有所有坟墓场中每张生物牌的所有起动式异能。
{'src': ['+2:', 'Search', 'your', 'library', 'for', 'a', 'basic', 'Mountain', 'card,', 'reveal', 'it,', 'put', 'it', 'into', 'your', 'hand,', 'then', 'shuffle.\n−3:', 'Koth,', 'Fire', 'of', 'Resistance', 'deals', 'damage', 'to', 'target', 'creature', 'equal', 'to', 'the', 'number', 'of', 'Mountains', 'you', 'control.\n−7:', 'You', 'get', 'an', 'emblem', 'with', '"Whenever', 'a', 'Mountain', 'enters', 'the', 'battlefield', 'under', 'your', 'control,', 'this', 'emblem', 'deals', '4', 'damage', 'to', 'any', 'target."'], 'trg

In [34]:
from utils import calculate_testset_bleu
calculate_testset_bleu(list(test_data)[:100], CT)

100%|██████████| 100/100 [01:50<00:00,  1.11s/it]


0.652888170122143