<a href="https://colab.research.google.com/github/suryansh29/DL-project/blob/main/Notebook2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!pip install datasets
!pip install trl
!pip install evaluate



In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
from datasets import load_dataset, ClassLabel, concatenate_datasets
import torch

mnli = load_dataset("nyu-mll/multi_nli")

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, OPTForSequenceClassification, TrainingArguments, Trainer

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
model = AutoModelForSequenceClassification.from_pretrained("facebook/opt-350m", num_labels=len(tokenizer))


In [None]:
from transformers import AutoTokenizer, OPTForSequenceClassification, OPTForCausalLM, TrainingArguments, Trainer

auto_tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
classification_model = OPTForSequenceClassification.from_pretrained("facebook/opt-350m", num_labels=50272)
# model = AutoModelForSequenceClassification.from_pretrained("facebook/opt-350m", num_labels=2)

causal_model = OPTForCausalLM.from_pretrained("facebook/opt-350m")

classification_model.score = causal_model.lm_head
model = classification_model

In [None]:
PATTERN="{text1} {text2} ?"
TARGET_PREFIX=" "

In [None]:
target_tokens_ids = tokenizer.convert_tokens_to_ids(['ĠYes', 'ĠNo'])


In [None]:
from datasets import ClassLabel

mnli = mnli.filter(lambda example: example["label"] != 1 and example["label"] != -1)

# change labels of contradiction examples from 2 to 1
def change_label(example):
    example["label"] = 1 if example["label"] == 2 else example["label"]
    return example
mnli = mnli.map(change_label)





In [None]:
mnli
 # change features to reflect the new labels
features = mnli["train"].features.copy()

# features["label"] = ClassLabel(num_classes=len(tokenizer))
features["label"] = ClassLabel(num_classes=50272)

mnli = mnli.cast(features)  # overwrite old features

In [None]:
def preprocess_function_without_context(examples):
  pattern_examples = [
          PATTERN.format(
              text1=examples['premise'][idx],
              text2=examples['hypothesis'][idx]) + TARGET_PREFIX
          for idx in range(len(examples['premise']))
      ]

  args = (pattern_examples,)
  result = tokenizer(*args, padding="max_length", truncation=True)

# Get tokens
  result["input_tokens"] = [tokenizer.convert_ids_to_tokens(
      ids) for ids in result["input_ids"]]

  # Decode input
  result["input_text"] = [tokenizer.decode(
      ids) for ids in result["input_ids"]]

  result["label"] = [target_tokens_ids[l] for l in examples["label"]]

  return result


In [None]:
tokenized_mnli_without_context = mnli.map(preprocess_function_without_context, batched=True)

In [None]:
import numpy as np

def create_few_shot_examples(examples, num_shots=3):
  indices = np.random.choice(range(len(examples['premise'])), size=num_shots, replace=False)
  separate_shots_by = "\n\n"
  context=""
  for idx in indices:
    formated_sample = PATTERN.format(
        text1=examples['premise'][idx],
        text2=examples['hypothesis'][idx]
    )
    verbalized_label = tokenizer.convert_ids_to_tokens([(examples['label'][idx])])[0]
    context += f"{formated_sample}{TARGET_PREFIX}{verbalized_label}{separate_shots_by}"
  return context



def preprocess_function_with_few_shot_context(examples):
  examples["label"] = [target_tokens_ids[l] for l in examples["label"]]
  pattern_examples = [
          create_few_shot_examples(examples, 3) +
          PATTERN.format(
              text1=examples['premise'][idx],
              text2=examples['hypothesis'][idx]) + TARGET_PREFIX
          for idx in range(len(examples['premise']))
      ]
  args = (pattern_examples,)
  result = tokenizer(*args, padding="max_length", truncation=True)

# Get tokens
  result["input_tokens"] = [tokenizer.convert_ids_to_tokens(
      ids) for ids in result["input_ids"]]

  # Decode input
  result["input_text"] = [tokenizer.decode(
      ids) for ids in result["input_ids"]]

  result["label"] = examples["label"]

  return result


In [None]:
tokenized_mnli_validation_matched_with_context = mnli['validation_matched'].map(preprocess_function_with_few_shot_context, batched=True)
tokenized_mnli_validation_mismatched_with_context = mnli['validation_mismatched'].map(preprocess_function_with_few_shot_context, batched=True)
tokenized_mnli_training_with_context = mnli['train'].map(preprocess_function_with_few_shot_context, batched=True)

In [None]:
mnli['train'].map(preprocess_function_without_context, batched=True)

In [None]:
import evaluate

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
  predictions, labels = eval_pred
  predictions = np.argmax(predictions, axis=1)
  print('predictions, labels: ', predictions, labels)
  return accuracy.compute(predictions=predictions, references=labels)

In [None]:
hans = load_dataset("hans")


In [None]:
 # change features to reflect the new labels
features = hans["train"].features.copy()
features["label"] = ClassLabel(num_classes=len(tokenizer))
hans = hans.cast(features)  # overwrite old features

In [None]:
tokenized_hans_without_context = hans.map(preprocess_function_without_context, batched=True)

In [None]:
from datasets import concatenate_datasets

# Get validation_matched and validation_mismatched datasets
validation_matched_dataset_without_context = tokenized_mnli_without_context["validation_matched"]
validation_mismatched_dataset_without_context = tokenized_mnli_without_context["validation_mismatched"]

# Concatenate the two datasets
combined_validation_dataset_without_context = concatenate_datasets([validation_matched_dataset_without_context, validation_mismatched_dataset_without_context])

# Now combined_validation_dataset contains both validation_matched and validation_mismatched datasets

# Optionally, shuffle the combined dataset
combined_validation_dataset_without_context = combined_validation_dataset_without_context.shuffle()


# Concatenate the two datasets
combined_validation_dataset_with_context = concatenate_datasets([tokenized_mnli_validation_matched_with_context, tokenized_mnli_validation_mismatched_with_context])

# Optionally, shuffle the combined dataset
combined_validation_dataset_with_context = combined_validation_dataset_with_context.shuffle()

**Zero Shot Inference In-Domain **

In [None]:
training_args = TrainingArguments(
    output_dir="zero_shot_mnli_validation",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    eval_dataset=combined_validation_dataset_without_context,
    tokenizer=tokenizer,
    # data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.evaluate()

In [None]:
tokenizer.convert_ids_to_tokens([23248])

**Zero Shot Inference Out-Domain **

In [None]:
trainer = Trainer(
    model=model,
    eval_dataset=hans['validation'],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.evaluate()

Few shot inference (ICL) out of domain



In [None]:
tokenized_hans_validation_with_context = hans['validation'].map(preprocess_function_with_few_shot_context, batched=True)



trainer = Trainer(
    model=model,
    eval_dataset=tokenized_hans_validation_with_context,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.evaluate()

Few shot inference (ICL) in domain


In [None]:
trainer = Trainer(
    model=model,
    eval_dataset=combined_validation_dataset_with_context,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.evaluate()

Finetuning

In [None]:
model_to_ft = OPTForSequenceClassification.from_pretrained("facebook/opt-350m", num_labels=50272)

KL divergence based context distillation

In [None]:
def calculate_kl_divergence(logits_with_context, logits_to_align):
    # Compute log softmax for both models
    # logits_to_align: (batch_size, seq_length, vocab_size)
    # logits_to_align: (batch_size, seq_length, vocab_size)

    log_probs_to_align = F.log_softmax(logits_to_align, dim=-1)
    log_probs_with_context = F.log_softmax(logits_with_context, dim=-1)

    # Sort log probabilities and select top 50 indices for each token
    _, top_indices = torch.topk(log_probs_with_context, k=50, dim=-1)

    # Gather log probabilities of top indices
    gathered_log_probs_to_align = torch.gather(log_probs_to_align, -1, top_indices)
    gathered_log_probs_with_context = torch.gather(log_probs_with_context, -1, top_indices)


    # Compute KL divergence
    kl_loss = F.kl_div(gathered_log_probs_with_context, gathered_log_probs_to_align, dim=-1, reduction='batchmean', log_target = True)

    return kl_loss

In [None]:
max_length = 180

# Pad or truncate input sequences to match the model's input length requirement
input_ids = [ids + [tokenizer.pad_token_id] * (max_length - len(ids)) if len(ids) < max_length else ids[:max_length] for ids in tokenized_mnli_training_with_context['input_ids']]


In [None]:
logits_with_context = model(input_ids = torch.tensor(input_ids)).logits

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, OPTForSequenceClassification, TrainingArguments, Trainer

class CustomTrainer(Trainer):
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
  def compute_loss(self, model, inputs, return_outputs=False):
    logits_to_align = model(input_ids=inputs['input_ids'])
    loss = calculate_kl_divergence(logits_with_context, logits_to_align)
    return (loss, logits_to_align) if return_outputs else loss

In [None]:
from peft import LoraConfig, TaskType, get_peft_model

peft_config = LoraConfig(task_type= TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)


peft_model = get_peft_model(model_to_ft, peft_config)
peft_model.print_trainable_parameters()


from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

peft_training_args = TrainingArguments(
    output_dir="suryansh/dl-project/mt0-large-lora-context-distillation",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    eval_steps=300,
    save_steps=7915,
    learning_rate=5e-5,
    fp16=True,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to="wandb"
)

context_distillation_trainer = CustomTrainer(
    model=peft_model,
    args=peft_training_args,
    train_dataset=tokenized_mnli_without_context,
    eval_dataset=combined_validation_dataset_without_context,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

context_distillation_trainer.train()

In [None]:
context_distillation_trainer.evaluate()


For each, calculate in-domain accuracy (on mnli) and out-domain accuracy (on hans)

1) Try out acuracy for zero shot inference without fine-tuning      
2) Try out accuracy for few shot inference (in-context learning) without fine-tuning      
3) Try out LoRA, QLoRA and PEFT based fine-tuning      
4) KL divergence loss based fine-tuning using LoRA/QLoRa, PEFT      