In [151]:
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.nn import Transformer
import torch.nn.functional as F
from torch import Tensor
from torch import nn
import warnings
import random
import torch
import math
import yaml
import json
import os
# warnings.filterwarnings("ignore")

In [152]:
class index2char():
    def __init__(self, root, tokenizer=None):
        if tokenizer is None:
            with open(root + '/tokenizer.yaml', 'r') as f:
                self.tokenizer = yaml.load(f, Loader=yaml.CLoader)
        else:
            self.tokenizer = tokenizer
    
    def __call__(self, indices:list, without_token=True):
        if type(indices) == Tensor:
            indices = indices.tolist()
        result = ''.join([self.tokenizer['index_2_char'][i] for i in indices])
        if without_token:
            result = result.split('[eos]')[0]
            result = result.replace('[sos]', '').replace('[eos]', '').replace('[pad]', '')
        return result

In [153]:
def metrics(pred:list, target:list) -> float:
    """
    pred: list of strings
    target: list of strings

    return: accuracy(%)
    """
    if len(pred) != len(target):
        raise ValueError('length of pred and target must be the same')
    correct = 0
    for i in range(len(pred)):
        if pred[i] == target[i]:
            correct += 1
    return correct / len(pred) * 100

In [154]:
embedding_num = 31
embedding_dim = 512
num_layers = 8
num_heads = 8
ff_dim = 1024
dropout = 0.1

In [155]:
class SpellCorrectionDataset(Dataset):
    def __init__(self, root, split:str = 'train', tokenizer=None, padding:int =0):
        super(SpellCorrectionDataset, self).__init__()
        #load your data here
        self.padding = padding
        
        if tokenizer:
            self.tokenizer = tokenizer
        else:
            with open(os.path.join(root, 'tokenizer.yaml'), 'r') as f:
                self.tokenizer = yaml.load(f, Loader=yaml.CLoader)
        
        data_path = os.path.join(root, f'{split}.json')
        with open(data_path, 'r') as f:
            self.all_data = json.load(f)
        self.data =[]
        for line in range(len(self.all_data)):
            for input in self.all_data[line]['input']:
                self.data.append({'input':input,'target':self.all_data[line]['target']})
    
    def tokenize(self, text:str):
        # tokenize your text here
        # ex: "data" -> [4, 1, 20, 1]
        
        # 將文本轉換為索引序列
        tokens = [self.tokenizer['char_2_index'].get(char, 0) for char in text]  # 0 可以是未識別字符的索引
        # 根據指定的padding進行填充或截斷
        if self.padding > 0:
            tokens = tokens[:self.padding] + [0] * max(0, self.padding - len(tokens))  # 使用 0 進行填充
        return tokens
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # get your data by index here
        # ex: return input_ids, target_ids
        # return type: torch.tensor
        item = self.data[index]
        input_text = item['input']
        input_ids = self.tokenize(input_text)
        target_text = item['target']
        target_ids = self.tokenize(target_text)
        # print(f"input_ids={input_ids}")
        # print(f"input_text={input_text}")
        # print(f"target_text={target_text}")
        # print(f"target_ids={target_ids}")
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)

In [156]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, batch_first: bool = False):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        self.batch_first = batch_first

    def forward(self, x: Tensor) -> Tensor:
        if self.batch_first:
            x = x.transpose(0, 1)
            # print(f"x shape: {x.shape}, pe shape: {self.pe[:x.size(0)].shape}")
            x = x + self.pe[:x.size(0)]
            return self.dropout(x.transpose(0, 1))
        else:
            x = x + self.pe[:x.size(0)]
            return self.dropout(x)

# Transformer

In [157]:
class Encoder(nn.Module):
    def __init__(self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100):
        super(Encoder, self).__init__()
        self.tok_embedding = nn.Embedding(num_emb, hid_dim)
        self.pos_embedding = PositionalEncoding(hid_dim, dropout, max_length, batch_first=True)
        # self.layer = <nn.TransformerEncoderLayer>
        # self.encoder = <nn.TransformerEncoder>
        self.layer = nn.TransformerEncoderLayer(d_model=hid_dim, nhead=n_heads, dim_feedforward=ff_dim, dropout=dropout)
        self.encoder = nn.TransformerEncoder(self.layer, num_layers=n_layers)

    def forward(self, src, src_mask=None, src_pad_mask=None):
        # tgt = your_embeddings(?)
        src_emb = self.tok_embedding(src)
        src_emb = self.pos_embedding(src_emb)
        # print(f"src_emb contains NaN: {torch.isnan(src_emb).any()}")
    
        # src = self.encoder(?)
        enc_src = self.encoder(src_emb, mask=src_mask, src_key_padding_mask=src_pad_mask) #這邊開始出現 nan
        # print(f"enc_src contains NaN: {torch.isnan(enc_src).any()}")
        
        return enc_src

class Decoder(nn.Module):
    def __init__(self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100):
        super(Decoder, self).__init__()
        self.tok_embedding = nn.Embedding(num_emb, hid_dim)
        nn.init.xavier_uniform_(self.tok_embedding.weight)
        self.pos_embedding = PositionalEncoding(hid_dim, dropout, max_length, batch_first=True)
        # self.layer = <nn.TransformerDecoderLayer>
        # self.encoder = <nn.TransformerDecoder>
        self.layer = nn.TransformerDecoderLayer(d_model=hid_dim, nhead=n_heads, dim_feedforward=ff_dim, dropout=dropout)
        self.decoder = nn.TransformerDecoder(self.layer, num_layers=n_layers)

    def forward(self, tgt, enc_src, tgt_mask=None, memory_mask=None, src_pad_mask=None, tgt_key_padding_mask=None):
        # tgt = your_embeddings(?)
        tgt_emb = self.tok_embedding(tgt)
        tgt_emb = self.pos_embedding(tgt_emb)
        
        # tgt = self.decoder(?)
        dec_output = self.decoder(tgt_emb, enc_src, tgt_mask=tgt_mask,
                                  memory_mask=memory_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask,
                                  memory_key_padding_mask=src_pad_mask)
        
        return dec_output

class TransformerAutoEncoder(nn.Module):
    def __init__(self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100, encoder=None):
        super(TransformerAutoEncoder, self).__init__()
        if encoder is None:
            self.encoder = Encoder(num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length)
        else:
            self.encoder = encoder
        self.decoder = Decoder(num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length)
        self.output_layer = nn.Linear(hid_dim, num_emb)

    def forward(self, src, tgt, src_pad_mask=None, tgt_mask=None, tgt_pad_mask=None):
        # enc_src = self.encoder(?)
        enc_src = self.encoder(src, src_pad_mask=src_pad_mask)
        # print(f"enc_src:{enc_src}")
        
        # out = self.decoder(?)
        dec_out = self.decoder(tgt, enc_src, src_pad_mask=src_pad_mask, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_pad_mask)  
        # print(f"dec_out:{dec_out}")
        
        out = self.output_layer(dec_out)
        out = F.softmax(out, dim=-1)
        
        return out

# Mask

In [158]:
def gen_padding_mask(src, pad_idx):
    # detect where the padding value is
    pad_mask = (src == pad_idx).transpose(0, 1)  # 生成布林掩碼，padding 位置為 True
    return pad_mask

def gen_mask(seq):
    # triu mask for decoder
    seq_len = seq.size(0)  # 獲取序列長度
    # 生成上三角掩碼，並擴展以適配注意力的計算
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return mask

def get_index(pred, dim=1):
    return pred.clone().argmax(dim=dim)

def random_change_idx(data: torch.Tensor, prob: float = 0.2):
    # randomly change the index of the input data
    change_mask = torch.rand(data.size()) < prob  # 隨機生成的布林掩碼
    new_values = torch.randint(0, data.size(-1), data.size(), dtype=torch.long)
    # 用新值替換原始數據中的部分值
    sample = data.clone()
    sample[change_mask] = new_values[change_mask]
    return sample

def random_masked(data: torch.Tensor, prob: float = 0.2, mask_idx: int = 3):
    # randomly mask the input data
    mask = torch.rand(data.size()) < prob  # 隨機生成的布林掩碼
    # 將選中的值替換為指定的 mask_idx
    sample = data.clone()
    sample[mask] = mask_idx
    return sample

# Pretrained encoder with random mask

In [159]:
# You can try to pretrain the Encoder here!
class PretrainedMaskedEncoder(Encoder):
    def __init__(self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100, pretrained=True):
        super(PretrainedMaskedEncoder, self).__init__(num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length)
        self.pretrained = pretrained

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # Apply random masking before passing through the encoder
        if self.pretrained:
            src = random_masked(src, prob=0.2, mask_idx=0)  # Apply random masking to input data
        return super().forward(src, src_mask, src_key_padding_mask)

encoder = PretrainedMaskedEncoder(embedding_num, embedding_dim, num_layers, num_heads, ff_dim, dropout,  pretrained=True)

# Define some dummy input (batch_size=4, sequence_length=10)
src = torch.randint(0, embedding_num, (4, 10))

# Forward pass through the encoder
output = encoder(src)

print(output.shape)

torch.Size([4, 10, 512])


# Train our spelling correction transformer

In [160]:
from tqdm import tqdm
i2c = index2char('./data/')

trainset = SpellCorrectionDataset('./data/', padding=22)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
testset = SpellCorrectionDataset('./data/', split='new_test', padding=22)
testloader = DataLoader(testset, batch_size=32, shuffle=False)
valset = SpellCorrectionDataset('./data/', split='test', padding=22)
valloader = DataLoader(valset, batch_size=32, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ce_loss = nn.CrossEntropyLoss(ignore_index=0)

In [161]:
def validation(dataloader, model, device, logout=False):
    pred_str_list = []
    tgt_str_list = []
    input_str_list = []
    losses = []
    for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)
            tgt_input = torch.full((tgt.shape[0], tgt.shape[1]), fill_value=0, device=device)
            tgt_input[:, 0] = 1  # Set first token to <sos>
            
            for i in range(tgt.shape[1]-1):
                # Generate the padding masks
                src_pad_mask = gen_padding_mask(src, pad_idx=0)
                tgt_pad_mask = gen_padding_mask(tgt_input, pad_idx=0)
                tgt_mask = gen_mask(tgt_input).to(device)
                
                # Forward pass
                pred = model(src, tgt_input,
                             src_pad_mask=src_pad_mask,
                             tgt_mask=tgt_mask,
                             tgt_pad_mask=tgt_pad_mask)
                
                # pred = <get the prediction idx from the model>
                # assign the prediction idx to the next token of tgt_input
                pred_idx = get_index(pred[:, -1, :])
                tgt_input[:, i + 1] = pred_idx  # 將預測結果添加到 tgt_input
                
            for i in range (tgt.shape[0]):
                pred_str_list.append(i2c(tgt_input[i].tolist()))
                tgt_str_list.append(i2c(tgt[i].tolist()))
                input_str_list.append(i2c(src[i].tolist()))
                if logout:
                    print('='*30)
                    print(f'input: {input_str_list[-1]}')
                    print(f'pred: {pred_str_list[-1]}')
                    print(f'target: {tgt_str_list[-1]}')
            loss = ce_loss(pred[:, :-1, :].permute(0, 2, 1), tgt[:, 1:])
            losses.append(loss.item())
    print(f"test_acc: {metrics(pred_str_list, tgt_str_list):.2f}", f"test_loss: {sum(losses)/len(losses):.2f}", end=' | ')
    print(f"[pred: {pred_str_list[0]} target: {tgt_str_list[0]}]")

In [162]:
# encoder.pretrained_mode = False
model = TransformerAutoEncoder(embedding_num, embedding_dim, num_layers, num_heads, ff_dim, dropout).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) #choose your optimizer

for epoch in range(1000):
    # train
    losses = []
    model.train()
    i_bar = tqdm(trainloader, unit='iter', desc=f'epoch_{epoch+1}')

    for src, tgt in i_bar:
        src, tgt = src.to(device), tgt.to(device)

        # generate the mask and padding mask
        src_pad_mask = gen_padding_mask(src, pad_idx=0)  # Generate the padding mask
        tgt_pad_mask = gen_padding_mask(tgt, pad_idx=0)  # Generate the padding mask
        tgt_mask = gen_mask(tgt).to(device)  # Generate the mask
        
        optimizer.zero_grad()
        
        pred = model(src, tgt,
                     src_pad_mask=src_pad_mask,
                     tgt_mask=tgt_mask,
                     tgt_pad_mask=tgt_pad_mask)

        loss = ce_loss(pred[:, :-1, :].permute(0, 2, 1), tgt[:, 1:])
        loss.backward()
        
        optimizer.step()
        losses.append(loss.item())
        i_bar.set_postfix_str(f"loss: {sum(losses)/len(losses):.3f}")
        
    # test
    model.eval()
    with torch.no_grad():
        validation(testloader, model, device)
    model.eval()
    with torch.no_grad():
        validation(valloader, model, device)

epoch_1:   1%|          | 4/404 [00:00<00:35, 11.24iter/s, loss: nan]


test_acc: 0.00 test_loss: nan | [pred:  target: appreciate]
test_acc: 0.00 test_loss: nan | [pred:  target: contented]


In [163]:
validation(testloader, model, device, logout=True)

input: apreciate
pred: 
target: appreciate
input: appeciate
pred: 
target: appreciate
input: apprciate
pred: 
target: appreciate
input: apprecate
pred: 
target: appreciate
input: apprecite
pred: 
target: appreciate
input: luve
pred: 
target: love
input: culd
pred: 
target: cold
input: heart
pred: 
target: heart
input: televiseon
pred: 
target: television
input: thone
pred: 
target: phone
input: phace
pred: 
target: phase
input: poam
pred: 
target: poem
input: tomorraw
pred: 
target: tomorrow
input: presishan
pred: 
target: precision
input: presishion
pred: 
target: precision
input: presisian
pred: 
target: precision
input: presistion
pred: 
target: precision
input: perver
pred: 
target: prefer
input: predgudice
pred: 
target: prejudice
input: predgudis
pred: 
target: prejudice
input: recievor
pred: 
target: receiver
input: reciover
pred: 
target: receiver
input: relieve
pred: 
target: relief
input: togather
pred: 
target: together
input: remuttance
pred: 
target: remittance
input: depo

In [164]:
validation(valloader, model, device, logout=True)

input: contenpted
pred: 
target: contented
input: begining
pred: 
target: beginning
input: problam
pred: 
target: problem
input: dirven
pred: 
target: driven
input: ecstacy
pred: 
target: ecstasy
input: juce
pred: 
target: juice
input: localy
pred: 
target: locally
input: compair
pred: 
target: compare
input: pronounciation
pred: 
target: pronunciation
input: transportibility
pred: 
target: transportability
input: miniscule
pred: 
target: minuscule
input: independant
pred: 
target: independent
input: aranged
pred: 
target: arranged
input: poartry
pred: 
target: poetry
input: leval
pred: 
target: level
input: basicaly
pred: 
target: basically
input: triangulaur
pred: 
target: triangular
input: unexpcted
pred: 
target: unexpected
input: stanerdizing
pred: 
target: standardizing
input: varable
pred: 
target: variable
input: neigbours
pred: 
target: neighbours
input: enxt
pred: 
target: next
input: powerfull
pred: 
target: powerful
input: practial
pred: 
target: practical
input: repatition