In [None]:
%pip install transformers
%pip install datasets
%pip install torch
%pip install nltk
%pip install wandb
%pip install "ipywidgets>=7.0,<8.0"

In [1]:
from datasets import load_dataset
import random
import torch
from torch.utils.data import Dataset,DataLoader
import numpy as np
from torch import nn
from torch.nn.functional import pad
from collections import Counter
from tqdm.notebook import tqdm
import re
from collections import Counter
import wandb

In [2]:
def seed(random_seed=42):
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(random_seed)
        torch.cuda.manual_seed_all(random_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False   

In [3]:
mt_data = load_dataset("mt_eng_vietnamese","iwslt2015-en-vi");mt_data

Found cached dataset mt_eng_vietnamese (/home/aimenext/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 133318
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 1269
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 1269
    })
})

In [4]:
mt_data['train'][random.randint(0,len(mt_data['train']))]

{'translation': {'en': 'Bacteria are incredibly multi-drug-resistant right now , and that &apos;s because all of the antibiotics that we use kill bacteria .',
  'vi': 'Vi khuẩn bây giờ có khả năng đề kháng rất nhiều loài thuốc và đó là vì tất cả các loại kháng sinh chúng ta sử dụng giết chết vi khuẩn .'}}

In [5]:
def clean(batch):
    en = batch['translation']['en']
    vi = batch['translation']['vi']
    
    en = re.sub(r'&.*;','',en)
    en = re.sub(r'([0-9,\.]{2,})(?![a-z\-])','',en)
    en = re.sub(r'[-\.]{2,}[0-9\.]*','',en)
    en = ' '.join(list(filter(lambda x: len(x),en.split())))
    batch['translation']['en'] = en
    
    vi = re.sub(r'&.*;','',vi)
    vi = re.sub(r'[-\.]{2,}[0-9\.]*','',vi)
    vi = re.sub(r'([0-9,\.]{2,})(?![a-z\-])','',vi)
    vi = ' '.join(list(filter(lambda x: len(x),vi.split())))
    batch['translation']['vi'] = vi
    return batch

In [6]:
mt_data = mt_data.map(clean)

Loading cached processed dataset at /home/aimenext/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71/cache-0ec0d62847218b8a.arrow
Loading cached processed dataset at /home/aimenext/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71/cache-a0b2f1e553bc8cc1.arrow
Loading cached processed dataset at /home/aimenext/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71/cache-a52147ac6173bebf.arrow


In [7]:
mt_data = mt_data.filter(lambda x: len(x['translation']['en'].strip()) and len(x['translation']['vi'].strip()) and len(x['translation']['en'].strip().split()) < 30 and len(x['translation']['vi'].strip().split()) < 30)

Loading cached processed dataset at /home/aimenext/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71/cache-f499bd2143ab12bd.arrow
Loading cached processed dataset at /home/aimenext/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71/cache-a7cc12d5d234db64.arrow
Loading cached processed dataset at /home/aimenext/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71/cache-3a4a2326351661ae.arrow


In [8]:
PAD = '<pad>'
SOS = '<sos>'
EOS = '<eos>'
UNK = '<unk>'

In [9]:
def get_vocab(batch):
    all_word_en = {k:v for k,v in Counter(' '.join([b['en'].lower() for b in batch['translation']]).split()).items() if v > 3}
    all_word_vi = {k:v for k,v in Counter(' '.join([b['vi'].lower() for b in batch['translation']]).split()).items() if v > 3}
    
    return {'vocab_en':[all_word_en],'vocab_vi':[all_word_vi]}

In [10]:
vocabs = mt_data.map(get_vocab,batched=True,batch_size=-1,remove_columns=mt_data.column_names['train'])

Loading cached processed dataset at /home/aimenext/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71/cache-50f13a4819ebb190.arrow
Loading cached processed dataset at /home/aimenext/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71/cache-9bcfd1c796158201.arrow
Loading cached processed dataset at /home/aimenext/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71/cache-251682ed2a91767c.arrow


In [11]:
vocabs

DatasetDict({
    train: Dataset({
        features: ['vocab_en', 'vocab_vi'],
        num_rows: 1
    })
    validation: Dataset({
        features: ['vocab_en', 'vocab_vi'],
        num_rows: 1
    })
    test: Dataset({
        features: ['vocab_en', 'vocab_vi'],
        num_rows: 1
    })
})

In [12]:
vocab_en = list(set(vocabs['train']['vocab_en'][0])|set(vocabs['test']['vocab_en'][0]) |set(vocabs['validation']['vocab_en'][0]))
vocab_en = {v:k for k,v in enumerate([PAD,SOS,EOS,UNK] + sorted(vocab_en))}
rev_vocab_en = {v:k for k,v in vocab_en.items()}

In [13]:
vocab_vi = list(set(vocabs['train']['vocab_vi'][0])|set(vocabs['test']['vocab_vi'][0]) |set(vocabs['validation']['vocab_vi'][0]))
vocab_vi = {v:k for k,v in enumerate([PAD,SOS,EOS,UNK] + sorted(vocab_vi))}
rev_vocab_vi = {v:k for k,v in vocab_vi.items()}

In [14]:
print(len(vocab_vi))
print(len(vocab_en))

4964
11306


In [15]:
class Encoder(nn.Module):
    def __init__(self, vocab,emb_size,hid_size,dropout=0.5,pad=PAD):
        super(Encoder,self).__init__()
        self.emb = nn.Embedding(len(vocab),emb_size,padding_idx=vocab[pad])
        self.gru = nn.GRU(input_size=emb_size,hidden_size=hid_size,bidirectional=True)
        self.dropout= nn.Dropout(dropout)
        self.fc = nn.Linear(hid_size*2,hid_size)
    def forward(self,x):
        x = self.dropout(self.emb(x))
        outputs,hidden = self.gru(x)
        hidden = torch.cat((hidden[-2],hidden[-1]),dim=-1)
        
        return outputs, torch.tanh(self.fc(hidden))

In [16]:
class Attention(nn.Module):
    def __init__(self,hid_size):
        super(Attention,self).__init__()
        self.attn = nn.Linear(hid_size*3,hid_size)
        self.v    = nn.Linear(hid_size,1,bias=False)
    
    def forward(self,hidden,encoder_outputs):
        src_len = encoder_outputs.shape[0]
        hidden  = hidden.unsqueeze(1).expand(-1,src_len,-1)
        
        energy = torch.tanh(self.attn(torch.cat((hidden,encoder_outputs.permute(1,0,2)),dim=-1)))
        
        attn = self.v(energy).squeeze(-1)
        
        return torch.softmax(attn,dim=-1)
        

In [17]:
class Decoder(nn.Module):
    def __init__(self,vocab,emb_size,hid_size,attention,dropout=0.5,pad=PAD):
        super(Decoder,self).__init__()
        self.emb = nn.Embedding(len(vocab),emb_size,padding_idx=vocab[pad])
        self.gru = nn.GRU(input_size=hid_size*2 + emb_size,hidden_size=hid_size)
        self.fc = nn.Linear(hid_size*3 + emb_size, len(vocab))
        self.dropout = nn.Dropout(dropout)
        self.attention = attention
        
    def forward(self,x,hidden,encoder_outputs):
        x = x.unsqueeze(0)
        emb = self.dropout(self.emb(x))
        attn = self.attention(hidden,encoder_outputs).unsqueeze(1)
        encoder_outputs = encoder_outputs.permute(1,0,2)
        
        weighted = torch.bmm(attn,encoder_outputs).permute(1,0,2)
        gru_input = torch.cat((emb,weighted),dim=2)
        output, hidden = self.gru(gru_input,hidden.unsqueeze(0))
        assert (output == hidden).all(),(output,hidden)
        
        emb = emb.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        
        logits = self.fc(torch.cat((output,weighted,emb),dim=1))
        
        return logits,hidden.squeeze(0)

In [18]:
class MTModel(nn.Module):
    def __init__(self,in_vocab,out_vocab,emb_size=300,hid_size=256,dropout=0.5,sos=SOS,eos=EOS,pad=PAD):
        super(MTModel,self).__init__()
        self.in_vocab = in_vocab
        self.out_vocab  = out_vocab
        self.rev_out_vocab = {v:k for k,v in out_vocab.items()}
        self.sos = sos
        self.eos = eos
        self.pad = pad
        self.encoder = Encoder(in_vocab,emb_size,hid_size,dropout,pad)
        self.decoder = Decoder(out_vocab,emb_size,hid_size,Attention(hid_size),dropout,pad)
        
    def forward(self,src,trg,teacher_forcing_ratio=0.5):
        batch_size = trg.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = len(self.out_vocab)
        
        outputs = torch.zeros(trg_len,batch_size,trg_vocab_size).to(src.device)
        
        encoder_outputs, hidden = self.encoder(src)
        input = trg[0,:]
        
        for t in range(1,trg_len):
            logits, hidden = self.decoder(input,hidden,encoder_outputs)
            outputs[t] = logits
            
            teacher_force = random.random() < teacher_forcing_ratio
            
            input = trg[t,:] if teacher_force else logits.argmax(-1)
        
        return outputs
    
    def infer(self,src):
        if len(src.shape) == 1:
            src = src.unsqueeze(1)

        batch_size = src.shape[1]

        encoder_outputs, hidden = self.encoder(src)
        
        input = (torch.LongTensor(batch_size * [self.out_vocab[self.sos]])).to(src.device)
        max_len = 50
        cur_len = 1
        outputs = [input]
        while not (outputs[-1] == self.out_vocab[self.eos]).all() and cur_len < max_len:
            logits,hidden = self.decoder(input,hidden,encoder_outputs)
            
            outputs.append(logits.argmax(-1))
            input = outputs[-1]
            cur_len += 1
        
        res = [[] for _ in range(batch_size)]
        for i in range(len(outputs)):
            for j in range(batch_size):
                res[j].append(self.rev_out_vocab[outputs[i][j].item()])
        return res
        
                        

In [19]:
class TextDataset(Dataset):
    def __init__(self,src,trg,in_pad,out_pad):
        self.src = src
        self.trg = trg
        self.in_pad = in_pad
        self.out_pad = out_pad

    def __getitem__(self, index):
        return self.src[index],self.trg[index]
    
    def __len__(self):
        return len(self.src)
    
    def pad(self,inputs,PAD):
        def pad_data(x, length, PAD):
            x_padded = pad(
                x, (0, length - x.shape[0]), mode="constant", value=PAD
            )
            return x_padded

        max_len = max((len(x) for x in inputs))
        padded = torch.stack([pad_data(torch.LongTensor(x), max_len, PAD) for x in inputs])

        return padded

    
    def collate_fn(self,batch):
        src = []
        trg = []
        
        for s,t in batch:
            src.append(s)
            trg.append(t)
        return self.pad(src,self.in_pad), self.pad(trg,self.out_pad)

In [20]:
def text2id(batch):
    batch['input_ids_vi'] = [[vocab_vi[SOS]] + [vocab_vi.get(word,vocab_vi[UNK]) for word in b['vi'].lower().split()] + [vocab_vi[EOS]] for b in batch['translation']]    
    batch['input_ids_en'] = [[vocab_en[SOS]] + [vocab_en.get(word,vocab_en[UNK]) for word in b['en'].lower().split()] + [vocab_en[EOS]] for b in batch['translation']]
    return batch

In [21]:
mt_ids = mt_data.map(text2id,batch_size=8,batched=True,remove_columns=mt_data.column_names['train'])

Loading cached processed dataset at /home/aimenext/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71/cache-9f6d24d5cac99944.arrow
Loading cached processed dataset at /home/aimenext/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71/cache-ca7977cf65596fa3.arrow
Loading cached processed dataset at /home/aimenext/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71/cache-b96d1bb261464c8e.arrow


In [22]:
' '.join([rev_vocab_vi[i] for i in mt_ids['validation']['input_ids_vi'][25]])

'<sos> những cô gái này đã rất may mắn . <eos>'

In [23]:
def bleu(candidate,reference):
    overlap = 0
    can_counter = Counter(candidate)
    ref_counter = Counter(reference)
    for can in candidate:
        overlap += min(can_counter.get(can),ref_counter.get(can,0))
    
    return min(1,np.exp(1-len(reference)/len(candidate))) * overlap / sum(can_counter.values())

from nltk.translate.bleu_score import sentence_bleu

references = [['my', 'lol', 'correct', 'sentence']]
candidates = ['my', 'phucking','sentence']
print(sentence_bleu(references, candidates,weights=(1,)))
print(bleu(candidates,references[0]))

0.47768754038252614
0.4776875403825262


In [24]:
test_dataset = TextDataset(mt_ids['test']['input_ids_en'],mt_ids['test']['input_ids_vi'],vocab_en['<pad>'],vocab_vi['<pad>'])
test_dataloader = DataLoader(test_dataset,batch_size=16,collate_fn=test_dataset.collate_fn)

In [25]:
train_dataset = TextDataset(mt_ids['train']['input_ids_en'],mt_ids['train']['input_ids_vi'],vocab_en['<pad>'],vocab_vi['<pad>'])
train_dataloader = DataLoader(train_dataset,batch_size=8,collate_fn=train_dataset.collate_fn,shuffle=True)

In [67]:
wandb.init(project='my-simple-nmt')

In [68]:
def train(device,model,num_steps,train_dataloader,val_dataloader,optimizer,criterion,metric_fn,step=0,checkpoint_step=5000):
    outer_bar = tqdm(total=num_steps,desc='Training', position=0)
    outer_bar.n = step
    outer_bar.last_print_n = step 
    outer_bar.refresh()
    epoch = 1
    while True:
        inner_bar = tqdm(total=len(train_dataloader),desc=f'Epoch {epoch}',position=1)
        model.train()
        
        for X, y in train_dataloader:
            X,y = X.to(device).permute(1,0),y.to(device).permute(1,0)
            outputs = model(X,y)
            
            loss = criterion(outputs[1:].contiguous().view(-1,outputs.shape[-1]),y[1:].contiguous().view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            step += 1
            outer_bar.update(1)
            inner_bar.update(1)
            wandb.log({'trainning loss':loss.item()})
                        
            if step % checkpoint_step == 0:
                outer_bar.write(f'Training loss at step {step}: {loss.item():.6}')
                torch.save({"model":model.state_dict(),"optimizer":optimizer.state_dict(),"step":step},f'checkpoints/checkpoint-{step}.pth')            
                model.eval()
                val_loss = []
                with torch.no_grad():
                    for X, y in val_dataloader:
                        X,y = X.to(device).permute(1,0),y.to(device).permute(1,0)
                        outputs = model(X,y)

                        val_loss.append(criterion(outputs[1:].contiguous().view(-1,outputs.shape[-1]),y[1:].contiguous().view(-1)).item())
                
                model.train()
                outer_bar.write(f'Val loss at step {step}: {np.mean(val_loss):.6}')
                wandb.log({'val loss':np.mean(val_loss)})
            if step >= num_steps:
                return
                
        epoch += 1    

In [27]:
seed()

In [27]:
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [28]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = MTModel(vocab_en,vocab_vi).to(device)
print(f'The model has {count_params(model):,} trainable params')
optimizer = torch.optim.Adam(model.parameters(),0.0009)
criterion = nn.CrossEntropyLoss(ignore_index=vocab_vi[PAD])
num_steps = 200000
metric_fn = lambda y_hats,ys: [bleu(y_hat,y) for y_hat,y in zip(y_hats,ys)]

The model has 12,194,812 trainable params


In [72]:
train(device,model,num_steps,train_dataloader,test_dataloader,optimizer,criterion,metric_fn)

Training:   0%|          | 0/200000 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/12030 [00:00<?, ?it/s]

Training loss at step 5000: 4.67544
Val loss at step 5000: 4.4716
Training loss at step 10000: 4.29089
Val loss at step 10000: 4.29298


Epoch 2:   0%|          | 0/12030 [00:00<?, ?it/s]

Training loss at step 15000: 4.74001
Val loss at step 15000: 4.25563
Training loss at step 20000: 4.7032
Val loss at step 20000: 4.11833


Epoch 3:   0%|          | 0/12030 [00:00<?, ?it/s]

Training loss at step 25000: 4.27266
Val loss at step 25000: 4.12298
Training loss at step 30000: 4.25208
Val loss at step 30000: 4.08961
Training loss at step 35000: 4.24699
Val loss at step 35000: 4.08566


Epoch 4:   0%|          | 0/12030 [00:00<?, ?it/s]

Training loss at step 40000: 3.82618
Val loss at step 40000: 4.1183
Training loss at step 45000: 4.2499
Val loss at step 45000: 4.00975


Epoch 5:   0%|          | 0/12030 [00:00<?, ?it/s]

Training loss at step 50000: 4.71677
Val loss at step 50000: 4.08016
Training loss at step 55000: 3.66346
Val loss at step 55000: 4.14163
Training loss at step 60000: 5.24408
Val loss at step 60000: 4.14183


Epoch 6:   0%|          | 0/12030 [00:00<?, ?it/s]

Training loss at step 65000: 3.09483
Val loss at step 65000: 4.04568
Training loss at step 70000: 3.98008
Val loss at step 70000: 4.00273


KeyboardInterrupt: 

In [73]:
wandb.finish()

VBox(children=(Label(value='0.007 MB of 0.007 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
trainning loss,▇██▇▅▃▄▄▄▂▃▅▃▅▃▂▂▇▂▄▅▆▅▄▃▃▄▃▃▅▁▂▁▃▃▆▄▁▃▂
val loss,█▅▅▃▃▂▂▃▁▂▃▃▂▁

0,1
trainning loss,3.95042
val loss,4.00273


In [45]:
for i,(X,y) in enumerate(test_dataloader):
    print(model(X.to(device).permute(1,0),y.to(device).permute(1,0)).argmax(-1).permute(1,0))
    print(y)
    if i == 1:
        break

tensor([[   0, 1865, 4182,  795, 2865,   10, 4182, 2644, 3422,    3, 3949, 2163,
         4824, 2961, 4261, 4036, 3901, 3901, 1322, 4378, 4182, 3865, 1597,  303,
            2],
        [   0, 4182, 4770, 3413, 4287, 1591, 4458, 4824, 2961, 4182,   12,    2,
         4182,   12,    2,   12,    2,   12,   12,   12,   12,   12,    2, 4182,
           12]], device='cuda:1')
tensor([[   1, 1865, 4182,  795, 2865,   10, 4182, 2644, 3422,    3, 3949, 2163,
         4824, 2961, 4261, 2843, 4036, 3901, 1322, 4378, 4182, 3865, 1597,  303,
            2],
        [   1, 4182, 4770, 3413, 4287, 1591, 4458, 4824, 2961, 4182,   12,    2,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0]])
tensor([[   0, 1261, 4777,  894, 4182, 1931, 2640,   10, 4378,  372, 3830, 4182,
         3840,  623, 4285, 3167,  664, 4784,   12,    2,    2,   12,   12,    2,
            2,    2,    2,   12],
        [   0, 2827, 4383, 2575, 2680,  894, 2953,   10, 2544, 4182, 2329

In [29]:
checkpoint = torch.load('checkpoints/checkpoint-200000.pth')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
step = checkpoint['step']
model.eval();

In [None]:
from ipywidgets import interact


@interact
def translate(text="You are so bad"):
    input_ids = [vocab_en[SOS]] +  [vocab_en.get(word,vocab_en[UNK]) for word in text.lower().split()]+ [vocab_en[EOS]]
    display(' '.join(model.infer(torch.LongTensor(input_ids).to(device))[0][1:-1]))