In [1]:
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.nn import Transformer
import torch.nn.functional as F
from torch import Tensor
from torch import nn
import warnings
import random
import torch
import math
import yaml
import json
import os
# warnings.filterwarnings("ignore")

In [None]:
class index2char():
    def __init__(self, root, tokenizer=None):
        if tokenizer is None:
            with open(root + '/tokenizer.yaml', 'r') as f:
                self.tokenizer = yaml.load(f, Loader=yaml.CLoader)
        else:
            self.tokenizer = tokenizer
    
    def __call__(self, indices:list, without_token=True):
        if type(indices) == Tensor:
            indices = indices.tolist()
        result = ''.join([self.tokenizer['index_2_char'][i] for i in indices])
        if without_token:
            result = result.split('[eos]')[0]
            result = result.replace('[sos]', '').replace('[eos]', '').replace('[pad]', '')
        return result

In [None]:
def metrics(pred:list, target:list) -> float:
    """
    pred: list of strings
    target: list of strings

    return: accuracy(%)
    """
    if len(pred) != len(target):
        raise ValueError('length of pred and target must be the same')
    correct = 0
    for i in range(len(pred)):
        if pred[i] == target[i]:
            correct += 1
    return correct / len(pred) * 100

In [2]:
embedding_num = 31
embedding_dim = 512
num_layers = 8
num_heads = 8
ff_dim = 1024
dropout = 0.1

In [3]:
class SpellCorrectionDataset(Dataset):
    def __init__(self, root, split:str = 'train', tokenizer=None, padding:int =0):
        super(SpellCorrectionDataset, self).__init__()
        #load your data here
        pass

    def tokenize(self, text:str):
        # tokenize your text here
        # ex: "data" -> [4, 1, 20, 1]
        pass
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # get your data by index here
        # ex: return input_ids, target_ids
        # return type: torch.tensor
        return None

In [5]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, batch_first: bool = False):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        self.batch_first = batch_first

    def forward(self, x: Tensor) -> Tensor:
        if self.batch_first:
            x = x.transpose(0, 1)
            x = x + self.pe[:x.size(0)]
            return self.dropout(x.transpose(0, 1))
        else:
            x = x + self.pe[:x.size(0)]
            return self.dropout(x)

In [6]:
class Encoder(nn.Module):
    def __init__(self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100):
        super(Encoder, self).__init__()
        self.tok_embedding = nn.Embedding(num_emb, hid_dim)
        self.pos_embedding = PositionalEncoding(hid_dim, dropout, max_length, batch_first=True)
        # self.layer = <nn.TransformerEncoderLayer>
        # self.encoder = <nn.TransformerEncoder>

    def forward(self, src, some_mask="put all your masks here: mask1, mask2, ..."):
        # tgt = your_embeddings(?)
        # src = self.encoder(?)
        return src

class Decoder(nn.Module):
    def __init__(self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100):
        super(Decoder, self).__init__()
        self.tok_embedding = nn.Embedding(num_emb, hid_dim)
        self.pos_embedding = PositionalEncoding(hid_dim, dropout, max_length, batch_first=True)
        # self.layer = <nn.TransformerDecoderLayer>
        # self.encoder = <nn.TransformerDecoder>

    def forward(self, tgt, some_mask="put all your masks here: mask1, mask2, ..."):
        # tgt = your_embeddings(?)
        # tgt = self.decoder(?)
        return tgt

class TransformerAutoEncoder(nn.Module):
    def __init__(self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100, encoder=None):
        super(TransformerAutoEncoder, self).__init__()
        if encoder is None:
            self.encoder = Encoder(num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length)
        else:
            self.encoder = encoder
        self.decoder = Decoder(num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length)

    def forward(self, src, tgt, src_pad_mask, tgt_mask, tgt_pad_mask):
        # enc_src = self.encoder(?)
        # out = self.decoder(?)
        return #out

In [7]:
def gen_padding_mask(src, pad_idx):
    # detect where the padding value is
    return #pad_mask

def gen_mask(seq):
    # triu mask for decoder
    return #mask

def get_index(pred, dim=2):
    return pred.clone().argmax(dim=dim)

def random_change_idx(data: torch.Tensor, prob: float = 0.2):
    # randomly change the index of the input data
    return #sample

def random_masked(data: torch.Tensor, prob: float = 0.2, mask_idx: int = 3):
    # randomly mask the input data
    return #sample

# Pretrained encoder with random mask

In [None]:
# You can try to pretrain the Encoder here!

# Train our spelling correction transformer

In [8]:
from tqdm import tqdm
i2c = index2char('./data/')

trainset = SpellCorrectionDataset('./data/', padding=22)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
testset = SpellCorrectionDataset('./data/', split='test', padding=22)
testloader = DataLoader(testset, batch_size=32, shuffle=False)
valset = SpellCorrectionDataset('./data/', split='valid', padding=22)
valloader = DataLoader(valset, batch_size=32, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ce_loss = nn.CrossEntropyLoss(ignore_index=0)

In [9]:
def validation(dataloader, model, device, logout=False):
    pred_str_list = []
    tgt_str_list = []
    input_str_list = []
    losses = []
    for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)
            tgt_input = None#An all pad token tensor with the same shape as tgt and the first token is <sos>
            for i in range(tgt.shape[1]-1):
                src_pad_mask = None#generate the padding mask
                tgt_pad_mask = None#generate the padding mask
                tgt_mask = None#generate the mask
                pred = model(src, tgt_input, src_pad_mask, tgt_mask, tgt_pad_mask)
                # pred = <get the prediction idx from the model>
                # assign the prediction idx to the next token of tgt_input
            for i in range (tgt.shape[0]):
                pred_str_list.append(i2c(tgt_input[i].tolist()))
                tgt_str_list.append(i2c(tgt[i].tolist()))
                input_str_list.append(i2c(src[i].tolist()))
                if logout:
                    print('='*30)
                    print(f'input: {input_str_list[-1]}')
                    print(f'pred: {pred_str_list[-1]}')
                    print(f'target: {tgt_str_list[-1]}')
            loss = ce_loss(pred[:, :-1, :].permute(0, 2, 1), tgt[:, 1:])
            losses.append(loss.item())
    print(f"test_acc: {metrics(pred_str_list, tgt_str_list):.2f}", f"test_loss: {sum(losses)/len(losses):.2f}", end=' | ')
    print(f"[pred: {pred_str_list[0]} target: {tgt_str_list[0]}]")

In [None]:
# encoder.pretrained_mode = False
model = TransformerAutoEncoder(embedding_num, embedding_dim, num_layers, num_heads, ff_dim, dropout).to(device)
optimizer = None#choose your optimizer
for eps in range(1000):
    # train
    losses = []
    model.train()
    i_bar = tqdm(trainloader, unit='iter', desc=f'epoch{eps}')
    for src, tgt in i_bar:
        src, tgt = src.to(device), tgt.to(device)
        # generate the mask and padding mask
        src_pad_mask = None#generate the padding mask
        tgt_pad_mask = None#generate the padding mask
        tgt_mask = None#generate the mask
        optimizer.zero_grad()
        pred = model(src, tgt, src_pad_mask, tgt_mask, tgt_pad_mask)
        loss = ce_loss("put your prediction and target here")
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        i_bar.set_postfix_str(f"loss: {sum(losses)/len(losses):.3f}")
    # test
    model.eval()
    with torch.no_grad():
        validation(testloader, model, device)
    model.eval()
    with torch.no_grad():
        validation(valloader, model, device)
    # eval

In [None]:
validation(testloader, model, device, logout=True)

In [None]:
validation(valloader, model, device, logout=True)