In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy.data import Field, TabularDataset, BucketIterator
from dataset.mtgcards import RuleText

import random
import math
import time
import os
import re
import spacy
from typing import Callable

from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def simple_tokenize(x: str)->list:
    return x.split()

SRC = Field(tokenize = simple_tokenize, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True, 
            include_lengths=True)
TRG = Field(tokenize = simple_tokenize, 
            init_token = '0', 
            eos_token = '0', 
            lower = True)

fields = {'src': ('src', SRC), 'trg': ('trg', TRG)}
train_data, valid_data, test_data = RuleText.splits(fields=fields, version='cnd')
print(f'Number of train_data: {len(train_data)}  Number of train_data: {len(valid_data)}  Number of train_data: {len(test_data)}')

Number of train_data: 37240  Number of train_data: 980  Number of train_data: 981


In [3]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)
print(f"Unique tokens in source (en) vocabulary: {len(SRC.vocab)}")
print(f"Unique tokens in target (0/1) vocabulary: {len(TRG.vocab)}")

Unique tokens in source (en) vocabulary: 4844
Unique tokens in target (0/1) vocabulary: 4


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE = 128

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    sort_within_batch = True,
    sort_key = lambda x: len(x.src),
    device = device)

cpu


In [5]:
class Detector(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, output_dim, num_layers, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(input_dim, emb_dim)
        
        self.rnn = nn.GRU(emb_dim, hid_dim, num_layers = num_layers, bidirectional = True)

        self.fc_out = nn.Linear(hid_dim * 2 + emb_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_len):
        
        #src = [src len, batch size]
        #src_len = [batch size]
        
        embedded = self.dropout(self.embedding(src))
        
        #embedded = [src len, batch size, emb dim]
                
        #need to explicitly put lengths on cpu!
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, src_len.to('cpu'))
                
        packed_outputs, hidden = self.rnn(packed_embedded)
                                 
        #packed_outputs is a packed sequence containing all hidden states
        #hidden is now from the final non-padded element in the batch
            
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs) 
            
        #outputs is now a non-packed sequence, all hidden states obtained
        #  when the input is a pad token are all zeros
            
        #outputs = [src len, batch size, hid dim * num directions]
        #hidden = [n layers * num directions, batch size, hid dim]
        
        #hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
        #outputs are always from the last layer
        
        outputs = self.fc_out(torch.cat((outputs, embedded), 2))
        # outputs = self.fc_out(outputs)
        
        #outputs = [src len, batch size, output_dim]
        
        return outputs

In [6]:
from models.model4.train import init_weights
from utils import count_parameters, train_loop

In [7]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
EMB_DIM = 256
HID_DIM = 512
DROPOUT = 0.5
NUM_LAYERS = 2
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]

model = Detector(INPUT_DIM, EMB_DIM, HID_DIM, OUTPUT_DIM, NUM_LAYERS, DROPOUT).to(device)

model.apply(init_weights)

Detector(
  (embedding): Embedding(4844, 256)
  (rnn): GRU(256, 512, num_layers=2, bidirectional=True)
  (fc_out): Linear(in_features=1280, out_features=4, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [8]:
print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 8,335,364 trainable parameters


In [9]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    # print(f'Number of batchs: {len(iterator)}')
    for i, batch in tqdm(enumerate(iterator), total=len(iterator)):
        
        src, src_len = batch.src
        trg = batch.trg
        
        optimizer.zero_grad()
        
        output = model(src, src_len)
        
        #trg = [trg len, batch size]
        #output = [src len, batch size, output dim]
        
        output_dim = output.shape[-1]
        
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        
        #trg = [(trg len - 1) * batch size]
        #output = [(trg len - 1) * batch size, output dim]
        
        loss = criterion(output, trg)

        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in tqdm(enumerate(iterator), total=len(iterator)):
            
            src, src_len = batch.src
            trg = batch.trg
            
            output = model(src, src_len)
            
            #trg = [trg len, batch size]
            #output = [src len, batch size, output dim]
            
            output_dim = output.shape[-1]
            
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)
            
            #trg = [(trg len - 1) * batch size]
            #output = [(trg len - 1) * batch size, output dim]
            
            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [10]:
optimizer = optim.Adam(model.parameters())
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [11]:
N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

file_name = 'cn-mask-model.pt'

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), file_name)
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

  2%|▏         | 5/291 [00:11<11:25,  2.40s/it]


KeyboardInterrupt: 

In [79]:
def detect_card_name(sentence: str, src_field: Field, trg_field: Field, model, device, **kwargs):
    model.eval()

    with torch.no_grad():
        input = src_field.preprocess(sentence)
        print(input)
        input, len = src_field.process([input])
        logits = model(input.to(device), len).squeeze(dim=1)
        probs = F.softmax(logits, dim=1)[:,trg_field.vocab.stoi['1']]
        id_list = logits.argmax(dim=1)
        output = [trg_field.vocab.itos[x] for x in id_list]

    return [output[1:]], [x.item() for x in list(probs)[1:]]

model.load_state_dict(torch.load(file_name, map_location=torch.device(device)))
data = 'when squadron hawk enters the battlefield , you may search your library for up to three cards named squadron hawk , reveal them , put them into your hand , then shuffle .'
pred = detect_card_name(data, SRC, TRG, model, device)
print(pred)

['when', 'squadron', 'hawk', 'enters', 'the', 'battlefield', ',', 'you', 'may', 'search', 'your', 'library', 'for', 'up', 'to', 'three', 'cards', 'named', 'squadron', 'hawk', ',', 'reveal', 'them', ',', 'put', 'them', 'into', 'your', 'hand', ',', 'then', 'shuffle', '.']
([['0', '1', '1', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0']], [0.000128699277411215, 0.9999594688415527, 0.9999862909317017, 1.820418901843368e-06, 2.550347133478681e-08, 4.9496897247536253e-08, 1.0808946626639226e-08, 1.3701204615301776e-09, 1.541576821750823e-08, 1.489890109951375e-05, 9.471022188733969e-09, 6.454686118928521e-09, 4.723096580505626e-09, 3.0016389374054597e-09, 2.962586176380455e-09, 7.407108260082396e-10, 8.776032123236632e-10, 5.687177804247767e-07, 0.10927976667881012, 0.16197699308395386, 1.6219815734075382e-05, 1.3087503303310655e-09, 1.3123041542328906e-10, 5.978242345605622e-10, 1.065

In [72]:
from utils.translate import Translator
model.load_state_dict(torch.load(file_name, map_location=torch.device(device)))
T = Translator(SRC, TRG, model, device, detect_card_name)

In [74]:
data = 'when xbp , the greatest one enters the battlefield , other creatures you control get + 1 / + 1 until end of turn .'
ret, probs = T.translate(data)
print(ret)
print(probs)

['when', 'xbp', ',', 'the', 'greatest', 'one', 'enters', 'the', 'battlefield', ',', 'other', 'creatures', 'you', 'control', 'get', '+', '1', '/', '+', '1', 'until', 'end', 'of', 'turn', '.']
[['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0']]
[2.3747002160234842e-06, 0.4759068787097931, 0.06690655648708344, 0.15563955903053284, 0.0015627037500962615, 0.008867300115525723, 1.6062978502873193e-08, 4.173410417251944e-08, 2.9044116445220425e-07, 6.742133251691484e-08, 3.217904553931561e-10, 1.353527290248735e-10, 3.536294035377807e-10, 1.516903380682777e-09, 2.9533731016329057e-09, 9.621600405296249e-09, 3.1035025660486326e-09, 5.420458526472771e-10, 7.021184078581655e-09, 3.5951785992693885e-09, 6.923982276418883e-09, 5.586129555013031e-05, 2.3767276502439927e-07, 1.3846920410287566e-07, 1.6650687939545605e-08, 3.5055805369665904e-08]


In [78]:
from utils import show_samples
long_data = [x for x in test_data.examples if len(x.src) > 30]
print(f'Number of samples: {len(long_data)}')
show_samples(long_data, T, n=3, beam_size=3)

Number of samples: 49
['when', 'squadron', 'hawk', 'enters', 'the', 'battlefield', ',', 'you', 'may', 'search', 'your', 'library', 'for', 'up', 'to', 'three', 'cards', 'named', 'squadron', 'hawk', ',', 'reveal', 'them', ',', 'put', 'them', 'into', 'your', 'hand', ',', 'then', 'shuffle', '.']
src: [when squadron hawk enters the battlefield , you may search your library for up to three cards named squadron hawk , reveal them , put them into your hand , then shuffle . ] trg = [011000000000000000110000000000000]
0110000000000000000000000000000000 	[probability: 0.00013]


IndexError: list index out of range

In [None]:
from utils import calculate_bleu

bleu = calculate_bleu(long_data, lambda x: T.translate(x, beam_size=3)[0][0])
print(bleu*100)

In [14]:
import json

with open('src_vocab.json', 'w') as f:
    json.dump(SRC.vocab.stoi, f, ensure_ascii=False)
with open('trg_vocab.json', 'w') as f:
    json.dump(TRG.vocab.stoi, f, ensure_ascii=False)

In [25]:
torch.save(SRC.vocab, 'src_vocab.pt')
torch.save(TRG.vocab, 'trg_vocab.pt')

In [24]:
v = torch.load('src_vocab.pt')
print(v.stoi)

defaultdict(<bound method Vocab._default_unk_index of <torchtext.legacy.vocab.Vocab object at 0x00000225449CF190>>, {'<unk>': 0, '<pad>': 1, '<sos>': 2, '<eos>': 3, '.': 4, ',': 5, 'you': 6, '{': 7, '}': 8, 'the': 9, 'a': 10, 'creature': 11, '1': 12, 'of': 13, '+': 14, 'your': 15, '/': 16, 'to': 17, 'target': 18, 'it': 19, 'card': 20, 'this': 21, 'control': 22, ':': 23, 'battlefield': 24, 'or': 25, 'turn': 26, '2': 27, 'and': 28, 'that': 29, 'if': 30, 'enters': 31, 'on': 32, 'put': 33, 'whenever': 34, 'may': 35, 'with': 36, "'s": 37, 'end': 38, 'each': 39, 'when': 40, 'cast': 41, 'from': 42, 'until': 43, 'damage': 44, 'spell': 45, 'as': 46, 'an': 47, 'player': 48, 'cards': 49, 'flying': 50, 'creatures': 51, 't': 52, 'counter': 53, 'for': 54, '3': 55, 'its': 56, 'hand': 57, 'life': 58, 'library': 59, 'deals': 60, 'graveyard': 61, 'gets': 62, "n't": 63, 'x': 64, 'mana': 65, '-': 66, 'draw': 67, 'at': 68, 'is': 69, 'create': 70, 'exile': 71, 'opponent': 72, 'token': 73, '—': 74, 'any': 75