# Complexity Control Expirements

Main notebook for using a base model (at this point it is BART) for generating with controllable complexity. List of things to try: 

- ```GeDi```: Using the discriminator guided generation. Dataset is arxiv vs. science news. 
- ```DExperts```: Using a pretrained model as an expert and anti experts. The two pretrained models are one of arxiv text, one on science news text. 
- ```Rerankers```: An SVM trained on same DS as above (arxiv and science news). 
- ```PPLM```: Not included here, run seperately based on the HUF code (https://github.com/huggingface/transformers/tree/main/examples/research_projects/pplm)


In [None]:
from transformers import (
     AutoTokenizer,
     AutoModelForSeq2SeqLM,
     AutoConfig, 
     LogitsProcessorList,
     MinLengthLogitsProcessor,
     TopKLogitsWarper,
     NoBadWordsLogitsProcessor,
     TemperatureLogitsWarper,
     BeamSearchScorer,
     LogitsProcessor,
     NoRepeatNGramLogitsProcessor,
     LogitsProcessorList,
)

import seaborn as sn
import matplotlib.pyplot as plt
import torch
from torch.nn import functional as F

from importlib import reload  
from joblib import dump, load

import os
import sys
from importlib import reload  
import pandas as pd
import numpy as np
import csv
import datasets
import json
import random
import uuid
import spacy 
import scispacy
from tqdm import tqdm
from spacy.pipeline import Sentencizer
import re
from typing import Iterable, List
from transformers import BertTokenizer, glue_convert_examples_to_features, BertForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset

random.seed(42)

sys.path.append('/lib')

import utils
import jargon_utils

MODEL_DIR = ''
RESOURCE_DIR = ''
GEDI_PATH = ''
DATA_DIR = ''

sys.path.append(GEDI_PATH)
sys.path.append(RESOURCE_DIR)

import GeDi
from GeDi import generate_GeDi


# Data

For now, just use the medq wiki dataset (which is what the base model is finetuned oned), but might want to also try the sci defs or the simple and normal wiki paired

In [None]:
##### Dataset reloading 
df_wiki_medq_strat_train = pd.read_csv('{}/model_data/medquad_wikipedia_k_means_with_sd_train.csv'.format(DATA_DIR))
df_wiki_medq_strat_dev = pd.read_csv('{}/model_data/medquad_wikipedia_k_means_with_sd_dev.csv'.format(DATA_DIR))
df_wiki_medq_strat_test = pd.read_csv('{}/model_data/medquad_wikipedia_k_means_with_sd_test.csv'.format(DATA_DIR))

#### for testing without running through the whole df
df_test = df_wiki_medq_strat_test

#### for just running the human eval (and then not human eval)
selected_questions = pd.read_csv('{}/model_data/human_eval_test_questions.txt'.format(DATA_DIR))
df_test_human_eval = df_test[df_test['question'].isin(selected_questions['question'])]
df_test_not_human_eval = df_test[~df_test['question'].isin(selected_questions['question'])]


In [None]:
GPU_NUMBER = '' 
CUDA_DEVICE = '0'
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_NUMBER

DEVICE = 'cuda:0' 

# Generation Types 



In [None]:

# This is similar to the function from the definition modelling notebook but simplified to ignore gpt2 and added parts fo ranking scores
def generate_answers(df, model, tokenizer, model_type,
                     num_return_sequences=1,
                     num_beams=5,
                     max_length=64,
                     min_length=8,
                     early_stopping=True,
                     temperature=1,
                     do_sample=True,
                     top_k=50,
                     top_p=0.9,
                     max_input_length=1024,
                     no_repeat_ngram_size=3,
                     device=None):
    
    answer_df_lists = []
    for i, r in tqdm(df.iterrows()):
        
        row = r.to_dict()
                
        ### make input
        q_doc = row['q_s2orc_doc']
        inputs = tokenizer(q_doc, return_tensors='pt', max_length=max_input_length, truncation=True)
        
        
        outputs = model.generate(**inputs.to(device),
                                 decoder_start_token_id=tokenizer.bos_token_id,
                                 num_return_sequences=num_return_sequences,
                                 num_beams=num_beams, 
                                 min_length=min_length,
                                 max_length=max_length, 
                                 early_stopping=early_stopping, 
                                 temperature=temperature,
                                 do_sample=do_sample,
                                 top_k=top_k,
                                 top_p=top_p,
                                 eos_token_id=tokenizer.eos_token_id,
                                 no_repeat_ngram_size=no_repeat_ngram_size,
                                 output_scores=True,
                                 return_dict_in_generate=True)

        # save all the answers with their associated scores in a df
        answers = [tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in outputs[0]]
        df_answers = pd.DataFrame(zip(answers, outputs['sequences_scores'].tolist()), columns=['response', 'scores'])
        df_answers['model-type'] = model_type
        
        # save information about the question
        df_answers['question'] = row['question']
        df_answers['category'] = row['category']
        df_answers['first_sentence'] = row['first_sentence']
    
        # append the df
        answer_df_lists.append(df_answers)
        
    return pd.concat(answer_df_lists)


# Reranker

Taking a slightly different approach from last time, where we will generate on the base model then score seperately here. 

For scoring, do so grouped by question, because I think the svm decision function matters for that and maybe the bart one does too

In [None]:
#### Helpful functions

def get_svm_score(answers, clf, response_col):
    feature_cols = ['avl_occ', 'function_words_prop', 'te_oov', 'response_gpt_ppl_score', 'word_count']
    answers = jargon_utils.make_jargon_features(answers, response_col)
    answers['predictions'] = clf.decision_function(answers[feature_cols])
    return answers


def get_bert_logits(answers, bert_tokenizer, bert_ranker_model, prefix='', device=None):
    answers['input_ids'] = [bert_tokenizer(a, return_tensors="pt") for a in answers['response']]
    answers[prefix+'bert_logits'] = [bert_ranker_model(**inputs.to(device)).logits.cpu().detach().numpy()[0] for inputs in answers['input_ids']]
    answers[[prefix+'bert_logits_news', prefix+'bert_logits_academic']] = answers[prefix+'bert_logits'].apply(pd.Series)
    
    return answers


In [None]:
#### get generator and rankers

### Base Model
bart_tokenizer = AutoTokenizer.from_pretrained("{}/bart_medq_wiki_s2orc/default".format(MODEL_DIR))
bart_model = AutoModelForSeq2SeqLM.from_pretrained("{}/bart_medq_wiki_s2orc/default".format(MODEL_DIR)).to(DEVICE)

_ = bart_model.eval()


#### SVM 
ranker_path = ''
feature_cols = ['avl_occ', 'function_words_prop', 'te_oov', 'response_gpt_ppl_score', 'word_count']
arxiv_sci_train_clf = load('{}/svm_train_arxiv.joblib'.format(ranker_path)) 



In [None]:
### Generating some answers based on the base model

### Because of cuda issues, loop through 20 times (*5 per time) for 100 answers each
all_bart_test_answers = []
for _ in range(20):
    bart_test_answers = generate_answers(df=df_test, num_return_sequences=5, model=bart_model, tokenizer=bart_tokenizer, model_type='bart-base', use_decoder_prefix=False, device=DEVICE)
    all_bart_test_answers.append(bart_test_answers)
    

In [None]:
# replacing the bart_dev_answers var because that is what is used below
bart_test_answers = pd.concat(all_bart_test_answers)

In [None]:
#### Scoring the answers

bart_test_answers_svm = bart_test_answers.groupby(['question']).apply(lambda g: get_svm_score(g, arxiv_sci_train_clf, response_col='response')).reset_index(drop=True)


In [None]:
### take models off the gpu 
bart_model=bart_model.cpu()


In [None]:
##### Save everything seperately
bart_test_answers_svm.to_csv('{}/bart_test_answers_svm.csv'.format(DATA_DIR))


# DExperts



In [None]:
#### These models are bart-large pretrained on the two datasets (but not finetuned on our task)
### rather than the DAPT models, which are pretrained on top of the finetuned bart

bart_article_model = AutoModelForSeq2SeqLM.from_pretrained("{}/bart_medq_wiki_s2orc/articles".format(MODEL_DIR)).to(DEVICE)
_ = bart_article_model.eval()

bart_journal_model = AutoModelForSeq2SeqLM.from_pretrained("{}/bart_medq_wiki_s2orc/journals".format(MODEL_DIR)).to(DEVICE)
_ = bart_journal_model.eval()


In [None]:
# mostly taken from https://huggingface.co/transformers/_modules/transformers/generation_utils.html#GenerationMixin.generate
# for encoder-decoder networks like BART 
def setup_s2s_generation(context, s2s_model, s2s_tokenizer, max_length=50):
    inputs = s2s_tokenizer(context, return_tensors="pt", max_length=1024, truncation=True)
        
    model_kwargs = s2s_model._prepare_encoder_decoder_kwargs_for_generation(inputs['input_ids'], {})

    input_ids = s2s_model._prepare_decoder_input_ids_for_generation(
        inputs['input_ids'], decoder_start_token_id=s2s_model.config.decoder_start_token_id, bos_token_id=s2s_model.config.bos_token_id
    )
    
    sequence_lengths, unfinished_sequences, cur_len = s2s_model._init_sequence_length_for_generation(
        input_ids, max_length
    )
    
    return input_ids, model_kwargs, sequence_lengths, unfinished_sequences


def generate_s2s_DEexperts(context,
                           model,
                           expert_model,
                           anti_model,
                           tokenizer,
                           top_k=50,
                           top_p=0.9,
                           num_beams=5,
                           max_length=64,
                           min_length=8,
                           no_repeat_ngram_size=3,
                           expert_alpha = 2.0,
                           early_stopping=True,
                           temperature=1.0,
                           do_sample=True, # not used, assumed
                           device=None,
                           num_return_sequences=1):

    input_ids, model_kwargs, sequence_lengths, unfinished_sequences = setup_s2s_generation(context, model, tokenizer, max_length)

    
    # make the beam search scorer
    beam_scorer = BeamSearchScorer(
                 batch_size=1,
                 max_length=max_length,
                 num_beams=num_beams,
                 device=device,
                 length_penalty=model.config.length_penalty,
                 do_early_stopping=early_stopping,
                 num_beam_hyps_to_keep=num_return_sequences)
        
        
    logits_processors = model._get_logits_processor(
            repetition_penalty=None,
            no_repeat_ngram_size=no_repeat_ngram_size,
            bad_words_ids=None,
            min_length=min_length,
            eos_token_id=model.config.eos_token_id,
            prefix_allowed_tokens_fn=None,
            num_beams=num_beams,
            num_beam_groups=1,
            diversity_penalty=None,
        )
    

    logits_warper = model._get_logits_warper(
            top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
        )

    eos_token_id = model.config.eos_token_id
    pad_token_id = model.config.pad_token_id

    
    # interleave with `num_beams`
    input_ids, model_kwargs = model._expand_inputs_for_generation(
        input_ids, expand_size=num_beams, is_encoder_decoder=model.config.is_encoder_decoder, **model_kwargs
    )
    
    # initalize for beam search
    batch_size = len(beam_scorer._beam_hyps)
    num_beams = beam_scorer.num_beams

    batch_beam_size, cur_len = input_ids.shape

    beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
    beam_scores = beam_scores.view((batch_size * num_beams,))
        
    

    while cur_len < max_length:
        # prepare inputs
        model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)

        # run forward of model
        output = model(**model_inputs)

        next_token_logits = output.logits[:, -1, :]
        
        
        # hack from transformers: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
        # cannot be generated both before and after the `F.log_softmax` operation.
        next_token_logits = model.adjust_logits_during_generation(
            next_token_logits, cur_len=cur_len, max_length=max_length
        )
        
        next_token_scores = F.log_softmax(next_token_logits, dim=-1)
        
        # pre-process distribution
        next_token_scores = logits_processors(input_ids, next_token_logits)
        next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
        next_token_scores = logits_warper(input_ids, next_token_scores)

        # get the anti-expert and expert
        expert_output = expert_model(**model_inputs)
        anti_output = anti_model(**model_inputs)
        

        next_token_scores = next_token_scores + (expert_alpha * (expert_output.logits[:, -1, :] - anti_output.logits[:, -1, :]))

        
        # reshape for beam search
        vocab_size = next_token_scores.shape[-1]
        next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

        probs = F.softmax(next_token_scores, dim=-1)

        next_tokens = torch.multinomial(probs, num_samples=(2 * num_beams))
        next_token_scores = torch.gather(next_token_scores, -1, next_tokens)

        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
        next_tokens = torch.gather(next_tokens, -1, _indices)

        next_indices = next_tokens // vocab_size
        next_tokens = next_tokens % vocab_size

        # stateless
        beam_outputs = beam_scorer.process(
            input_ids,
            next_token_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
        )
        beam_scores = beam_outputs["next_beam_scores"]
        beam_next_tokens = beam_outputs["next_beam_tokens"]
        beam_idx = beam_outputs["next_beam_indices"]
        
        
        input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

        cur_len += 1
        
        # update kwargs
        model_kwargs = model._update_model_kwargs_for_generation(
                    output, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
        )
        
        # update sequence length
        if eos_token_id is not None:
            sequence_lengths, unfinished_sequences = model._update_seq_length_for_generation(
                sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
            )


        # pretty sure this is early stopping
        if unfinished_sequences.max() == 0:
            break
        

        # pretty sure this is early stopping for the beam scorer
        if beam_scorer.is_done:
            break
            
    sequence_outputs = beam_scorer.finalize(
        input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
    )
            
    return sequence_outputs["sequences"], model_kwargs, sequence_outputs["sequence_scores"]

In [None]:
###### Journal 
journal_answer_df_lists = []

for i, r in tqdm(df_test_not_human_eval.iterrows()):

    row = r.to_dict()

    ### make input
    q_doc = row['q_s2orc_doc']
    output_ids, model_kwargs, seq_scores = generate_s2s_DEexperts(context=q_doc, 
                                                 model=bart_model,
                                                 expert_model=bart_journal_model,
                                                 anti_model=bart_article_model,
                                                 tokenizer=bart_tokenizer,
                                                 top_k=50,
                                                 top_p=.9,
                                                 num_beams=10,
                                                 max_length=64,
                                                 min_length=8,
                                                 no_repeat_ngram_size=3,
                                                 expert_alpha = 1.0,
                                                 num_return_sequences=10
                                                 )

    # save all the answers with their associated scores in a df
    answers = [bart_tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in output_ids]
    df_answers = pd.DataFrame(zip(answers, seq_scores), columns=['response', 'scores'])
    
    df_answers['model-type'] = 'DExpert-journal'

    # save information about the question
    df_answers['question'] = row['question']
    df_answers['category'] = row['category']
    df_answers['first_sentence'] = row['first_sentence']

    # append the df
    journal_answer_df_lists.append(df_answers)

        
###### News
article_answer_df_lists = []

for i, r in tqdm(df_test_not_human_eval.iterrows()):

    row = r.to_dict()

    ### make input
    q_doc = row['q_s2orc_doc']
    output_ids, model_kwargs, seq_scores = generate_s2s_DEexperts(context=q_doc, 
                                                 model=bart_model,
                                                 expert_model=bart_article_model,
                                                 anti_model=bart_journal_model,
                                                 tokenizer=bart_tokenizer,
                                                 top_k=50,
                                                 top_p=.9,
                                                 num_beams=10,
                                                 max_length=64,
                                                 min_length=8,
                                                 no_repeat_ngram_size=3,
                                                 expert_alpha = 1.0,
                                                 num_return_sequences=10
                                                 )

    # save all the answers with their associated scores in a df
    answers = [bart_tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in output_ids]
    df_answers = pd.DataFrame(zip(answers, seq_scores), columns=['response', 'scores'])
    df_answers['model-type'] = 'DExpert-article'

    # save information about the question
    df_answers['question'] = row['question']
    df_answers['category'] = row['category']
    df_answers['first_sentence'] = row['first_sentence']

    # append the df
    article_answer_df_lists.append(df_answers)



DExperts_journal_test_answers = pd.concat(journal_answer_df_lists)
DExperts_article_test_answers = pd.concat(article_answer_df_lists)


In [None]:
# DExperts_journal_test_answers = pd.concat(journal_answer_df_lists)
DExperts_journal_test_answers.to_csv('{}/model_answers/DExperts_journal_test_answers.csv'.format(DATA_DIR))
DExperts_article_test_answers.to_csv('{}/model_answers/DExperts_article_test_answers.csv'.format(DATA_DIR))


# GeDi


This has to be run twice, one for each of the arg sets (journal or news)


In [None]:
# for namespacing the args
class Bunch(object):
    def __init__(self, adict):
        self.__dict__.update(adict)
           
args = {'model_type': 'bart',
        'gedi_model_type': 'bart',
        'gen_model_name_or_path': '{}/bart_medq_wiki_s2orc/default/'.format(MODEL_DIR),
        'gedi_model_name_or_path': '/homes/gws/taugust/Projects/ARK/sci_comm/resources/GeDi/topic_GeDi_journal_news_bart_medq_wiki',
        'fp16': True, 'load_in_half_prec': False,
        'config_name': '', 'tokenizer_name': '',
        'cache_dir': '', 'do_lower_case': False,
        'no_cuda': False, 'gen_length': 64,
        'stop_token': None, 'temperature': 1.0,
        'disc_weight': 30.0, 'filter_p': 0.8, 
        'class_bias': None, 'target_p': 0.8,
        'do_sample': True, 'repetition_penalty': 1.2,
        'rep_penalty_scale': 10.0, 'penalize_cond': True,
        'k': 50, 'p': 0.9, 'gen_type': 'gedi',
        'mode': 'topic', 'secondary_code': 'journal', 
        'gpt3_api_key': None, 
        'prompt': None,
        'num_return_sequences':10,
        'device':DEVICE}

args = Bunch(args_template)

In [None]:

answer_df_lists_gedi_rows = []

for i, r in tqdm(df_test.iterrows()):
      
    row = r.to_dict()                
    
    args.prompt = row['q_s2orc_doc']
    args.term_start = None
    
    returned_answers = generate_GeDi.generate_GeDi_simple(args, tokenizer, model, gedi_model)
    
    # handle all the returned sequences
    for answer, output in returned_answers:
        new_row = row.copy()
        new_row['response'] = answer.strip('</s>')
        new_row['output'] = output[1].cpu().numpy()[0]
        answer_df_lists_gedi_rows.append(new_row)

    
### make the df -- curently for journal
bart_gedi_journal_test_answers = pd.DataFrame(answer_df_lists_gedi_rows)


    

In [None]:
### Saving (change to news when running that one)
bart_gedi_journal_test_answers.to_csv('{}/model_answers/bart_gedi_journal_test_answers.csv'.format(DATA_DIR))


# Combining

Combine all the answers into one df for evaluation


In [None]:

#### Here taking the top 10, for the rest take all of them 

### Reranking SVM
bart_test_answers_svm_journal = bart_test_answers_svm.groupby(['question']).apply(lambda g: utils.get_top(g, col='predictions', n=10)).reset_index(drop=True)
bart_test_answers_svm_news = bart_test_answers_svm.groupby(['question']).apply(lambda g: utils.get_top(g, col='predictions', ascending=True, n=10)).reset_index(drop=True)



In [None]:
### make sure all responses have model type

### Rerankers
bart_test_answers_svm_journal['model-type'] = 'svm-rerank-journal'
bart_test_answers_svm_news['model-type'] = 'svm-rerank-news'

### Gedi
bart_gedi_news_test_answers['model-type'] = 'gedi-news'
bart_gedi_journal_test_answers['model-type'] = 'gedi-journal'


### DExperts 
DExperts_news_test_answers = DExperts_article_test_answers
DExperts_journal_test_answers['model-type'] = 'DExpert-journal'
DExperts_news_test_answers['model-type'] = 'DExpert-news'


### PPLM
bart_pplm_journal_test_answers['model-type'] = 'PPLM-journal'
bart_pplm_news_test_answers['model-type'] = 'PPLM-news'

In [None]:
common_columns = ['response', 'model-type', 'question', 'first_sentence', 'category']

dfs = [bart_test_answers_bert_journal, bart_test_answers_bert_news,
       bart_test_answers_svm_journal, bart_test_answers_svm_news,
       bart_gedi_journal_test_answers, bart_gedi_news_test_answers,
       DExperts_journal_test_answers, DExperts_news_test_answers,
       bart_journals_test_answers, bart_articles_test_answers,
      bart_pplm_journal_test_answers, bart_pplm_news_test_answers]

df = pd.concat([d[common_columns] for d in dfs])

In [None]:

df.to_csv('{}/model_answers/test_complexity_responses.csv'.format(DATA_DIR))

# Complexity Measures

Includes FK by concatting all definitions together

In [None]:
from readability import Readability


In [None]:
df = pd.read_csv('{}/model_answers/test_complexity_responses.csv'.format(DATA_DIR))

In [None]:
df = df[df['word_count'] > 0] 
df = jargon_utils.make_jargon_features(df, text_col='response')

## Flesch Kincaid 

In [None]:

### function for getting the overall Flesch Kincaid score by concating all definitions in a df together 
def get_overall_fk(df, def_col='response'):
    all_defs = df[def_col].str.cat(sep=' ') 
    all_defs_readability = Readability(all_defs)
    return all_defs_readability.flesch_kincaid()


df.groupby(['model-type']).apply(lambda g: get_overall_fk(g))