In [225]:
from transformers import OPTForCausalLM, AutoTokenizer
import torch

In [226]:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

teacher_model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
student_model = OPTForCausalLM.from_pretrained("facebook/opt-350m")


In [227]:
from datasets import load_dataset

dataset = load_dataset("nyu-mll/glue", "mnli", split='validation_matched')

In [228]:
def label_to_text(label):
    return ["entailment", "neutral", "contradiction"][label]

In [229]:
context_examples = [{key: value[i] for key, value in dataset[5:7].items()} for i in range(2)]
query_example = {key: value for key, value in dataset[7].items()}


In [230]:
context_examples

[{'premise': "well that would be a help i wish they would do that here we have got so little landfill space left that we're going to run out before the end of this decade and it's really going to be",
  'hypothesis': 'We have plenty of space in the landfill.',
  'label': 2,
  'idx': 5},
 {'premise': 'yeah i know and i did that all through college and it worked too',
  'hypothesis': 'I did that all through college but it never worked ',
  'label': 2,
  'idx': 6}]

In [231]:
query_example

{'premise': "Calcutta seems to be the only other production center having any pretensions to artistic creativity at all, but ironically you're actually more likely to see the works of Satyajit Ray or Mrinal Sen shown in Europe or North America than in India itself.",
 'hypothesis': "Most of Mrinal Sen's work can be found in European collections.",
 'label': 1,
 'idx': 7}

In [232]:
def s_create_prompt(premise, hypothesis, label=None):
    prompt = f"Label if this is entailment or contradiction.\nPremise: {premise},\nHypothesis: {hypothesis},\nLabel:"
    return prompt

In [233]:
def t_create_prompt(premise, hypothesis, label=None):
    prompt = f"\nPremise: {premise}, \nHypothesis: {hypothesis}"
    if label is not None:
        prompt += f",\nLabel: {label_to_text(label)}"
    if label is None:
        prompt += f",\nLabel:"
    return prompt

In [234]:
def create_extended_prompt(context_examples, query_example):
    context_prompt = ""
    for example in context_examples:
        premise = example['premise']
        hypothesis = example['hypothesis']
        label = example['label']
        context_prompt += t_create_prompt(premise, hypothesis, label) + "\n"
    
    query_premise = query_example['premise']
    query_hypothesis = query_example['hypothesis']
    query_prompt = t_create_prompt(query_premise, query_hypothesis) 
    return context_prompt + query_prompt

In [235]:
teacher_prompt = create_extended_prompt(context_examples, query_example)
student_prompt = s_create_prompt(query_example['premise'], query_example['hypothesis'])

teacher_inputs = tokenizer(teacher_prompt, return_tensors="pt")
student_inputs = tokenizer(student_prompt, return_tensors="pt")

print(teacher_prompt)


Premise: well that would be a help i wish they would do that here we have got so little landfill space left that we're going to run out before the end of this decade and it's really going to be, 
Hypothesis: We have plenty of space in the landfill.,
Label: contradiction

Premise: yeah i know and i did that all through college and it worked too, 
Hypothesis: I did that all through college but it never worked ,
Label: contradiction

Premise: Calcutta seems to be the only other production center having any pretensions to artistic creativity at all, but ironically you're actually more likely to see the works of Satyajit Ray or Mrinal Sen shown in Europe or North America than in India itself., 
Hypothesis: Most of Mrinal Sen's work can be found in European collections.,
Label:


In [236]:
teacher_outputs = teacher_model.generate(**teacher_inputs, max_length=teacher_inputs['input_ids'].shape[-1] + 2)
t_output_text = tokenizer.decode(teacher_outputs[0], skip_special_tokens=True)
print(t_output_text)
teacher_predicted_label = t_output_text.split("Label:")[-1].strip().split('.')[0].strip()



Premise: well that would be a help i wish they would do that here we have got so little landfill space left that we're going to run out before the end of this decade and it's really going to be, 
Hypothesis: We have plenty of space in the landfill.,
Label: contradiction

Premise: yeah i know and i did that all through college and it worked too, 
Hypothesis: I did that all through college but it never worked,
Label: contradiction

Premise: Calcutta seems to be the only other production center having any pretensions to artistic creativity at all, but ironically you're actually more likely to see the works of Satyajit Ray or Mrinal Sen shown in Europe or North America than in India itself., 
Hypothesis: Most of Mrinal Sen's work can be found in European collections.,
Label: contradiction



In [237]:
teacher_outputs_1 = teacher_model.generate(**teacher_inputs, max_length=teacher_inputs['input_ids'].shape[-1] + 1, output_scores=True, return_dict_in_generate=True)
teacher_probs_1 = torch.nn.functional.softmax(teacher_outputs_1.scores[0], dim=-1)
print(tokenizer.decode(teacher_outputs_1[0][0][-1], skip_special_tokens=True))
print(teacher_outputs_1[0].shape, teacher_outputs_1.scores[0].shape)

argmax_index = torch.argmax(teacher_outputs_1.scores[0])
print(teacher_probs_1.shape)


 contradiction
torch.Size([1, 189]) torch.Size([1, 50272])
torch.Size([1, 50272])


In [238]:
student_model.train()

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 512, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 1024)
      (project_out): Linear(in_features=1024, out_features=512, bias=False)
      (project_in): Linear(in_features=512, out_features=1024, bias=False)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=409

In [239]:
# student_outputs = student_model.generate(**student_inputs, max_length=student_inputs['input_ids'].shape[-1] + 2)
# s_output_text = tokenizer.decode(student_outputs[0], skip_special_tokens=True)
# student_predicted_label = s_output_text.split("Label:")[-1].strip().split('.')[0].strip()
# print(s_output_text)

# student_outputs = student_model.generate(**student_inputs, max_length=student_inputs['input_ids'].shape[-1] + 1, output_scores=True, return_dict_in_generate=True)
# student_probs_1 = torch.nn.functional.softmax(student_outputs.scores[0], dim=-1)

# print(tokenizer.decode(student_outputs[0][0][-1], skip_special_tokens=True))
# print(student_outputs[0].shape, student_outputs.scores[0].shape)

# s_argmax_index = torch.argmax(student_outputs.scores[0])
# print(student_probs_1.shape, s_argmax_index)
# student_probs_1.requires_grad

student_logits = student_model(**student_inputs).logits 
student_probs_1 = torch.nn.functional.softmax(student_logits[:,-1,:], dim=-1)
student_probs_1.shape

torch.Size([1, 50272])

In [240]:
optimizer = torch.optim.AdamW(student_model.parameters(), lr=0.001)

kl_divergence = torch.nn.functional.kl_div(student_probs_1.log(), teacher_probs_1, reduction='batchmean')

print(kl_divergence)
kl_divergence.requires_grad

tensor(11.3630, grad_fn=<DivBackward0>)


True

In [241]:
total_loss = 0

optimizer.zero_grad()
kl_divergence.backward()
optimizer.step()

total_loss += kl_divergence.item()

print(f"1, Average Loss: {total_loss/len(query_example)}")

1, Average Loss: 2.8407492637634277
