In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.insert(0,'../')

In [101]:
from mllib.new_bert import *

# Bert Example run

In [102]:
import string
import random
from sklearn.model_selection import train_test_split


def random_examples(n_examples, n_largest):
    letters = string.ascii_lowercase
    train_x = []
    train_y = []
    for i in range(n_examples):
        l = random.choice(range(1,n_largest+1))
        x = ''.join(random.choice(letters) for i in range(l))
        y = ':'+ x[::-1]
        yield x,y
        
data =[[x,y] for x,y in random_examples(10000,10)]
raw_data={}
raw_data['train'], raw_data['test'] = train_test_split(data, test_size=0.33, random_state = 42)

In [161]:
from torchtext.experimental.datasets.translation import *
from torchtext.data.datasets_utils import _wrap_datasets
from torch.nn.utils.rnn import pad_sequence
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split


PAD_IDX = 1

def generate_batch(data_batch):
    src_batch, trg_batch = [], []
    for (src, trg) in data_batch:
        src_batch.append(src)
        trg_batch.append(trg)
    
    src_batch = pad_sequence(src_batch, padding_value = PAD_IDX)
    trg_batch = pad_sequence(trg_batch, padding_value = PAD_IDX)
    
    # get mask for them as well
    return src_batch, trg_batch

def ReversedString(data, tokenizer, split_=('train','test')):
    # split the data into three parts
    src_tokenizer, trg_tokenizer = tokenizer
    src_text_vocab_transform = sequential_transforms(src_tokenizer)
    trg_text_vocab_transform = sequential_transforms(trg_tokenizer)
    
    # build vocab only on training dataset
    src_vocab = build_vocab(data['train'], src_text_vocab_transform, index=0)
    trg_vocab = build_vocab(data['train'], trg_text_vocab_transform, index=1,)
    
    datasets = []
    
    for key in split_:
        src_text_transform = sequential_transforms(src_text_vocab_transform, 
                                                   vocab_func(trg_vocab), 
                                                   totensor(dtype=torch.long) )
        trg_text_transform = sequential_transforms(trg_text_vocab_transform, 
                                                   vocab_func(trg_vocab), 
                                                   totensor(dtype=torch.long) )
        
        
        datasets.append(TranslationDataset(data[key], (src_vocab, trg_vocab), (src_text_transform, trg_text_transform)))
        
    return _wrap_datasets(tuple(datasets), split_)

# Here trick is to tie up the vocabulary between src and trg to make learning faster

In [162]:
#tokenizer = get_tokenizer(tokenizer=None), get_tokenizer(tokenizer=None) # split tokenizer
tokenizer = list, list

ds = ReversedString(data = raw_data, tokenizer=tokenizer,split_=('train','test'))



100%|██████████| 6700/6700 [00:00<00:00, 428963.64lines/s]


100%|██████████| 6700/6700 [00:00<00:00, 362815.01lines/s]


In [163]:
ds[0][0]

(tensor([20,  5, 21, 10]), tensor([ 2, 10, 21,  5, 20]))

In [182]:
import torch
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.optim.sgd import SGD

from warmup_scheduler import GradualWarmupScheduler

class LitTransformer(pl.LightningModule):
    def __init__(self, learning_rate=0.001, batch_size=4, num_workers=0):
        super().__init__()
        self.learning_rate=learning_rate
        self.batch_size = batch_size
        self.num_workers=num_workers
        
        self.loss_crit = LabelSmoothingLoss2(ignore_value = 1, label_smoothing=0.1)
        self.save_hyperparameters()

        
    def make_src_mask(self, src):
        src_mask = (src != PAD_IDX).unsqueeze(1).unsqueeze(2)
        # (N , 1, 1, src_len)
        return src_mask
    
    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N , 1, trg_len, trg_len)
        return trg_mask
    
    def forward(self, src, trg):
        
        # get mask for src
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        
        return self.model.forward(src, src_mask, trg, trg_mask)
        
        
    def prepare_data(self):
        data = [[x,y] for x, y in random_examples(10000,10)]
        self.raw_data={}
        self.raw_data['train'], self.raw_data['test'] = train_test_split(data, test_size=0.33, random_state = 42)
        
    
    def setup(self, stage = None):
        tokenizer = list, list
        reversed_train, reversed_test = ReversedString(data = raw_data, tokenizer=tokenizer)
        
        # save the vocab
        self.src_vocab, self.trg_vocab = reversed_train.get_vocab()
        
        # define the model based on trg vocab. Note: We don't use src_vocab here.
        self.model = make_model(len(self.trg_vocab), len(self.trg_vocab), 
                               N=4, d_model=128, d_ff=128, h=4, dropout=0.2)
        
        self.criterion = SimpleLossCompute(self.model.generator, self.loss_crit, None)

        # train / val split
        n = len(reversed_train)
        p = int(0.8*n)
        rerversed_train, reversed_val = random_split(reversed_train, [p, n-p])
        
        # asssign to use in dataloaders
        self.train_ds = reversed_train
        self.test_ds = reversed_test
        self.val_ds = reversed_val
        
    def configure_optimizers(self):
        optim = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
    
        # scheduler_warmup is chained with schduler_steplr
        scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)
        scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr)
    
        return [optim],[scheduler_warmup]
        
    
    def training_step(self, batch, batch_idx):
        src, trg = batch
        src = src.permute(1,0)
        trg = trg.permute(1,0)
        
        # pass through seq2seq model and get loss
        out =  self.forward(src,trg[:,:-1])
        loss = self.criterion(out, trg[:,1:])
        self.log('loss', loss)
        return {'loss': loss}
    
    def validation_step(self, batch, batch_idx):
        ret = self.training_step(batch, batch_idx)
        self.log('val_loss', ret['loss'])
        return {'val_loss': ret['loss']}
        
    def train_dataloader(self):
        dl = DataLoader(self.train_ds, self.batch_size,
                          collate_fn=generate_batch, num_workers=self.num_workers)
        return dl
    
    def val_dataloader(self):
        return DataLoader(self.val_ds, self.batch_size,
                          collate_fn=generate_batch,num_workers=self.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.test_ds, self.batch_size,
                          collate_fn=generate_batch,num_workers=self.num_workers)
        

In [183]:
class LogHistogramCallback(pl.Callback):
    def __init__(self, patience=25):
        self.patience = patience
        
    def on_after_backward(self, trainer, pl_module):
        if trainer.global_step % self.patience == 0:
            for k, v in pl_module.named_parameters():
                trainer.logger.experiment.add_histogram(tag=k, values=v.grad, global_step = trainer.global_step)

In [184]:
class ModelTestCallback(pl.Callback):
    def __init__(self, max_len=10, test ='puneet'):
        super().__init__()
        self.max_len = max_len
        self.test_sentence = test
    
    def on_fit_start(self, trainer, pl_module):
        # called when trainer setup is done.. model initiatlization has not happened yet
        self.transforms = pl_module.train_ds.transforms
        self.vocabs = pl_module.train_ds.get_vocab()
    
    def on_train_epoch_end(self, trainer, pl_module, outputs):
        # take a random sentence and convert
        # apply src transforms on the text
        # here output contains the dictionary coming from training_step

        self.trg_vocab = self.vocabs[1]
        src_tensor = self.transforms[0](self.test_sentence).unsqueeze(0)
        src_mask = pl_module.make_src_mask(src_tensor)       # N X 1 X 6 X 6
        
        # output tensor
        out = ":"        # initial target
        out_tensor = self.transforms[1](out) 
        with torch.no_grad():
            enc_src = pl_module.model.encode(src_tensor, src_mask)
            
        trg_indices = [2]
        for i in range(self.max_len):
            trg_tensor = torch.LongTensor(trg_indices).unsqueeze(0).to(device)
            trg_mask = pl_module.make_trg_mask(trg_tensor)
        
            with torch.no_grad():
                output = pl_module.model.decode(enc_src, src_mask, trg_tensor, trg_mask)
                output = pl_module.model.generator(output)

                pred_token = output.argmax(2)[:,-1].item()
                trg_indices.append(pred_token)
                
                if pred_token == PAD_IDX:
                    break

        trg_tokens = [self.trg_vocab.itos[i] for i in trg_indices]
        decode_string = ''.join(trg_tokens)
        print('decoded input : {} -> output: {} '.format(self.test_sentence, decode_string))
        trainer.logger.experiment.add_text('decodes', decode_string, trainer.current_epoch)
        
        return trg_tokens[1:]
        
        

# Run Training

In [185]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

logger = TensorBoardLogger('tb_logs', name='bert')
model = LitTransformer()

trainer = Trainer(fast_dev_run=False, progress_bar_refresh_rate=5, max_epochs=10,enable_pl_optimizer=False, 
                        callbacks=[
                            ModelTestCallback(test='puneet'), 
                            LogHistogramCallback(),
                            ModelCheckpoint(dirpath='.checkpoints/', monitor='val_loss')
                        ], logger=logger, auto_lr_find=True)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [186]:
trainer.fit(model)


100%|██████████| 6700/6700 [00:00<00:00, 385749.30lines/s]

100%|██████████| 6700/6700 [00:00<00:00, 371992.97lines/s]

  | Name      | Type                | Params
--------------------------------------------------
0 | loss_crit | LabelSmoothingLoss2 | 0     
1 | model     | EncoderDecoder      | 1.1 M 
--------------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.294     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

In [178]:
mt = ModelTestCallback()
mt.on_fit_start(trainer, trainer.model)
mt.on_train_epoch_end(trainer,trainer.model, outputs=None)

decoded input : puneet -> output: :teenupuppp 


['t', 'e', 'e', 'n', 'u', 'p', 'u', 'p', 'p', 'p']