In [None]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments,Trainer, AutoModelForSequenceClassification
import torch
import pandas as pd
import evaluate
from collections import Counter
from peft import LoraConfig, get_peft_model, TaskType, PeftModel, PeftConfig
from transformers import pipeline
import numpy as np
from IPython.display import Image, display


In [2]:
dataset_name = 'knkarthick/dialogsum'
dataset = load_dataset(dataset_name)
path_to_data = '../data/'
split = 'test'
print(dataset[split][0]['dialogue'])
print("Summary " , dataset['train'][0]['summary'])


#Person1#: Ms. Dawson, I need you to take a dictation for me.
#Person2#: Yes, sir...
#Person1#: This should go out as an intra-office memorandum to all employees by this afternoon. Are you ready?
#Person2#: Yes, sir. Go ahead.
#Person1#: Attention all staff... Effective immediately, all office communications are restricted to email correspondence and official memos. The use of Instant Message programs by employees during working hours is strictly prohibited.
#Person2#: Sir, does this apply to intra-office communications only? Or will it also restrict external communications?
#Person1#: It should apply to all communications, not only in this office between employees, but also any outside communications.
#Person2#: But sir, many employees use Instant Messaging to communicate with their clients.
#Person1#: They will just have to change their communication methods. I don't want any - one using Instant Messaging in this office. It wastes too much time! Now, please continue with the memo. Wh

In [3]:
model_name = 'google/flan-t5-base'
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True)

# Zero Shot Inference 

In [None]:
example_indicies = [0,1]
for i,index in enumerate(example_indicies):
    dialoge = dataset[split][index]['dialogue']
    summary = dataset[split][index]['summary']
    
    prompt = f""" 
Dialogue: 
{dialoge} 
What was going on? 
"""
    inputs  = tokenizer(prompt, return_tensors='pt')
    output = tokenizer.decode(model.generate(inputs['input_ids'],max_new_tokens=50)[0],skip_special_tokens=True)
    print("\n Dialoge: \n" , prompt)
    print("\n Summary: \n" , summary)
    print("\n Output: \n"  , output)
    print("-----------------------------------------------")

# One Shot Learning

In [None]:
example_indicies = [0,1]
for i,index in enumerate(example_indicies):
    dialoge = dataset[split][index]['dialogue']
    summary = dataset[split][index]['summary']
    
    prompt = f""" 
Dialogue: 
{dialoge} 
What was going on? 
{summary}
"""
    prompt = f""" 
Dialogue: 
{dialoge} 
What was going on? 
"""
    inputs  = tokenizer(prompt, return_tensors='pt')
    output = tokenizer.decode(model.generate(inputs['input_ids'],max_new_tokens=50)[0],skip_special_tokens=True)
    print("\n Dialoge: \n" , prompt)
    print("\n Output: \n"  , output)
    print("-----------------------------------------------")

# Few shot learning

In [None]:
example_indicies = [323,9]
prompt = ''
for i,index in enumerate(example_indicies):
    dialoge = dataset[split][index]['dialogue']
    summary = dataset[split][index]['summary']
    
    prompt += f""" 
Dialogue: 
{dialoge} 
What was going on? 
{summary}
"""
    
index = 1
dialoge = dataset[split][index]['dialogue']
summary = dataset[split][index]['summary']

prompt += f""" 
Dialogue: 
{dialoge} 
What was going on? 
"""

inputs  = tokenizer(prompt, return_tensors='pt')
output = tokenizer.decode(model.generate(inputs['input_ids'],max_new_tokens=50)[0],skip_special_tokens=True,do_sample=True, tempreture=0.5)
print("\n Dialoge: \n" , prompt)
print("\n Output: \n"  , output)
print("-----------------------------------------------")

# Helper Functions

In [40]:
def create_prompt(example_indicies, target_index=1):
    prompt = ''
    
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary:'
    
    for i,index in enumerate(example_indicies):
        dialoge = dataset[split][index]['dialogue']
        summary = dataset[split][index]['summary']
        prompt += start_prompt+dialoge+end_prompt+summary
    
    index = target_index
    dialoge = dataset[split][index]['dialogue']
    ground_truth = dataset[split][index]['summary']
    prompt = start_prompt+dialoge+end_prompt
    
    return prompt, ground_truth

def gen_output(prompt, original_model, fine_tuned_model):
    inputs  = tokenizer(prompt, return_tensors='pt')
    outputO = tokenizer.decode(original_model.generate(inputs['input_ids']  )[0],skip_special_tokens=True,do_sample=True, tempreture=0.5)
    outputF = tokenizer.decode(fine_tuned_model.generate(input_ids=inputs['input_ids'])[0],skip_special_tokens=True,do_sample=True, tempreture=0.5)

    return ' '.join(outputO.split('Dialogue:')[-1:]), ' '.join(outputF.split('Dialogue:')[-1:])
    
def tokenizer_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary:'
    prompt = [start_prompt+dialogue+end_prompt for dialogue in example['dialogue']]
    example['input_ids'] = tokenizer(prompt, padding='max_length', truncation=True, return_tensors="pt").input_ids
    example['labels'] = tokenizer(example['summary'], padding='max_length', truncation=True, return_tensors="pt").input_ids
    
    return example
    

# Fully Instruction Fine Tuning

In [None]:
tokenized_dataset = dataset.map(tokenizer_function, batched=True)
tokenized_datasets = tokenized_dataset.remove_columns(['id','topic','dialogue','summary'])
tokenized_datasets = tokenized_datasets.filter(lambda example, index : index % 2 == 0, with_indices = True)


In [None]:
# model_f = AutoModelForSeq2SeqLM.from_pretrained(path_to_data + 'Flan_T5_Full_Fine_Tuned')
model_o = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-base')

training_args = TrainingArguments(output_dir = path_to_data+'summary_data/',
                                 learning_rate=1e-5, 
                                 num_train_epochs=5, 
                                 weight_decay=0.01,
                                 logging_steps=1,
                                 max_steps=25,
                                 auto_find_batch_size=True)

trainer = Trainer(model = model_o, 
                 args = training_args, 
                 train_dataset= tokenized_datasets['test'])

trainer.train()
trainer.save_model('../data/Flan_T5_Full_Fine_Tuned' )


# Evaluation 

In [None]:
model_f = AutoModelForSeq2SeqLM.from_pretrained(path_to_data + 'Flan_T5_Full_Fine_Tuned')
model_o = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-base')

tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base',use_fast=True)

In [None]:
print("original" , sum([p.numel()/1e6 for p in model_o.parameters() if p.requires_grad]) , 'M trainable parameters')


In [None]:
orgs = []
fids = []
gros = []
for i in range(5,15):
    prompt,gr_ = create_prompt([1,2,3], i)
    org_ , fid_ =  gen_output(prompt, model_o, model_f)
    orgs.append(org_)
    fids.append(fid_)
    gros.append(gr_)

zipped_summaries = list(zip(gros,orgs,fids))
df = pd.DataFrame(zipped_summaries,columns=['G','O','F'])


In [None]:
rouge = evaluate.load('rouge')

original_model_results = rouge.compute(
             predictions=orgs,
             references=gros,
             use_aggregator=True,
             use_stemmer=True)

fined_model_results = rouge.compute(
             predictions=fids,
             references=gros,
             use_aggregator=True,
             use_stemmer=True)

print("Original" , original_model_results)
print("Fined" , fined_model_results)

# Evaluation Function 

In [None]:
def compare_counter(c1,c2):
    '''
    number of matches between two counters
    '''
    num_matches = 0
    for k,v in c1.items():
        if k in c2: 
            v2 = c2[k]
            num_matches += min(v2,v)
    return num_matches

def ngram(text, n):
    '''
    ngram
    '''
    ngrams = []
    for idx in range(len(text.split())-n+1):
        ngram_ = []
        for i in range(n):
            ngram_.append(text.split()[idx+i])
        ngrams.append(' '.join(ngram_))
    return ngrams

def rouge_n(gros,tars,n):
    '''
    rouge n-gram
    '''
    rouge_n_recall_o = 0
    rouge_n_precision_o = 0
    for idx in range(len(gros)):
        
        g,o = ngram(gros[idx],n),ngram(tars[idx],n)
        cg = Counter(g)
        co = Counter(o)
        n_matches_o  = compare_counter(cg,co)
        
        rouge_n_recall_o    += (n_matches_o)/len(g)
        rouge_n_precision_o += (n_matches_o)/len(o)
        
    rouge_n_f1_o        = 2*rouge_n_recall_o*rouge_n_precision_o/(rouge_n_precision_o+rouge_n_recall_o+0.0001)
        
    return rouge_n_recall_o/len(gros), rouge_n_precision_o/len(gros), rouge_n_f1_o/len(gros)


def LCS(text1, text2):
    '''
    Longest common subsequence between two list of words
    '''
    w1 = text1.split()
    w2 = text2.split()
    
    l,r = 0,0
    dp = {}
    
    def dfs(l,r):
        
        if l>=len(w1) or r >= len(w2):
            return 0
        
        if (l,r) in dp:
            return dp[(l,r)]
        
        if w1[l] == w2[r]:
            dp[(l,r)] = 1 + dfs(l+1,r+1)
        else: 
            dp[(l,r)] = max(dfs(l,r+1),dfs(l+1,r))
        
        return dp[(l,r)]
    
    return dfs(0,0)
    
def rouge_L(gros,tars):
    '''
    rouge LCS
    '''
    rougeL_f1_o = 0
    for idx in range(len(gros)):
        
        g,o = (gros[idx]),(tars[idx])
        lcs_matches  = LCS(g,o)
        
        rougeL_recall_o    = (lcs_matches)/len(g.split())
        rougeL_precision_o = (lcs_matches)/len(o.split())
        
        rougeL_f1_o        += 2*rougeL_recall_o*rougeL_precision_o/(rougeL_precision_o+rougeL_recall_o+0.0001)
        
    return rougeL_recall_o/len(gros),rougeL_precision_o/len(gros), rougeL_f1_o/len(gros)

def bleu(gros, tars, n):
    '''
    avg precision across range of ngrams
    '''
    pr_total = 0
    for i in range(1,n+1):
        _,pr_i,_ = rouge_n(gros,tars,i)
        pr_total += pr_i 
    return pr_total/n


In [None]:
# unit tests

text1 = 'it is cold outside'
text2 = 'it is too cold in the outside'

assert LCS(text1,text2) == 4 , "LCS is not correct"


text1 = ['it is dark']
text2 = ['dark is it']

assert rouge_n(text1,text2,1)[-1] >= 0.99 , "rouge_1 score don't match"
assert rouge_n(text1,text2,2)[-1] <= 0.49 , "rouge_2 score don't match"
assert rouge_L(text1,text2)[-1] <= 0.49 , "rouge_L score don't match"
assert bleu(text1,text2,2) <= 0.5 , "bleu score don't match"
assert bleu(text1,text2,2) ==  (rouge_n(text1,text2,1)[1] + rouge_n(text1,text2,2)[1])/2, "bleu is not the average of first two-gram prcisions"



text1 = ['it is cold']
text2 = ['it is very cold']

assert rouge_n(text1,text2,1)[-1] >= 0.8 , "rouge_1 score don't match"
assert rouge_n(text1,text2,2)[-1] >= 0.35 , "rouge_2 score don't match"
assert rouge_L(text1,text2)[-1] >= 0.8 , "rouge_L score don't match"
assert bleu(text1,text2,2) <= 0.6 , "bleu score don't match"
assert bleu(text1,text2,2) ==  (rouge_n(text1,text2,1)[1] + rouge_n(text1,text2,2)[1])/2, "bleu is not the average of first two-gram prcisions"



# Parameter Efficient Fine-Tuning LoRA

Low rank adaptation 

In [41]:
tokenized_dataset = dataset.map(tokenizer_function, batched=True)
tokenized_datasets = tokenized_dataset.remove_columns(['id','topic','dialogue','summary'])
tokenized_datasets = tokenized_datasets.filter(lambda example, index : index % 2 == 0, with_indices = True)


Map:   0%|          | 0/8017 [00:00<?, ? examples/s]

Map:   0%|          | 0/2005 [00:00<?, ? examples/s]

Filter:   0%|          | 0/8017 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2005 [00:00<?, ? examples/s]

In [36]:
lora_config = LoraConfig(r = 32, 
                        lora_alpha = 32, 
                        target_modules=['q','v'],
                        lora_dropout = 0.05,
                        bias = "none",
                        task_type=TaskType.SEQ_2_SEQ_LM)

peft_model = get_peft_model(model_o, lora_config)


In [80]:
print("peft" ,sum([p.numel()/1e6 for p in peft_model.parameters()]) , 'M parameters')

print("peft" ,sum([p.numel()/1e6 for p in peft_model.parameters() if p.requires_grad]) , 'M trainable parameters')


peft 251.11679999999868 M parameters
peft 3.5389440000000074 M trainable parameters


In [82]:
peft_training_args = TrainingArguments(output_dir = path_to_data+'summary_data/',
                                 learning_rate=1e-3, 
                                 num_train_epochs=10, 
                                 auto_find_batch_size=True,
                                 weight_decay=0.01,
                                 logging_steps=1,
                                 max_steps=20)

trainer = Trainer(model = peft_model, 
                 args = peft_training_args, 
                 train_dataset= tokenized_datasets['train'],
                 eval_dataset = tokenized_datasets['test'])

trainer.train()
trainer.save_model('../data/Flan_T5_Lora_Fine_Tuned' )

trainer.model.save_pretrained('../data/Flan_T5_Lora_Torch')



max_steps is given, it will override any value given in num_train_epochs
The following columns in the training set don't have a corresponding argument in `PeftModelForSeq2SeqLM.forward` and have been ignored: query. If query are not expected by `PeftModelForSeq2SeqLM.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 4009
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 20
  Number of trainable parameters = 3538944


Step,Training Loss
1,49.0441
2,47.5256
3,41.4148
4,37.0442
5,34.0453
6,31.0716
7,27.7987
8,26.5222
9,24.1546
10,22.6077




Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to ../data/Flan_T5_Lora_Fine_Tuned
Trainer.model is not a `PreTrainedModel`, only saving its state dict.


### inference Merge Peft with base

In [83]:
model_o = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-base')

tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base',use_fast=True)

# Merging 
peft_model_path = '../data/Flan_T5_Lora_Torch'
peft_model = PeftModel.from_pretrained(model_o,peft_model_path, 
                                       torch_dype = torch.bfloat16,
                                       is_trainable = False)


loading configuration file config.json from cache at /Users/raminanushiravani/.cache/huggingface/hub/models--google--flan-t5-base/snapshots/7bcac572ce56db69c1ea7c8af255c5d7c9672fc2/config.json
Model config T5Config {
  "_name_or_path": "google/flan-t5-base",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      

### Evaluation LoRA

In [85]:
model_o = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-base')


orgs = []
pfids = []
gros = []
for i in range(5,15):
    prompt,gr_ = create_prompt([1,2,3], i)
    org_ , fid_ =  gen_output(prompt, model_o, peft_model)
    orgs.append(org_)
    pfids.append(fid_)
    gros.append(gr_)

zipped_summaries = list(zip(gros,orgs,pfids))
df = pd.DataFrame(zipped_summaries,columns=['G','O','PEF'])


loading configuration file config.json from cache at /Users/raminanushiravani/.cache/huggingface/hub/models--google--flan-t5-base/snapshots/7bcac572ce56db69c1ea7c8af255c5d7c9672fc2/config.json
Model config T5Config {
  "_name_or_path": "google/flan-t5-base",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      

In [86]:
rouge = evaluate.load('rouge')

original_model_results = rouge.compute(
             predictions=orgs,
             references=gros,
             use_aggregator=True,
             use_stemmer=True)

fined_model_results = rouge.compute(
             predictions=pfids,
             references=gros,
             use_aggregator=True,
             use_stemmer=True)

print("Original" , original_model_results)
print("PEFT" , fined_model_results)

Original {'rouge1': 0.30313868934458577, 'rouge2': 0.14182656053623793, 'rougeL': 0.2807120517713355, 'rougeLsum': 0.2804762084944087}
PEFT {'rouge1': 0.2382748040217824, 'rouge2': 0.0923680693756653, 'rougeL': 0.21115572374122304, 'rougeLsum': 0.207305349349413}


# Fine Tune with RLHF

In [None]:
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

In [None]:
def build_dataset(model_name,dataset_name,minLen,maxLen):
    dataset = load_dataset(dataset_name,split='train')
    dataset = dataset.filter(lambda x: len(x['dialogue'])>minLen and len(x['dialogue'])<=maxLen)

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    def tokenize(sample):
        prompt = f"""
        Summarize the following conversation.
        {sample['dialogue']}
        
        Summary:
        """
        
        sample['input_ids'] = tokenizer(prompt, return_tensors='pt').input_ids.to("mps")
        sample['query'] = tokenizer.decode(sample['input_ids'][0])
        return sample 
    
    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type='torch')
    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)
    return dataset_splits

In [31]:
dataset_name = 'knkarthick/dialogsum'
model_name = 'google/flan-t5-base'
dataset = build_dataset(model_name,dataset_name,minLen=200,maxLen=1000)

Map:   0%|          | 0/10022 [00:00<?, ? examples/s]

In [32]:

model_o = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-base')

lora_config = LoraConfig(r = 32, 
                        lora_alpha = 32, 
                        target_modules=['q','v'],
                        lora_dropout = 0.05,
                        bias = "none",
                        task_type=TaskType.SEQ_2_SEQ_LM)

peft_model_path = '../data/Flan_T5_Lora_Torch'
peft_model = PeftModel.from_pretrained(model_o,
                                       peft_model_path, 
                                       lora_config = lora_config,
                                       torch_dype = torch.bfloat16,
                                       is_trainable = True)


# Proximal Policy Optimization Model 

In [None]:
display(Image(filename='../snapshots/ppo.png'))

In [None]:
# ppo model
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,
                                                               torch_dtyoe = torch.bfloat16, 
                                                               is_trainable = True
                                                              ,device_map = 'auto')
# reference model 
ref_model = create_reference_model(ppo_model)

# Later we use KL divergence to compare the output from ppo model to the reference model 

print(ppo_model.v_head)

In [None]:
toxic_model_name = 'facebook/roberta-hate-speech-dynabench-r4-target'
toxic_tokenizer = AutoTokenizer.from_pretrained(toxic_model_name,device_map = 'auto')
toxic_model = AutoModelForSequenceClassification.from_pretrained(toxic_model_name, device_map = 'auto')


In [None]:
device = 0 if torch.cuda.is_available() else 'cpu'
sentiment_pipe = pipeline('sentiment-analysis',model=toxic_model_name, device = device)

reward_logits_kwargs = {"top_k" : None, 
                        "function_to_apply": "none",
                        "batch_size" : 16}

reward_probs_kwargs = {"top_k" : None, 
                        "function_to_apply": "softmax",
                        "batch_size" : 16}

toxic_eval  = evaluate.load("toxicity", toxic_model_name, model_type='measurement',toxic_label= "hate")

def evaluate_toxic(model, toxic_eval, tokenizer, dataset, num_samples):
    max_new_tokens = 100
    toxics = []
    input_texts = []
    for i, sample in enumerate(dataset):
        input_text = sample["query"]
        if i > num_samples: 
            break
        input_ids = tokenizer(input_text, return_tensors = 'pt', padding = True).input_ids
        gen_config = GenerationConfig(max_new_tokens = max_new_tokens, 
                                     top_k=0.0,top_p=1.0, do_sample=True)
        response_token_ids = model.generate(input_ids=input_ids , 
                                           generation_config = gen_config)
        generated_text = tokenizer.decode(response_token_ids[0], skip_special_tokens=True)
        toxic_score = toxic_eval.compute(predictions = [(input_text + " "+ generated_text)])
        toxics.extend(toxic_score['toxicity'])
        
    mean = np.mean(toxics)
    std = np.std(toxics)
    
    return mean, std
        

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map = 'auto')
mean_b_toxic, std_b_toxic = evaluate_toxic(ref_model, toxic_eval=toxic_eval, 
                                           tokenizer=tokenizer, dataset = dataset['test'],
                                           num_samples = 10)
print(f'toxic [mean,std] before detox [{mean_b_toxic},{std_b_toxic}]')



In [None]:
lr = 1.4e-5
max_ppo_epochs = 1
mini_batch_size = 4,
batch_size = 16

config = PPOConfig(
    model_name = model_name,
    learning_rate = lr,
    ppo_epochs =max_ppo_epochs, 
    mini_batch_size=mini_batch_size,
    batch_size=batch_size)

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

ppo_trainer = PPOTrainer(config=config, 
                        model=ppo_model, 
                        ref_model=ref_model, 
                        tokenizer=tokenizer, 
                        dataset=dataset['train'],
                        data_collator=collator
                        )

In [19]:
output_min_len = 100
output_max_len = 400
output_len_sampler = LengthSampler(output_min_len, output_max_len)

generation_kwargs = {
                    "min_length":5,
                    "top_k": 0.0,
                    "top_p":1.0,
                    "do_sample":True}

max_ppo_steps = 10

for step,batch in enumerate(ppo_trainer.dataloader):
    if step >= max_ppo_steps: 
        break
    prompt_tensors = batch['input_ids']
    summary_tensors = []
    for prompt_tensor in prompt_tensors:
        max_new_tokens = output_len_sampler()
        generation_kwargs['max_new_tokens'] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)
        summary_tensors.append(summary.squeeze()[-max_new_tokens:])
    
    batch['response'] = [tokenizer.decode(r.squeeze()) for r in summary_tensors]
    query_response_pair = [q+r for q,r in zip(batch['query'],batch['response'])]
    rewards = sentiment_pipe(query_response_pairs, **reward_kwargs)
    
    reward_tensors = [torch.tensor(reward[not_hate_index]['score']) for reward in rewards]
    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)
    

TypeError: Operation 'abs_out_mps()' does not support input type 'int64' in MPS backend.

In [None]:
def kl_divergence(p, q):
     return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))