In [1]:
from transformers import OPTForCausalLM, AutoTokenizer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [224]:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

teacher_model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
student_model = OPTForCausalLM.from_pretrained("facebook/opt-350m")


In [225]:
from datasets import load_dataset

dataset = load_dataset("nyu-mll/glue", "mnli", split='validation_matched')

In [226]:
def label_to_text(label):
    return ["entailment", "neutral", "contradiction"][label]

In [227]:
context_examples = [{key: value[i] for key, value in dataset[1:3].items()} for i in range(2)]
query_example = {key: value for key, value in dataset[3].items()}


In [228]:
context_examples

[{'premise': 'This site includes a list of all award winners and a searchable database of Government Executive articles.',
  'hypothesis': 'The Government Executive articles housed on the website are not able to be searched.',
  'label': 2,
  'idx': 1},
 {'premise': "uh i don't know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him",
  'hypothesis': 'I like him for the most part, but would still enjoy seeing someone beat him.',
  'label': 0,
  'idx': 2}]

In [229]:
query_example

{'premise': "yeah i i think my favorite restaurant is always been the one closest  you know the closest as long as it's it meets the minimum criteria you know of good food",
 'hypothesis': 'My favorite restaurants are always at least a hundred miles away from my house. ',
 'label': 2,
 'idx': 3}

In [236]:
def s_create_prompt(premise, hypothesis, label=None):
    prompt = f"Answer if it is entailment or Contradiction. Premise: {premise}, Hypothesis: {hypothesis}, Label:entailment or Contradiction"
    return prompt

In [237]:
def t_create_prompt(premise, hypothesis, label=None):
    prompt = f"Premise: {premise}, Hypothesis: {hypothesis}"
    if label is not None:
        prompt += f", Label: {label_to_text(label)}"
    return prompt

In [238]:
def create_extended_prompt(context_examples, query_example):
    context_prompt = ""
    for example in context_examples:
        premise = example['premise']
        hypothesis = example['hypothesis']
        label = example['label']
        context_prompt += t_create_prompt(premise, hypothesis, label) + ". "
    
    query_premise = query_example['premise']
    query_hypothesis = query_example['hypothesis']
    query_prompt = t_create_prompt(query_premise, query_hypothesis) 
    return context_prompt + query_prompt

In [239]:
teacher_prompt = create_extended_prompt(context_examples, query_example)
student_prompt = s_create_prompt(query_example['premise'], query_example['hypothesis'])

teacher_inputs = tokenizer(teacher_prompt, return_tensors="pt")
student_inputs = tokenizer(student_prompt, return_tensors="pt")

teacher_prompt

"Premise: This site includes a list of all award winners and a searchable database of Government Executive articles., Hypothesis: The Government Executive articles housed on the website are not able to be searched., Label: contradiction. Premise: uh i don't know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him, Hypothesis: I like him for the most part, but would still enjoy seeing someone beat him., Label: entailment. Premise: yeah i i think my favorite restaurant is always been the one closest  you know the closest as long as it's it meets the minimum criteria you know of good food, Hypothesis: My favorite restaurants are always at least a hundred miles away from my house. "

In [256]:
# teacher_outputs = teacher_model.generate(**teacher_inputs, max_length=teacher_inputs['input_ids'].shape[-1] + 2)
# t_output_text = tokenizer.decode(teacher_outputs[0], skip_special_tokens=True)
# teacher_predicted_label = t_output_text.split("Label:")[-1].strip().split('.')[0].strip()
# teacher_predicted_label

teacher_outputs = student_model(**student_inputs, labels=student_inputs['input_ids'])
t_logits = teacher_outputs.logits
t_probs = torch.nn.functional.softmax(t_logits, dim=-1)

In [258]:
student_outputs = student_model(**student_inputs, labels=student_inputs['input_ids'])
student_logits = student_outputs.logits
student_probs = torch.nn.functional.softmax(student_logits, dim=-1)

# student_outputs = student_model.generate(**student_inputs, max_length=student_inputs['input_ids'].shape[-1] + 10)
# s_output_text = tokenizer.decode(student_outputs[0], skip_special_tokens=True)
# student_predicted_label = s_output_text.split("Label:")[-1].strip().split('.')[0].strip()
# print(student_predicted_label)

In [259]:
student_model.train()

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 512, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 1024)
      (project_out): Linear(in_features=1024, out_features=512, bias=False)
      (project_in): Linear(in_features=512, out_features=1024, bias=False)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=409

In [254]:
# label_index = {'entailment': 0, 'entailment or Contradiction': 1, 'contradiction': 2}
# teacher_label_idx = torch.tensor([label_index[teacher_predicted_label]], dtype=torch.long)
# student_label_idx = torch.tensor([label_index[student_predicted_label]], dtype=torch.long)

# teacher_probs = torch.nn.functional.one_hot(teacher_label_idx, num_classes=3).float()
# student_probs = torch.nn.functional.one_hot(student_label_idx, num_classes=3).float()
# student_probs = torch.nn.functional.softmax(student_probs, dim=-1)
# teacher_probs = torch.nn.functional.softmax(teacher_probs, dim=-1)

# print(student_probs)

# optimizer = torch.optim.AdamW(student_model.parameters(), lr=0.001)

# total_loss = 0
# criterion = torch.nn.KLDivLoss(reduction='batchmean')
# loss = criterion(student_probs.log(), teacher_probs)
# loss

tensor([[0.2119, 0.5761, 0.2119]])


tensor(0.3642)

In [260]:
criterion = torch.nn.KLDivLoss(reduction='batchmean')
loss = criterion(student_probs.log(), t_probs)
loss    

tensor(10.9139, grad_fn=<DivBackward0>)

In [261]:
total_loss = 0
student_model.zero_grad()
loss.backward()
optimizer = torch.optim.AdamW(student_model.parameters(), lr=0.001)
optimizer.step()
total_loss += loss.item()

print(f"1, Average Loss: {total_loss/len(query_example)}")

1, Average Loss: 2.728483200073242
