In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy.data import Field, TabularDataset, BucketIterator

from dataset.mtgcards import RuleText
from utils.preprocess import fields_for_rule_text

import random
from tqdm import tqdm

import models.model6_1 as model6_1

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SRC, TRG = fields_for_rule_text(include_lengths=False, batch_first=True)
fields = {'src': ('src', SRC), 'trg': ('trg', TRG)}

train_data, valid_data, test_data = RuleText.splits(fields=fields, version='v2.2')

In [3]:
SRC.build_vocab(train_data, min_freq = 4)
TRG.build_vocab(train_data, min_freq = 4)
print(f"Unique tokens in source (en) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (zh) vocabulary: {len(TRG.vocab)}")

for x in random.sample(list(train_data), 1):
    print(x.src, x.trg)

Unique tokens in source (en) vocabulary: 1294
Unique tokens in target (zh) vocabulary: 1949
['when', '<', '1', '>', 'enters', 'the', 'battlefield', ',', 'any', 'player', 'may', 'sacrifice', 'a', 'land', '.'] ['当', '<', '1', '>', '进场', '时', '，', '任意', '牌手', '可以', '牺牲', '一个', '地', '。']


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE = 128

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    sort_within_batch = True,
    sort_key = lambda x: len(x.src),
    device = device)

print(next(iter(train_iterator)))

cuda

[torchtext.legacy.data.batch.Batch of size 128]
	[.src]:[torch.cuda.LongTensor of size 128x34 (GPU 0)]
	[.trg]:[torch.cuda.LongTensor of size 128x46 (GPU 0)]


In [6]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

model = model6_1.create_model(INPUT_DIM, OUTPUT_DIM, SRC_PAD_IDX, TRG_PAD_IDX, device,
                            )

Load parameters from /gemini/code/models/model6_1/configs/default.json
Parameters: {'HID_DIM': 256, 'ENC_LAYERS': 3, 'DEC_LAYERS': 3, 'ENC_HEADS': 8, 'DEC_HEADS': 8, 'ENC_PF_DIM': 512, 'DEC_PF_DIM': 512, 'ENC_DROPOUT': 0.1, 'DEC_DROPOUT': 0.1}


In [7]:
from utils import count_parameters, train_loop
from models.model6.train import initialize_weights, train, evaluate
model.apply(initialize_weights)
count_parameters(model)

5284765

In [9]:
LEARNING_RATE = 0.0005
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)
train_loop(model, optimizer, criterion, train, evaluate,
           train_iterator, valid_iterator, 
           save_path='result/', file_name='model6.1-rule-v2.2.pt', load_before_train=True)

model will be saved to result/model6.1-rule-v2.2.pt
load model parameters from result/model6.1-rule-v2.2.pt
Val. Loss: 0.127 |  Val. PPL:   1.136


100%|██████████| 454/454 [00:20<00:00, 22.40it/s]


Epoch: 01 | Time: 0m 20s
	Train Loss: 0.125 | Train PPL:   1.133
	 Val. Loss: 0.129 |  Val. PPL:   1.137


100%|██████████| 454/454 [00:19<00:00, 22.84it/s]


Epoch: 02 | Time: 0m 19s
	Train Loss: 0.112 | Train PPL:   1.119
	 Val. Loss: 0.123 |  Val. PPL:   1.131


100%|██████████| 454/454 [00:19<00:00, 23.31it/s]


Epoch: 03 | Time: 0m 19s
	Train Loss: 0.105 | Train PPL:   1.111
	 Val. Loss: 0.118 |  Val. PPL:   1.125


100%|██████████| 454/454 [00:19<00:00, 23.18it/s]


Epoch: 04 | Time: 0m 19s
	Train Loss: 0.102 | Train PPL:   1.107
	 Val. Loss: 0.112 |  Val. PPL:   1.118


100%|██████████| 454/454 [00:20<00:00, 22.68it/s]


Epoch: 05 | Time: 0m 20s
	Train Loss: 0.097 | Train PPL:   1.102
	 Val. Loss: 0.109 |  Val. PPL:   1.115


100%|██████████| 454/454 [00:19<00:00, 23.02it/s]


Epoch: 06 | Time: 0m 19s
	Train Loss: 0.086 | Train PPL:   1.090
	 Val. Loss: 0.108 |  Val. PPL:   1.114


100%|██████████| 454/454 [00:19<00:00, 22.83it/s]


Epoch: 07 | Time: 0m 20s
	Train Loss: 0.087 | Train PPL:   1.091
	 Val. Loss: 0.103 |  Val. PPL:   1.109


100%|██████████| 454/454 [00:20<00:00, 22.68it/s]


Epoch: 08 | Time: 0m 20s
	Train Loss: 0.081 | Train PPL:   1.084
	 Val. Loss: 0.103 |  Val. PPL:   1.108


100%|██████████| 454/454 [00:20<00:00, 22.40it/s]


Epoch: 09 | Time: 0m 20s
	Train Loss: 0.078 | Train PPL:   1.081
	 Val. Loss: 0.109 |  Val. PPL:   1.116


100%|██████████| 454/454 [00:19<00:00, 22.91it/s]


Epoch: 10 | Time: 0m 19s
	Train Loss: 0.076 | Train PPL:   1.079
	 Val. Loss: 0.101 |  Val. PPL:   1.106


In [10]:
from utils.translate import Translator
from models.model6.definition import beam_search
model.load_state_dict(torch.load('result/model6.1-rule-v2.2.pt', map_location=torch.device(device)))
T = Translator(SRC, TRG, model, device, beam_search)

In [11]:
data = 'Whenever <1> becomes attached to a creature, for as long as <1> remains attached to it, you may have that creature become a copy of another target creature you control.'
data = 'target creature gets - 1 / - 1 until end of turn .'
ret, prob = T.translate(data, max_len=100)
print(*ret[:3], sep='\n')

['目标', '生物', '得', '-', '1', '/', '-', '1', '直到', '回合', '结束', '。', '<eos>']
['直到', '回合', '结束', '，', '目标', '生物', '得', '-', '1', '/', '-', '1', '。', '<eos>']
['目标', '结附于', '生物', '得', '-', '1', '/', '-', '1', '直到', '回合', '结束', '。', '<eos>']


In [12]:
from dataset.mtgcards import TestSets
from utils import calculate_bleu
from torchtext.legacy.data import Field
from models.card_name_detector.definition import TrainedDetector
from utils.translate import sentencize, CardTranslator, CTHelper

fields = {'src-rule': ('src', Field(tokenize=lambda x: x.split(' '))), 'trg-rule': ('trg', Field())}
test_data = TestSets.load(fields)

D = TrainedDetector()

path: /gemini/code/models/card_name_detector


In [13]:
dic = {}
dic = {'oil':'烁油', 'rebel':'反抗军'}
helper = CTHelper(D, dic)
silent = False
CT = CardTranslator(sentencize, T, 
                    preprocess=lambda x: helper.preprocess(x, silent), 
                    postprocess=lambda x: helper.postprocess(x, silent))

example = list(test_data)[13]
example = list(test_data)[8]
# ret = CT.translate(' '.join(example.src))
# print(ret)
for example in random.sample(list(test_data), 3):
    print(vars(example))
    ret = CT.translate(' '.join(example.src))
    print(ret)

{'src': ['Vigilance\nIf', 'a', 'permanent', 'entering', 'the', 'battlefield', 'causes', 'a', 'triggered', 'ability', 'of', 'a', 'permanent', 'you', 'control', 'to', 'trigger,', 'that', 'ability', 'triggers', 'an', 'additional', 'time.\nPermanents', 'entering', 'the', 'battlefield', "don't", 'cause', 'abilities', 'of', 'permanents', 'your', 'opponents', 'control', 'to', 'trigger.'], 'trg': ['警戒', '如果由你操控之永久物具有的触发式异能因永久物进战场而触发，则该异能额外触发一次。', '进战场的永久物不会触发由对手操控之永久物的异能。']}


AttributeError: 'str' object has no attribute 'removeprefix'

In [34]:
from utils import calculate_testset_bleu
calculate_testset_bleu(list(test_data)[:100], CT)

100%|██████████| 100/100 [01:50<00:00,  1.11s/it]


0.652888170122143