In [None]:
import torch
import json
import matplotlib.pyplot as plt

from transformers import OPTForCausalLM, AutoTokenizer
from datasets import load_dataset



In [None]:
def json_pretty(json_string):
    print(json.dumps(json_string,indent=4))
    return None

def label_to_text(label):
    return ["entailment", "neutral", "contradiction"][label]

def teacher_create_prompt(premise, hypothesis, label=None):
    prompt = f"\nPremise: {premise}, \nHypothesis: {hypothesis}"
    if label is not None:
        prompt += f",\nLabel: {label_to_text(label)}"
    if label is None:
        prompt += f",\nLabel:"
    return prompt

def student_create_prompt(premise, hypothesis, label=None):
    prompt = f"Label if this is entailment or contradiction.\nPremise: {premise},\nHypothesis: {hypothesis},\nLabel:"
    return prompt

def create_extended_prompt(context_examples, query_example):
    context_prompt = ""
    for example in context_examples:
        premise = example['premise']
        hypothesis = example['hypothesis']
        label = example['label']
        context_prompt += teacher_create_prompt(premise, hypothesis, label) + "\n"
    
    query_premise = query_example['premise']
    query_hypothesis = query_example['hypothesis']
    query_prompt = teacher_create_prompt(query_premise, query_hypothesis) 
    return context_prompt + query_prompt

### Tokenizer + Models

In [None]:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") # tokenizer
teacher_model = OPTForCausalLM.from_pretrained("facebook/opt-350m") # teacher model
student_model = OPTForCausalLM.from_pretrained("facebook/opt-350m") # student model

### Datasets 

In [None]:
dataset = load_dataset("nyu-mll/glue", "mnli", split='train')

### Data Example

In [None]:
context_examples = [{key: value[i] for key, value in dataset[5:7].items()} for i in range(2)]
query_example = {key: value for key, value in dataset[7].items()}

In [None]:
json_pretty(context_examples)

In [None]:
json_pretty(query_example)

### Prompt Example

In [None]:
teacher_prompt = create_extended_prompt(context_examples, query_example)
student_prompt = student_create_prompt(query_example['premise'], query_example['hypothesis'])

teacher_inputs = tokenizer(teacher_prompt, return_tensors="pt")
student_inputs = tokenizer(student_prompt, return_tensors="pt")

print(teacher_prompt)

In [None]:
teacher_outputs_1 = teacher_model.generate(**teacher_inputs, max_length=teacher_inputs['input_ids'].shape[-1] + 1, output_scores=True, return_dict_in_generate=True)
teacher_probs_1 = torch.nn.functional.softmax(teacher_outputs_1.scores[0], dim=-1)
print(tokenizer.decode(teacher_outputs_1[0][0][-1], skip_special_tokens=True))
print(teacher_outputs_1[0].shape, teacher_outputs_1.scores[0].shape)

argmax_index = torch.argmax(teacher_outputs_1.scores[0])
print(teacher_probs_1.shape)


### Student model training

In [None]:
student_model.train()

In [None]:
student_outputs = student_model.generate(**student_inputs, max_length=student_inputs['input_ids'].shape[-1] + 2)
s_output_text = tokenizer.decode(student_outputs[0], skip_special_tokens=True)
predicted_label = s_output_text.split("Label:")[-1].strip().split('.')[0].strip()
print(s_output_text)

In [None]:
# student_outputs = student_model.generate(**student_inputs, max_length=student_inputs['input_ids'].shape[-1] + 2)
# s_output_text = tokenizer.decode(student_outputs[0], skip_special_tokens=True)
# student_predicted_label = s_output_text.split("Label:")[-1].strip().split('.')[0].strip()
# print(s_output_text)

# student_outputs = student_model.generate(**student_inputs, max_length=student_inputs['input_ids'].shape[-1] + 1, output_scores=True, return_dict_in_generate=True)
# student_probs_1 = torch.nn.functional.softmax(student_outputs.scores[0], dim=-1)

# print(tokenizer.decode(student_outputs[0][0][-1], skip_special_tokens=True))
# print(student_outputs[0].shape, student_outputs.scores[0].shape)

# s_argmax_index = torch.argmax(student_outputs.scores[0])
# print(student_probs_1.shape, s_argmax_index)
# student_probs_1.requires_grad

student_logits = student_model(**student_inputs).logits 
student_probs_1 = torch.nn.functional.softmax(student_logits[:,-1,:], dim=-1)
student_probs_1.shape

In [None]:
optimizer = torch.optim.AdamW(student_model.parameters(), lr=0.001)

kl_divergence = torch.nn.functional.kl_div(student_probs_1.log(), teacher_probs_1, reduction='batchmean')

print(kl_divergence)
kl_divergence.requires_grad

In [None]:
total_loss = 0

optimizer.zero_grad()
kl_divergence.backward()
optimizer.step()

total_loss += kl_divergence.item()

print(f"1, Average Loss: {total_loss/len(query_example)}")

In [None]:
student_outputs = student_model.generate(**student_inputs, max_length=student_inputs['input_ids'].shape[-1] + 1)
s_output_text = tokenizer.decode(student_outputs[0], skip_special_tokens=True)
predicted_label = s_output_text.split("Label:")[-1].strip().split('.')[0].strip()
print(s_output_text)
print("Label only:" , predicted_label)

In [None]:
def create_batches(dataset, context_size=2, batch_size=3):
    batch = []
    for index in range(0, len(dataset), context_size + 1):
        if index + batch_size > len(dataset):
            break  # Avoid going out of index range, discard last incomplete batch
        
        # Create batch: 2 context + 1 query
        context = dataset[index:index + context_size]
        query = dataset[index + context_size]

        # Collect and format data for the batch
        batch_data = {
            'context': [{key: value[i] for key, value in context.items()} for i in range(2)],
            'query': {key: value for key, value in query.items()}
        }
        batch.append(batch_data)
    
    return batch

# Subset of only 300 rows.
subset_dataset = dataset.select(range(300))
batches = create_batches(subset_dataset)

In [None]:
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Assuming `losses` is the list of epoch losses returned from the `train` function
def plot_losses(losses):
    plt.figure(figsize=(10, 5))
    plt.plot(losses, marker='o', linestyle='-', color='b')
    plt.title('Training Loss Per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Average Loss')
    plt.grid(True)
    plt.show()


In [None]:
import torch

def train(data, epochs=10, batch_size=16, device='cpu'):
    student_model.to(device)
    teacher_model.to(device)
    
    student_model.train()
    optimizer = torch.optim.AdamW(student_model.parameters(), lr=0.001)
    epoch_losses = []


    for epoch in range(epochs):
        total_loss = 0
        num_batches = len(data) // batch_size 
        
        for i in range(num_batches):
            batch_loss = 0
            for j in range(batch_size):
                index = i * batch_size + j
                teacher_prompt = create_extended_prompt(data[index]['context'], data[index]['query'])
                student_prompt = student_create_prompt(data[index]['query']['premise'], data[index]['query']['hypothesis'])
                
                teacher_inputs = tokenizer(teacher_prompt, return_tensors="pt").to(device)
                student_inputs = tokenizer(student_prompt, return_tensors="pt").to(device)
                
                teacher_outputs = teacher_model.generate(**teacher_inputs, max_length=teacher_inputs['input_ids'].shape[-1] + 1, output_scores=True, return_dict_in_generate=True)
                teacher_probs = torch.nn.functional.softmax(teacher_outputs.scores[0], dim=-1)
                
                student_logits = student_model(**student_inputs).logits
                student_probs = torch.nn.functional.softmax(student_logits[:, -1, :], dim=-1)
                
                kl_divergence = torch.nn.functional.kl_div(student_probs.log(), teacher_probs, reduction='batchmean')
                
                optimizer.zero_grad()
                kl_divergence.backward()
                optimizer.step()

                batch_loss += kl_divergence.item()

            # Average loss for the batch
            batch_loss /= batch_size
            total_loss += batch_loss
            # print(f"Epoch {epoch + 1}, Batch {i + 1}, Average Loss: {batch_loss}")


        # Average loss for the epoch
        epoch_loss = total_loss / num_batches
        epoch_losses.append(epoch_loss)

        print(f"Epoch {epoch + 1}, Total Loss: {epoch_loss}")
        
    print(f"Total loss : {total_loss/epochs}")
    plot_losses(epoch_losses)
    torch.save(student_model, 'student.pth')


train(batches, device=device)


In [None]:
student_logits = student_logits.to(device)
student_inputs = student_inputs.to(device)
student_outputs = student_model.generate(**student_inputs, max_length=student_inputs['input_ids'].shape[-1] + 1).to(device)
s_output_text = tokenizer.decode(student_outputs[0], skip_special_tokens=True)
predicted_label = s_output_text.split("Label:")[-1].strip().split('.')[0].strip()
print(s_output_text)
print("Label only:" , predicted_label)

In [None]:
validation_dataset = load_dataset("nyu-mll/glue", "mnli", split='validation_matched')

In [None]:
def create_student_token(data, tokenizer, device = 'cpu'):
    tokens = []
    for i in range(len(data)):
        prompt = student_create_prompt(data[i]['premise'], data[i]['hypothesis'])
        tokens.append(tokenizer(prompt, return_tensors="pt").to(device))
    return tokens

tokens = create_student_token(validation_dataset, tokenizer, device)

In [40]:
def predict(model, source, target=None, device='cpu'):
    predict = []
    for token in source:
        output = model.generate(**token, max_length=token['input_ids'].shape[-1] + 1).to(device)
        decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
        predicted_label = decoded_output.split("Label:")[-1].strip().split('.')[0].strip()
        predict.append(predicted_label)
        print(predicted_label)
    return predict
   
predict(student_model, tokens, device=device)  


contradiction
neutral
neutral
contradiction
contradiction
contradiction
neutral
neutral
neutral
contradiction
contradiction
neutral
neutral
neutral
neutral
neutral
neutral
neutral
neutral
neutral
contradiction
neutral
neutral
neutral
neutral
neutral
neutral
neutral
neutral
neutral
contradiction
neutral
contradiction
neutral
contradiction
neutral
neutral
neutral
contradiction
neutral
neutral
neutral
contradiction
contradiction
contradiction
neutral
neutral
neutral
contradiction
contradiction
neutral
neutral
neutral
neutral
contradiction
contradiction
neutral
neutral
neutral
neutral
neutral
neutral
neutral
neutral
contradiction
neutral
neutral
neutral
neutral
contradiction
neutral
neutral
neutral
neutral
contradiction
neutral
neutral
neutral
neutral
neutral
neutral
neutral
contradiction
neutral
neutral
neutral
neutral
neutral
contradiction
contradiction
contradiction
neutral
neutral
neutral
neutral
neutral
neutral
neutral
neutral
neutral
neutral
neutral
neutral
contradiction
neutral
cont