In [None]:
# imports
import transformers
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler, Dataset
import pandas as pd
import numpy as np
import json
import os
from os.path import exists

import torch.nn.functional as F
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint

import math
import random
import re
import argparse

class LitModel(pl.LightningModule):
  # Instantiate the model
  def __init__(self, learning_rate, tokenizer, model, hparams):
    super().__init__()
    self.tokenizer = tokenizer
    self.model = model
    self.learning_rate = learning_rate
    # self.freeze_encoder = freeze_encoder
    # self.freeze_embeds_ = freeze_embeds
    self.hparams = hparams

    if self.hparams.freeze_encoder:
      freeze_params(self.model.get_encoder())

    if self.hparams.freeze_embeds:
      self.freeze_embeds()
  
  def freeze_embeds(self):
    ''' freeze the positional embedding parameters of the model; adapted from finetune.py '''
    freeze_params(self.model.model.shared)
    for d in [self.model.model.encoder, self.model.model.decoder]:
      freeze_params(d.embed_positions)
      freeze_params(d.embed_tokens)

  # Do a forward pass through the model
  def forward(self, input_ids, **kwargs):
    return self.model(input_ids, **kwargs)
  
  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
    return optimizer

  def training_step(self, batch, batch_idx):
    # Load the data into variables
    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]
    # Shift the decoder tokens right (but NOT the tgt_ids)
    decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)

    # Run the model and get the logits
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    lm_logits = outputs[0]
    # Create the loss function
    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
    # Calculate the loss on the un-shifted tokens
    loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

    return {'loss':loss}

  def validation_step(self, batch, batch_idx):

    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]

    decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)
    
    # Run the model and get the logits
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    lm_logits = outputs[0]

    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
    val_loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

    return {'loss': val_loss}
  
  # Method that generates text using the BartForConditionalGeneration's generate() method
  def generate_text(self, text, eval_beams, early_stopping = False, max_len = 500):
    ''' Function to generate text '''
    generated_ids = self.model.generate(
        text["input_ids"],
        attention_mask=text["attention_mask"],
        use_cache=True,
        decoder_start_token_id = self.tokenizer.pad_token_id,
        num_beams= eval_beams,
        max_length = max_len,
        min_length = 50,
        early_stopping = early_stopping
    )
    return [self.tokenizer.decode(w, skip_special_tokens=True, clean_up_tokenization_spaces=True) for w in generated_ids]

    def freeze_params(model):
      ''' Function that takes a model as input (or part of a model) and freezes the layers for faster training
          adapted from finetune.py '''
      for layer in model.parameters():
        layer.requires_grade = False

In [None]:
# Create a dataloading module as per the PyTorch Lightning Docs
class SummaryDataModule(pl.LightningDataModule):
  def __init__(self, tokenizer, data_file, batch_size, num_examples = 20000):
    super().__init__()
    self.tokenizer = tokenizer
    self.data_file = data_file
    self.batch_size = batch_size
    self.num_examples = num_examples
  
  # Loads and splits the data into training, validation and test sets with a 60/20/20 split
  def prepare_data(self):
    self.data = pd.read_csv(self.data_file)[:self.num_examples]
    self.train, self.validate, self.test = np.split(self.data.sample(frac=1), [int(.6*len(self.data)), int(.8*len(self.data))])

  # encode the sentences using the tokenizer  
  def setup(self, stage):
    self.train = encode_sentences(self.tokenizer, self.train['source'], self.train['target'])
    self.validate = encode_sentences(self.tokenizer, self.validate['source'], self.validate['target'])
    self.test = encode_sentences(self.tokenizer, self.test['source'], self.test['target'])

  # Load the training, validation and test sets in Pytorch Dataset objects
  def train_dataloader(self):
    dataset = TensorDataset(self.train['input_ids'], self.train['attention_mask'], self.train['labels'])                          
    train_data = DataLoader(dataset, sampler = RandomSampler(dataset), batch_size = self.batch_size)
    return train_data

  def val_dataloader(self):
    dataset = TensorDataset(self.validate['input_ids'], self.validate['attention_mask'], self.validate['labels']) 
    val_data = DataLoader(dataset, batch_size = self.batch_size)                       
    return val_data

  def test_dataloader(self):
    dataset = TensorDataset(self.test['input_ids'], self.test['attention_mask'], self.test['labels']) 
    test_data = DataLoader(dataset, batch_size = self.batch_size)                   
    return test_data


In [None]:
# hparams to pass in the model
hparams = argparse.Namespace()

hparams.freeze_encoder = True
hparams.freeze_embeds = True
hparams.eval_beams = 4


def shift_tokens_right(input_ids, pad_token_id):
  """ Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).
      This is taken directly from modeling_bart.py
  """
  prev_output_tokens = input_ids.clone()
  index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
  prev_output_tokens[:, 1:] = input_ids[:, :-1]
  return prev_output_tokens

def encode_sentences(tokenizer, source_sentences, target_sentences, max_length=300, pad_to_max_length=True, return_tensors="pt"):
  ''' Function that tokenizes a sentence 
      Args: tokenizer - the BART tokenizer; source and target sentences are the source and target sentences
      Returns: Dictionary with keys: input_ids, attention_mask, target_ids
  '''

  input_ids = []
  attention_masks = []
  target_ids = []
  tokenized_sentences = {}

  for sentence in source_sentences:
    encoded_dict = tokenizer(
          sentence.lower(),
          max_length=max_length,
          padding="max_length" if pad_to_max_length else None,
          truncation=True,
          return_tensors=return_tensors,
          add_prefix_space = True
      )

    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

  input_ids = torch.cat(input_ids, dim = 0)
  attention_masks = torch.cat(attention_masks, dim = 0)

  for sentence in target_sentences:
    encoded_dict = tokenizer(
          sentence.lower(),
          max_length=max_length,
          padding="max_length" if pad_to_max_length else None,
          truncation=True,
          return_tensors=return_tensors,
          add_prefix_space = True
      )
    # Shift the target ids to the right
    # shifted_target_ids = shift_tokens_right(encoded_dict['input_ids'], tokenizer.pad_token_id)
    target_ids.append(encoded_dict['input_ids'])

  target_ids = torch.cat(target_ids, dim = 0)
  

  batch = {
      "input_ids": input_ids,
      "attention_mask": attention_masks,
      "labels": target_ids,
  }

  return batch


# Load the model
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW, BartConfig

# Load pre-trained
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', add_prefix_space=False, bos_token="<s>", eos_token="</s>")
bart_model = BartForConditionalGeneration.from_pretrained(
    "facebook/bart-base")

def find_best_epoch(ckpt_folder):
    """
    Find the highest epoch in the Test Tube file structure.
    :param ckpt_folder: dir where the checpoints are being saved.
    :return: Integer of the highest epoch reached by the checkpoints.
    """
    ckpt_files = os.listdir(ckpt_folder)  # list of strings
    epochs = [int(filename.split('step=')[-1].split('.')[0]) for filename in ckpt_files]  # 'epoch={int}.ckpt' filename format
    best_epoch = max(epochs)
    for filename in ckpt_files:
        if str('{}.ckpt'.format(best_epoch)) in filename:
            return filename
    return best_epoch

def generate_queries(seed_query, num_queries, model_, noise_percent = 0.25, multiple_queries = False, max_query_history = 3):
  ''' Function that generates queries based on previously generated queries 
      Args: seed_query - a prior query
            num_queries - number of queries to generate
            model_ - model used to generate
            multiple_queries - model generates based on prior queries
            max_query_history - maximum number of prior queries
      Returns a list with num_queries
  '''
  # Put the model on eval mode
  model_.to(torch.device('cpu'))
  model_.eval()
  queries = []
  queries.append(seed_query)
  prompt_tokens = tokenizer(seed_query.lower(), max_length = 300, return_tensors = "pt", truncation = True)
  print(prompt_tokens)
  query = model_.generate_text(prompt_tokens, eval_beams = 4)
  print('pred',query[0].strip())
  return query[0].strip()


In [None]:
def main():
    torch.cuda.empty_cache()
    suggestions = {}
    # each file is a user, 2 columns: source, target 
    # "source" column is prior query, "target" column is predicting target
    # each row represents a query
    for filename in os.listdir("./mytestdata/full"):
        if filename.endswith(".csv"):
            user = filename.replace('.csv','')
            suggestions[user] = []
            # split data to train / test
            ratio = .7
            print('--------------',user,ratio,'----------')
            allindex = pd.read_csv('./mytestdata/full/'+filename).index.values
            splitindex = allindex[int(len(allindex)*ratio)]
            pred_index = allindex[int(len(allindex)*ratio):]
            
            # Load the data into the model for training
            summary_data = SummaryDataModule(tokenizer, './mytestdata/full/'+filename,
                                             batch_size = 14, num_examples = splitindex)


            model = LitModel(learning_rate = 2e-5, tokenizer = tokenizer, model = bart_model, hparams = hparams)


            ckpt_dir = './checkpoint_files_2/'+user+'_full'
            checkpoint = ModelCheckpoint(ckpt_dir)
            # Load the model from saved checkpoint if exists
            if exists(ckpt_dir):
                best_epoch = find_best_epoch(ckpt_dir)
                print(ckpt_dir+'/'+best_epoch)
                trainer = pl.Trainer(gpus = 4,
                                 max_epochs = 20,
                                 min_epochs = 20,
                                 auto_lr_find = False,
                                 resume_from_checkpoint = ckpt_dir+'/'+best_epoch,
                                 progress_bar_refresh_rate = 10)
            # finetune from stratch
            else:
                trainer = pl.Trainer(gpus = 4,
                                 max_epochs = 20,
                                 min_epochs = 20,
                                 auto_lr_find = False,
                                 checkpoint_callback = checkpoint,
                                 progress_bar_refresh_rate = 10)

            # Fit the instantiated model to the data
            trainer.fit(model, summary_data)

            pred_df = pd.read_csv('./mytestdata/full/'+filename)[splitindex:]
            for index, row in pred_df.iterrows():
                pred_target = generate_queries(seed_query = row['source'], num_queries = 2, model_ = model,
                                       noise_percent = 0.25, multiple_queries = True, max_query_history = 2)
                suggestions[user].append([row['target'],pred_target,row['source'],index])

    with open('./suggestions.json', 'w') as outfile:
        json.dump(suggestions, outfile)

main()  