In [None]:
!pip install datasets
!pip install transformers==4.37.0
!pip install nltk

In [None]:
import torch
import pandas as pd
import numpy as np
import gc
import nltk
import copy
nltk.download('punkt')
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from nltk.translate.bleu_score import sentence_bleu

import warnings
warnings.filterwarnings("ignore")

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen1.5-0.5B-Chat",torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", padding_side = "left")

# mainmodel = AutoModelForCausalLM.from_pretrained(
#     "Qwen/Qwen1.5-0.5B-Chat",torch_dtype=torch.bfloat16,
# )
# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", padding_side = "left")

In [None]:
device = "cuda:2"
model.to(device)
print(model)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-6)


In [None]:
# mainmodel.to(device)

In [None]:
## Loading the dataset
dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
df = dataset['train_gen'].to_pandas()
display(df)

In [None]:
## Random Sampling
subset_size = 256  

# Use the sample() method to select a random subset
df = df.sample(n=subset_size)
df.reset_index(inplace = True)
df = df[['prompt','prompt_id','messages']]
display(df)

In [None]:
## Data Preprocessing

## 1) Separating prompts and responses
# Concatenate all 'content' entries in the arrays
# contents = df['messages'].apply(lambda x: x[1]['content'])
answers = []
for i in range(len(df)):
#     print(i)
    if(len(df.at[i,'messages'])<2):
        df.drop(i,inplace=True)
        
for i in range(len(df)):
    content = df.iat[i,2][1]['content']
    answers.append(content)
# Create a new column in the DataFrame with the concatenated content
df['answer'] = answers

df = df[['prompt','answer']]

## 2) Sorting according to prompt length to incorporate curriculum learning
df['length_col'] = df['prompt'].apply(len)

# Sort the DataFrame by the length column
df_sorted = df.sort_values(by='length_col', ascending=True)  # Use ascending=False for descending order
df = df_sorted[['prompt','answer']]
display(df)

## 3) Removing garbage prompts with very small lengths and hence insufficient context
df=df[10:138]
df.reset_index(inplace=True)
df.drop('index',axis=1)
df = df[['prompt','answer']]
display(df)


In [None]:
## Creating the dataloader class
class Customdataset(Dataset):
    def __init__(self,original_dataset):
        self.original_dataset = original_dataset
    def __len__(self):
        return len(self.original_dataset)
    def __getitem__(self,index):
        prompt = self.original_dataset.iat[index,0]
        response = self.original_dataset.iat[index,1]
        return prompt,response

In [None]:
## Initializing the dataloader
batch_size = 2
d_train = Customdataset(df)
dataloader = DataLoader(d_train, batch_size=batch_size, shuffle=False)

In [None]:
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
# torch.backends.cudnn.benchmark = False

In [None]:
## Defining the custom tokenizer
def tokenize_and_pad(texts, tokenizer):
    # Tokenize the batch of texts
    #tokenizer.pad_token = tokenizer.eos_token
    tokenized_batches = [tokenizer(batch, return_tensors="pt", padding=False, truncation=True, max_length = 1024) for batch in texts]
    max_length = max(len(text['input_ids'][0]) for text in tokenized_batches)
    tokenized_batches = [tokenizer(batch, return_tensors="pt", padding=False, truncation=True, max_length = max_length) for batch in texts]

    # Pad the sequences with zeros at the end
    for batch in tokenized_batches:
        for key in batch.keys():
    # Calculate the amount of padding needed
            padding_length = max(0, max_length - len(batch[key][0]))
            if key=="attention_mask":
                pad_value = 0
            else:    
                pad_value = tokenizer.convert_tokens_to_ids('<|endoftext|>')  # Assuming you have a tokenizer object

    # Perform left padding with the <s> token
            if padding_length > 0:
                padding_tensor = torch.full((batch[key].shape[0], padding_length), pad_value)

                # Concatenate along the correct dimension
                # If you want to add padding to the right (columns), use dim=1
                batch[key] = torch.cat([ padding_tensor,batch[key]], dim=1)
                
        
    return tokenized_batches

In [None]:
## Using BLEU Score as the evaluation metric
def calculate_bleu_score(paragraph1, paragraph2):
    # Remove tokens in the form of <...> from both paragraphs
    paragraph1_clean = " ".join(word for word in paragraph1.split() if not word.startswith("<") and not word.endswith(">"))
    paragraph2_clean = " ".join(word for word in paragraph2.split() if not word.startswith("<") and not word.endswith(">"))
    
    # Tokenize the paragraphs into lists of words
    reference = nltk.word_tokenize(paragraph1_clean)
    candidate = nltk.word_tokenize(paragraph2_clean)
    
    # Calculate BLEU scores
    bleu_1 = sentence_bleu([reference], candidate, weights=(1, 0, 0, 0),smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method4)
    bleu_2 = sentence_bleu([reference], candidate, weights=(0.5, 0.5, 0, 0),smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method4)
    bleu_3 = sentence_bleu([reference], candidate, weights=(0.33, 0.33, 0.33, 0),smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method4)
    bleu_4 = sentence_bleu([reference], candidate,smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method4)
    
    return bleu_1, bleu_2, bleu_3, bleu_4

In [None]:
## Initializing the optimizer and loading the model
# device = "cuda:2"
# model.to(device)
# print(model)
# optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-6)


In [None]:
## Defining the SPIN-finetuning loss 
def compute_spin_loss(model_logits_gt, opponent_logits_gt, model_logits_syn, opponent_logits_syn, ground_truth_ids, synthetic_response_ids, lambda_reg=0.1):
    # Apply softmax to convert logits to probabilities
    # Shapes after softmax: [batch_size, sequence_length, vocab_size]
    model_probs_gt = torch.nn.functional.softmax(model_logits_gt, dim=-1)
    opponent_probs_gt = torch.nn.functional.softmax(opponent_logits_gt, dim=-1)
    model_probs_syn = torch.nn.functional.softmax(model_logits_syn, dim=-1)
    opponent_probs_syn = torch.nn.functional.softmax(opponent_logits_syn, dim=-1)

    # Gather log probabilities for the actual tokens in the ground truth sequence
    # [batch_size, sequence_length, vocab_size] -> [batch_size, sequence_length]
    log_model_probs_gt = torch.log(torch.gather(
        model_probs_gt, dim=2, index=ground_truth_ids.unsqueeze(-1)
    ).squeeze(-1))
    log_opponent_probs_gt = torch.log(torch.gather(
        opponent_probs_gt, dim=2, index=ground_truth_ids.unsqueeze(-1)
    ).squeeze(-1))

    # Gather log probabilities for the actual tokens in the synthetic sequence
    # [batch_size, sequence_length, vocab_size] -> [batch_size, sequence_length]
    log_model_probs_syn = torch.log(torch.gather(
        model_probs_syn, dim=2, index=synthetic_response_ids.unsqueeze(-1)
    ).squeeze(-1))
    log_opponent_probs_syn = torch.log(torch.gather(
        opponent_probs_syn, dim=2, index=synthetic_response_ids.unsqueeze(-1)
    ).squeeze(-1))

    # Calculate log probability ratios for the tokens in the sequence
    # [batch_size, sequence_length]
    log_prob_ratio_gt = log_model_probs_gt - log_opponent_probs_gt
    log_prob_ratio_syn = log_model_probs_syn - log_opponent_probs_syn

    # Sum the log probability ratios over the sequence
    # [batch_size] -> scalar
    sum_log_prob_ratio_gt = torch.sum(log_prob_ratio_gt, dim=1)
    sum_log_prob_ratio_syn = torch.sum(log_prob_ratio_syn, dim=1)

    # Calculate the combined loss term for each sequence in the batch, scaled by lambda_reg
    # [batch_size] -> scalar
    combined_loss = lambda_reg * (sum_log_prob_ratio_gt - sum_log_prob_ratio_syn)

    # Apply the logistic loss to the combined term
    # [batch_size] -> scalar
    logistic_loss = torch.log(1 + torch.exp(-combined_loss))

    # Compute the mean of the logistic loss across the batch
    # scalar
    spin_loss = logistic_loss.mean()
    return spin_loss

In [None]:
num_iters = 3

for iter in range(num_iters):
    print("Training Epoch"+str(iter+1)+"/"+str(num_iters))
    total_loss = 0
    losses = []
    synthetic_data = []
    opponent_logits_gt_list = []

    for step,batch in enumerate(dataloader):
        print("Step No "+str(step))

        prompts, ground_truth = batch
        messages = [[{"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}] for prompt in prompts]
        text = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages]
        tokenized_batches = tokenize_and_pad(text,tokenizer)
        prompt_ids = torch.stack([x['input_ids'][0].to(device) for x in tokenized_batches], dim = 0)
        prompt_attention_mask = torch.stack([x['attention_mask'][0].to(device) for x in tokenized_batches], dim = 0)

        with torch.no_grad():
            model.eval()
            synthetic_response = model.generate(input_ids = prompt_ids, max_new_tokens = 2048)
        
            output=[output_ids[len(input_ids):] for input_ids, output_ids in zip(tokenized_batches[0].input_ids, synthetic_response)]
            synthetic_response_ids = torch.empty((1,output[0].size(0))).to(device)
            for j in range(batch_size):
                output=[output_ids[len(input_ids):] for input_ids, output_ids in zip(tokenized_batches[j].input_ids, synthetic_response)]
                synthetic_response_ids=torch.cat([synthetic_response_ids.long(),output[0].unsqueeze(0)],dim = 0)
            synthetic_response_ids=synthetic_response_ids[1:,:]
            
            ground_truth_messages = [[{"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": ground_tru}] for ground_tru in ground_truth]
            ground_truth_text = [ tokenizer.apply_chat_template(ground_truth_message,tokenize=False,add_generation_prompt=True) for ground_truth_message in ground_truth_messages]

            ground_truth_encoding = tokenize_and_pad(ground_truth_text,tokenizer)
            ground_truth_ids = torch.stack([x['input_ids'][0].to(device) for x in ground_truth_encoding], dim = 0)
            ground_truth_attention_mask = torch.stack([x['attention_mask'][0].to(device) for x in ground_truth_encoding], dim = 0)

            opponent_logits_gt = model(
                input_ids=ground_truth_ids, 
                attention_mask=ground_truth_attention_mask
            ).logits

            opponent_logits_gt_list.append(opponent_logits_gt)
            
            opponent_logits_syn = model(input_ids=synthetic_response_ids).logits
        model.train()


        main_player_logits_gt = model(input_ids=ground_truth_ids, attention_mask=ground_truth_attention_mask).logits
        main_player_logits_syn = model(input_ids=synthetic_response_ids).logits
        
#         Compute the loss
        loss = compute_spin_loss(
            main_player_logits_gt, opponent_logits_gt,
            main_player_logits_syn, opponent_logits_syn,
            ground_truth_ids, synthetic_response_ids, lambda_reg=0.1
        )
#         loss = torch.tensor([1.0]).to(device)
        total_loss += loss.item()
        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
#         torch.cuda.empty_cache()
#         gc.collect()    
#         model.to(device)

    average_loss = total_loss/(len(dataloader))
    print(f"Iteration {iter + 1}/{num_iters}, Average Loss: {average_loss}")







In [None]:
df_test = pd.read_csv('test.csv')

In [None]:
batch_size = 4
d_test = Customdataset(df)
test_dataloader = DataLoader(d_test, batch_size=batch_size, shuffle=False)

In [None]:
score_list = []

for step,batch in enumerate(dataloader):
    
    print("Step No "+str(step))
    prompts, ground_truth = batch
    messages = [[{"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}] for prompt in prompts]
    text = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages]
    tokenized_batches = tokenize_and_pad(text,tokenizer)
    prompt_ids = torch.stack([x['input_ids'][0].to(device) for x in tokenized_batches], dim = 0)
    prompt_attention_mask = torch.stack([x['attention_mask'][0].to(device) for x in tokenized_batches], dim = 0)
    
    with torch.no_grad():
        model.eval()
        synthetic_response = model.generate(input_ids = prompt_ids, max_new_tokens = 2048)
        output=[output_ids[len(input_ids):] for input_ids, output_ids in zip(tokenized_batches[0].input_ids, synthetic_response)]
        synthetic_response_ids = torch.empty((1,output[0].size(0))).to(device)
        for j in range(batch_size):
            output=[output_ids[len(input_ids):] for input_ids, output_ids in zip(tokenized_batches[j].input_ids, synthetic_response)]
            synthetic_response_ids=torch.cat([synthetic_response_ids.long(),output[0].unsqueeze(0)],dim = 0)
        synthetic_response_ids=synthetic_response_ids[1:,:]
        print(tokenizer.decode(synthetic_response_ids[0]))
        
        ground_truth_messages = [[{"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": ground_tru}] for ground_tru in ground_truth]
        ground_truth_text = [ tokenizer.apply_chat_template(ground_truth_message,tokenize=False,add_generation_prompt=True) for ground_truth_message in ground_truth_messages]
        
    for i in range(len(synthetic_response_ids)):
        scores = calculate_bleu_score(tokenizer.decode(synthetic_response_ids[i]),ground_truth_text[i])
        score_list.append(scores)
    
    torch.cuda.empty_cache()
    gc.collect()
    model.to(device)
            
avg_bleu_score = sum(score[0] for score in score_list)/len(score_list)
print(f"Average BLEU-1 Score is {avg_bleu_score}")
                

In [None]:
print(tokenizer.decode(synthetic_response[1]))