yatb - yet another test bench

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy.data import Field, TabularDataset, BucketIterator

from dataset.mtgcards import RuleText
from utils.preprocess import fields_for_rule_text

import random
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SRC, TRG = fields_for_rule_text(include_lengths=False, batch_first=True)
fields = {'src': ('src', SRC), 'trg': ('trg', TRG)}

train_data, valid_data, test_data = RuleText.splits(fields=fields, version='v2.1')

In [3]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)
print(f"Unique tokens in source (en) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (zh) vocabulary: {len(TRG.vocab)}")

for x in random.sample(list(train_data), 1):
    print(x.src, x.trg)

Unique tokens in source (en) vocabulary: 1475
Unique tokens in target (zh) vocabulary: 2306
['if', 'this', 'spell', 'was', 'kicked', ',', 'sacrifice', 'that', 'creature', 'at', 'the', 'beginning', 'of', 'the', 'next', 'end', 'step', '.'] ['如果', '此', '咒语', '已', '增幅', '，', '则', '在', '下', '一个', '结束', '步骤', '开始', '时', '牺牲', '该', '生物', '。']


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE = 128

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    sort_within_batch = True,
    sort_key = lambda x: len(x.src),
    device = device)

print(next(iter(train_iterator)))

cpu

[torchtext.legacy.data.batch.Batch of size 128]
	[.src]:[torch.LongTensor of size 128x6]
	[.trg]:[torch.LongTensor of size 128x8]


In [5]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

from models.model6.definition import create_model
model = create_model(INPUT_DIM, OUTPUT_DIM, SRC_PAD_IDX, TRG_PAD_IDX, device)

In [6]:
from utils import count_parameters, train_loop
from models.model6.train import initialize_weights, train, evaluate
model.apply(initialize_weights)
count_parameters(model)

5565442

In [None]:
LEARNING_RATE = 0.0005
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)
train_loop(model, optimizer, criterion, train, evaluate,
           train_iterator, valid_iterator, 
           save_path='result/', file_name='model6-rule-v2.1.pt', load_before_train=True)

In [8]:
from utils.translate import Translator
from models.model6.definition import beam_search
model.load_state_dict(torch.load('result/model6-rule-v2.1.pt', map_location=torch.device(device)))
T = Translator(SRC, TRG, model, device, beam_search)

In [9]:
data = 'target creature gets - 1 / - 1 until end of turn .'
data = 'Whenever <1> becomes attached to a creature, for as long as <1> remains attached to it, you may have that creature become a copy of another target creature you control.'
ret, prob = T.translate(data, max_len=100)
print(*ret[:3], sep='\n')

['每当', '<', '1', '>', '贴附于', '其', '上', '结附', '灵气', '的', '时段', '内', '，', '你', '可以', '使', '<', '1', '>', '成为', '另', '一个', '目标', '生物', '，', '且', '你', '可以', '令', '该', '生物', '的', '复制品', '。', '<eos>']
['每当', '<', '1', '>', '贴附于', '其', '上', '结附', '灵气', '的', '时段', '内', '，', '你', '可以', '使', '<', '1', '>', '成为', '另', '一个', '目标', '生物', '的', '复制品', '，', '且', '你', '可以', '令', '此', '衍生物', '成为', '另', '一个', '目标', '生物', '。', '<eos>']
['每当', '<', '1', '>', '贴附于', '其', '上', '结附', '灵气', '的', '时段', '内', '，', '你', '可以', '使', '<', '1', '>', '成为', '另', '一个', '目标', '生物', '的', '复制品', '，', '且', '你', '可以', '令', '此', '衍生物', '成为', '另', '一个', '目标', '生物', '之', '复制品', '。', '<eos>']


In [10]:
from dataset.mtgcards import TestSets
from utils import calculate_bleu
from torchtext.legacy.data import Field
from models.card_name_detector.definition import TrainedDetector
from utils.translate import sentencize, CardTranslator, CTHelper

fields = {'src-rule': ('src', Field(tokenize=lambda x: x.split(' '))), 'trg-rule': ('trg', Field())}
test_data = TestSets.load(fields)

D = TrainedDetector()

path: d:\ddw\school\大三下\语音信息处理技术\期末作业\code\mtg-cards-translation\models\card_name_detector


In [14]:
dic = {}
dic = {'oil':'烁油', 'rebel':'反抗军'}
helper = CTHelper(D, dic)
CT = CardTranslator(sentencize, T, 
                    preprocess=lambda x: helper.preprocess(x, False), 
                    postprocess=lambda x: helper.postprocess(x, False))

example = list(test_data)[13]
example = list(test_data)[8]
# ret = CT.translate(' '.join(example.src))
# print(ret)
for example in random.sample(list(test_data), 3):
    print(vars(example))
    ret = CT.translate(' '.join(example.src))
    print(ret)

{'src': ['Return', 'target', 'nonland', 'permanent', 'to', 'its', "owner's", 'hand.', 'If', 'that', 'permanent', 'had', 'mana', 'value', '3', 'or', 'less,', 'proliferate.', '(Choose', 'any', 'number', 'of', 'permanents', 'and/or', 'players,', 'then', 'give', 'each', 'another', 'counter', 'of', 'each', 'kind', 'already', 'there.)'], 'trg': ['将目标非地永久物移回其拥有者手上。如果该永久物的法术力值等于或小于3，则增殖。（选择任意数量的永久物和／或牌手，然后为其已有之每种指示物各多放置一个同类的指示物。）']}
[after preprocess]:return target non land permanent to its owner 's hand .
[before postprocess]:将目标非地永久物移回其拥有者手上。
[after preprocess]:if that permanent had mana value 3 or less , proliferate .
[before postprocess]:如果该永久物的法术力值等于或小于3，则增殖。
[after preprocess]:choose any number of permanents and / or players , then give each another counter of each kind already there .
[before postprocess]:你选择任意数量其上有指示物的永久物和／或牌手，然后在其上放置一个它已有之类别的指示物。
将目标非地永久物移回其拥有者手上。 如果该永久物的法术力值等于或小于3，则增殖。 你选择任意数量其上有指示物的永久物和／或牌手，然后在其上放置一个它已有之类别的指示物。
{'src': ['Evolved', 'Spinoderm', 'enters', 'the', 'batt

In [None]:
from utils import calculate_testset_bleu
calculate_testset_bleu(list(test_data)[:100], CT)