# Setup

In [None]:
%%capture
from google.colab import drive
import json

drive.mount('/gdrive')
%cd /gdrive/My\ Drive/projects/mentored_decoding/experiments

def save(filename, data):
    with open(filename, 'w', encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False)

!pip install -q -U git+https://github.com/huggingface/datasets.git
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U evaluate
!pip install sentencepiece

In [None]:
%%capture
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from datasets import load_dataset
import evaluate
import time
import json
import os
import types
import hashlib

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

target_model_id = "t5-large"
draft_model_id = "t5-small"

tokenizer = T5Tokenizer.from_pretrained(draft_model_id)

target_model = T5ForConditionalGeneration.from_pretrained(
    target_model_id,
    torch_dtype=torch.float16,
    device_map="auto"
).eval()

draft_model = T5ForConditionalGeneration.from_pretrained(
    draft_model_id,
    device_map="auto"
).eval()

NUM_SAMPLES = 1000
GEN_LEN = draft_model.config.task_specific_params["translation_en_to_fr"]["max_length"]
PREFIX = draft_model.config.task_specific_params["translation_en_to_fr"]["prefix"]

ds = load_dataset("wmt15", "fr-en", split="validation", streaming=True)
bleu = evaluate.load("bleu")

# Sampling methods



In [None]:
def greedy_decoding(q):
    """
    Sample a token with greedy decoding
    """
    return torch.argmax(q, -1)

def multinomial_sampling(q):
    """
    Sample a token with multinomial sampling
    """
    return torch.multinomial(q, 1)

def speculative_sampling(x, p, q):
    """
    Sample a token with speculative sampling
    """
    random_number = torch.rand(1)[0]
    if random_number*p[x] < q[x]:
        return x
    else:
        s = torch.clip(q-p, min=0)
        s /= s.sum()
        return torch.multinomial(s, 1)

def mentored_sampling(x, p, q, target_dkl, tol=0.2):
    """
    Sample a token with mentored sampling
    """
    random_number = torch.rand(1)[0]
    # Whatever the target KL divergence, we know that P(x accepted) ≥ min(1, q(x)/p(x))
    if random_number*p[x] < q[x]:
        return x

    q_over_p = q/p
    sorted_q_over_p, argsort_q_over_p = torch.sort(q_over_p)
    sorted_q = q[argsort_q_over_p]
    cumsum_q_log_q_over_p = (sorted_q*torch.log(sorted_q_over_p)).cumsum(-1)

    # If the KL divergence between q and p is less than the target KL
    # divergence, we can direct accept the draft token (we're ready to accept a
    # slightly higher KL divergence to reduce the computational overhead)
    dkl = cumsum_q_log_q_over_p[-1]
    higher_bound = target_dkl*(1+tol)
    if dkl <= higher_bound:
        return x

    sorted_p = p[argsort_q_over_p]
    cumsum_p = sorted_p.cumsum(-1)
    cumsum_q = sorted_q.cumsum(-1)
    cumsum_last_p = torch.flip(torch.flip(sorted_p, [0]).cumsum(-1), [0])
    cumsum_last_q = torch.flip(torch.flip(sorted_q, [0]).cumsum(-1), [0])
    right = cumsum_last_q / sorted_q_over_p - cumsum_last_p

    # We're ready to accept a slightly higher KL divergence to reduce the
    # computational overhead
    lower_bound = target_dkl*(1-tol)

    min_alpha, max_alpha = 0, 1
    best_alpha, best_beta, best_one_minus_R, best_dkl = 1, 1, 1 - (torch.clip(q/p, max=1)*p).sum(), 0
    for _ in range(10):
        alpha = (min_alpha + max_alpha)/2
        try:
            n1 = torch.nonzero(sorted_q_over_p <= alpha)[-1]
        except IndexError:
            n1 = 0
        one_minus_R = cumsum_p[n1] - cumsum_q[n1] / alpha
        try:
            n2 = torch.nonzero(right <= one_minus_R)[0]
        except IndexError:
            n2 = -1
        beta = cumsum_last_q[n2-1]/(one_minus_R + cumsum_last_p[n2-1])
        if one_minus_R < 0:
            dkl = cumsum_q_log_q_over_p[-1]
        else:
            dkl = (
                torch.math.log(alpha)*cumsum_q[n1]
                + cumsum_q_log_q_over_p[n2-1]
                - cumsum_q_log_q_over_p[n1]
                + torch.math.log(beta)*cumsum_last_q[n2]
            )

        if dkl < lower_bound:
            max_alpha = alpha
            best_alpha, best_beta, best_one_minus_R, best_dkl = alpha, beta, one_minus_R, dkl
        elif dkl > higher_bound:
            min_alpha = alpha
        else:
            break
    else:
        alpha, beta, one_minus_R, dkl = best_alpha, best_beta, best_one_minus_R, best_dkl

    # With α, we can now definitely accept or reject the token sampled from p
    if alpha*random_number < q_over_p[x]:
        return x

    # If the candidate token is rejected, we sample from s
    s = torch.clip(sorted_q/beta-sorted_p, min=0)/one_minus_R
    return argsort_q_over_p[torch.multinomial(s, 1)]

# Implementation of multinomial sampling, speculative decoding and mentored decoding

In [None]:
def generate(
    llm,
    encoder_tokens,
    decoder_tokens,
    num_new_tokens,
    temperature=0.3,
    return_proba=False,
    past_key_values=None,
    encoder_last_hidden_state = None
):
    """
    Generate a sequence of tokens with multinomial sampling
    """
    num_tokens = decoder_tokens.shape[-1]
    attention_mask = torch.ones((1, num_tokens)).to(device)
    if past_key_values is not None:
        num_tokens += past_key_values[0][0].shape[2]
    initial_num_tokens = num_tokens
    input_ids = encoder_tokens if past_key_values is None else None
    decoder_input_ids = decoder_tokens[:, -1:]
    proba = None
    outputs = types.SimpleNamespace()
    outputs.past_key_values = past_key_values
    outputs.encoder_last_hidden_state = encoder_last_hidden_state

    for _ in range(num_new_tokens):
        with torch.no_grad():
            if outputs.past_key_values is None:
                outputs = llm(
                    input_ids=input_ids,
                    decoder_input_ids=decoder_input_ids,
                    use_cache=True
                )
            else:
                outputs = llm(
                    decoder_input_ids=decoder_input_ids,
                    past_key_values=outputs.past_key_values,
                    encoder_outputs=(outputs.encoder_last_hidden_state,),
                    use_cache=True
                )

        q = torch.softmax(outputs.logits[0, -1:, :].float()/temperature, -1)
        if return_proba:
            if proba is None:
                proba = q
            else:
                proba = torch.concat((proba, q), axis=0)
        decoder_input_ids = multinomial_sampling(q[0]).reshape((1, 1))
        decoder_tokens = torch.concat((decoder_tokens, decoder_input_ids), axis=-1)
        num_tokens += 1
        if decoder_input_ids[0, 0] == tokenizer.eos_token_id:
            break

    return (
        decoder_tokens,
        num_tokens - initial_num_tokens,
        outputs.past_key_values,
        outputs.encoder_last_hidden_state,
        proba if return_proba else q
    )

In [None]:
def truncate(past_key_values, n):
    """
    Truncate past key values to keep only the `n` first ones
    """
    return tuple(
        (
            x[0][:, :, :n, :],
            x[1][:, :, :n, :],
            x[2],
            x[3]
        )
        for x in past_key_values
    )

In [None]:
def verify(target_proba, draft_proba, draft_tokens, target_dkl):
    """
    Verify a sequence tokens given the target and draft probabilities.
    """
    selected_tokens = []
    #assert draft_proba.shape[0] == target_proba.shape[0] - 1

    for i in range(draft_proba.shape[0]):
        p, q = draft_proba[i, :], target_proba[i, :]
        candidate_token = draft_tokens[0, -draft_proba.shape[0]+i]
        if target_dkl == 0:
            selected_token = speculative_sampling(candidate_token, p, q).reshape((1, 1))
        else:
            selected_token = mentored_sampling(candidate_token, p, q, target_dkl).reshape((1, 1))

        if i == 0:
            selected_tokens = selected_token
        else:
            selected_tokens = torch.concat((selected_tokens, selected_token), axis=-1)
        if selected_token[0, 0] == tokenizer.eos_token_id:
            break
        if selected_token != draft_tokens[0, -draft_proba.shape[0]+i]:
            break
    else:
        q = target_proba[-1, :]
        selected_token = multinomial_sampling(q).reshape((1, 1))
        selected_tokens = torch.concat((selected_tokens, selected_token), axis=-1)
    return selected_tokens

In [None]:
def draft_and_verify(
    target_llm,
    draft_llm,
    encoder_tokens,
    decoder_tokens,
    num_new_tokens,
    target_dkl=0,
    draft_length=5,
    temperature=0.3
):
    """
    Generate a sequence of tokens with mentored decoding
    """
    num_tokens = decoder_tokens.shape[-1]
    initial_num_tokens = num_tokens
    max_num_tokens = num_tokens + num_new_tokens
    num_selected = []
    target_past_key_values = None
    draft_past_key_values = None
    draft_encoder_last_hidden_state = None

    while num_tokens < max_num_tokens:

        # After the first iteration, we need to provide the past key values that
        # are still valid
        if draft_past_key_values is not None:

            # For the target model, we keep all past key values corresponding to
            # the generated tokens except the last one
            target_past_key_values = truncate(target_outputs.past_key_values, num_tokens-1)

            # If all draft tokens except maybe the last one were verified, we
            # keep all the past key values of the draft model. If not we keep
            # those corresponding to all generated tokens except the last one
            if selected_tokens.shape[1] < draft_length:
                draft_past_key_values = truncate(draft_past_key_values, num_tokens-1)

            # δ equals 1 or 2. That's the number of generated tokens for which
            # the past key values haven't been computed by the draft model
            delta = num_tokens - draft_past_key_values[0][0].shape[2]
            decoder_input_ids = decoder_tokens[:, -delta:]
            prefix = decoder_tokens[:, :-delta]
        else:
            decoder_input_ids = decoder_tokens
            prefix = None

        # We generate draft tokens
        draft_tokens, num_new_tokens, draft_past_key_values, draft_encoder_last_hidden_state, draft_proba = generate(
            draft_llm,
            encoder_tokens,
            decoder_input_ids,
            draft_length,
            return_proba=True,
            past_key_values=draft_past_key_values,
            encoder_last_hidden_state=draft_encoder_last_hidden_state
        )
        if prefix is not None:
            draft_tokens = torch.concat((prefix, draft_tokens), axis=-1)

        # We compute the next token probabilities for draft tokens
        with torch.no_grad():
            if target_past_key_values is None:
                target_outputs = target_llm(
                    input_ids=encoder_tokens,
                    decoder_input_ids=draft_tokens,
                    use_cache=True
                )
            else:
                target_outputs = target_llm(
                    decoder_input_ids=draft_tokens[:, (num_tokens-1):],
                    encoder_outputs=(target_outputs.encoder_last_hidden_state,),
                    past_key_values=target_past_key_values,
                    use_cache=True
                )

        target_proba = torch.softmax(
            target_outputs.logits[0, -num_new_tokens-1:, :].float()/temperature,
            -1
        )

        # Probabilities slightly increased for numerical stability purposes
        target_proba += 1e-12
        target_proba /= target_proba.sum(axis=-1, keepdim=True)
        draft_proba += 1e-12
        draft_proba /= draft_proba.sum(axis=-1, keepdim=True)

        # We use the rejection sampling scheme to verify tokens
        selected_tokens = verify(target_proba, draft_proba, draft_tokens, target_dkl)

        decoder_tokens = torch.concat((decoder_tokens, selected_tokens), axis=-1)
        num_tokens += selected_tokens.shape[1]
        num_selected.append(selected_tokens.shape[1])
        if decoder_tokens[0, -1] == tokenizer.eos_token_id:
            break

    return decoder_tokens[:, :max_num_tokens], min(max_num_tokens, num_tokens) - initial_num_tokens, num_selected

# Experiments

In [None]:
configs = [(0, 0, model, temperature) for model in ["draft", "target"] for temperature in [0.3]]
configs += [(dkl, draft_length, "speculative", temperature) for draft_length in [3, 5, 7, 9, 11, 13, 15] for dkl in [0, 0.1, 0.15, 0.2] for temperature in [0.3]]

In [None]:
def log_result(dkl, draft_length, model, temperature, example):
    """
    Perform an experiment and record its results
    """
    idx = hashlib.md5(example["translation"]["en"].encode()).hexdigest()
    filename = f"{model}_t{temperature}_dkl{dkl}_dl{draft_length}_{idx}.json"

    encoder_tokens = tokenizer(
        f"{PREFIX}{example['translation']['en']}",
        return_tensors="pt"
    ).input_ids.to(device)
    decoder_tokens = torch.zeros((1, 1), dtype=int).to(device)

    start = time.time()
    if draft_length == 0:
        tokens, num_new_tokens, _, _, _ = generate(
            draft_model if model == "draft" else target_model,
            encoder_tokens,
            decoder_tokens,
            GEN_LEN,
            return_proba=False,
            temperature=temperature
        )
    else:
        tokens, num_new_tokens, _ = draft_and_verify(
            target_model,
            draft_model,
            encoder_tokens,
            decoder_tokens,
            GEN_LEN,
            target_dkl=dkl,
            draft_length=draft_length,
            temperature=temperature
        )
    delay = time.time() - start
    prediction = tokenizer.decode(tokens[0, -num_new_tokens:], skip_special_tokens=True).strip()

    try:
        results = bleu.compute(
            predictions=[prediction],
            references=[example["translation"]["fr"]]
        )
    except ZeroDivisionError:
        results = {
            'bleu': 0.0,
            'precisions': [0.0, 0.0, 0.0, 0.0],
            'brevity_penalty': 1.0,
            'length_ratio': 1.0,
            'translation_length': 1,
            'reference_length': 1
        }

    log = {
        "output": prediction,
        "num_tokens": num_new_tokens,
        "delay": delay,
        **results
    }

    save(filename, log)

In [None]:
for example in iter(ds.take(NUM_SAMPLES)):
    for config in configs:
        log_result(*config, example)

# Visualization of results

In [None]:
import os
import pandas as pd

data = {
    "model": [],
    "temperature": [],
    "dkl": [],
    "draft_length": [],
    "example_id": [],
    "bleu": [],
    "num_tokens": [],
    "delay": []
}

for f in os.listdir():
    if ".json" in f:
        model, temperature, dkl, draft_length, i = f.split(".json")[0].split("_")
        data["model"].append(model)
        data["temperature"].append(float(temperature[1:]))
        data["dkl"].append(float(dkl[3:]))
        data["draft_length"].append(int(draft_length[2:]))
        data["example_id"].append(i)
        with open(f) as file:
            results = json.loads(file.read())
        for k in ["bleu", "num_tokens", "delay"]:
            data[k].append(results[k])

df = pd.DataFrame.from_dict(data)
grouped = df.groupby(["model", "temperature", "dkl", "draft_length"]).sum()
grouped["token_per_second"] = grouped["num_tokens"]/grouped["delay"]
grouped2 = df.groupby(["model", "temperature", "dkl", "draft_length"]).mean()
grouped3 = df.groupby(["model", "temperature", "dkl", "draft_length"]).count()
grouped["bleu"] = grouped2["bleu"]
grouped["count"] = grouped3["delay"]
grouped