In [None]:
!pip install transformers
!pip install datasets
!pip install sentencepiece
!pip install ipywidgets
#!pip install wandb

Import libraries

In [None]:
import json
from datasets import load_dataset, load_metric, load_from_disk
import pandas as pd
from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer
from transformers import Adafactor
import torch
from torch import nn
import torch.nn.functional as F
import wandb

Dataclass for preprocessing and creating train and test csv files

Huggingface models and tokenizers

In [None]:
max_length= 384
batch_size = 2  # 4
dim = 768 # change BERT hidden size to change

pretrained_model = 't5-base'
#pretrained_model = 'google/t5-v1_1-base'

tokenizer = T5Tokenizer.from_pretrained(pretrained_model)
qr_model = T5ForConditionalGeneration.from_pretrained(pretrained_model)
rc_model = T5ForConditionalGeneration.from_pretrained(pretrained_model)

Tokenize dataset

In [None]:
dataset = load_from_disk('/storage/qrecc/processed')
dataset.set_format(
    type='torch', columns=['ctx_input_ids', 'rwrt_input_ids', 'psg_input_ids',
                           'ans_input_ids', 'ctx_attention_mask', 'rwrt_attention_mask',
                           'psg_attention_mask'],)

Train and test dataloaders

In [None]:
train_loader = torch.utils.data.DataLoader(dataset['train'], batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(dataset['test'], batch_size=batch_size)

Forward function and function for rolling tensors

In [None]:
def roll_by_gather(mat, dim, shifts:torch.LongTensor):
    # assumes 2D array
    n_rows, n_cols = mat.shape
    
    if dim == 0:
        #print(mat)
        arange1 = torch.arange(n_rows).view((n_rows, 1)).repeat((1, n_cols)).to(device)
        #print(arange1)
        arange2 = (arange1 - shifts) % n_rows
        #print(arange2)
        return torch.gather(mat, 0, arange2)
    elif dim == 1:
        arange1 = torch.arange(n_cols).view((1,n_cols)).repeat((n_rows,1)).to(device)
        #print(arange1)
        arange2 = (arange1 - shifts) % n_cols
        #print(arange2)
        return torch.gather(mat, 1, arange2)


def forward(batch):

    # context + question input
    ctx_input = batch['ctx_input_ids'].to(device) # QR input
    ctx_attention = batch['ctx_attention_mask'].to(device)

    # gold rewrite input for qr loss
    rwrt_input = batch['rwrt_input_ids']
    # # tokens with indices set to -100 are ignored (masked)
    rwrt_input[rwrt_input == tokenizer.pad_token_id] = -100 
    rwrt_input = rwrt_input.to(device)
    rwrt_attention = batch['rwrt_attention_mask'].to(device)

    # passage input
    psg_input = batch['psg_input_ids'].to(device)
    # need to add sep token at the begining
    # roll by 1 and add column of 1s
    psg_input = torch.roll(psg_input, 1, 1)
    psg_input[:, 0] = 1

    # answer input
    ans_input = batch['ans_input_ids']
    # # tokens with indices set to -100 are ignored (masked)
    ans_input[ans_input == tokenizer.pad_token_id] = -100 
    ans_input = ans_input.to(device)

    # feed context+question input and rewrite label to qr model
    qr_output = qr_model(input_ids=ctx_input, attention_mask=ctx_attention, labels=rwrt_input)

    # logits to be sampled from
    logits = qr_output.logits

    # qr loss
    qr_loss = qr_output.loss

    # gumbel softmax on the logits
    # slice upto actual vocabulary sizegumbel_softmax
    gumbel_output = F.gumbel_softmax(logits, tau=1, hard=True)[..., :act_vocab_size]
    # print(gumbel_output.shape) # 2, 384, 32100
    
    norm_ycord = torch.linspace(-1, 1, act_vocab_size).to(device)
    norm_xcord = torch.linspace(-1, 1, dim).to(device)
    
    embeddings = rc_model.get_input_embeddings().weight[:act_vocab_size, :] # 32100, 768
    embeddings = embeddings.view(1, 1, act_vocab_size, -1) # 1, 1, 32100, 768
    
    embeddings = embeddings.repeat(gumbel_output.shape[0], 1, 1, 1) # 2, 1, 32100, 768

    for i in range(max_length):
      gumbeli = gumbel_output[:, i, :]
      gumbeli = gumbeli.view(gumbeli.shape[0], 1, -1)  # grid
      
      gumbeli = torch.mul(gumbeli, norm_ycord)
      print(gumbeli.shape)
        
      break
      

    # use to one hot samples (straight through trick) to get vocab ids using dummy vocab
    rc_input = gumbel_output@dummy_vocab
    rc_input = rc_input.to(device)

    del gumbel_output, qr_output, logits, ctx_input, ctx_attention, rwrt_input

    # mask rc input ids with attention mask
    rc_input = torch.mul(rc_input, rwrt_attention)
    # flip the rewrite attention mask, replace 1s with 0s and vice versa
    # now the 1s represent the 'free space' in the rc_input tensor to fit the passages
    flipped_rwrt_mask = torch.fliplr(rwrt_attention)
    flipped_mask = flipped_rwrt_mask.clone()
    flipped_mask[flipped_rwrt_mask == 0] = 1
    flipped_mask[flipped_rwrt_mask == 1] = 0
    # mask passage to extract ids that can fit in the rc_input tensor
    extr_psg = torch.mul(flipped_mask, psg_input)
    # find the shifts for each row of extr_psg
    # this is equal to the number of 1s in each row of rwrt_attention
    # reshape to column vector as required by the custom gather function
    shifts = (rwrt_attention==1).sum(dim=1).reshape(-1, 1) 
    # roll each row by the amount occupied by rc_input in that row
    trunc_psg = roll_by_gather(extr_psg, 1, shifts)
    # add to get rwrt + psg as rc_input
    rc_input = torch.add(rc_input, trunc_psg)
    # create attention mask
    rc_attention = rc_input.clone()
    rc_attention[rc_input != 0] = 1

    del flipped_rwrt_mask, flipped_mask, extr_psg, shifts, trunc_psg, psg_input
    
    rc_loss = rc_model(input_ids=rc_input, attention_mask=rc_attention, labels=ans_input).loss

    #del ans_input, rc_input, rc_attention

    return qr_loss, rc_loss



In [None]:
device = torch.device('cuda')
qr_model.to(device)
#rc_model.to(device)

# load finetuned models
qr_model.load_state_dict(torch.load('/storage/qrecc/models/qr/qr_gen3.pth'))
#rc_model.load_state_dict(torch.load('/storage/qrecc/models/rc/rc_gen3.pth'))

qr_model.train()
#rc_model.train()

num_epochs = 2

# optimizer
optim = optimizer = Adafactor(
    #list(qr_model.parameters())+list(rc_model.parameters()),
    qr_model.parameters(),
    lr=1e-5,
    eps=(1e-30, 1e-3),
    clip_threshold=1.0,
    decay_rate=-0.8,
    beta1=None,
    weight_decay=0.0,
    relative_step=False,
    scale_parameter=False,
    warmup_init=False
)

config = {
  "learning_rate": 1e-5,
  "epochs": 2,
  "batch_size": 4,
  "weight_decay": 0.0,
  "temperature": 1
}

#wandb.init(project="e2e-gradients", entity="suicune", reinit=True, config=config)

#wandb.watch(qr_model, log="all", log_freq=100)

# vocabulary size
act_vocab_size = len(tokenizer.get_vocab())
# dummy vocab to get vocab ids after gumbel softmax
dummy_vocab = torch.arange(act_vocab_size).long()

Validation and Training loop

In [None]:

def valid_loss():
  
  qr_epoch_loss = 0
  rc_epoch_loss = 0
  idx = 0

  for batch in test_loader:

    qr_loss, rc_loss = forward(batch)

    qr_epoch_loss += qr_loss.item()
    rc_epoch_loss += rc_loss.item()

    #del ans_input, rc_input, rc_attention
    del qr_loss, rc_loss 

    idx += 1

  print('Valid loss : {}, {}'.format(qr_epoch_loss/idx, rc_epoch_loss/idx))


for epoch in range(1, num_epochs+1):
    
  qr_epoch_loss = 0
  rc_epoch_loss = 0

  idx = 1

  for batch in train_loader:

    qr_loss, rc_loss = forward(batch)
    #total_loss = sum([qr_loss, rc_loss])
    qr_epoch_loss += qr_loss.item()
    rc_epoch_loss += rc_loss.item()

    #total_loss.backward()
    rc_loss.backward()

    if idx % 100 == 0:
      print('epoch {}, batch {}'.format(epoch, idx))

      #grad_dict = {}
      for name, param in rc_model.named_parameters():
        if param.requires_grad:
          print(name, param.grad)

      #wandb.log(grad_dict)

    optim.step()
    optim.zero_grad()
    
    break
  break
      
    #del ans_input, rc_input, rc_attention
    #del qr_loss, rc_loss, total_loss


    #idx += 1


  print('Train loss : {}, {}'.format(qr_epoch_loss/len(train_loader), rc_epoch_loss/len(train_loader)))
  qr_model.eval()
  rc_model.eval()
  valid_loss()
  print('\n')
  qr_model.train()
  rc_model.train()
  torch.save(qr_model.state_dict(), '/storage/qrecc/models/e2e/qr'+str(epoch+3)+'.pth')
  torch.save(rc_model.state_dict(), '/storage/gumbel_softmaxqrecc/models/e2e/rc'+str(epoch+3)+'.pth')

Train loss : 0.3970739206526509, 0.45609857336574516
Valid loss : 0.4757325287272927, 0.5390424355815783

In [None]:
saved_weights = [0.1, 0.2, 0.3, 0.25]
loaded_weights = torch.tensor(saved_weights)
loaded_weights.requires_grad = True
loaded_weights