# Simplistic T5 model with no fancy tricks

The previous example of Chinese-English machine translation has the following problems: 

<ul>
    <li>Dataset is trash</li>
    <li>Includes too many tricks (scheduler, parameter freezing, callback, metrics) that I cannot handle</li>
</ul>

Now write a T5 Chinese-English translator with better data and no fancy trick. 

In [16]:
import pandas as pd
from tokenizers import SentencePieceBPETokenizer
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    T5Model, 
    T5ForConditionalGeneration, 
    AdamW,
)
import pytorch_lightning as pl
import time
from datetime import datetime
import textwrap

device = torch.device(
    'cuda:0' if torch.cuda.is_available() else 'cpu'
)
print(f'device = {device}')

device = cuda:0


## Load data

The entire data is too large to load directly into memory. For now, only load the first `nLine` lines.  

Learn to handle big data with PyTorch dataloader if needed. 

In [2]:
%%time
enFile = open('./en-zh/UNv1.0.en-zh.en', 'r', encoding = 'utf-8')
zhFile = open('./en-zh/UNv1.0.en-zh.zh', 'r', encoding = 'utf-8')

nLine = 10000

dataMatrix = []

for i in range(nLine): 
    zhLine = zhFile.readline().strip()
    enLine = enFile.readline().strip()
    dataMatrix.append([zhLine, enLine])
    
df_UN = pd.DataFrame(dataMatrix, columns = ['zh', 'en']).sample(frac=1).reset_index(drop=True) # Shuffle the data
df_UN

# Notice: The run time of appending rows in DataFrame is notoriously long

Wall time: 19 ms


Unnamed: 0,zh,en
0,10. 会议上人们反复强调，发展中国家在采用空间技术方面的问题不是技术本身的问题，因为技术是...,It was stressed repeatedly during the meeting ...
1,冰冻面制图和监测工作经常进行，重点是斯瓦巴德周围及冰层边缘优先区域。,Ice mapping and monitoring have been performed...
2,该《法案》的基本目的是把毛利族人的土地保留在与有关土地相关的传统族裔后代的手中。,The general theme of the Act was the retention...
3,使轨道实验室准备就绪，以供使用。,● Readying the orbital laboratory for use.
4,A．国际减灾十年的作用和,A. The role of the International Decade for Na...
...,...,...
9995,现阶段值得注意的是，对未来的分析受到超出仅仅对这一问题进行数学描述的诸多因素的影响，如：,It is worth noting at this stage that the anal...
9996,"然而,委员会认为关于种族歧视问题的教育可以作为关于普遍性歧视问题的教育的一部分,其中包括因其...","Nevertheless, the Committee believes that teac..."
9997,(b) 1987－1989年：在军方采取打击恐怖主义活动行动的情况下恐怖主义暴力活动增加，政...,(b) 1987-1989: Terrorist violence increased in...
9998,1987年5月21日,Sixth report 21 May 1987


## Tokenization and PyTorch `Dataset`

We first instantiate SentencePiece tokenizers and train them on our data. 

<b style="color:red;">Warning!</b> For some reason I can no longer find the API for `SentencePieceBPETokenizer`. Did huggingface deprecate the old version tokenizer? 

In [3]:
# Need to store all texts in file before training tokenizer
pathAllZh = './en-zh/allZh.txt'
pathAllEn = './en-zh/allEn.txt'

zhTextsUN = df_UN['zh'].tolist()
enTextsUN = df_UN['en'].tolist()

with open(pathAllZh, 'w', encoding = 'utf-8') as file:
    for line in zhTextsUN:
        file.write(line + '\n')
    file.close()
    
with open(pathAllEn, 'w', encoding = 'utf-8') as file: 
    for line in enTextsUN:
        file.write(line + '\n')
    file.close()

In [4]:
# Instantiate and train tokenizers 
zhTokenizer = SentencePieceBPETokenizer()
zhTokenizer.train([pathAllZh], vocab_size = 500000, special_tokens = ['<s>', '<pad>', '</s>', '<unk>', '<mask>'])

enTokenizer = SentencePieceBPETokenizer()
enTokenizer.train([pathAllEn], vocab_size = 500000, special_tokens = ['<s>', '<pad>', '</s>', '<unk>', '<mask>'])

For more details about tokenizer, see `Bo-Eng-Machine-Transation/warm_up_Chinese_English/01_practice_ch_en_tranlation.ipynb`. 

Now define PyTorch `DataLoader`. 

In [5]:
class MyDataset(Dataset): 
    def __init__(self, zhTexts, enTexts, zhTokenizer, enTokenizer, zhMaxLen, enMaxLen): 
        super().__init__()
        self.zhTexts = zhTexts 
        self.enTexts = enTexts
        self.zhTokenizer = zhTokenizer
        self.enTokenizer = enTokenizer 
        
        # Enable padding and truncation
        self.zhTokenizer.enable_padding(length = zhMaxLen)
        self.zhTokenizer.enable_truncation(max_length = zhMaxLen)
        self.enTokenizer.enable_padding(length = enMaxLen)
        self.enTokenizer.enable_truncation(max_length = enMaxLen)
        
    '''
    Return the size of dataset
    '''
    def __len__(self):
        return len(self.zhTexts)
    
    '''
    -- The routine for querying one data entry 
    -- The index of must be specified as an argument
    -- Return a dictionary 
    '''
    def __getitem__(self, idx): 
        # Apply tokenizer 
        zhOutputs = self.zhTokenizer.encode(self.zhTexts[idx])
        enOutputs = self.enTokenizer.encode(self.enTexts[idx])
        
        # Get numerical tokens
        zhEncoding = zhOutputs.ids
        enEncoding = enOutputs.ids
        
        # Get attention mask 
        zhMask = zhOutputs.attention_mask
        enMask = enOutputs.attention_mask
        
        return {
            'source_ids': torch.tensor(zhEncoding), 
            'source_mask': torch.tensor(zhMask), 
            'target_ids': torch.tensor(enEncoding), 
            'target_mask': torch.tensor(enMask)
        }

## Define model class

Use Pytorch-lighning

In [6]:
class T5FineTuner(pl.LightningModule): 
    ''' Part 1: Define the architecture of model in init '''
    def __init__(self, hparams):
        super(T5FineTuner, self).__init__()
        self.model = T5ForConditionalGeneration.from_pretrained(
            hparams['pretrainedModelName'], 
            return_dict = True    # I set return_dict true so that outputs  are presented as dictionaries
        )
        self.zhTokenizer = hparams['zhTokenizer']
        self.enTokenizer = hparams['enTokenizer']
        self.hparams = hparams
        
        
    ''' Part 2: Define the forward propagation '''
    def forward(self, input_ids, attention_mask = None, decoder_input_ids = None, decoder_attention_mask = None, labels = None):  
        return self.model(
            input_ids, 
            attention_mask = attention_mask, 
            decoder_input_ids = decoder_input_ids, 
            decoder_attention_mask = decoder_attention_mask, 
            labels = labels
        )
    
    
    ''' Part 3: Configure optimizer and scheduler '''
    def configure_optimizers(self): 
        optimizer = AdamW(self.parameters(), lr = self.hparams['learning_rate'])
        return optimizer
    
    
    ''' Part 4.1: Training logic '''
    def training_step(self, batch, batch_idx):         
        loss = self._step(batch)
        self.log('train_loss', loss)
        return loss 
    
    
    def _step(self, batch): 
        labels = batch['target_ids'] 
        # labels[labels[:, ] == 0] = -100    # Change the pad id from 0 to -100, but I do not know why the example chooses to do so. I will comment it out for now
        
        outputs = self(
            input_ids = batch['source_ids'], 
            attention_mask = batch['source_mask'], 
            labels = labels, 
            decoder_attention_mask = batch['target_mask']
        )
        
        return outputs.loss

    
    ''' Part 4.2: Validation logic '''
    def validation_step(self, batch, batch_idx):  
        loss = self._step(batch)
        self.log('val_loss', loss)
        
        
    ''' Part 4.3: Test logic '''
    def test_step(self, batch, batch_idx): 
        loss = self._step(batch)
        self.log('test_loss', loss)
    
    
    ''' Part 5: Data loaders '''
    def _get_dataloader(self, start_idx, end_idx): 
        dataset = MyDataset(
            zhTexts = zhTextsUN[start_idx:end_idx], 
            enTexts = enTextsUN[start_idx:end_idx], 
            zhTokenizer = self.hparams['zhTokenizer'], 
            enTokenizer = self.hparams['enTokenizer'], 
            zhMaxLen = self.hparams['max_input_len'], 
            enMaxLen = self.hparams['max_output_len']
        )
        
        return DataLoader(dataset, batch_size = hparams['batch_size'])
    
    
    def train_dataloader(self): 
        start_idx = 0
        end_idx = int(self.hparams['train_percentage'] * len(zhTextsUN))
        return self._get_dataloader(start_idx, end_idx)
        
    
    def val_dataloader(self): 
        start_idx = int(self.hparams['train_percentage'] * len(zhTextsUN))
        end_idx = int((self.hparams['train_percentage'] + self.hparams['val_percentage']) * len(zhTextsUN))
        return self._get_dataloader(start_idx, end_idx)
    
    
    def test_dataloader(self): 
        start_idx = int((self.hparams['train_percentage'] + self.hparams['val_percentage']) * len(zhTextsUN))
        end_idx = len(zhTextsUN)
        return self._get_dataloader(start_idx, end_idx)

In [7]:
hparams = {
    'zhTokenizer': zhTokenizer,
    'enTokenizer': enTokenizer,
    'pretrainedModelName': 't5-small', 
    'train_percentage': 0.85, 
    'val_percentage': 0.13, 
    'learning_rate': 3e-4, 
    'max_input_len': 100, 
    'max_output_len': 100, 
    'batch_size': 8, 
    'num_train_epochs': 2, 
    'num_gpu': 1
}

## Training and testing

In [8]:
torch.cuda.empty_cache()

train_params = dict(
    gpus = hparams['num_gpu'], 
    max_epochs = hparams['num_train_epochs'], 
    progress_bar_refresh_rate = 20, 
)

model = T5FineTuner(hparams)

trainer = pl.Trainer(**train_params)

trainer.fit(model)

# Save model for later use
now = datetime.now()
trainer.save_checkpoint('t5simple_' + now.strftime("%Y-%d-%m-%Y--%H=%M=%S") + '.ckpt')

trainer.test()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 60 M  


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [12]:
# Load a previously saved model

torch.cuda.empty_cache()

modelLoaded = T5FineTuner.load_from_checkpoint(checkpoint_path='t5simple_2020-06-12-2020--14=17=05.ckpt').to(device)

In [25]:
# Testting without help of `pl.LightningModule`
start_idx = int((hparams['train_percentage'] + hparams['val_percentage']) * len(zhTextsUN))
end_idx = len(zhTextsUN)

testset = MyDataset(
    zhTexts = zhTextsUN[start_idx:end_idx], 
    enTexts = enTextsUN[start_idx:end_idx], 
    zhTokenizer = hparams['zhTokenizer'], 
    enTokenizer = hparams['enTokenizer'], 
    zhMaxLen = hparams['max_input_len'], 
    enMaxLen = hparams['max_output_len']
)

test_dataloader = DataLoader(testset, batch_size = hparams['batch_size'])
testit = iter(test_dataloader)

# Take one batch from testset 
batch = next(testit)

# Generate target ids
outs = model.model.generate(
    batch['source_ids'].cuda(), 
    attention_mask = batch['source_mask'].cuda(), 
    use_cache = True, 
    decoder_attention_mask = batch['target_mask'].cuda(), 
    max_length = hparams['max_output_len'], 
    num_beams = 2, 
    repetition_penalty = 2.5, 
    length_penalty = 1.0, 
    early_stopping = True
)

pred_texts = [enTokenizer.decode(ids) for ids in outs.tolist()]
source_texts = [zhTokenizer.decode(ids) for ids in batch['source_ids'].tolist()]
target_texts = [enTokenizer.decode(ids) for ids in batch['target_ids'].tolist()]

for i in range(len(pred_texts)): 
    lines = textwrap.wrap("Chinese Text:\n%s\n" % source_texts[i], width=100)
    print("\n".join(lines))
    print("\nActual translation: %s" % target_texts[i])
    print("\nPredicted translation: %s" % pred_texts[i])
    print('=' * 50 + '\n')

Chinese Text: 116. ETS-VI是一颗2级卫星,装有一台二元推进远地点发动机,并新增了以下装置:一台用于控制南北轨道的离子发动机;高精度态控制系统;轻质结构体;轻质太阳电;卫星运载舱
部分有防高温和控制温度的系统,以便确保运转良好。

Actual translation: ETS-VI is a satellite in the 2-tonne class satellite with a bipropellant apogee engine and the following additional features: an ion engine for controlling the north-south orbit; a high-precision attitude control system; a light structural body; a light solar battery paddle; and a system of high heat prevention and heat control in the satellite bus part to ensure an excellent performance.

Predicted translation: The Committee of the State party in the use of the% with the, to be a that space. was and the Space for are on its by the from of the+ as is the protection of the United Nations an rights or it development at the report of theate were theS been),; have hadH handicraf provided6 and theU members of the situation of the satellite of the national women. data
 In the( of the indigenous$" C information3 per cent of the

Chinese Text: 计额比参与

In [26]:
# %tensorboard --logdir lightning_logs/