In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from dcd import dcd_pipeline_registry, set_stop_words
from dcd import create_prompt, create_prompt_student
import torch

In [2]:
dcd_pipeline_registry()

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
beam_size = 1
max_length = 250

alpha_coef = 0.1
beta_coef = 0.8
dropout_rate = 0.2

type_prompt = 4  # The synthetic demonstration prompt for arithmetic problems

stopping_criteria = set_stop_words(tokenizer=tokenizer, stop_words=["Q:", "\end{code}", "</s>", "Wrong explanation:"])

generation_config = GenerationConfig(
    do_sample=False,
    num_beams=beam_size,
    pad_token_id=0,
    eos_token_id=0,
)

class Args:
    def __init__(self) -> None:
        self.prompt_file = 'gsm8k'
        self.data_name = "gsm8k"
        self.cot_flag = True
        self.direct_answer_trigger_for_fewshot = 'The answer is'

args_prompt = Args()

Added stop word:  Q: with the ids [28824, 28747]
Added stop word:  \end{code} with the ids [28756, 416, 28751, 1409, 28752]
Added stop word:  </s> with the ids [2]
Added stop word:  Wrong explanation: with the ids [17055, 566, 13268, 28747]


### Greedy

In [6]:
question = "Toulouse has twice as many sheep as Charleston. Charleston has 4 times as many sheep as Seattle. How many sheep do Toulouse, Charleston, and Seattle have together if Seattle has 20 sheep?"
question_formated = "Q: " + question + "\n" + "A:"
inputs = tokenizer(create_prompt(args_prompt, data_name=args_prompt.data_name) + question_formated, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)

In [7]:
inputs_args_greedy = dict(
    generation_config=generation_config,
    return_dict_in_generate=True,
    output_scores=True,
    max_new_tokens=max_length,
    stopping_criteria=stopping_criteria,
    min_tokens_to_keep=2 if beam_size > 1 else 1,
    dropout_rate=dropout_rate
)

In [8]:
output_sequences = model.generate(
    input_ids=input_ids,
    **inputs_args_greedy)

s_greedy = output_sequences.sequences[0]
output_greedy = tokenizer.decode(s_greedy, skip_special_tokens=True)

output_formated_greedy = output_greedy.split("A: ")[-1].replace("\n\nQ:", "")
print(f"Output of greedy: {output_formated_greedy}")

Output of greedy: Toulouse has twice as many sheep as Charleston. So if Charleston has 4 times as many sheep as Seattle, then Toulouse has 4 times as many sheep as Seattle. So Toulouse, Charleston, and Seattle have 20 + 4 + 4 = 28 sheep. The answer is 28.


### Distillation Contrastive Decoding

In [9]:
question = "Toulouse has twice as many sheep as Charleston. Charleston has 4 times as many sheep as Seattle. How many sheep do Toulouse, Charleston, and Seattle have together if Seattle has 20 sheep?"
question_formated = "Q: " + question + "\n" + "A:"
inputs = tokenizer(create_prompt(args_prompt, data_name=args_prompt.data_name) + question_formated, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)

inputs_student = tokenizer(create_prompt_student(args_prompt, type=type_prompt, data_name=args_prompt.data_name) + question_formated, return_tensors="pt")
input_ids_student = inputs_student["input_ids"].to(device)

In [10]:
inputs_args_greedy = dict(
    generation_config=generation_config,
    return_dict_in_generate=True,
    output_scores=True,
    max_new_tokens=max_length,
    stopping_criteria=stopping_criteria,
    alpha_coef=alpha_coef,
    beta_coef=beta_coef,
    min_tokens_to_keep=2 if beam_size > 1 else 1,
    teacher_student=True,
    dropout_rate=dropout_rate
)

In [11]:
output_sequences = model.generate(
    input_ids=input_ids,
    input_ids_student=input_ids_student,
    **inputs_args_greedy)

s_dcd = output_sequences.sequences[0]
output_dcd = tokenizer.decode(s_dcd, skip_special_tokens=True)
output_formated_dcd = output_dcd.split("A: ")[-1].replace("\n\nQ:", "")
print(f"Output of DCD: {output_formated_dcd}")

Output of DCD: Toulouse has twice as many sheep as Charleston. If Charleston has 4 times as many as Seattle, then Toulouse has 2 * 4 = 8 times as many as Seattle. 8 times as many as Seattle is 8 x 20 = 160. The answer is 160.
