In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In [None]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

import pandas as pd
from sklearn import metrics
from sklearn.model_selection import train_test_split
import transformers
from transformers import AdamW, T5Tokenizer, T5ForConditionalGeneration

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

import warnings
warnings.filterwarnings("ignore")

In [None]:
class config:
    
    MAX_LEN_I = 448
    MAX_LEN_O = 224
    TRAIN_BATCH_SIZE = 16
    VALID_BATCH_SIZE = 8
    EPOCHS = 13
    MODEL_PATH = "T5-base-TPU.pth"
    TRAINING_FILE = '../input/table-to-text-generation-dataset-google-totto/totto_data/tablesWithTag.csv'
    TOKENIZER = transformers.T5Tokenizer.from_pretrained('t5-base',do_lower_case =True)

In [None]:
special_tokens_dict = {'pad_token': '<pad>', 'bos_token': '<bos>', 'eos_token': '<eos>', 
                       'additional_special_tokens': ['<PAGESTART>', '<PAGEEND>', '<SECTIONSTART>', '<SECTIONEND>',
                                                     '<TABLESTART>','<TABLEEND>','<CELLSTART>','<CELLEND>','<COLHEADERSTART>',
                                                     '<COLHEADEREND>','<ROWHEADERSTART>','<ROWHEADEREND>']}

num_added_toks = config.TOKENIZER.add_special_tokens(special_tokens_dict)

In [None]:
df=pd.read_csv(config.TRAINING_FILE)
train_df, val_df=train_test_split(df, test_size=0.1)
train_df=train_df.reset_index(drop=True)
val_df=val_df.reset_index(drop=True)

In [None]:
class tottoDataset(Dataset):
  def __init__(self,df,tokenizer):
    self.sentence=df['sentence']
    self.table=df['table']
    self.tokenizer=tokenizer

  def __len__(self):
    return len(self.sentence)
  
  def __getitem__(self,idx):
    inp=(self.table[idx]+'</s>').replace("<page_title>", "<PAGESTART>").replace("</page_title>", "<PAGEEND>") \
                                    .replace("<section_title>", "<SECTIONSTART>").replace("</section_title>", "<SECTIONEND>") \
                                    .replace("<table>", "<TABLESTART>").replace("</table>", "<TABLEEND>") \
                                    .replace("<cell>", "<CELLSTART>").replace("</cell>", "<CELLEND>") \
                                    .replace("<col_header>", "<COLHEADERSTART>").replace("</col_header>", "<COLHEADEREND>") \
                                    .replace("<row_header>", "<ROWHEADERSTART>").replace("</row_header>", "<ROWHEADEREND>")
    out=self.sentence[idx]+'</s>'
    inp_tokens=self.tokenizer.encode_plus(inp, padding="max_length", max_length=config.MAX_LEN_I, truncation=True)
    out_tokens=self.tokenizer.encode_plus(out, padding="max_length", max_length=config.MAX_LEN_O, truncation=True)
    inp_id=inp_tokens.input_ids
    out_id=out_tokens.input_ids
    inp_mask=inp_tokens.attention_mask
    out_mask=out_tokens.attention_mask
    labels=out_tokens.input_ids.copy()
    labels=[-100  if x==self.tokenizer.pad_token_id else x for x in labels]

    return {
        "table_text":inp,
        "sentence":out,
        "input_ids":torch.tensor(inp_id, dtype=torch.long),
        "input_attention_mask":torch.tensor(inp_mask, dtype=torch.long),
        "decoder_input_ids":torch.tensor(out_id, dtype=torch.long),
        "decoder_attention_mask":torch.tensor(out_mask, dtype=torch.long),
        "labels":torch.tensor(labels, dtype=torch.long)
    }

In [None]:
!export XLA_USE_BF16=1

In [None]:
def train_fn(dataloader, model, optimizer, device, scheduler, epoch, num_epoch, num_steps):
    model.train()
    for i, batch in enumerate(dataloader):
        input_ids=batch['input_ids'].to(device)
        labels=batch['labels'].to(device)
        
        outputs=model(input_ids=input_ids,
                     labels=labels)
        
        loss=outputs.loss
        loss.backward()
        
        xm.optimizer_step(optimizer)
        
        if scheduler is not None:
            scheduler.step()
            
        if(i%800==0):
            print(f"Epoch: {epoch+1}/{num_epoch} Batch {i+1}/{num_steps} Loss:{loss.item()} Time Taken:{time.asctime()}\n")

In [None]:
def eval_fn(dataloader, model, device, size):
    model.eval()
    loss=0
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            input_ids=batch["input_ids"].to(device)
            labels=batch["labels"].to(device)
            
            outputs=model(input_ids=input_ids,
                         labels=labels)
            
            loss+=outputs.loss.item()
    print(f"Val Loss:{loss/size}")
    return loss

In [None]:
model=T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)
model.encoder.resize_token_embeddings(len(config.TOKENIZER))
model.decoder.resize_token_embeddings(len(config.TOKENIZER))

In [None]:
def _run():
    train_dataset=tottoDataset(train_df, config.TOKENIZER)
    valid_dataset=tottoDataset(val_df, config.TOKENIZER)
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=4
    )
    
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
          valid_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=False)

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=4
    )
    
    device = xm.xla_device()
    model.to(device)
    
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    ]
    
    lr = 0.4 * 1e-5 * xm.xrt_world_size()

    num_train_steps = int(len(train_dataset) / config.TRAIN_BATCH_SIZE / xm.xrt_world_size() * config.EPOCHS)
    optimizer = AdamW(optimizer_parameters, lr=lr)
    
    best_loss=1e9
    
    for epoch in range(config.EPOCHS):
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        train_fn(para_loader.per_device_loader(device), model, optimizer, device, scheduler=None, epoch=epoch, num_epoch=config.EPOCHS, num_steps=num_train_steps)
        
        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        loss = eval_fn(para_loader.per_device_loader(device), model, device, len(valid_dataset))
        
        if loss<best_loss:
            xm.save(model.state_dict(),config.MODEL_PATH)
            best_loss=loss

In [None]:
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = _run()

FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')