In [None]:
import pandas as pd

In [None]:
df=pd.read_csv("../input/table-to-text-generation-dataset-google-totto/totto_data/tablesWithTag.csv")
df=df[:45000]
df=df.sample(frac=1).reset_index()

In [None]:
MAXLENI=400
MAXLENO=200

In [None]:
from transformers import T5Tokenizer
from transformers import T5ForConditionalGeneration
from transformers import AdamW, WarmUp, get_linear_schedule_with_warmup
from tqdm.notebook import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import os
import time
import copy
import numpy
import matplotlib.pyplot as plt

In [None]:
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model=T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)

In [None]:
special_tokens_dict = {'pad_token': '<pad>', 'bos_token': '<bos>', 'eos_token': '<eos>', 
                       'additional_special_tokens': ['<PAGESTART>', '<PAGEEND>', '<SECTIONSTART>', '<SECTIONEND>',
                                                     '<TABLESTART>','<TABLEEND>','<CELLSTART>','<CELLEND>','<COLHEADERSTART>',
                                                     '<COLHEADEREND>','<ROWHEADERSTART>','<ROWHEADEREND>']}

num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

print('We have added', num_added_toks, 'tokens')
model.encoder.resize_token_embeddings(len(tokenizer))
model.decoder.resize_token_embeddings(len(tokenizer))

In [None]:
class tottodataset(Dataset):
  def __init__(self,df,tokenizer):
    self.sentence=df['sentence']
    self.table=df['table']
    self.tokenizer=tokenizer

  def __len__(self):
    return len(self.sentence)
  
  def __getitem__(self,idx):
    inp=(self.table[idx]+'</s>').replace("<page_title>", "<PAGESTART>").replace("</page_title>", "<PAGEEND>") \
                                    .replace("<section_title>", "<SECTIONSTART>").replace("</section_title>", "<SECTIONEND>") \
                                    .replace("<table>", "<TABLESTART>").replace("</table>", "<TABLEEND>") \
                                    .replace("<cell>", "<CELLSTART>").replace("</cell>", "<CELLEND>") \
                                    .replace("<col_header>", "<COLHEADERSTART>").replace("</col_header>", "<COLHEADEREND>") \
                                    .replace("<row_header>", "<ROWHEADERSTART>").replace("</row_header>", "<ROWHEADEREND>")
    out=self.sentence[idx]+'</s>'
    inp_tokens=self.tokenizer.encode_plus(inp, padding="max_length", max_length=MAXLENI, truncation=True)
    out_tokens=self.tokenizer.encode_plus(out, padding="max_length", max_length=MAXLENO, truncation=True)
    inp_id=inp_tokens.input_ids
    out_id=out_tokens.input_ids
    inp_mask=inp_tokens.attention_mask
    out_mask=out_tokens.attention_mask
    labels=out_tokens.input_ids.copy()
    labels=[-100  if x==self.tokenizer.pad_token_id else x for x in labels]

    return {
        "table_text":inp,
        "sentence":out,
        "input_ids":torch.tensor(inp_id, dtype=torch.long),
        "input_attention_mask":torch.tensor(inp_mask, dtype=torch.long),
        "decoder_input_ids":torch.tensor(out_id, dtype=torch.long),
        "decoder_attention_mask":torch.tensor(out_mask, dtype=torch.long),
        "labels":torch.tensor(labels, dtype=torch.long)
    }



In [None]:
train_df=df[:41000]
val_df=df[41000:].reset_index()

In [None]:
train_dataset=tottodataset(train_df,tokenizer)
val_dataset=tottodataset(val_df,tokenizer)

train_dataloader=DataLoader(train_dataset,
                            batch_size=4,
                            num_workers=2,
                            shuffle=False)

val_dataloader=DataLoader(val_dataset,
                            batch_size=4,
                            num_workers=2,
                            shuffle=False)

In [None]:
dataloaders={'train':train_dataloader, 'eval':val_dataloader}

In [None]:
dataset_sizes={'train':len(train_dataset), 'eval':len(val_dataset)}

In [None]:
def train_fn(model,optimizer,scheduler,num_epochs=5):
    since=time.time()
    best_wts=copy.deepcopy(model.state_dict())
    best_loss=float('inf')
    for epoch in range(num_epochs):
        print(f'Epoch:{epoch}/{num_epochs}')
        print('-'*10)
        
        for mode in ['train','eval']:
            if mode=='train':
                model.train()
            elif mode=='eval':
                model.eval()
            
            running_loss=0.0
            
            for data in tqdm(dataloaders[mode]):
                input_ids = data["input_ids"].to(device, dtype=torch.long)
                labels = data['labels'].to(device, dtype=torch.long)
            
                optimizer.zero_grad()
                with torch.set_grad_enabled(mode=='train'):
                    outputs=model(
                                input_ids =input_ids,
                                labels = labels
                            )
                    loss, logits=outputs[:2]
                    
                    if mode=='train':
                        loss.backward()
                        optimizer.step()                    
                    running_loss += loss.item()

            if mode == 'train':
                scheduler.step()
                
            epoch_loss=running_loss/dataset_sizes[mode]
            
            print('{} Loss: {:.4f} '.format(
                mode, epoch_loss))
            
            if mode=='eval' and epoch_loss<best_loss:
                best_wts=copy.deepcopy(model.state_dict())
                best_loss=epoch_loss
            
            print()

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print('Best val loss: {:4f}'.format(best_loss))
    
        model.load_state_dict(best_wts)
    return model
        

In [None]:
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
model=model.to(device)
optimizer=AdamW(model.parameters(), lr=1e-4)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)

In [None]:
!nvidia-smi

In [None]:
history=train_fn(model,optimizer,scheduler,num_epochs=7)

In [None]:
torch.save(model, "T5Epoch:7")