In [1]:
import os
if not os.path.exists("/scratch/sanika"):
    os.makedirs("/scratch/sanika")

os.environ["HF_HOME"] = "/scratch/sanika/"

In [2]:
import torch
import torch.nn.functional as F
from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoModelForSeq2SeqLM
from typing import List, Tuple, Optional
import time
import numpy as np
import datasets

In [3]:
# # WMT16 EN-DE dataset
# def preprocess_function_ende(examples, tokenizer, args, train_args, prefix="translate English to German: "):
#     if train_args.debug:
#         all_text = examples['translation'][:2]
#     else:
#         all_text = examples['translation']
        
#     inputs = []
#     targets = []
#     for excerpt in all_text:
#         en_text = prefix + excerpt['en']
#         de_text = excerpt['de']

#         inputs.append(en_text)
#         targets.append(de_text)
            
#     padding = 'max_length'
#     model_inputs = tokenizer(
#         inputs,
#         max_length=args.source_max_length,
#         padding=padding,
#         truncation=True,
#         return_tensors="pt",
#     )
#     # Tokenize targets with the `text_target` keyword argument
#     labels = tokenizer(
#         text_target=targets,
#         max_length=args.train_target_max_length,
#         padding=padding,
#         truncation=True,
    #     return_tensors="pt",
    # )

    # if padding == "max_length":
    #             labels["input_ids"] = [
    #                 [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
    #             ]

    # model_inputs["labels"] = labels["input_ids"]
    # model_inputs["decoder_attention_mask"] = labels["attention_mask"]
    # return model_inputs

en_gr_dataset = datasets.load_dataset('wmt16', 'de-en', split='validation')

# Base Speculative

In [4]:
class SpeculativeDecoder:
    def __init__(
        self,
        target_model_name = "google-t5/t5-3b",
        draft_model_name = "google-t5/t5-small",
        gamma = 4,
        temperature = 1.0
    ):
        self.gamma = gamma

        self.temperature = temperature

        self.tokenizer = T5Tokenizer.from_pretrained(target_model_name)
        self.target_model = AutoModelForSeq2SeqLM.from_pretrained(target_model_name, device_map='auto')

        self.draft_model = T5ForConditionalGeneration.from_pretrained(draft_model_name, device_map='auto')

        self.target_model.eval()
        self.draft_model.eval()

    def get_draft_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        gamma: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get draft logits for gamma tokens"""
        draft_tokens = []
        draft_probs = []
        current_decoder_ids = decoder_input_ids

        # Generate gamma tokens from the draft model
        for _ in range(gamma):
            with torch.no_grad():
                outputs = self.draft_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=current_decoder_ids,
                    return_dict=True,
                )
                logits = outputs.logits[:, -1, :]  # Get logits for last position
                probs = logits
                # probs = F.softmax(logits, dim=-1)

                token_id = torch.argmax(probs, dim=-1)
                prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)

                draft_tokens.append(token_id.item())
                draft_probs.append(prob.item())

                # Update decoder inputs for next iteration
                current_decoder_ids = torch.cat(
                    [current_decoder_ids, token_id.view(1, 1)],
                    dim=1
                )

                if token_id.item() == self.tokenizer.eos_token_id:
                    break

        return draft_tokens, draft_probs

    def get_target_probs(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        draft_tokens: torch.Tensor,
    ) -> torch.Tensor:
        """Get target probabilities for the draft tokens in parallel."""
        with torch.no_grad():
            # Add draft tokens to decoder input
            full_decoder_ids = torch.cat([decoder_input_ids, draft_tokens.unsqueeze(0)], dim=1)
            # print(input_ids_batched.shape, attention_mask_batched.shape, padded_decoder_ids.shape)
            # decoder_mask = torch.triu(
            #     torch.ones((full_decoder_ids.shape[1], full_decoder_ids.shape[1] + 1))
            # )
            # decoder_mask = decoder_mask[-(len(draft_tokens) + 1):, :-1]
            # decoder_mask = 1 - decoder_mask
            
            # the shapes that we want to see are:
            # torch.Size([11, 12]) torch.Size([1, 12])
            # torch.Size([11, 12]) torch.Size([11, 12])
            # torch.Size([11, 32128])
            # torch.Size([11, 32128]) torch.Size([11, 32128])

            # What im getting
            # torch.Size([1, 12]) torch.Size([1, 12])
            # torch.Size([1, 12]) torch.Size([1, 12])
            # torch.Size([1, 11, 32128])


            decoder_mask = torch.ones(full_decoder_ids.shape[1])
            decoder_mask = decoder_mask.unsqueeze(0)


            # print(decoder_mask.shape, full_decoder_ids.shape)

            # conver to a batched input
            # input_ids = input_ids.repeat(len(draft_tokens) + 1, 1)
            # attention_mask = attention_mask.repeat(len(draft_tokens) + 1, 1)
            # full_decoder_ids = full_decoder_ids.repeat(len(draft_tokens) + 1, 1)

            # print(decoder_mask.shape, full_decoder_ids.shape)


            outputs = self.target_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=full_decoder_ids,
                decoder_attention_mask=decoder_mask,
                return_dict=True
            )

            
            # dim_0_indices = torch.arange(len(draft_tokens) + 1)
            # dim_1_indices = torch.arange(len(draft_tokens) + 1) + full_decoder_ids.shape[1] - 1 - len(draft_tokens)
            # logits = outputs.logits[dim_0_indices, dim_1_indices, :]
            logits = outputs.logits[:, -(len(draft_tokens) + 1):, :]
            logits = logits.squeeze(0)

            # print(logits.shape)
            
            # Get probabilities for positions before each draft token
            # logits = outputs.logits[:, -(len(draft_tokens) + 1):-1, :]
            target_probs = F.softmax(logits, dim=-1)

            # print(target_probs.shape, target_probs.squeeze(0).shape)
            

            return target_probs.squeeze(0), logits

    def verify_tokens(
        self,
        target_probs: torch.Tensor,
        draft_tokens: torch.Tensor,
        draft_probs: torch.Tensor,
    ) -> int:
        """Determine number of accepted tokens"""
        # Get target probabilities for the draft tokens
        # get the probabilities of the tokens at the indices of the draft tokens
        target_probs_draft_tokens = torch.gather(target_probs, 1, draft_tokens.unsqueeze(0))
        # Calculate acceptance ratios
        acceptance_ratios = target_probs_draft_tokens / draft_probs.clamp(min=1e-10)

        # Sample uniform random numbers 
        random_nums = torch.rand_like(acceptance_ratios)
        acceptance_mask = random_nums <= acceptance_ratios

        num_accepted = (acceptance_mask.cumsum(dim=-1) == torch.arange(1, len(acceptance_ratios) + 1)).sum().item()

        return num_accepted

    def translate(
        self,
        source_text: str,
        max_length: int = 128
    ) -> str:
        """Translate source text using speculative decoding."""
        # Encode source text
        encoder_inputs = self.tokenizer(
            f"translate English to German: {source_text}",
            return_tensors="pt",
            padding=True
        )

        # Initialize with start token
        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]])

        output = self.target_model(
            input_ids=encoder_inputs.input_ids,
            attention_mask=encoder_inputs.attention_mask,
            decoder_input_ids=decoder_input_ids,
            return_dict=True
        )

        probs = output.logits[:, -1, :]
                    
        probs = F.softmax(probs, dim=-1)
        token_id = torch.multinomial(probs, num_samples=1)

        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id, token_id.item()]])

        total_tokens = 0
        accepted_tokens = 0

        while decoder_input_ids.shape[1] < max_length:
            # Get draft tokens autoregressively
            draft_tokens, draft_probs = self.get_draft_logits(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids,
                self.gamma
            )

            draft_tokens = torch.tensor(draft_tokens)
            draft_probs = torch.tensor(draft_probs)

            # softmax the draft probs
            draft_probs = F.softmax(draft_probs, dim=-1)

            if len(draft_tokens) == 0:
                raise ValueError("Draft tokens not generated.")

            # Get target probabilities in parallel
            # start = time.time()
            target_probs, target_logits = self.get_target_probs(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids,
                draft_tokens
            )
            # target probs are the logits but with softmax applied

            # Verify tokens
            n_accepted = self.verify_tokens(target_probs, draft_tokens, draft_probs)
            # Accept verified tokens
            if n_accepted > 0:
                decoder_input_ids = torch.cat([
                    decoder_input_ids,
                    draft_tokens[:n_accepted].unsqueeze(0)
                ], dim=1)
                
            with torch.no_grad():
                # n_rejected = self.gamma - n_accepted
                n_rejected = len(draft_tokens) - n_accepted 
                total_tokens += len(draft_tokens)
                accepted_tokens += n_accepted

                if n_rejected > 0:
                    probs = target_logits[-n_rejected, :] #- draft_logits[:, 1-n_rejected, :]
                else:
                    probs = target_logits[-1, :]
                    
                probs = F.softmax(probs, dim=-1)
                token_id = torch.multinomial(probs, num_samples=1).unsqueeze(0)

                decoder_input_ids = torch.cat([decoder_input_ids, token_id], dim=1)

            # Check for end of sequence
            if decoder_input_ids[0][-1].item() == self.tokenizer.eos_token_id:
                break

            # or if a full stop is generated
            if decoder_input_ids[0][-1].item() == self.tokenizer.convert_tokens_to_ids('.'):
                break

        # Decode translation
        translation = self.tokenizer.decode(
            decoder_input_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        perc_accepted = accepted_tokens / total_tokens * 100
        return translation, perc_accepted

In [11]:
# just translate using the big model
big_model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-3b", device_map='auto')

tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-3b")

source_text = "translate English to German: I like to eat apples and bananas."
# source_text = "translate English to German: India is also reportedly hoping for a deal on defence collaboration between the two nations."
start = time.time()

input_ids = tokenizer(source_text, return_tensors="pt").input_ids.to(big_model.device)
decoder_input_ids = torch.tensor([[tokenizer.pad_token_id]], device=big_model.device)  # Start with PAD token
eos_token_id = tokenizer.eos_token_id

output_tokens = []
while True:
    # Predict the next token
    outputs = big_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
    next_token_logits = outputs.logits[:, -1, :]  # Get logits for the last token
    next_token_id = torch.argmax(next_token_logits, dim=-1)  # Greedy decoding

    # Append the generated token to the decoder input
    decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.unsqueeze(0)], dim=-1)
    output_tokens.append(next_token_id.item())

    # Break the loop if the EOS token is generated
    if next_token_id.item() == eos_token_id:
        break

big_translation = tokenizer.decode(output_tokens, skip_special_tokens=True)

# Print the results
print("Big model translation: ", big_translation)
print("Time taken: ", time.time() - start)


Big model translation:  Ich mag pfel und Bananen zu essen.
Time taken:  1.7336390018463135


In [5]:
speculative_decoder = SpeculativeDecoder(gamma=10)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [6]:
source_text = "I like to eat apples and bananas."
# source_text = "India is also reportedly hoping for a deal on defence collaboration between the two nations."
speculative_decoder.translate(source_text)
start = time.time()
speculative_translation, pc = speculative_decoder.translate(source_text)
end = time.time()
print(f"Speculative translation: {speculative_translation}")
print(f"Percentage tokens accepted: {pc:.2f}%")
print(f"Time taken: {end - start:.4f} seconds")

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Speculative translation: Ich essen gernepfel und Banan.
Percentage tokens accepted: 41.18%
Time taken: 0.6877 seconds


# Normal

In [7]:
class NormalDecoder:
    def __init__(
        self,
        model_name: str = "google-t5/t5-3b",
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        self.device = device

        # Initialize tokenizer and model
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        # self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map='auto')
        self.model.eval()

    def get_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor
    ) -> torch.Tensor:
        """Get logits from model for the last token."""
        with torch.no_grad():
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                return_dict=True
            )
            return outputs.logits[:, -1, :]

    def sample_token(self, logits: torch.Tensor, temperature: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        """Sample a token from logits using temperature sampling."""
        if temperature == 0:
            # Greedy sampling
            token_id = torch.argmax(logits, dim=-1)
            prob = torch.ones_like(token_id, dtype=torch.float)
        else:
            # Temperature sampling
            probs = F.softmax(logits / temperature, dim=-1)
            token_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
            prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)
        return token_id, prob

    def translate(
        self,
        source_text: str,
        max_length: int = 128,
        temperature: float = 0.7
    ) -> str:
        """Translate source text using the normal T5 model."""
        # Encode source text
        encoder_inputs = self.tokenizer(
            f"translate English to German: {source_text}",
            return_tensors="pt",
            padding=True
        ).to(self.device)

        # Initialize decoder input with start token
        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], device=self.device)

        while decoder_input_ids.shape[1] < max_length:
            # Generate logits for the next token
            logits = self.get_logits(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids
            )

            # Sample a token
            token_id, _ = self.sample_token(logits, temperature)

            # Add token to the decoder input
            decoder_input_ids = torch.cat(
                [decoder_input_ids, token_id.view(1, 1)],
                dim=1
            )

            # Break if end token is generated
            if token_id.item() == self.tokenizer.eos_token_id:
                break

        # Decode and return translation
        translation = self.tokenizer.decode(
            decoder_input_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        return translation


In [7]:
# speculative_decoder = SpeculativeDecoder()
# normal_decoder = NormalDecoder()

# source_text = "He said Lamb made the fateful 911 call sometime after that." # spec does not work

# speculative_decoder.translate(source_text)
# start = time.time()
# speculative_translation, pc = speculative_decoder.translate(source_text)
# end = time.time()
# print(f"Speculative translation: {speculative_translation}")
# print(f"Percentage tokens accepted: {pc:.2f}%")
# print(f"Time taken: {end - start:.4f} seconds")

# start = time.time()
# normal_translation = normal_decoder.translate(source_text)
# end = time.time()
# print(f"Normal translation: {normal_translation}")
# print(f"Time taken: {end - start:.4f} seconds")

In [8]:
# Initialize decoder
from tqdm import tqdm

speculative_decoder = SpeculativeDecoder(gamma=10, temperature=0.1)
normal_decoder = NormalDecoder()

spec_total_time = 0
normal_total_time = 0
total_iters = 0

for i in tqdm(en_gr_dataset['translation'][:30]):
    source_text = i['en']
    target_text = i['de']
    
    # Time the translation
    start_time = time.time()
    spec_translation, pc = speculative_decoder.translate(source_text)
    end_time = time.time()

    spec_time = end_time - start_time

    # spec_total_time += spec_time
    
    start_time = time.time()
    normal_translation = normal_decoder.translate(source_text)
    end_time = time.time()

    normal_time = end_time - start_time

    # normal_total_time += normal_time

    print(f"Source: {source_text}")
    print(f"Normal Translation: {normal_translation}")
    print(f"Time taken: {normal_time:.2f} seconds")
    print(f"Speculative Translation: {spec_translation}")
    print(f"Time taken: {spec_time:.2f} seconds")
    print(f"Percentage tokens accepted: {pc:.2f}%")
    
    print(f"Target: {target_text}")

    # if normal_time - spec_time > -0.1:
    spec_total_time += spec_time
    normal_total_time += normal_time
    total_iters += 1

print(f"\nAverage time taken for normal decoding: {normal_total_time / total_iters:.2f} seconds")
print(f"Average time taken for speculative decoding: {spec_total_time / total_iters:.2f} seconds")
print(f"Average speedup over {total_iters} iterations: {normal_total_time / spec_total_time:.2f}x")

  3%|▎         | 1/30 [00:01<00:40,  1.39s/it]

Source: India and Japan prime ministers meet in Tokyo
Normal Translation: Indien trifft auf Japan
Time taken: 0.65 seconds
Speculative Translation: Indien und Japans Premierminister treffen sich in Tokio
Time taken: 0.74 seconds
Percentage tokens accepted: 100.00%
Target: Die Premierminister Indiens und Japans trafen sich in Tokio.


  7%|▋         | 2/30 [00:10<02:41,  5.78s/it]

Source: India's new prime minister, Narendra Modi, is meeting his Japanese counterpart, Shinzo Abe, in Tokyo to discuss economic and security ties, on his first major foreign visit since winning May's election.
Normal Translation: Der neue indische Premierminister Narendra Modi trifft auf seinem ersten großen Auslandsbesuch seit seinem Wahlsieg im Mai sein japanisches Amtskollege Shinzo Abe in Tokio, um Wirtschafts- und Sicherheitsbeziehungen zu diskutieren.
Time taken: 6.15 seconds
Speculative Translation: Auf seinemersten großen Außenbesuch seit dem Sieg der Wahlen im Mai trifft Indiens neuer Premierminister Narendra Modi sein japanisches Amtskollegen Shinzo Abe in Tokio, um überwirtschaftliche und Sicherheitsbeziehungen zu diskutieren.
Time taken: 2.70 seconds
Percentage tokens accepted: 75.71%
Target: Indiens neuer Premierminister Narendra Modi trifft bei seinem ersten wichtigen Auslandsbesuch seit seinem Wahlsieg im Mai seinen japanischen Amtskollegen Shinzo Abe in Toko, um wirtsc

 10%|█         | 3/30 [00:15<02:29,  5.52s/it]

Source: Mr Modi is on a five-day trip to Japan to strengthen economic ties with the third largest economy in the world.
Normal Translation: Herr Modi ist auf einer fünftägigen Reise nach Japan, um die wirtschaftlichen Beziehungen mit der drittgrößten Volkswirtschaft der Welt zu verstärken.
Time taken: 3.58 seconds
Speculative Translation: Vor fünf Tagenreist Herr Mod nach Japan, um die wirtschaftlichen Beziehungen zur drittgrößten Wirtschaft der Welt zu stärken.
Time taken: 1.64 seconds
Percentage tokens accepted: 62.50%
Target: Herr Modi befindet sich auf einer fünftägigen Reise nach Japan, um die wirtschaftlichen Beziehungen mit der drittgrößten Wirtschaftsnation der Welt zu festigen.


 13%|█▎        | 4/30 [00:17<01:52,  4.32s/it]

Source: High on the agenda are plans for greater nuclear co-operation.
Normal Translation: Eine hohe Priorität hat die Intensivierung der nuklearen Zusammenarbeit.
Time taken: 1.63 seconds
Speculative Translation: Ein Schwerpunkt auf der Agenda sind diePläne für eine verstärkte nukleare Zusammenarbeit
Time taken: 0.85 seconds
Percentage tokens accepted: 70.00%
Target: Pläne für eine stärkere kerntechnische Zusammenarbeit stehen ganz oben auf der Tagesordnung.


 17%|█▋        | 5/30 [00:21<01:44,  4.20s/it]

Source: India is also reportedly hoping for a deal on defence collaboration between the two nations.
Normal Translation: Indien erhofft sich angeblich auch eine Abmachung über Verteidigungskooperationen zwischen den beiden Nationen.
Time taken: 2.83 seconds
Speculative Translation: Unter anderem hoff Indien auf eine Vereinbarung über die Zusammenarbeit im Verteidigungsbereich zwischen den beiden Nation.
Time taken: 1.15 seconds
Percentage tokens accepted: 73.33%
Target: Berichten zufolge hofft Indien darüber hinaus auf einen Vertrag zur Verteidigungszusammenarbeit zwischen den beiden Nationen.


 20%|██        | 6/30 [00:25<01:34,  3.95s/it]

Source: Karratha police arrest 20-year-old after high speed motorcycle chase
Normal Translation: Die Polizei in Karratha verhaftet 20-jährige Mute nach einer schnellen Motorradjagd
Time taken: 2.17 seconds
Speculative Translation: Die Polizei inratha verhafte einen 20-jährigen Nachfolger von Hochgeschwindigkeitsjagd
Time taken: 1.29 seconds
Percentage tokens accepted: 54.55%
Target: Polizei von Karratha verhaftet 20-Jährigen nach schneller Motorradjagd


 23%|██▎       | 7/30 [00:28<01:21,  3.54s/it]

Source: A motorcycle has been seized after it was ridden at 125km/h in a 70km/h zone and through bushland to escape police in the Pilbara.
Normal Translation: Die deutschen rzte haben zwei neue medizinische Geräte für den Menschen entwickelt.
Time taken: 1.86 seconds
Speculative Translation: Carlo hat sich in der letzten Zeit mit externalen nderungen befass.
Time taken: 0.83 seconds
Percentage tokens accepted: 84.21%
Target: Ein Motorrad wurde beschlagnahmt, nachdem der Fahrer es mit 125 km/h in einer 70 km/h-Zone und durch Buschland gefahren hatte, um der Polizei in Bilbara zu entkommen.


 27%|██▋       | 8/30 [00:35<01:47,  4.88s/it]

Source: Traffic police on patrol in Karratha this morning tried to pull over a blue motorcycle when they spotted it reaching 125km/h as it pulled out of a service station on Bathgate Road.
Normal Translation: 15 Menschen sind im Gefängnis in der südlichsten Region des Landes verhaftet worden. Die Polizei hat vier Personen inhaftiert und die Polizisten sind weiterhin im Gefängnis.
Time taken: 4.83 seconds
Speculative Translation: EU-Behörde für Polizei und Polizei in Karratha hat heute Morgen versucht, ein blaues Motorrad zu überziehen, als sie sahen, dass e 125 km/h erreicht, alses aus einer Servicestation auf der Straß Bathgate zog.
Time taken: 2.93 seconds
Percentage tokens accepted: 74.67%
Target: Verkehrspolizisten in Karratha versuchten heute morgen, ein blaues Motorrad zu stoppen, nachdem sie es dabei beobachtet hatten, wie es mit 125 km/h eine Tankstelle auf der Bathdate Road verließ.


 30%|███       | 9/30 [00:42<01:56,  5.56s/it]

Source: Police say the rider then failed to stop and continued on to Burgess Road before turning into bushland, causing the officers to lose sight of it.
Normal Translation: In der Schlacht um den Brand stürzte der künstlerische Trainer der Lindauer Fussballnationalmannschaft, Claude Vias, in einem Unfall die Turbine, die das Fahrrad auf dem Straßenrand verschossen hatte.
Time taken: 5.86 seconds
Speculative Translation: One of the most important things about the ride was the fact that the rider was not in auto- mode.
Time taken: 1.20 seconds
Percentage tokens accepted: 74.07%
Target: Die Polizei berichtet, dass der Fahrer die Haltesignale dann ignorierte und weiter auf der Burgess Road fuhr, bevor er in das Buschland abbog, wo die Beamten es aus den Augen verloren.


 33%|███▎      | 10/30 [00:47<01:44,  5.21s/it]

Source: The motorcycle and a person matching the description of the rider was then spotted at a house on Walcott Way in Bulgarra.
Normal Translation: Der Motorrad und ein Mensch, der der Beschreibung des Fahrers entsprach, wurden dann in einem Haus auf der Walcott Way in Bulgarra gesehen.
Time taken: 3.26 seconds
Speculative Translation: Ein Motorrad und eine Person, die der Beschreibung des Fahrers entspricht, wurden dann in einem Haus auf der Walcott Way in Bulgarra gesehen.
Time taken: 1.17 seconds
Percentage tokens accepted: 100.00%
Target: Das Motorrad sowie eine Person, die der Beschreibung des Fahrers entsprach wurden später bei einem Haus im Walcott Way in Bulgarra gesehen.


 37%|███▋      | 11/30 [00:52<01:38,  5.19s/it]

Source: Karratha Police have charged a 20-year-old man with failing to stop and reckless driving.
Normal Translation: Die Polizei von Karratha hat einen 20-Jährigen wegen Fahrverbot und Fahrenstumpf beschuldigt.
Time taken: 2.74 seconds
Speculative Translation: Auf der Polizei von Karratha wurde ein 20-jähriger Mann angeklag, er sei vorgeworf worden, nicht gleich zu bleiben und fahrlässig fahren zu müssen
Time taken: 2.39 seconds
Percentage tokens accepted: 52.31%
Target: Die Polizei von Karratha beschuldigt einen 20-jährigen Mann der Nichtbeachtung eines Haltesignals sowie rücksichtslosen Fahrens.


 40%|████      | 12/30 [00:55<01:20,  4.45s/it]

Source: He is due to appear in Karratha Magistrates Court on September 23.
Normal Translation: Der Mann soll am 23. September vor dem Magistrate in Karratha erscheinen.
Time taken: 1.71 seconds
Speculative Translation: Aber er wird am 23 September im Magistra von Karratha erscheinen.
Time taken: 1.03 seconds
Percentage tokens accepted: 53.85%
Target: Er soll am 23. September vor dem Amtsgericht in Karratha erscheinen.


 43%|████▎     | 13/30 [00:58<01:08,  4.02s/it]

Source: The motorcycle was seized and impounded for three months.
Normal Translation: Das Motorrad wurde für drei Monate beschlagnahmt und verwahrt.
Time taken: 1.61 seconds
Speculative Translation: 400 m2–150 m Fläche, diezwänglich errichtet wurde
Time taken: 1.42 seconds
Percentage tokens accepted: 38.46%
Target: Das Motorrad wurde sichergestellt und für drei Monate beschlagnahmt.


 47%|████▋     | 14/30 [01:02<01:04,  4.03s/it]

Source: George Webster accused of Nairn and Pitlochry hotel rapes
Normal Translation: George Webster angeklagt für die Vergewaltigungen in Hotels in Nairn und Pitlochry
Time taken: 2.35 seconds
Speculative Translation: George Web hat der Vergewaligung von Hotelgäst in Nair und Pitlochry beschuldigt
Time taken: 1.72 seconds
Percentage tokens accepted: 40.43%
Target: George Webster wegen Hotelvergewaltigungen in Naim und Pitlochry angeklagt


 50%|█████     | 15/30 [01:06<01:02,  4.17s/it]

Source: A man is to stand trial accused of raping women at two hotels.
Normal Translation: Ein Mann soll vor Gericht gestellt werden, weil er angeklagt ist, Frauen in zwei Hotels vergewaltigt zu haben.
Time taken: 2.91 seconds
Speculative Translation: Ein Mann wird vor Gericht gestellt,  könne Frauen zweier Wohnungen vergewaltig haben.
Time taken: 1.57 seconds
Percentage tokens accepted: 41.46%
Target: Ein Mann steht wegen der Vergewaltigung von Frauen in zwei Hotels vor Gericht.


 53%|█████▎    | 16/30 [01:11<00:58,  4.20s/it]

Source: George Webster, 28, faced the charges during a hearing at the High Court in Glasgow.
Normal Translation: George Webster, 28 Jahre alt, leistete sich dieser Anklage in einer Anhörung des High Court in Glasgow.
Time taken: 2.73 seconds
Speculative Translation: Der 28-jährige George Webster hat sichwährend einer Anhörung am Hohen Gericht Glasgow mit vorgeworfen
Time taken: 1.54 seconds
Percentage tokens accepted: 56.41%
Target: George Webster, 28, wurde die Anklage bei einer Anhörung von dem Obersten Gericht in Glasgow verlesen.


 57%|█████▋    | 17/30 [01:14<00:52,  4.06s/it]

Source: He is alleged to have raped a woman at the Scotland's Hotel in Pitlochry in Perthshire on June 7, 2013.
Normal Translation: Er soll am 7. Juni 2013 im Scotland's Hotel in Pitlochry im Perthshire eine Frau vergewaltigt haben.
Time taken: 2.59 seconds
Speculative Translation: Er soll am 7. Juni 2013 eine Frau im Scottish's Hotel in Pitlochry in Perthshire vergewaltig haben.
Time taken: 1.15 seconds
Percentage tokens accepted: 85.19%
Target: Er wird beschuldigt, am 7. Juni 2013 eine Frau im Scotland's Hotel in Pitlochry in Perthshire vergewaltigt zu haben.


 60%|██████    | 18/30 [01:19<00:52,  4.36s/it]

Source: It is claimed Webster attacked her while she was "unconscious, asleep and incapable of giving consent."
Normal Translation: Sie wohnte in einem kleinen Haus in Woodsboro, Virginia.
Time taken: 1.71 seconds
Speculative Translation: Internet hat im Jahr 2373 dieamerikanischenen Dienstestoßeschließlich erlitten. 012365 hat eine deutscheAdresse, dienennt sie Webster..
Time taken: 3.36 seconds
Percentage tokens accepted: 35.96%
Target: Die Anklage lautet, dass Webster sie angriff, während sie "bewusstlos war, schlief, und kein Einverständnis signalisieren konnte."


 63%|██████▎   | 19/30 [01:25<00:51,  4.66s/it]

Source: Webster is then charged with raping a second woman at the Golf View Hotel in Nairn in the Highlands on May 4, 2014.
Normal Translation: Webster wird dann angeklagt, eine zweite Frau im Golf View Hotel in Nairn in der Highlands am 4. Mai 2014 vergewaltigt zu haben.
Time taken: 3.54 seconds
Speculative Translation: Weil Web einen zweiten Mann im Golf View Hotel in Nairn im Highland verschleppt hat, wird Webster am 4. Mai 2014 angeklagt.
Time taken: 1.80 seconds
Percentage tokens accepted: 65.31%
Target: Webster wird darüber hinaus vorgeworfen, am 4. Mai 2014 eine zweite Frau im Golf View Hotel in Naim im schottischen Hochland vergewaltigt zu haben.


 67%|██████▋   | 20/30 [01:27<00:39,  3.98s/it]

Source: Judge Lady Rae set a trial date for November 17 at the High Court in Edinburgh.
Normal Translation: Der Prozess wird am 17. November vor dem High Court von Edinburgh stattfinden.
Time taken: 1.35 seconds
Speculative Translation: Laut Richterin Lady Rae wurde vor dem Hohen Gericht in Edinburgh ein Verfahrensdatum für den 17. November festgelegt.
Time taken: 1.07 seconds
Percentage tokens accepted: 85.71%
Target: Richterin Lady Rae setzte den Verhandlungstermin für den 17. November am Obersten Gericht in Edinburgh an.


 70%|███████   | 21/30 [01:30<00:33,  3.72s/it]

Source: Reconnecting With the Very American Ideal That Labor Rights Are Human Rights
Normal Translation: Die Wiederherstellung des amerikanischen Ideals, wonach Arbeitsrechte Menschenrechte sind
Time taken: 1.88 seconds
Speculative Translation: Wiederherstellung der Verbindung dem sehramerikanischen Ideal, dass Arbeitsrechte Menschenrechte sind
Time taken: 1.23 seconds
Percentage tokens accepted: 51.61%
Target: Rückbesinnung auf das sehr amerikanische Ideal der Arbeitsrechte als Menschenrechte


 73%|███████▎  | 22/30 [01:35<00:31,  3.96s/it]

Source: Congressmen Keith Ellison and John Lewis have proposed legislation to protect union organizing as a civil right.
Normal Translation: Die Kongressabgeordneten Keith Ellison und John Lewis haben Gesetzesvorschläge zur Sicherung der Gewerkschaftsorganisation als Bürgerrecht vorgelegt.
Time taken: 3.40 seconds
Speculative Translation: Die Kongressabgeordneten Keith Ellison und John Lewis haben Gesetze vorgeschlagen, um Gewerkschaftsorganisationen als Bürgerrecht zu schützen.
Time taken: 1.13 seconds
Percentage tokens accepted: 100.00%
Target: Die Kongressabgeordneten Keith Ellison und John Lewis haben einen Gesetzesvorschlag eingebracht, um die Organisation von Gewerkschaften als Bürgerrecht zu etablieren.


 77%|███████▋  | 23/30 [01:42<00:35,  5.08s/it]

Source: "As go unions, so go middle-class jobs," says Ellison, the Minnesota Democrat who serves as a Congressional Progressive Caucus co-chair.
Normal Translation: "Was mit Gewerkschaften geht, geht auch mit den Arbeitsplätzen der Mittelschicht" sagt Ellison, der Demokrat aus Minnesota, der als Ko-Vorsitzender des Progressiven Kongressausschusses fungiert.
Time taken: 5.24 seconds
Speculative Translation: "So wie Gewerkschaften gehen, so gehen die Arbeitsplätze derschichtsträchtig", sagt Ellison, der Demokrat in Minnesota, der als Mitvorsitzender des Progressiven Kongressausschusse im Kongress fungiert
Time taken: 2.44 seconds
Percentage tokens accepted: 67.69%
Target: "So wie Gewerkschaften sterben, sterben auch die Mittelklassejobs," sagte Ellison, ein Demokrat aus Minnesota und stellvertretender Vorsitzender des Progressive Caucus im Kongress.


 80%|████████  | 24/30 [01:47<00:28,  4.81s/it]

Source: That's why I'm proud to introduce the Employee Empowerment Act with civil rights icon John Lewis.
Normal Translation: Aus diesem Grund freue ich mich, den Employee Empowerment Act mit dem Bürgerrechts-Ikonen John Lewis vorstellen zu können.
Time taken: 2.84 seconds
Speculative Translation: Das ist der Grund,warum ich stolz bin, das Arbeitnehmereigenschaftsgesetz mit dem Bürgerrechtsikon John Lewis präsentieren zu dürfen
Time taken: 1.33 seconds
Percentage tokens accepted: 76.47%
Target: Daher stelle ich stolz gemeinsam mit der Bürgerrechtsikone John Lewis das Mitarbeiterermächtigungsgesetz vor.


 83%|████████▎ | 25/30 [01:56<00:30,  6.06s/it]

Source: This ground-breaking legislation will give workers the same legal options for union organizing discrimination as for other forms of discrimination - stopping anti-union forces in their tracks
Normal Translation: Diese wegweisende Gesetzgebung wird Arbeitnehmern die gleichen gesetzlichen Möglichkeiten für die Diskriminierung aufgrund gewerkschaftlicher Organisationen geben wie für andere Formen der Diskriminierung - und antigewerkschaftliche Kräfte aufhalten;
Time taken: 5.40 seconds
Speculative Translation: Mit dieserwegweisenden Geetzgebung werden die Arbeitnehmer die gleichen rechtlichen Möglichkeiten für die Organisation von Diskriminierungen in gewerkschaftlichen Organisationen haben wie für andere Formen der Diskriminierung - die Gewerkschaftseinsätze in ihrenenen Richtung stoppen.
Time taken: 3.57 seconds
Percentage tokens accepted: 58.06%
Target: Dieses bahnbrechende Gesetz gibt Arbeitern die gleichen rechtlichen Möglichkeiten bei Diskriminierung wegen der Organisation v

 87%|████████▋ | 26/30 [02:07<00:30,  7.73s/it]

Source: Amending the National Labor Relations Act to allow workers who face discrimination for engaging in union organizing to sue for justice in the civil courts - and to collect compensatory and punitive damages - is a sound and necessary initiative.
Normal Translation: Die nderung des "National Labor Relations Act", das es Arbeitnehmern, die wegen ihrer Beteiligung beim Gewerkschaftsaufbau diskriminiert werden, erlaubt, sich vor Zivilgerichten für Gerechtigkeit einzusetzen - und Entschädigungs- und Strafgelder einzufordern - ist eine sinnvolle und notwendige Initiative.
Time taken: 7.71 seconds
Speculative Translation: des National Relations Act zu erlauben, Arbeitnehmern, die Diskriminierungandrohungen ausgesetzt sind, die Gewerkschaftsorganisation an den Zivilgericht klagen zu lassen  und entschädigende undstrafbare Schäden einzuziehen, ist eine gute und notwendige Initiative.
Time taken: 3.91 seconds
Percentage tokens accepted: 55.24%
Target: Die Ergänzung des nationalen Arbeitsr

 90%|█████████ | 27/30 [02:11<00:19,  6.48s/it]

Source: But it in certainly not a radical initiative - at least by American standards.
Normal Translation: Doch ist dies sicherlich keine radikale Initiative - zumindest nach amerikanischen Maßstäben.
Time taken: 2.53 seconds
Speculative Translation: Es ist jedochsicherlich keine radikale Initiative - zumindest in amerikanischen Standards.
Time taken: 1.03 seconds
Percentage tokens accepted: 69.23%
Target: Aber es ist mit Sicherheit keine radikale Initiative - jedenfalls nicht nach amerikanischen Standards.


 93%|█████████▎| 28/30 [02:18<00:13,  6.57s/it]

Source: Indeed, the best way to understand what Ellison, Lewis and the cosponsors of their legislation are proposing is as a reconnection with a very American idea.
Normal Translation: Tatsächlich ist das, was Ellison, Lewis und die Mitsponsoren ihrer Gesetzesvorlage vorschlagen, am besten als eine Wiederverbindung mit einer ziemlich amerikanischen Idee zu verstehen.
Time taken: 4.66 seconds
Speculative Translation: Tatsächlich ist die beste Möglichkeit, zu verstehen, was Ellison Lewis und die Mitverfasser ihrer Gesetzgebung vorschlagen, eine Verknüpfung mit sehr amerikanischer Idee.
Time taken: 2.11 seconds
Percentage tokens accepted: 74.55%
Target: Tatsächlich ist die beste Art und Weise zum Verständnis dessen, was Ellison, Lewis und die weiteren Sponsoren ihrer Gesetzesvorlage vorschlagen, die Verbindung zurück zu einer sehr amerikanischen Idee.


 97%|█████████▋| 29/30 [02:29<00:08,  8.16s/it]

Source: Despite the battering that unions have taken in recent years - in Wisconsin, Michigan and states across the country - Americans once encouraged countries around the world to embrace, extend and respect labor rights.
Normal Translation: Trotz des starken Schlags, den die Gewerkschaften in den letzten Jahren - in Wisconsin, Michigan und in weiteren Bundesstaaten - erlitten, ermutigten die Amerikaner einst Länder auf der ganzen Welt, ihre Arbeitsrechte aufzunehmen, auszubauen und zu respektieren.
Time taken: 7.22 seconds
Speculative Translation: Trotz der erschütternden Haltung, die diewerklichen Gewerkschaft in denletzten Jahren - in Wisconsin, Michigan und in allen Bundesstaat des Landes- eingeleitet, ermutigten die Amerikaner einst Länder auf der ganzen Welt, ihre Arbeitsrechte zu berücksichtig, auszuweit und zu respektieren
Time taken: 4.67 seconds
Percentage tokens accepted: 49.61%
Target: Trotz der Rückschläge, denen die Gewerkschaften in den vergangenen Jahren ausgesetzt wa

100%|██████████| 30/30 [02:37<00:00,  5.24s/it]

Source: There was a time, within the living memory of millions of Americans, when this country championed democracy, freedom of speech, freedom of the press and the right to organize in the same breath.
Normal Translation: Es gab eine Zeit, an die Millionen von Amerikanern noch erinnern, als dieses Land im selben Atemzug für Demokratie, Rede- und Pressefreiheit und das Recht auf Organisation eintrat.
Time taken: 4.78 seconds
Speculative Translation: Es gab eine Zeit, in der Millionen von Amerikanern leben, als dieses Land Demokratie, Redfreiheit, Pressfreiheit und das Recht, sich in gleichem Atem organisieren zu können, für sich kämpfte.
Time taken: 2.58 seconds
Percentage tokens accepted: 64.29%
Target: Es gab eine Zeit, an die sich Millionen von Amerikanern noch erinnern, als dieses Land Demokratie, Redefreiheit, Pressefreiheit und das Vereinigungsrecht in einem Atemzug nannte.

Average time taken for normal decoding: 3.39 seconds
Average time taken for speculative decoding: 1.85 sec




In [None]:
# Initialize decoder
from tqdm import tqdm

speculative_decoder = SpeculativeDecoder(gamma=8, temperature=0.1, draft_model_name='google-t5/t5-base')
normal_decoder = NormalDecoder()

spec_total_time = 0
normal_total_time = 0
total_iters = 0

for i in tqdm(en_gr_dataset['test']['translation'][:30]):
    source_text = i['en']
    target_text = i['de']
    
    # Time the translation
    start_time = time.time()
    spec_translation, pc = speculative_decoder.translate(source_text)
    end_time = time.time()

    spec_time = end_time - start_time

    # spec_total_time += spec_time
    
    start_time = time.time()
    normal_translation = normal_decoder.translate(source_text)
    end_time = time.time()

    normal_time = end_time - start_time

    # normal_total_time += normal_time

    print(f"Source: {source_text}")
    print(f"Normal Translation: {normal_translation}")
    print(f"Time taken: {normal_time:.2f} seconds")
    print(f"Speculative Translation: {spec_translation}")
    print(f"Time taken: {spec_time:.2f} seconds")
    print(f"Percentage tokens accepted: {pc:.2f}%")
    
    print(f"Target: {target_text}")

    if normal_time - spec_time > -0.1:
        spec_total_time += spec_time
        normal_total_time += normal_time
        total_iters += 1

print(f"\nAverage time taken for normal decoding: {normal_total_time / total_iters:.2f} seconds")
print(f"Average time taken for speculative decoding: {spec_total_time / total_iters:.2f} seconds")
print(f"Average speedup over {total_iters} iterations: {normal_total_time / spec_total_time:.2f}x")

In [None]:
# Initialize decoder
from tqdm import tqdm

speculative_decoder = SpeculativeDecoder(gamma=4, temperature=0.1, draft_model_name='google-t5/t5-large')
normal_decoder = NormalDecoder()

spec_total_time = 0
normal_total_time = 0
total_iters = 0

for i in tqdm(en_gr_dataset['test']['translation'][:100]):
    source_text = i['en']
    target_text = i['de']
    
    # Time the translation
    start_time = time.time()
    spec_translation, pc = speculative_decoder.translate(source_text)
    end_time = time.time()

    spec_time = end_time - start_time

    # spec_total_time += spec_time
    
    start_time = time.time()
    normal_translation = normal_decoder.translate(source_text)
    end_time = time.time()

    normal_time = end_time - start_time

    # normal_total_time += normal_time

    print(f"Source: {source_text}")
    print(f"Normal Translation: {normal_translation}")
    print(f"Time taken: {normal_time:.2f} seconds")
    print(f"Speculative Translation: {spec_translation}")
    print(f"Time taken: {spec_time:.2f} seconds")
    print(f"Percentage tokens accepted: {pc:.2f}%")
    
    print(f"Target: {target_text}")

    if normal_time - spec_time > -0.1:
        spec_total_time += spec_time
        normal_total_time += normal_time
        total_iters += 1

print(f"\nAverage time taken for normal decoding: {normal_total_time / total_iters:.2f} seconds")
print(f"Average time taken for speculative decoding: {spec_total_time / total_iters:.2f} seconds")
print(f"Average speedup over {total_iters} iterations: {normal_total_time / spec_total_time:.2f}x")

# Online Spec

In [None]:
class OnlineSpeculativeDecoder:
    def __init__(
        self,
        target_model_name = "google-t5/t5-large",
        draft_model_name = "google-t5/t5-small",
        device = "cuda" if torch.cuda.is_available() else "cpu",
        gamma = 4,
        update_interval=2,  # Update draft model after every `update_interval` iterations
        buffer_size_threshold=2,  # Buffer size threshold for updates
        time_threshold=2,  # Time threshold (in seconds) for updates
    ):
        self.device = device
        self.gamma = gamma
        self.update_interval = update_interval
        self.buffer_size_threshold = buffer_size_threshold
        self.time_threshold = time_threshold
        self.last_update_time = time.time()  # Track last update time

        self.tokenizer = T5Tokenizer.from_pretrained(target_model_name)
        self.target_model = T5ForConditionalGeneration.from_pretrained(target_model_name).to(device)
        self.draft_model = T5ForConditionalGeneration.from_pretrained(draft_model_name).to(device)

        self.target_model.eval()
        self.draft_model.eval()

        # Buffers for storing token proposals and updates
        self.replay_buffer = []
        self.temp_buffer = []  # Temporary buffer for a single request

        # Counter for iteration tracking
        self.iteration_count = 0

    def get_draft_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        gamma: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get draft logits for gamma tokens"""
        draft_tokens = []
        draft_probs = []
        current_decoder_ids = decoder_input_ids

        # Generate gamma tokens from the draft model
        for _ in range(gamma):
            with torch.no_grad():
                outputs = self.draft_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=current_decoder_ids,
                    return_dict=True
                )
                logits = outputs.logits[:, -1, :]  # Get logits for last position
                probs = F.softmax(logits, dim=-1)

                # Sample token
                token_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
                prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)

                draft_tokens.append(token_id.item())
                draft_probs.append(prob.item())

                # Update decoder inputs for next iteration
                current_decoder_ids = torch.cat(
                    [current_decoder_ids, token_id.view(1, 1)],
                    dim=1
                )

                if token_id.item() == self.tokenizer.eos_token_id:
                    break

        return draft_tokens, draft_probs, current_decoder_ids, outputs.logits

    def get_target_probs(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        draft_tokens: torch.Tensor
    ) -> torch.Tensor:
        """Get target probabilities for the draft tokens in parallel."""
        with torch.no_grad():
            # Add draft tokens to decoder input
            full_decoder_ids = torch.cat([decoder_input_ids, draft_tokens.unsqueeze(0)], dim=1)

            outputs = self.target_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=full_decoder_ids,
                return_dict=True
            )

            # Get probabilities for positions before each draft token
            logits = outputs.logits[:, -(len(draft_tokens) + 1):-1, :]
            target_probs = F.softmax(logits, dim=-1)

            return target_probs.squeeze(0), outputs.logits
        
    def get_logits(self, model, input_ids, attention_mask):
        return model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        ).logits

    def verify_tokens(
        self,
        target_probs: torch.Tensor,
        draft_tokens: torch.Tensor,
        draft_probs: torch.Tensor,
    ) -> int:
        """Determine number of accepted tokens"""
        # Get target probabilities for the draft tokens
        target_probs_draft_tokens = target_probs.gather(
            -1,
            draft_tokens.unsqueeze(-1)
        ).squeeze(-1)

        # Calculate acceptance ratios
        acceptance_ratios = target_probs_draft_tokens / draft_probs

        # Sample uniform random numbers
        random_nums = torch.rand_like(acceptance_ratios)

        # Find number of accepted tokens
        # Accept if random number < min(1, target_prob / draft_prob)
        accepts = random_nums < torch.minimum(
            torch.ones_like(acceptance_ratios),
            acceptance_ratios
        )

        # Find first rejection
        try:
            n_accepted = torch.where(~accepts)[0][0].item()
        except:
            n_accepted = len(accepts)

        return n_accepted

        # accepted_tokens = []
        # for i in range(len(draft_tokens)):
        #     if target_probs[i] / draft_probs[i] > torch.rand(1).item():
        #         accepted_tokens.append(draft_tokens[i])
        #     else:
        #         break # Stop if token is not accepted

        # return len(accepted_tokens)
    
    # TODO: verify this, might need to do some window size thing
    def update_draft_model(self):
        """Update draft model with the replay buffer."""
        if len(self.replay_buffer) == 0:
            return

        # Get draft tokens, draft and target probabilities from the replay buffer
        # draft_tokens = torch.tensor([x[0] for x in self.replay_buffer], device=self.device)
        # print(self.replay_buffer[0][0])
        # draft_probs = self.replay_buffer[:, 0]
        # target_probs = self.replay_buffer[:, 1]
        draft_probs = torch.stack([x[0][0] for x in self.replay_buffer], dim=0)
        target_probs = torch.stack([x[1][0] for x in self.replay_buffer], dim=0)

        self.draft_model.train()

        # for param in self.draft_model.parameters():
        #     print(param.requires_grad)


        # criterion = torch.nn.CrossEntropyLoss()
        # print(draft_probs.shape, target_probs.shape)
        loss = self.soft_cross_entropy(draft_probs, target_probs)
        print("Loss grad_fn:", loss.grad_fn)
        print("Draft probs grad_fn:", draft_probs.grad_fn)
        print("Target probs grad_fn:", target_probs.grad_fn)

        loss.backward()

        self.draft_model.eval()

        # Clear the replay buffer
        self.replay_buffer = []

    def soft_cross_entropy(self, predicts, targets, padding_mask=None):
        predict_log_prob = torch.nn.functional.log_softmax(predicts, dim=-1)
        targets_prob = torch.nn.functional.softmax(targets, dim=-1)
        entropy = -targets_prob * predict_log_prob
        # expand_mask = padding_mask.unsqueeze(-1).expand_as(entropy)
        # entropy.masked_fill_(expand_mask, 0)
        # mean_entropy = entropy.sum() / (~padding_mask).sum()
        return entropy

    def translate(
        self,
        source_text: str,
        max_length: int = 128
    ) -> str:
        """Translate source text using speculative decoding."""
        # Encode source text
        encoder_inputs = self.tokenizer(
            f"translate English to German: {source_text}",
            return_tensors="pt",
            padding=True
        ).to(self.device)

        # Initialize with start token
        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], device=self.device)

        self.iteration_count = 0
        self.replay_buffer = []

        while decoder_input_ids.shape[1] < max_length:
            self.temp_buffer = []

            while decoder_input_ids.shape[1] < max_length:
                # Get draft tokens autoregressively
                # print("Encoder Inputs", encoder_inputs.input_ids.shape)
                draft_tokens, draft_probs, proposed_decoder_ids, draft_logits = self.get_draft_logits(
                    encoder_inputs.input_ids,
                    encoder_inputs.attention_mask,
                    decoder_input_ids,
                    self.gamma
                )

                draft_tokens = torch.tensor(draft_tokens, device=self.device)
                draft_probs = torch.tensor(draft_probs, device=self.device)

                if len(draft_tokens) == 0:
                    break

                # Get target probabilities in parallel
                target_probs, target_logits = self.get_target_probs(
                    encoder_inputs.input_ids,
                    encoder_inputs.attention_mask,
                    decoder_input_ids,
                    draft_tokens
                )
                # print(draft_logits.shape, target_logits.shape)

                # Verify tokens
                n_accepted = self.verify_tokens(target_probs, draft_tokens, draft_probs)

                # Accept verified tokens
                if n_accepted > 0:
                    decoder_input_ids = torch.cat([
                        decoder_input_ids,
                        draft_tokens[:n_accepted].unsqueeze(0)
                    ], dim=1)

                # # If rejection or no acceptance, sample one token from target
                # if n_accepted < len(draft_tokens):
                #     with torch.no_grad():
                #         outputs = self.target_model(
                #             input_ids=encoder_inputs.input_ids,
                #             attention_mask=encoder_inputs.attention_mask,
                #             decoder_input_ids=decoder_input_ids,
                #             return_dict=True
                #         )
                #         logits = outputs.logits[:, -1, :]
                #         probs = F.softmax(logits, dim=-1)
                #         token_id = torch.multinomial(probs, num_samples=1)
                #         decoder_input_ids = torch.cat([decoder_input_ids, token_id], dim=1)

                with torch.no_grad():
                    # print(target_logits.shape, draft_logits.shape)
                    probs = target_logits[:, -1, :] #- draft_logits[:, -1, :]
                    probs = F.softmax(probs, dim=-1)
                    token_id = torch.multinomial(probs, num_samples=1)
                    # print(probs.shape, token_id.shape)
                    decoder_input_ids = torch.cat([decoder_input_ids, token_id], dim=1)

                # Check for end of sequence
                if decoder_input_ids[0][-1].item() == self.tokenizer.eos_token_id:
                    break
                
                # TODO: Update buffer with draft and target logits of the first rejected token, verify implementation
                # rejected_tokens = draft_tokens[n_accepted]
                if n_accepted < len(draft_tokens):
                    # rejected_prob_draft = draft_logits[:, n_accepted, :]
                    # rejected_prob_target = target_logits[:, n_accepted, :]

                    self.temp_buffer.append((draft_logits[:, -1, :], target_logits[:, -1, :]))

                # Check for end of sequence
                if decoder_input_ids[0][-1].item() == self.tokenizer.eos_token_id:
                    break

            self.replay_buffer.extend(self.temp_buffer)
            self.iteration_count += 1

            if self.iteration_count % self.update_interval == 0:
                self.update_draft_model()
                self.iteration_count = 0

        # Decode translation
        translation = self.tokenizer.decode(
            decoder_input_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        return translation

In [None]:
# Initialize decoder
online_decoder = OnlineSpeculativeDecoder()

# Example translation
source_text = "In a world where technology evolves at an unprecedented pace, individuals and organizations must adapt quickly to the rapid advancements in artificial intelligence, machine learning, and automation, ensuring that ethical considerations, environmental sustainability, and equitable access to resources are prioritized to create a future that benefits all of humanity."

# Time the translation
start_time = time.time()
translation = online_decoder.translate(source_text)
end_time = time.time()

print(f"Source: {source_text}")
print(f"Translation: {translation}\n")
print(f"Time taken: {end_time - start_time:.2f} seconds")

# Online V2

In [None]:
class OnlineSpeculativeDecoder:
    def __init__(
        self,
        target_model_name = "google-t5/t5-large",
        draft_model_name = "google-t5/t5-small",
        device = "cuda" if torch.cuda.is_available() else "cpu",
        gamma = 4,
        update_interval=2,  # Update draft model after every `update_interval` iterations
        buffer_size_threshold=2,  # Buffer size threshold for updates
        time_threshold=2,  # Time threshold (in seconds) for updates
    ):
        self.device = device
        self.gamma = gamma
        self.update_interval = update_interval
        self.buffer_size_threshold = buffer_size_threshold
        self.time_threshold = time_threshold
        self.last_update_time = time.time()  # Track last update time

        self.tokenizer = T5Tokenizer.from_pretrained(target_model_name)
        self.target_model = T5ForConditionalGeneration.from_pretrained(target_model_name).to(device)
        self.draft_model = T5ForConditionalGeneration.from_pretrained(draft_model_name).to(device)

        self.target_model.eval()
        self.draft_model.eval()

        # Buffers for storing token proposals and updates
        self.replay_buffer = []
        self.temp_buffer = []  # Temporary buffer for a single request

        # Counter for iteration tracking
        self.iteration_count = 0

    def get_draft_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        gamma: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get draft logits for gamma tokens"""
        draft_tokens = []
        draft_probs = []
        current_decoder_ids = decoder_input_ids

        # Generate gamma tokens from the draft model
        for _ in range(gamma):
            with torch.no_grad():
                outputs = self.draft_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=current_decoder_ids,
                    return_dict=True
                )
                logits = outputs.logits[:, -1, :]  # Get logits for last position
                probs = F.softmax(logits, dim=-1)

                # Sample token
                token_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
                prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)

                draft_tokens.append(token_id.item())
                draft_probs.append(prob.item())

                # Update decoder inputs for next iteration
                current_decoder_ids = torch.cat(
                    [current_decoder_ids, token_id.view(1, 1)],
                    dim=1
                )

                if token_id.item() == self.tokenizer.eos_token_id:
                    break

        return draft_tokens, draft_probs, current_decoder_ids, outputs.logits

    def get_target_probs(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        draft_tokens: torch.Tensor
    ) -> torch.Tensor:
        """Get target probabilities for the draft tokens in parallel."""
        with torch.no_grad():
            # Add draft tokens to decoder input
            full_decoder_ids = torch.cat([decoder_input_ids, draft_tokens.unsqueeze(0)], dim=1)

            outputs = self.target_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=full_decoder_ids,
                return_dict=True
            )

            # Get probabilities for positions before each draft token
            logits = outputs.logits[:, -(len(draft_tokens) + 1):-1, :]
            target_probs = F.softmax(logits, dim=-1)

            return target_probs.squeeze(0), outputs.logits
        
    def get_logits(self, model, input_ids, attention_mask, decoder_input_ids):
        return model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
        ).logits

    def verify_tokens(
        self,
        target_probs: torch.Tensor,
        draft_tokens: torch.Tensor,
        draft_probs: torch.Tensor,
    ) -> int:
        """Determine number of accepted tokens"""
        # Get target probabilities for the draft tokens
        target_probs_draft_tokens = target_probs.gather(
            -1,
            draft_tokens.unsqueeze(-1)
        ).squeeze(-1)

        # Calculate acceptance ratios
        acceptance_ratios = target_probs_draft_tokens / draft_probs

        # Sample uniform random numbers
        random_nums = torch.rand_like(acceptance_ratios)

        # Find number of accepted tokens
        # Accept if random number < min(1, target_prob / draft_prob)
        accepts = random_nums < torch.minimum(
            torch.ones_like(acceptance_ratios),
            acceptance_ratios
        )

        # Find first rejection
        try:
            n_accepted = torch.where(~accepts)[0][0].item()
        except:
            n_accepted = len(accepts)

        return n_accepted

        # accepted_tokens = []
        # for i in range(len(draft_tokens)):
        #     if target_probs[i] / draft_probs[i] > torch.rand(1).item():
        #         accepted_tokens.append(draft_tokens[i])
        #     else:
        #         break # Stop if token is not accepted

        # return len(accepted_tokens)
    
    # TODO: verify this, might need to do some window size thing
    # def update_draft_model(self):
    #     """Update draft model with the replay buffer."""
    #     if len(self.replay_buffer) == 0:
    #         return

    #     # Get draft tokens, draft and target probabilities from the replay buffer
    #     draft_tokens = torch.tensor([x[0] for x in self.replay_buffer], device=self.device)
    #     target_logits = torch.tensor([x[1] for x in self.replay_buffer], device=self.device)

    #     encoder_inputs = s
    #     output = self.draft_model(
    #         input_ids=encoder_inputs.input_ids,
    #         attention_mask=encoder_inputs.attention_mask,
    #         decoder_input_ids=decoder_input_ids,
    #         return_dict=True
    #     )

    def pad_to_2d(self, tensor_list, pad_token_id, max_len=None):
        if not isinstance(tensor_list[0], torch.Tensor):
            tensor_list = [torch.tensor(t).reshape(1, -1) for t in tensor_list]
        if max_len is None:
            max_len = max([t.shape[-1] for t in tensor_list])
        assert max_len > 0

        # Pad each tensor to the max length and stack them to form a 2D tensor
        result = torch.cat(
            [
                torch.nn.functional.pad(
                    tensor, (0, max_len - tensor.shape[-1]),
                    value=pad_token_id
                )
                for tensor in tensor_list
            ],
            dim=0
        )
        return result
        

    def soft_cross_entropy(self, predicts, targets, padding_mask=None):
        predict_log_prob = torch.nn.functional.log_softmax(predicts, dim=-1)
        targets_prob = torch.nn.functional.softmax(targets, dim=-1)
        entropy = -targets_prob * predict_log_prob
        expand_mask = padding_mask.unsqueeze(-1).expand_as(entropy)
        entropy.masked_fill_(expand_mask, 0)
        mean_entropy = entropy.sum() / (~padding_mask).sum()
        return mean_entropy

    def translate_dataset(
        self,
        sentences: List[str],
        max_length: int = 128
    ) -> List[str]:
        """Translate dataset using online speculative decoding."""

        self.iteration_count = 0
        self.replay_buffer = []

        translated_data = []

        for source_text in sentences:
            # Encode source text
            encoder_inputs = self.tokenizer(
                f"translate English to German: {source_text}",
                return_tensors="pt",
                padding=True
            ).to(self.device)

            # Initialize with start token
            decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], device=self.device)
            self.temp_buffer = []

            while decoder_input_ids.shape[1] < max_length:

                # Check for end of sequence
                if decoder_input_ids[0][-1].item() == self.tokenizer.eos_token_id:
                    break

                # Get draft tokens autoregressively
                # print("Encoder Inputs", encoder_inputs.input_ids.shape)
                draft_tokens, draft_probs, proposed_decoder_ids, draft_logits = self.get_draft_logits(
                    encoder_inputs.input_ids,
                    encoder_inputs.attention_mask,
                    decoder_input_ids,
                    self.gamma
                )

                draft_tokens = torch.tensor(draft_tokens, device=self.device)
                draft_probs = torch.tensor(draft_probs, device=self.device)

                if len(draft_tokens) == 0:
                    break

                # Get target probabilities in parallel
                target_probs, target_logits = self.get_target_probs(
                    encoder_inputs.input_ids,
                    encoder_inputs.attention_mask,
                    decoder_input_ids,
                    draft_tokens
                )
                # print(draft_logits.shape, target_logits.shape)

                # Verify tokens
                n_accepted = self.verify_tokens(target_probs, draft_tokens, draft_probs)

                # Accept verified tokens
                if n_accepted > 0:
                    decoder_input_ids = torch.cat([
                        decoder_input_ids,
                        draft_tokens[:n_accepted].unsqueeze(0)
                    ], dim=1)

                with torch.no_grad():
                    # print(target_logits.shape, draft_logits.shape)
                    probs = target_logits[:, -1, :] #- draft_logits[:, -1, :]
                    probs = F.softmax(probs, dim=-1)
                    token_id = torch.multinomial(probs, num_samples=1)
                    # print(probs.shape, token_id.shape)
                    decoder_input_ids = torch.cat([decoder_input_ids, token_id], dim=1)
                
                # rejected_tokens = draft_tokens[n_accepted]
                if n_accepted < len(draft_tokens):

                    self.temp_buffer.append((encoder_inputs, decoder_input_ids, target_logits, n_accepted))

            self.replay_buffer.extend(self.temp_buffer)
            self.iteration_count += 1

            if self.iteration_count % self.update_interval == 0:
                # self.update_draft_model()
                self.draft_model.train()
                
                # finetune over collected tokens and logits
                encoder_input_ids = self.pad_to_2d([x[0].input_ids for x in self.replay_buffer], 0)
                encoder_attention_mask = torch.stack([x[0].attention_mask[0] for x in self.replay_buffer], dim=0)
                decoder_input_ids = self.pad_to_2d([x[1] for x in self.replay_buffer], 0, 512)

                print(encoder_input_ids.shape, encoder_attention_mask.shape, decoder_input_ids.shape)

                target_logits = [x[2] for x in self.replay_buffer]
                for i in range(len(target_logits)):
                    temp = torch.zeros(1, 32128, device=self.device).repeat(512 - target_logits[i].shape[1], 1).unsqueeze(0)
                    target_logits[i] = torch.cat([target_logits[i], temp], dim=1)

                n_accepted_tokens = [x[3] for x in self.replay_buffer]

                # CUDA out of memory error
                draft_logits = self.get_logits(self.draft_model, encoder_input_ids, encoder_attention_mask, decoder_input_ids).float()

                # need to get loss only using the wrong tokens
                # TODO: check if we need to ignore the pad tokens in the mask
                mask = torch.ones_like(decoder_input_ids, dtype=torch.bool)
                for i in range(len(n_accepted_tokens)):
                    mask[i, n_accepted_tokens[i]:] = False
                loss = self.soft_cross_entropy(draft_logits, target_logits, mask)
                loss.backward()

                self.draft_model.eval()
                self.replay_buffer = []
                self.iteration_count = 0

            # Decode translation
            translation = self.tokenizer.decode(
                decoder_input_ids[0],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True
            )
            translated_data.append(translation)
        return translated_data

In [None]:
# Initialize decoder
online_decoder = OnlineSpeculativeDecoder()

# Example translation
source_text = "In a world where technology evolves at an unprecedented pace, individuals and organizations must adapt quickly to the rapid advancements in artificial intelligence, machine learning, and automation, ensuring that ethical considerations, environmental sustainability, and equitable access to resources are prioritized to create a future that benefits all of humanity."
sents = [source_text] * 10
# Time the translation
start_time = time.time()
translation = online_decoder.translate_dataset(sents)
end_time = time.time()

for i, sent in enumerate(sents):
    print(f"Source: {sent}")
    print(f"Translation: {translation[i]}\n")
print(f"Time taken: {end_time - start_time:.2f} seconds")