# Distillation Contrastive Decoding: Improving LLMs Reasoning with Contrastive Decoding and Distillation

TL;DR: Distillation Contrastive Decoding work introduces an innovative decoding approach that enhances language model reasoning by leveraging the concept of contrastive decoding (Innovates upon the traditional contrastive decoding and Chain-of-Thought (CoT) prompting methods). It operates by contrasting the logits from a expert model (Answer right) with Amateur model (Answer wrong). Notably, both models are the same.

<p align="center"><img src="https://github.com/pphuc25/distil-cd/blob/main/assets/figure1-method.jpg?raw=true" width="800"></p>

**Resources**:
<!-- - Read our paper on [arXiv](#) for a deep dive into our methodology. -->
- Explore the codebase and contribute on GitHub: [distil-cd](https://github.com/pphuc25/distil-cd/tree/main).
<!-- - Join the conversation on Twitter: [Twitter Discussion](#). -->

> **Access Requirement**: The following demonstration employs the [Gemma 2b model by Google](https://huggingface.co/google/gemma-2b), which requires authorized access. Please ensure you have the necessary permissions on Hugging Face to interact with the model before proceeding.

## Login HuggingFace

In [1]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Token: 
Add token as git credential? (Y/n) Y
Token is valid (permission: write).
[1m[31mCannot authenticate through 

## Environment Setup

1. Git clone our repo
2. Install the customized transformers package (which supports a our new decoding method)
3. Install other requirements from pip (upgrade transformers upto date to apply newest mdoel

In [2]:
!git clone https://github.com/pphuc25/distil-cd.git
!cd distil-cd && pip install -e .
!pip install transformers -U

Cloning into 'distil-cd'...
remote: Enumerating objects: 610, done.[K
remote: Counting objects: 100% (610/610), done.[K
remote: Compressing objects: 100% (131/131), done.[K
remote: Total 610 (delta 489), reused 592 (delta 472), pack-reused 0[K
Receiving objects: 100% (610/610), 1.39 MiB | 21.27 MiB/s, done.
Resolving deltas: 100% (489/489), done.
Obtaining file:///content/distil-cd
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: dcd
  Building editable for dcd (pyproject.toml) ... [?25l[?25hdone
  Created wheel for dcd: filename=dcd-0.0.1-0.editable-py3-none-any.whl size=1195 sha256=5d2b3c2cfac4a2a491ea828737d352705c93ef9e20464232302e56eb50e41d51
  Stored in directory: /tmp/pip-ephe

### Register Decoding Method:

**Troubleshooting Note**: If you encounter any issues executing this setup cell, please restart the runtime/kernel. This can resolve initial setup conflicts. To do so in Google Colab, go to the menu bar and select `Runtime` > `Restart session`, and then run this cell again.

In [1]:
# Registry DCD method
from dcd import dcd_pipeline_registry
dcd_pipeline_registry()

## Getting started

In this Google Colab example, we will utilize the [Gemma 2b model by Google](https://huggingface.co/google/gemma-2b) to showcase our methodologies. This interactive example will compare the performance differences between a standard greedy decoding approach and our Distillation Contrastive Decoding (DCD) method when applied Dropout on Amateur. Through this, we aim to provide a clear, comparative illustration of the efficacy of DCD in enhancing model reasoning.

### Import Libraries and Load the Model

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from dcd import set_stop_words, create_prompt, create_prompt_student
import torch

In [3]:
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", device_map="auto")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Setup configs and question

For using DCD, you must set the beam_size to 1, as it is a variant of the greedy method.

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

beam_size = 1
max_length = 250

alpha_coef = 0.1
beta_coef = 0.7
dropout_rate = 0.2

type_prompt = 4  # The synthetic demonstration prompt for arithmetic problems

stopping_criteria = set_stop_words(tokenizer=tokenizer, stop_words=["Q:", "\end{code}", "</s>", "Wrong explanation:"])

generation_config = GenerationConfig(
    do_sample=False,
    num_beams=beam_size,
    pad_token_id=0,
    eos_token_id=0,
)

class Args:
    def __init__(self) -> None:
        self.prompt_file = 'gsm8k'
        self.data_name = "gsm8k"
        self.cot_flag = True
        self.direct_answer_trigger_for_fewshot = 'The answer is'

args_prompt = Args()

Added stop word:  Q: with the ids [235292]
Added stop word:  \end{code} with the ids [615, 235282, 2564, 235270]
Added stop word:  </s> with the ids [235256, 235313]
Added stop word:  Wrong explanation: with the ids [15844, 235292]


In [5]:
question = "Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?"

### Greedy Version

In [6]:
question_formated = "Q: " + question + "\n" + "A:"
inputs = tokenizer(create_prompt(args_prompt, data_name=args_prompt.data_name) + question_formated, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)

In [7]:
inputs_args_greedy = dict(
    generation_config=generation_config,
    return_dict_in_generate=True,
    output_scores=True,
    max_new_tokens=max_length,
    stopping_criteria=stopping_criteria,
    min_tokens_to_keep=2 if beam_size > 1 else 1,
    dropout_rate=dropout_rate
)

In [8]:
output_sequences = model.generate(
    input_ids=input_ids,
    **inputs_args_greedy)

s_greedy = output_sequences.sequences[0]
output_greedy = tokenizer.decode(s_greedy, skip_special_tokens=True)

output_formated_greedy = output_greedy.split("A: ")[-1].replace("\n\nQ:", "")
print(f"Output of greedy: {output_formated_greedy}")

Output of greedy: Josh bought the house for $80,000. Then he put in $50,000 in repairs. So the value of the house increased by 150%. 150% of $80,000 is 150/100 * 80,000 = 120,000. So the value of the house increased by 120,000. 80,000 + 120,000 = 200,000. The answer is 200,000.


### Distillation Contrastive Decoding with Dropout Version

In [9]:
question_formated = "Q: " + question + "\n" + "A:"
inputs = tokenizer(create_prompt(args_prompt, data_name=args_prompt.data_name) + question_formated, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)

# Input ids and attention mask for amateur model
inputs_student = tokenizer(create_prompt_student(args_prompt, type=type_prompt, data_name=args_prompt.data_name) + question_formated, return_tensors="pt")
input_ids_student = inputs_student["input_ids"].to(device)
attention_mask_student = inputs_student["attention_mask"].to(device)

In [10]:
inputs_args_dcd = dict(
    generation_config=generation_config,
    return_dict_in_generate=True,
    output_scores=True,
    max_new_tokens=max_length,
    stopping_criteria=stopping_criteria,

    # DCD parameters of dropout
    alpha_coef=alpha_coef,
    beta_coef=beta_coef,
    min_tokens_to_keep=2 if beam_size > 1 else 1,
    teacher_student=True,
    dropout_rate=dropout_rate,

    # Setting attention mask for amateur model
    model_kwargs_student = dict(
        attention_mask=attention_mask_student
    )
)

In [11]:
output_sequences = model.generate(
    input_ids=input_ids,
    input_ids_student=input_ids_student,
    **inputs_args_dcd)

s_dcd = output_sequences.sequences[0]
output_dcd = tokenizer.decode(s_dcd, skip_special_tokens=True)
output_formated_dcd = output_dcd.split("A: ")[-1].replace("\n\nQ:", "")
print(f"Output of DCD: {output_formated_dcd}")

Output of DCD: Josh started with 80,000 dollars. He spent 50,000 dollars on repairs. So he had 80,000 - 50,000 = 30,000 dollars left. Then the house was sold for 30,000 dollars more. 30,000 + 30,000 is 60,000. The answer is 60,000.
