yatb - yet another test bench

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy.data import Field, TabularDataset, BucketIterator

from dataset.mtgcards import RuleText
from utils.preprocess import fields_for_rule_text

import random
from tqdm import tqdm

import models.model6 as model6

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SRC, TRG = fields_for_rule_text(include_lengths=False, batch_first=True)
fields = {'src': ('src', SRC), 'trg': ('trg', TRG)}

train_data, valid_data, test_data = RuleText.splits(fields=fields, version='v2.1')

In [3]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)
print(f"Unique tokens in source (en) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (zh) vocabulary: {len(TRG.vocab)}")

for x in random.sample(list(train_data), 1):
    print(x.src, x.trg)

Unique tokens in source (en) vocabulary: 1475
Unique tokens in target (zh) vocabulary: 2306
['at', 'the', 'beginning', 'of', 'each', 'opponent', "'s", 'end', 'step', ',', 'that', 'player', 'creates', 'a', '1', '/', '1', 'red', 'goblin', 'creature', 'token', 'with', '"', 'creatures', 'you', 'control', 'attack', 'each', 'combat', 'if', 'able', '.', '"'] ['在', '每', '位', '对手', '的', '结束', '步骤', '开始', '时', '，', '该', '牌手', '派出', '一个', '1', '/', '1', '红色', '鬼怪', '衍生', '生物', '，', '且', '具有', '"由', '你', '操控', '的', '生物', '每次', '战斗', '若', '能', '攻击', '，', '则', '必须', '攻击', '。', '"']


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE = 128

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    sort_within_batch = True,
    sort_key = lambda x: len(x.src),
    device = device)

print(next(iter(train_iterator)))

cpu

[torchtext.legacy.data.batch.Batch of size 128]
	[.src]:[torch.LongTensor of size 128x14]
	[.trg]:[torch.LongTensor of size 128x24]


In [5]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

model = model6.create_model(INPUT_DIM, OUTPUT_DIM, SRC_PAD_IDX, TRG_PAD_IDX, device)

Load parameters from d:\ddw\school\大三下\语音信息处理技术\期末作业\code\mtg-cards-translation\models\model6/configs/default.json
Parameters: {'HID_DIM': 256, 'ENC_LAYERS': 3, 'DEC_LAYERS': 3, 'ENC_HEADS': 8, 'DEC_HEADS': 8, 'ENC_PF_DIM': 512, 'DEC_PF_DIM': 512, 'ENC_DROPOUT': 0.1, 'DEC_DROPOUT': 0.1}


In [6]:
from utils import count_parameters, train_loop
from models.model6.train import initialize_weights, train, evaluate
model.apply(initialize_weights)
count_parameters(model)

5565442

In [None]:
LEARNING_RATE = 0.0005
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)
train_loop(model, optimizer, criterion, train, evaluate,
           train_iterator, valid_iterator, 
           save_path='result/', file_name='model6-rule-v2.1.pt', load_before_train=True)

In [7]:
from utils.translate import Translator
from models.model6.definition import beam_search
model.load_state_dict(torch.load('result/model6-rule-v2.1.pt', map_location=torch.device(device)))
T = Translator(SRC, TRG, model, device, beam_search)

In [9]:
data = 'Whenever <1> becomes attached to a creature, for as long as <1> remains attached to it, you may have that creature become a copy of another target creature you control.'
data = 'target creature gets - 1 / - 1 until end of turn .'
ret, prob = T.translate(data, max_len=100)
print(*ret[:3], sep='\n')

['目标', '生物', '得', '-', '1', '/', '-', '1', '直到', '回合', '结束', '。', '<eos>']
['目标', '结附于', '生物', '得', '-', '1', '/', '-', '1', '直到', '回合', '结束', '。', '<eos>']
['直到', '回合', '结束', '，', '目标', '生物', '得', '-', '1', '/', '-', '1', '直到', '回合', '结束', '。', '<eos>']


In [10]:
from dataset.mtgcards import TestSets
from utils import calculate_bleu
from torchtext.legacy.data import Field
from models.card_name_detector.definition import TrainedDetector
from utils.translate import sentencize, CardTranslator, CTHelper

fields = {'src-rule': ('src', Field(tokenize=lambda x: x.split(' '))), 'trg-rule': ('trg', Field())}
test_data = TestSets.load(fields)

D = TrainedDetector()

path: d:\ddw\school\大三下\语音信息处理技术\期末作业\code\mtg-cards-translation\models\card_name_detector


In [11]:
dic = {}
dic = {'oil':'烁油', 'rebel':'反抗军'}
helper = CTHelper(D, dic)
CT = CardTranslator(sentencize, T, 
                    preprocess=lambda x: helper.preprocess(x, False), 
                    postprocess=lambda x: helper.postprocess(x, False))

example = list(test_data)[13]
example = list(test_data)[8]
# ret = CT.translate(' '.join(example.src))
# print(ret)
for example in random.sample(list(test_data), 3):
    print(vars(example))
    ret = CT.translate(' '.join(example.src))
    print(ret)

{'src': ['At', 'the', 'beginning', 'of', 'combat', 'on', 'your', 'turn,', 'target', 'creature', 'you', 'control', 'gets', '+1/+1', 'until', 'end', 'of', 'turn.', 'If', 'that', 'creature', 'has', 'toxic,', 'instead', 'it', 'gets', '+2/+2', 'until', 'end', 'of', 'turn.'], 'trg': ['在你回合的战斗开始时，目标由你操控的生物得+1/+1直到回合结束。若该生物具有下毒异能，则改为它得+2/+2直到回合结束。']}
[after preprocess]:at the beginning of combat on your turn , target creature you control gets + 1 / + 1 until end of turn .
[before postprocess]:在你回合的战斗开始时，目标由你操控的生物得+1/+1直到回合结束。
[after preprocess]:if that creature has toxic , instead it gets + 2 / + 2 until end of turn .
[before postprocess]:如果该生物具有<unk>，则改为该生物得+2/+2直到回合结束。
在你回合的战斗开始时，目标由你操控的生物得+1/+1直到回合结束。 如果该生物具有<unk>，则改为该生物得+2/+2直到回合结束。
{'src': ['For', 'Mirrodin!', '(When', 'this', 'Equipment', 'enters', 'the', 'battlefield,', 'create', 'a', '2/2', 'red', 'Rebel', 'creature', 'token,', 'then', 'attach', 'this', 'to', 'it.)\nEquipped', 'creature', 'gets', '+1/-1.\nEquip', '{1}', '({1}:', 'Attac

In [None]:
from utils import calculate_testset_bleu
calculate_testset_bleu(list(test_data)[:100], CT)