In [1]:
import os
if not os.path.exists("/scratch/sarthak"):
    os.makedirs("/scratch/sarthak")

os.environ["HF_HOME"] = "/scratch/sarthak/"

In [2]:
import torch
import torch.nn.functional as F
from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoModelForSeq2SeqLM
from typing import List, Tuple, Optional
import time
import numpy as np
# %pip install tqdm
# %pip install ipywidgets

In [3]:
from typing import Tuple
import torch
import datasets
from transformers import T5Tokenizer

# # WMT16 EN-DE dataset
# def preprocess_function_ende(examples, tokenizer, args, train_args, prefix="translate English to German: "):
#     if train_args.debug:
#         all_text = examples['translation'][:2]
#     else:
#         all_text = examples['translation']
        
#     inputs = []
#     targets = []
#     for excerpt in all_text:
#         en_text = prefix + excerpt['en']
#         de_text = excerpt['de']

#         inputs.append(en_text)
#         targets.append(de_text)
            
#     padding = 'max_length'
#     model_inputs = tokenizer(
#         inputs,
#         max_length=args.source_max_length,
#         padding=padding,
#         truncation=True,
#         return_tensors="pt",
#     )
#     # Tokenize targets with the `text_target` keyword argument
#     labels = tokenizer(
#         text_target=targets,
#         max_length=args.train_target_max_length,
#         padding=padding,
#         truncation=True,
    #     return_tensors="pt",
    # )

    # if padding == "max_length":
    #             labels["input_ids"] = [
    #                 [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
    #             ]

    # model_inputs["labels"] = labels["input_ids"]
    # model_inputs["decoder_attention_mask"] = labels["attention_mask"]
    # return model_inputs

en_gr_dataset = datasets.load_dataset('wmt16', 'de-en')

# Base Speculative

In [4]:
class SpeculativeDecoder:
    def __init__(
        self,
        target_model_name = "google-t5/t5-3b",
        draft_model_name = "google-t5/t5-small",
        device = "cuda" if torch.cuda.is_available() else "cpu",
        gamma = 4,
        temperature = 0.5
    ):
        self.device = device
        self.gamma = gamma

        self.temperature = temperature

        self.tokenizer = T5Tokenizer.from_pretrained(target_model_name)
        
        # self.target_model = T5ForConditionalGeneration.from_pretrained(target_model_name).to(device)
        self.target_model = AutoModelForSeq2SeqLM.from_pretrained(target_model_name, device_map='auto')

        self.draft_model = T5ForConditionalGeneration.from_pretrained(draft_model_name).to(device)

        self.target_model.eval()
        self.draft_model.eval()

    def get_draft_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        gamma: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get draft logits for gamma tokens"""
        draft_tokens = []
        draft_probs = []
        current_decoder_ids = decoder_input_ids

        # Generate gamma tokens from the draft model
        for _ in range(gamma):
            with torch.no_grad():
                outputs = self.draft_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=current_decoder_ids,
                    return_dict=True
                )
                logits = outputs.logits[:, -1, :]  # Get logits for last position
                probs = F.softmax(logits, dim=-1)

                # Sample token
                token_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
                prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)

                draft_tokens.append(token_id.item())
                draft_probs.append(prob.item())

                # Update decoder inputs for next iteration
                current_decoder_ids = torch.cat(
                    [current_decoder_ids, token_id.view(1, 1)],
                    dim=1
                )

                if token_id.item() == self.tokenizer.eos_token_id:
                    break

        return draft_tokens, draft_probs#, current_decoder_ids, outputs.logits

    # def get_target_probs(
    #     self,
    #     input_ids: torch.Tensor,
    #     attention_mask: torch.Tensor,
    #     decoder_input_ids: torch.Tensor,
    #     draft_tokens: torch.Tensor
    # ) -> torch.Tensor:
    #     """Get target probabilities for the draft tokens in parallel."""
    #     with torch.no_grad():
    #         # Add draft tokens to decoder input
    #         # full_decoder_ids = torch.cat([decoder_input_ids, draft_tokens.unsqueeze(0)], dim=1)

    #         full_decoder_ids = [decoder_input_ids]
    #         for i in range(len(draft_tokens)):
    #             x = torch.cat([decoder_input_ids, draft_tokens.unsqueeze(0)[:, :i+1]], dim=1)
    #             full_decoder_ids.append(x)

    #         maxlen = max([x.shape[1] for x in full_decoder_ids])

    #         padded_decoder_ids = torch.stack([torch.tensor(
    #             F.pad(x, (0, maxlen - x.shape[1]), value=self.tokenizer.pad_token_id)[0]
    #         , device=self.device) for x in full_decoder_ids])

    #         batch_size = padded_decoder_ids.shape[0]
    #         input_ids_batched = input_ids.repeat(batch_size, 1)
    #         attention_mask_batched = attention_mask.repeat(batch_size, 1)

    #         # make it a triangular attention mask


    #         # print(input_ids_batched.shape, attention_mask_batched.shape, padded_decoder_ids.shape)

    #         # outputs = self.target_model(
    #         #     input_ids=input_ids,
    #         #     attention_mask=attention_mask,
    #         #     decoder_input_ids=full_decoder_ids,
    #         #     return_dict=True
    #         # )
    #         outputs = self.target_model(
    #             input_ids=input_ids_batched,
    #             attention_mask=attention_mask_batched,
    #             decoder_input_ids=padded_decoder_ids,
    #             # decoder_attention_mask=torch.triu(
    #             #     torch.zeros((padded_decoder_ids.shape[0], padded_decoder_ids.shape[1]), device=self.device)
    #             # ),
    #             return_dict=True
    #         )
    #         # print("passes target model")

    #         batched_logits = outputs.logits
    #         # print(batched_logits.shape)
    #         # 5, 5, 32128
    #         logits = torch.zeros((batched_logits.shape[1], batched_logits.shape[2]), device=self.device)
    #         for i in range(len(draft_tokens)):
    #             logits[i] = batched_logits[i, i, :]

    #         # print(logits.shape, batched_logits[-1].shape)

    #         target_probs = F.softmax(logits, dim=-1)

    #         # # Get probabilities for positions before each draft token
    #         # logits = outputs.logits[:, -(len(draft_tokens) + 1):-1, :]
    #         # target_probs = F.softmax(logits, dim=-1)

    #         # print(batched_logits[-1].unsqueeze(0).shape)

    #         return target_probs.squeeze(0), logits.unsqueeze(0)

    def get_target_probs(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        draft_tokens: torch.Tensor
    ) -> torch.Tensor:
        """Get target probabilities for the draft tokens in parallel."""
        with torch.no_grad():
            # Add draft tokens to decoder input
            full_decoder_ids = torch.cat([decoder_input_ids, draft_tokens.unsqueeze(0)], dim=1)

            # print(input_ids_batched.shape, attention_mask_batched.shape, padded_decoder_ids.shape)

            outputs = self.target_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=full_decoder_ids,
                # decoder_attention_mask=torch.triu(
                #     torch.zeros((full_decoder_ids.shape[0], full_decoder_ids.shape[1]), device=self.device)
                # ),
                return_dict=True
            )

            # Get probabilities for positions before each draft token
            logits = outputs.logits[:, -(len(draft_tokens) + 1):-1, :]
            target_probs = F.softmax(logits, dim=-1)

            return target_probs.squeeze(0), outputs.logits

    # def verify_tokens(
    #     self,
    #     target_probs: torch.Tensor,
    #     draft_tokens: torch.Tensor,
    #     draft_probs: torch.Tensor,
    # ) -> int:
    #     """Determine number of accepted tokens"""
    #     # Get target probabilities for the draft tokens
    #     target_probs_draft_tokens = target_probs.gather(
    #         -1,
    #         draft_tokens.unsqueeze(-1)
    #     ).squeeze(-1)

    #     # Calculate acceptance ratios
    #     acceptance_ratios = target_probs_draft_tokens / draft_probs

    #     # Sample uniform random numbers
    #     random_nums = torch.rand_like(acceptance_ratios)

    #     # Find number of accepted tokens
    #     # Accept if random number < min(1, target_prob / draft_prob)
    #     accepts = random_nums < torch.minimum(
    #         torch.ones_like(acceptance_ratios),
    #         acceptance_ratios
    #     )

    #     # Find first rejection
    #     try:
    #         n_accepted = torch.where(~accepts)[0][0].item()
    #     except:
    #         n_accepted = len(accepts)

    #     return n_accepted

    def verify_tokens(
        self,
        target_probs: torch.Tensor,
        draft_tokens: torch.Tensor,
        draft_probs: torch.Tensor,
        temperature: float = 1.0
    ) -> int:
        """Determine number of accepted tokens"""
        # Get target probabilities for the draft tokens
        target_probs_draft_tokens = target_probs.gather(
            -1,
            draft_tokens.unsqueeze(-1)
        ).squeeze(-1)

        target_probs_draft_tokens = target_probs_draft_tokens / max(temperature, 1e-10)
        draft_probs = draft_probs / max(temperature, 1e-10)

        # Calculate acceptance ratios
        acceptance_ratios = target_probs_draft_tokens.float() / draft_probs

        # Sample uniform random numbers 
        random_nums = torch.zeros_like(target_probs_draft_tokens).float().uniform_()

        mask = random_nums > acceptance_ratios
        num_accepted = (mask.cumsum(dim = -1) == 0).sum(dim = -1)

        return num_accepted.int().item()

    def translate(
        self,
        source_text: str,
        max_length: int = 128
    ) -> str:
        """Translate source text using speculative decoding."""
        # Encode source text
        encoder_inputs = self.tokenizer(
            f"translate English to German: {source_text}",
            return_tensors="pt",
            padding=True
        ).to(self.device)

        # Initialize with start token
        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], device=self.device)

        output = self.target_model(
            input_ids=encoder_inputs.input_ids,
            attention_mask=encoder_inputs.attention_mask,
            decoder_input_ids=decoder_input_ids,
            return_dict=True
        )

        probs = output.logits[:, -1, :]
                    
        probs = F.softmax(probs, dim=-1)
        token_id = torch.multinomial(probs, num_samples=1)

        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id, token_id.item()]], device=self.device)


        total_tokens = 0
        accepted_tokens = 0

        while decoder_input_ids.shape[1] < max_length:
            # Get draft tokens autoregressively
            draft_tokens, draft_probs = self.get_draft_logits(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids,
                self.gamma
            )

            draft_tokens = torch.tensor(draft_tokens, device=self.device)
            draft_probs = torch.tensor(draft_probs, device=self.device)

            if len(draft_tokens) == 0:
                raise ValueError("Draft tokens not generated.")

            # Get target probabilities in parallel
            # start = time.time()
            target_probs, target_logits = self.get_target_probs(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids,
                draft_tokens
            )
            # print("Time taken for target probs: ", time.time() - start)

            # Verify tokens
            n_accepted = self.verify_tokens(target_probs, draft_tokens, draft_probs, self.temperature)
            # print(n_accepted)

            # Accept verified tokens
            if n_accepted > 0:
                decoder_input_ids = torch.cat([
                    decoder_input_ids,
                    draft_tokens[:n_accepted].unsqueeze(0)
                ], dim=1)
                
            with torch.no_grad():
                # n_rejected = self.gamma - n_accepted
                n_rejected = len(draft_tokens) - n_accepted # CHECK IF THIS IS CORRECT
                total_tokens += len(draft_tokens) #self.gamma
                accepted_tokens += n_accepted
                if n_rejected > 0:
                    probs = target_logits[:, -n_rejected, :] #- draft_logits[:, 1-n_rejected, :]
                else:
                    probs = target_logits[:, -1, :]
                    
                probs = F.softmax(probs, dim=-1)
                token_id = torch.multinomial(probs, num_samples=1)

                decoder_input_ids = torch.cat([decoder_input_ids, token_id], dim=1)

            # Check for end of sequence
            if decoder_input_ids[0][-1].item() == self.tokenizer.eos_token_id:
                break

        # Decode translation
        translation = self.tokenizer.decode(
            decoder_input_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        perc_accepted = accepted_tokens / total_tokens * 100
        return translation, perc_accepted

In [5]:
# speculative_decoder = SpeculativeDecoder(gamma=2)

# source_text = "Obama receives Netanyahu"

# speculative_decoder.translate(source_text)
# start = time.time()
# speculative_translation, pc = speculative_decoder.translate(source_text)
# end = time.time()
# print(f"Speculative translation: {speculative_translation}")
# print(f"Percentage tokens accepted: {pc:.2f}%")
# print(f"Time taken: {end - start:.4f} seconds")

# Normal

In [6]:
class NormalDecoder:
    def __init__(
        self,
        model_name: str = "google-t5/t5-3b",
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        self.device = device

        # Initialize tokenizer and model
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        # self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map='auto')
        self.model.eval()

    def get_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor
    ) -> torch.Tensor:
        """Get logits from model for the last token."""
        with torch.no_grad():
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                return_dict=True
            )
            return outputs.logits[:, -1, :]

    def sample_token(self, logits: torch.Tensor, temperature: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        """Sample a token from logits using temperature sampling."""
        if temperature == 0:
            # Greedy sampling
            token_id = torch.argmax(logits, dim=-1)
            prob = torch.ones_like(token_id, dtype=torch.float)
        else:
            # Temperature sampling
            probs = F.softmax(logits / temperature, dim=-1)
            token_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
            prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)
        return token_id, prob

    def translate(
        self,
        source_text: str,
        max_length: int = 128,
        temperature: float = 0.7
    ) -> str:
        """Translate source text using the normal T5 model."""
        # Encode source text
        encoder_inputs = self.tokenizer(
            f"translate English to German: {source_text}",
            return_tensors="pt",
            padding=True
        ).to(self.device)

        # Initialize decoder input with start token
        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], device=self.device)

        while decoder_input_ids.shape[1] < max_length:
            # Generate logits for the next token
            logits = self.get_logits(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids
            )

            # Sample a token
            token_id, _ = self.sample_token(logits, temperature)

            # Add token to the decoder input
            decoder_input_ids = torch.cat(
                [decoder_input_ids, token_id.view(1, 1)],
                dim=1
            )

            # Break if end token is generated
            if token_id.item() == self.tokenizer.eos_token_id:
                break

        # Decode and return translation
        translation = self.tokenizer.decode(
            decoder_input_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        return translation


In [7]:
# speculative_decoder = SpeculativeDecoder()
# normal_decoder = NormalDecoder()

# source_text = "He said Lamb made the fateful 911 call sometime after that." # spec does not work

# speculative_decoder.translate(source_text)
# start = time.time()
# speculative_translation, pc = speculative_decoder.translate(source_text)
# end = time.time()
# print(f"Speculative translation: {speculative_translation}")
# print(f"Percentage tokens accepted: {pc:.2f}%")
# print(f"Time taken: {end - start:.4f} seconds")

# start = time.time()
# normal_translation = normal_decoder.translate(source_text)
# end = time.time()
# print(f"Normal translation: {normal_translation}")
# print(f"Time taken: {end - start:.4f} seconds")

In [8]:
# Initialize decoder
from tqdm import tqdm

speculative_decoder = SpeculativeDecoder(gamma=4, temperature=0.1)
normal_decoder = NormalDecoder()

spec_total_time = 0
normal_total_time = 0
total_iters = 0

for i in tqdm(en_gr_dataset['test']['translation'][:30]):
    source_text = i['en']
    target_text = i['de']
    
    # Time the translation
    start_time = time.time()
    spec_translation, pc = speculative_decoder.translate(source_text)
    end_time = time.time()

    spec_time = end_time - start_time

    # spec_total_time += spec_time
    
    start_time = time.time()
    normal_translation = normal_decoder.translate(source_text)
    end_time = time.time()

    normal_time = end_time - start_time

    # normal_total_time += normal_time

    print(f"Source: {source_text}")
    print(f"Normal Translation: {normal_translation}")
    print(f"Time taken: {normal_time:.2f} seconds")
    print(f"Speculative Translation: {spec_translation}")
    print(f"Time taken: {spec_time:.2f} seconds")
    print(f"Percentage tokens accepted: {pc:.2f}%")
    
    print(f"Target: {target_text}")

    # if normal_time - spec_time > -0.1:
    spec_total_time += spec_time
    normal_total_time += normal_time
    total_iters += 1

print(f"\nAverage time taken for normal decoding: {normal_total_time / total_iters:.2f} seconds")
print(f"Average time taken for speculative decoding: {spec_total_time / total_iters:.2f} seconds")
print(f"Average speedup over {total_iters} iterations: {normal_total_time / spec_total_time:.2f}x")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
  3%|▎         | 1/30 [00:03<01:33,  3.21s/it]

Source: Obama receives Netanyahu
Normal Translation: Obama empfing Netanjahu
Time taken: 0.72 seconds
Speculative Translation: Obama Netanjahu empfing
Time taken: 2.49 seconds
Percentage tokens accepted: 60.00%
Target: Obama empfängt Netanyahu


  7%|▋         | 2/30 [00:07<01:50,  3.94s/it]

Source: The relationship between Obama and Netanyahu is not exactly friendly.
Normal Translation: Die Beziehung zwischen Obama und Netanjahu ist nicht gerade freundschaftlich.
Time taken: 1.64 seconds
Speculative Translation: Die Beziehungen zwischen Obama und Netanjahu sind nichtakt. Andersgesagt, das  gegenebul Klima gibt es in Europa nie. Nicht: wir diese 1995 erträumtenk Revolution in Berlin
Time taken: 2.81 seconds
Percentage tokens accepted: 42.86%
Target: Das Verhältnis zwischen Obama und Netanyahu ist nicht gerade freundschaftlich.


 10%|█         | 3/30 [00:13<02:11,  4.87s/it]

Source: The two wanted to talk about the implementation of the international agreement and about Teheran's destabilising activities in the Middle East.
Normal Translation: Die beiden wollten über die Umsetzung des internationalen Abkommens und über Teheran's destabilisierende Aktivitäten im Nahen Osten sprechen.
Time taken: 2.90 seconds
Speculative Translation: Daswätige Wort stammt vom Ausshiadan Dara, dem von den NGOs Hochnation- Azerthwa zu der-m
Time taken: 3.06 seconds
Percentage tokens accepted: 15.00%
Target: Die beiden wollten über die Umsetzung der internationalen Vereinbarung sowie über Teherans destabilisierende Maßnahmen im Nahen Osten sprechen.


 13%|█▎        | 4/30 [00:21<02:39,  6.12s/it]

Source: The meeting was also planned to cover the conflict with the Palestinians and the disputed two state solution.
Normal Translation: „Die israelische Regierung hat die israelische Währung im Land zerstört, die floh auf dem Marktplatz, nach einem offenen Markt, der überhaupt nicht mehr existieren kann“.
Time taken: 4.00 seconds
Speculative Translation: Se. Kalenderonul Georgs. Sichtbar war dass sämtlicheMädchen unter ihrertanz Exnikawege und Grundstücke entlang der Confluences würden, den Sohn Fas vondj keiner zu töten.
Time taken: 4.03 seconds
Percentage tokens accepted: 26.67%
Target: Bei der Begegnung soll es aber auch um den Konflikt mit den Palästinensern und die diskutierte Zwei-Staaten-Lösung gehen.


 17%|█▋        | 5/30 [00:24<01:59,  4.76s/it]

Source: Relations between Obama and Netanyahu have been strained for years.
Normal Translation: Die Beziehungen zwischen Obama und Netanjahu sind seit Jahren gespannt.
Time taken: 1.39 seconds
Speculative Translation: Die Beziehungen zwischen Obama und Netanjahu sind seit Jahren verfellich.
Time taken: 0.97 seconds
Percentage tokens accepted: 61.90%
Target: Das Verhältnis zwischen Obama und Netanyahu ist seit Jahren gespannt.


 20%|██        | 6/30 [00:29<01:56,  4.87s/it]

Source: Washington criticises the continuous building of settlements in Israel and accuses Netanyahu of a lack of initiative in the peace process.
Normal Translation: Washington kritisiert den fortgesetzten Aufbau von Siedlungen in Israel und beschuldigt Netanjahu des Mangels an Initiative im Friedensprozess.
Time taken: 3.32 seconds
Speculative Translation: Washington kritisiert den ständigen  Bau von Siedlungen in Israel und beschuldigt Netanjahu  einedetächtige Initiative Friedensprozesse
Time taken: 1.77 seconds
Percentage tokens accepted: 59.09%
Target: Washington kritisiert den andauernden Siedlungsbau Israels und wirft Netanyahu mangelnden Willen beim Friedensprozess vor.


 23%|██▎       | 7/30 [00:33<01:50,  4.81s/it]

Source: The relationship between the two has further deteriorated because of the deal that Obama negotiated on Iran's atomic programme, .
Normal Translation: Die Beziehungen zwischen beiden haben sich durch die Abmachung, die Obama über das Atomprogramm des Iran ausgehandelt hat, noch weiter verschlechtert.
Time taken: 3.15 seconds
Speculative Translation: Die Beziehungen zwischen beiden Staaten haben sich auch weiter verschlechtert, Obamas Abkommen über das Atomprogramm des Iran im Unter.
Time taken: 1.53 seconds
Percentage tokens accepted: 62.16%
Target: Durch den von Obama beworbenen Deal um das iranische Atomprogramm hat sich die Beziehung der beiden weiter verschlechtert.


 27%|██▋       | 8/30 [00:39<01:53,  5.17s/it]

Source: In March, at the invitation of the Republicans, Netanyahu made a controversial speech to the US Congress, which was partly seen as an affront to Obama.
Normal Translation: Im März hielt Netanjahu auf Einladung der Republikaner eine umstrittene Rede vor dem US-Kongress ab, die teilweise als Affront gegen Obama angesehen wurde.
Time taken: 4.06 seconds
Speculative Translation: Im März hielt Netanjahu auf Einladung der Republikaner eine kontroverse Rede den US-Kongress vor, dieer als Affront Obama angesehen wurde.
Time taken: 1.89 seconds
Percentage tokens accepted: 68.75%
Target: Im März hatte Netanyahu auf Einladung der Republikaner vor dem US-Kongress eine umstrittene Rede gehalten, die teils als Affront gegen Obama gewertet wurde.


 30%|███       | 9/30 [00:44<01:47,  5.10s/it]

Source: The speech had not been agreed with Obama, who had rejected a meeting with reference to the election that was at that time impending in Israel.
Normal Translation: Die Rede war nicht mit Obama vereinbart worden, der ein Treffen mit Bezug auf die Wahl, die damals in Israel ansteht, abgelehnt hatte.
Time taken: 2.89 seconds
Speculative Translation: Die Rede war Obama nicht vereinbart der einen betreffend Begegnung Blick auf die Wahl, die in Israel zum lichen bevor..
Time taken: 2.06 seconds
Percentage tokens accepted: 36.00%
Target: Die Rede war mit Obama nicht abgesprochen, ein Treffen hatte dieser mit Hinweis auf die seinerzeit bevorstehende Wahl in Israel abgelehnt.


 33%|███▎      | 10/30 [00:47<01:27,  4.40s/it]

Source: In 911 Call, Professor Admits to Shooting Girlfriend
Normal Translation: In einem 911-Anruf räumt ein Professor ein, seine Freundin erschossen zu haben
Time taken: 1.87 seconds
Speculative Translation: In Ruf 911 Professora1 einließ Freundin
Time taken: 0.95 seconds
Percentage tokens accepted: 20.83%
Target: In einem Notruf gesteht Professor, seine Freundin erschossen zu haben


 37%|███▋      | 11/30 [00:55<01:44,  5.51s/it]

Source: In a 911 call, his voice only slightly shaky, college professor Shannon Lamb told police he had shot his girlfriend and officers needed to get over to their house.
Normal Translation: In einer Telefonnummernrufe mit nur leicht zäher Stimme erklärte der Hochschulprofessor Shannon Lamb der Polizei, dass er seine Freundin erschossen habe und die Polizei zum Haus siesel gäbe.
Time taken: 4.64 seconds
Speculative Translation: In einem-ück  911-Anruf mit der Schwerkeit seiner Stimme erklärte College-Professor Shannon Lamb die Polizei, er habe seine Freundinschossen– und Bord, dieenfalls von der Polizeiiert seien zu ihre Unterkunft würden!
Time taken: 3.38 seconds
Percentage tokens accepted: 38.82%
Target: In einem Notruf erzählte Professor Shannon Lamb mit einer etwas zittrigen Stimme der Polizei, dass er seine Freundin erschossen habe und dass die Beamten zu seinem Haus kommen müssten.


 40%|████      | 12/30 [01:02<01:46,  5.94s/it]

Source: Lamb made a point to say his "sweet dog" was there alive and probably upset, and said the dead woman's family contacts could be found on her phone.
Normal Translation: Lamb machte einen Punkt, um zu sagen, dass sein "süßer Hund" war dort lebend und vermutlich verärgert, und sagte, die Kontakte der Familie der toten Frau auf ihrem Telefon zu finden.
Time taken: 4.36 seconds
Speculative Translation: Lamb hat angeen dasser seinen "süßen Hund" in Siouxdew wurde leben und wahrscheinlichstern getausch und dass sich die Familienkontakte der Toten auf Telefon.
Time taken: 2.58 seconds
Percentage tokens accepted: 42.19%
Target: Lamb war es wichtig zu betonen, dass sein "süßer Hund" aber noch lebe und wahrscheinlich aufgeregt sei, und er sagte, die Familienkontakte der toten Frau könnten auf dem Handy gefunden werden.


 43%|████▎     | 13/30 [01:11<01:58,  6.95s/it]

Source: Inside the home, officers found Amy Prentiss' body and a hand-written note scribbled on a white legal pad: "I am so very sorry I wish I could take it back I loved Amy and she is the only woman who ever loved me," read the letter authorities say was signed by Lamb.
Normal Translation: Das Unternehmen hat seine Geschäfte mit der amerikanischen Branche mit einem großen Umsatz im Konzernbereich Transport und Logistik angekurbelt, es sind heute rund 20.000 in-house- & - Logistik- und Logistikfirmen in der USA und im Ausland.
Time taken: 5.97 seconds
Speculative Translation: Während diestarken Berufskrankheiten guteg Beg rechnen Bundestags sind uns langsam. Nicht zu des  Wortes, sehr guteich ist auch meine Meinung derl Poli.
Time taken: 3.30 seconds
Percentage tokens accepted: 22.50%
Target: Innerhalb des Hauses fanden die Beamten die Leiche von Amy Prentiss und eine handgeschriebene Notiz, die auf einen weißen Block gekritzelt war: "Mir tut es so leid, ich wollte, ich könnte es rück

 47%|████▋     | 14/30 [01:24<02:17,  8.59s/it]

Source: There was no indication that Lamb, who was teaching two online classes for Delta State University in Cleveland, Mississippi, had already traveled 300 miles to the school's campus, where police believe he shot and killed a well-liked history professor, Ethan Schmidt, in the doorway to his office.
Normal Translation: Der Verdacht, dass Lamb, der zwei Online-Kurse bei der Delta State University in Cleveland, Mississippi, instructiert hatte, bereits 540 Kilometer zu dem Campus der Schule gefahren war, wo die Polizei glauben, dass er einen beliebten Geschichtsprofessor, Ethan Schmidt, in der Tür seines Büros erschoss und tötete.
Time taken: 8.10 seconds
Speculative Translation: Es gab keine Hinweise darauf, dass Lamb, der zwei Online-Klassen für die Delta State University in Cleveland, Mississippi,, schon 300 Meilen zum Campus der   Polyiden reiste, wo Polizei glauben, dass er einen gutgesnten Geschichtsprofessor, Ethan Schmidt, in der Tür zu seinem Büro geschschossen und getötet.
T

 50%|█████     | 15/30 [01:31<02:02,  8.16s/it]

Source: Delta State University police chief Lynn Buford said university officials heard about the shooting at 10:18 a.m.
Normal Translation: Der amerikanische Flughafen in Los Angeles wird anlässlich der nächsten Flugverbindung nicht abgelegt.
Time taken: 1.99 seconds
Speculative Translation: In derre Idida Roomts kam Bing's für einen Interview der -i-strate Andreas Ratau zu einer galt wie dem DIN A15.en und später in der neoalbden MediennetZ Easte seint. A eine Studie
Time taken: 5.17 seconds
Percentage tokens accepted: 33.94%
Target: Delta State University Polizeichef Lynn Buford sagte, dass Universitäts-Mitarbeiter das Schießen um 10:18 Uhr gehört haben.


 53%|█████▎    | 16/30 [01:37<01:44,  7.44s/it]

Source: He said Lamb made the fateful 911 call sometime after that.
Normal Translation: Er sagte, Lamb machte den schicksalhaften 911-Anruf irgendwann danach.
Time taken: 2.55 seconds
Speculative Translation: Johnson27 und Econ wurden von Louis  Dre, Nathanmie bak|angw. *:fragt aufoane "Hperide
Time taken: 3.21 seconds
Percentage tokens accepted: 20.90%
Target: Er sagte, dass Lamb den verhängnisvollen Notruf irgendwann danach gemacht hat.


 57%|█████▋    | 17/30 [01:49<01:56,  8.97s/it]

Source: By the end of the day, there would be one more death: Lamb took his own life as police closed in on him.
Normal Translation: Bis zum Ende des Tages war noch ein Tot zu beklagen: Lamb tötete sich selbst, während die Polizei auf ihn zuging.
Time taken: 3.21 seconds
Speculative Translation: Nachmmentsgerede stellten die Chancen einer Umbesetzungsierung derheitemäßig fest.Saendigebe fu@ lt,endem die Erde sich in einer Dopplewegebildete neuenNachtdeationliege.Iorany will wenig, die tue. Bild von der untauglichen Sicht Kelers aus.Hur.Kova siesten sich auf Fox.I Februar 2006 eine iconcla alsonnem Anre
Time taken: 9.32 seconds
Percentage tokens accepted: 31.61%
Target: Bis zum Ende des Tages gab es einen weiteren Tod: Lamm nahm sich das Leben, als die Polizei ihn einkesselte.


 60%|██████    | 18/30 [01:57<01:43,  8.60s/it]

Source: A day after the school shooting forced students and faculty to hide behind locked doors, authorities were still trying to piece together what motivated Lamb.
Normal Translation: Ein Tag nach dem Schulausbruch, der Schüler und Lehrer zwang, sich hinter verschlossenen Türen zu verstecken, versuchten die Behörden noch immer, herauszufinden, worauf Lamb zurückgreifen wollte.
Time taken: 5.00 seconds
Speculative Translation: Ein Tag nach der Schulanschläge erzwungen Studenten und Lehrer sich unter verschlossenen Türentenstehlen zu, probierten die Behörden nach wie vor zusammenzu, was Lambiert.
Time taken: 2.75 seconds
Percentage tokens accepted: 53.57%
Target: Einen Tag nach der Schießerei in der Universität, die Studenten und Dozenten dazu zwang, sich hinter verschlossenen Türen zu verstecken, versuchen die Behörden immer noch, sich ein Bild davon zu verschaffen, was Lamb motivierte.


 63%|██████▎   | 19/30 [02:07<01:40,  9.10s/it]

Source: The details released by investigators at both ends of the state as well as students and staff who knew him helped paint a picture of a talented but possibly troubled teacher.
Normal Translation: Die Details, die von den Ermittlern in beiden Enden des Staates sowie von Studenten und Mitarbeitern freigegeben wurden, die ihn kannten, haben dazu beigetragen, ein Bild von einem talentierten, aber vermutlich beunruhigenden Lehrer zu zeichnen.
Time taken: 6.36 seconds
Speculative Translation: Andenfalls war er auch derste, der von den USA fragte Erigte (us.ineid, f r Judas, der com habe oder bleibt wer weiß gendet).
Time taken: 3.91 seconds
Percentage tokens accepted: 30.12%
Target: Die von den Ermittlern an beiden Enden des Staates veröffentlichten Details, wie auch das, was Studenten und Mitarbeiter, die ihn kannten, aussagten, half dabei, ein Bild von einem talentierten, aber möglicherweise schwierigen Lehrer zu zeichnen.


 67%|██████▋   | 20/30 [02:09<01:10,  7.05s/it]

Source: Students said they looked forward to his class.
Normal Translation: Die Studenten sagten, sie freuen sich auf seine Klasse.
Time taken: 1.64 seconds
Speculative Translation: Schüler haben, dass man auf seine Stelle wartet
Time taken: 0.64 seconds
Percentage tokens accepted: 58.33%
Target: Studenten sagten, dass sie sich auf seinen Unterricht freuten.


 70%|███████   | 21/30 [02:20<01:12,  8.03s/it]

Source: Police in Gautier, where Prentiss died, said he had no history of violence or criminal record.
Normal Translation: Die Polizei in Gautier, wo Prentiss gestorben ist, sagte, dass er keine Geschichte von Gewalt und keinen Strafregistereintrag habe.
Time taken: 3.52 seconds
Speculative Translation: Noch 1939 um: kurzepäagogische und politische Aufführungen wurde von der dortigen Kommunikationsintanz -ageleitet Hannah Buna Arkansas- eingeführt. Sie waren am eines dernersten Kindes die Staatsgehörigkeit erten sich bei ihnen vonppell umsetzt.A Jahre späters kamen in Brüssel wiederUngläubige auf die Kontente
Time taken: 6.79 seconds
Percentage tokens accepted: 35.66%
Target: Die Polizei in Gautier, wo Prentiss starb, sagte, er habe keine Geschichte der Gewalt oder eine kriminelle Vorgeschichte.


 73%|███████▎  | 22/30 [02:26<00:59,  7.46s/it]

Source: Schmidt himself had included Lamb in a book he wrote where he acknowledged the "wonderful people" he shared his academic life with.
Normal Translation: Schmidt selbst hatte Lamb in ein Buch aufgenommen, in dem er die "wundervollen Leute" anerkennt, mit denen er sein akademisches Leben gemeinsam hatte.
Time taken: 4.08 seconds
Speculative Translation: Da Lawrence. diese bekämpft und verlieren an Haltung nahm eine Ausbildung an der UCLA statt.
Time taken: 2.04 seconds
Percentage tokens accepted: 26.19%
Target: Schmidt selbst hatte Lamb in einem von ihm geschriebenen Buch erwähnt, in dem er die "wunderbaren Menschen" erwähnte, mit denen er sein akademisches Leben teilte.


 77%|███████▋  | 23/30 [02:35<00:54,  7.83s/it]

Source: Both taught in the Division of Social Sciences and History, which lists 17 faculty members, and many students took courses from both.
Normal Translation: Der größte Teil des amerikanischen Volkes räumte dies auch der ffentlichkeit ein.
Time taken: 2.39 seconds
Speculative Translation: Înată  Media Medias ează ababelul 20r de manys in. Inindre auits deapoat aud 9 marin dod, a absorb fast bala desLice.au oberei Dec Iyat prn.alogicalnet-90%
Time taken: 6.30 seconds
Percentage tokens accepted: 21.05%
Target: Beide unterrichteten in der Abteilung für Sozialwissenschaften und Geschichte, deren Lehrkörper 17 Mitglieder umfasst, und viele Studenten besuchten Kurse von beiden.


 80%|████████  | 24/30 [02:37<00:36,  6.16s/it]

Source: At the same time, there were some inclinations of problems.
Normal Translation: Zugleich hatte es Ansätze von Problemen gegeben.
Time taken: 1.46 seconds
Speculative Translation: Gleichzeitig waren einige sogar verfügbar Probleme da.
Time taken: 0.79 seconds
Percentage tokens accepted: 46.67%
Target: Zur gleichen Zeit gab es einige Neigungen zu problematischen Verhaltensweisen.


 83%|████████▎ | 25/30 [02:42<00:29,  5.88s/it]

Source: A student who praised Lamb, Brandon Beavers, said he also seemed agitated and jittery, "like there was something wrong with him."
Normal Translation: Dennoch hat er seinen eigenen Weg gefunden, um sich wieder zu den Bemühungen der politischen Stellungnahme des demokratischen Kapitalismus zu bekehren.
Time taken: 4.09 seconds
Speculative Translation: Fear is not responsible. Incidental search is important but light be on this basic.
Time taken: 1.13 seconds
Percentage tokens accepted: 61.90%
Target: Ein Student, Brandon Beavers, der Lamb lobte, sagte, er schien auch ein wenig aufgeregt und nervös, "als ob etwas mit ihm falsch sei."


 87%|████████▋ | 26/30 [02:49<00:24,  6.11s/it]

Source: Another student, Mikel Sykes, said Lamb told him he was dealing with stress at the end of the 2014-15 academic year.
Normal Translation: Ein anderer Student, Mikel Sykes, sagte, Lamb habe ihm am Ende des Schuljahres 2014/15 mitgeteilt, dass er sich mit Stress auseinander setzt.
Time taken: 3.75 seconds
Speculative Translation: Ein Student, Mikel Sykes,, Pearl hat zudem gesagt, dass ihm Lambgültig, wenn sowohl  vom.. -rang: 'Stressstand'
Time taken: 2.88 seconds
Percentage tokens accepted: 44.83%
Target: Ein anderer Student, Mikel Sykes, sagte, dass Lamm ihm erzählt habe, dass er am Ende des akademischen Jahres 2014-15 mit Stress zu tun hatte.


 90%|█████████ | 27/30 [02:56<00:19,  6.58s/it]

Source: Lamb had earlier asked Delta State University for a medical leave of absence, saying he had a health issue of some sort.
Normal Translation: Nach dem Kaiserreich Napoleon beanspruchte er die Anfänge der koalitionellen militärischen Macht und stellte sich für die Kriegserklärung gegen Frankreich ein.
Time taken: 3.78 seconds
Speculative Translation: Unter Berücksichtigung der Manchester-lerRatsi vondenglichenen Agrarstruktur und wiederhol Auferstehung von Arbeitslosigkeit zwischen denn wurde variierend  Preiserechnet.Allerdings entlehnte man sich, während den Umfragen diehlen würde
Time taken: 3.89 seconds
Percentage tokens accepted: 38.14%
Target: Lamb hatte zuvor die Delta State University um eine Beurlaubung aus gesundheitlichen Gründen gebeten und dabei gesagt, dass er irgendein gesundheitliches Problem habe.


 93%|█████████▎| 28/30 [02:59<00:10,  5.36s/it]

Source: This year, he was only teaching two online classes.
Normal Translation: Dieses Jahr hat er nur noch zwei Online-Klassen gehalten.
Time taken: 1.48 seconds
Speculative Translation: In diesem Jahr gab dieser nur Online in zweienungen.
Time taken: 1.06 seconds
Percentage tokens accepted: 32.00%
Target: In diesem Jahr unterrichtete er nur zwei Online-Kurse.


 97%|█████████▋| 29/30 [03:05<00:05,  5.51s/it]

Source: Recent changes in the university's hiring policies meant that the doctorate Lamb had worked so hard to earn would not guarantee him an automatic tenure track to become an assistant professor.
Normal Translation: Nach der langer, mühsamen Arbeit, die ihn zum Doctorat beworben hatte, war er nicht mehr der ersten Hund, der als Assistent Professor zugelassen war.
Time taken: 4.01 seconds
Speculative Translation: Der Staat hatte sich zum "Spal" für Länder entscheiden die zahlzahlstechnischstengleichen ausgewiesen Dieliche.
Time taken: 1.85 seconds
Percentage tokens accepted: 34.78%
Target: Neueste Änderungen in der Beschäftigungspolitik der Universität bedeuteten, dass die Promotion, für die Lamb so hart gearbeitet hatte, für ihn keine Garantie für einen automatischen Weg zu einer Anstellung als Assistant-Professor darstellen würde.


100%|██████████| 30/30 [03:13<00:00,  6.44s/it]

Source: University President William LaForge said he didn't know of any conflict between Lamb and Schmidt but "obviously there was something in Mr. Lamb's mind."
Normal Translation: University President William LaForge said he didn't know of any conflict between Lamb and Schmidt but "obviously there was something in Mr. Lamb's mind"
Time taken: 3.30 seconds
Speculative Translation: "athes"." Aber  der neue desgroßen Fix sind auch neue Förderungen vons gehen die Char oder Behung Chsn, hergekommen Mar : "ncel...e  mls." (dh
Time taken: 4.57 seconds
Percentage tokens accepted: 18.97%
Target: Universitäts-Präsident William LaForge sagte, er wisse nichts von einem Konflikt zwischen Lamb und Schmidt, aber "natürlich gab es etwas in Mr Lambs Vorstellung."

Average time taken for normal decoding: 3.39 seconds
Average time taken for speculative decoding: 3.05 seconds
Average speedup over 30 iterations: 1.11x





In [None]:
# Initialize decoder
from tqdm import tqdm

speculative_decoder = SpeculativeDecoder(gamma=8, temperature=0.1, draft_model_name='google-t5/t5-base')
normal_decoder = NormalDecoder()

spec_total_time = 0
normal_total_time = 0
total_iters = 0

for i in tqdm(en_gr_dataset['test']['translation'][:30]):
    source_text = i['en']
    target_text = i['de']
    
    # Time the translation
    start_time = time.time()
    spec_translation, pc = speculative_decoder.translate(source_text)
    end_time = time.time()

    spec_time = end_time - start_time

    # spec_total_time += spec_time
    
    start_time = time.time()
    normal_translation = normal_decoder.translate(source_text)
    end_time = time.time()

    normal_time = end_time - start_time

    # normal_total_time += normal_time

    print(f"Source: {source_text}")
    print(f"Normal Translation: {normal_translation}")
    print(f"Time taken: {normal_time:.2f} seconds")
    print(f"Speculative Translation: {spec_translation}")
    print(f"Time taken: {spec_time:.2f} seconds")
    print(f"Percentage tokens accepted: {pc:.2f}%")
    
    print(f"Target: {target_text}")

    if normal_time - spec_time > -0.1:
        spec_total_time += spec_time
        normal_total_time += normal_time
        total_iters += 1

print(f"\nAverage time taken for normal decoding: {normal_total_time / total_iters:.2f} seconds")
print(f"Average time taken for speculative decoding: {spec_total_time / total_iters:.2f} seconds")
print(f"Average speedup over {total_iters} iterations: {normal_total_time / spec_total_time:.2f}x")

In [None]:
# Initialize decoder
from tqdm import tqdm

speculative_decoder = SpeculativeDecoder(gamma=4, temperature=0.1, draft_model_name='google-t5/t5-large')
normal_decoder = NormalDecoder()

spec_total_time = 0
normal_total_time = 0
total_iters = 0

for i in tqdm(en_gr_dataset['test']['translation'][:100]):
    source_text = i['en']
    target_text = i['de']
    
    # Time the translation
    start_time = time.time()
    spec_translation, pc = speculative_decoder.translate(source_text)
    end_time = time.time()

    spec_time = end_time - start_time

    # spec_total_time += spec_time
    
    start_time = time.time()
    normal_translation = normal_decoder.translate(source_text)
    end_time = time.time()

    normal_time = end_time - start_time

    # normal_total_time += normal_time

    print(f"Source: {source_text}")
    print(f"Normal Translation: {normal_translation}")
    print(f"Time taken: {normal_time:.2f} seconds")
    print(f"Speculative Translation: {spec_translation}")
    print(f"Time taken: {spec_time:.2f} seconds")
    print(f"Percentage tokens accepted: {pc:.2f}%")
    
    print(f"Target: {target_text}")

    if normal_time - spec_time > -0.1:
        spec_total_time += spec_time
        normal_total_time += normal_time
        total_iters += 1

print(f"\nAverage time taken for normal decoding: {normal_total_time / total_iters:.2f} seconds")
print(f"Average time taken for speculative decoding: {spec_total_time / total_iters:.2f} seconds")
print(f"Average speedup over {total_iters} iterations: {normal_total_time / spec_total_time:.2f}x")

# Online Spec

In [None]:
class OnlineSpeculativeDecoder:
    def __init__(
        self,
        target_model_name = "google-t5/t5-large",
        draft_model_name = "google-t5/t5-small",
        device = "cuda" if torch.cuda.is_available() else "cpu",
        gamma = 4,
        update_interval=2,  # Update draft model after every `update_interval` iterations
        buffer_size_threshold=2,  # Buffer size threshold for updates
        time_threshold=2,  # Time threshold (in seconds) for updates
    ):
        self.device = device
        self.gamma = gamma
        self.update_interval = update_interval
        self.buffer_size_threshold = buffer_size_threshold
        self.time_threshold = time_threshold
        self.last_update_time = time.time()  # Track last update time

        self.tokenizer = T5Tokenizer.from_pretrained(target_model_name)
        self.target_model = T5ForConditionalGeneration.from_pretrained(target_model_name).to(device)
        self.draft_model = T5ForConditionalGeneration.from_pretrained(draft_model_name).to(device)

        self.target_model.eval()
        self.draft_model.eval()

        # Buffers for storing token proposals and updates
        self.replay_buffer = []
        self.temp_buffer = []  # Temporary buffer for a single request

        # Counter for iteration tracking
        self.iteration_count = 0

    def get_draft_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        gamma: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get draft logits for gamma tokens"""
        draft_tokens = []
        draft_probs = []
        current_decoder_ids = decoder_input_ids

        # Generate gamma tokens from the draft model
        for _ in range(gamma):
            with torch.no_grad():
                outputs = self.draft_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=current_decoder_ids,
                    return_dict=True
                )
                logits = outputs.logits[:, -1, :]  # Get logits for last position
                probs = F.softmax(logits, dim=-1)

                # Sample token
                token_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
                prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)

                draft_tokens.append(token_id.item())
                draft_probs.append(prob.item())

                # Update decoder inputs for next iteration
                current_decoder_ids = torch.cat(
                    [current_decoder_ids, token_id.view(1, 1)],
                    dim=1
                )

                if token_id.item() == self.tokenizer.eos_token_id:
                    break

        return draft_tokens, draft_probs, current_decoder_ids, outputs.logits

    def get_target_probs(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        draft_tokens: torch.Tensor
    ) -> torch.Tensor:
        """Get target probabilities for the draft tokens in parallel."""
        with torch.no_grad():
            # Add draft tokens to decoder input
            full_decoder_ids = torch.cat([decoder_input_ids, draft_tokens.unsqueeze(0)], dim=1)

            outputs = self.target_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=full_decoder_ids,
                return_dict=True
            )

            # Get probabilities for positions before each draft token
            logits = outputs.logits[:, -(len(draft_tokens) + 1):-1, :]
            target_probs = F.softmax(logits, dim=-1)

            return target_probs.squeeze(0), outputs.logits
        
    def get_logits(self, model, input_ids, attention_mask):
        return model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        ).logits

    def verify_tokens(
        self,
        target_probs: torch.Tensor,
        draft_tokens: torch.Tensor,
        draft_probs: torch.Tensor,
    ) -> int:
        """Determine number of accepted tokens"""
        # Get target probabilities for the draft tokens
        target_probs_draft_tokens = target_probs.gather(
            -1,
            draft_tokens.unsqueeze(-1)
        ).squeeze(-1)

        # Calculate acceptance ratios
        acceptance_ratios = target_probs_draft_tokens / draft_probs

        # Sample uniform random numbers
        random_nums = torch.rand_like(acceptance_ratios)

        # Find number of accepted tokens
        # Accept if random number < min(1, target_prob / draft_prob)
        accepts = random_nums < torch.minimum(
            torch.ones_like(acceptance_ratios),
            acceptance_ratios
        )

        # Find first rejection
        try:
            n_accepted = torch.where(~accepts)[0][0].item()
        except:
            n_accepted = len(accepts)

        return n_accepted

        # accepted_tokens = []
        # for i in range(len(draft_tokens)):
        #     if target_probs[i] / draft_probs[i] > torch.rand(1).item():
        #         accepted_tokens.append(draft_tokens[i])
        #     else:
        #         break # Stop if token is not accepted

        # return len(accepted_tokens)
    
    # TODO: verify this, might need to do some window size thing
    def update_draft_model(self):
        """Update draft model with the replay buffer."""
        if len(self.replay_buffer) == 0:
            return

        # Get draft tokens, draft and target probabilities from the replay buffer
        # draft_tokens = torch.tensor([x[0] for x in self.replay_buffer], device=self.device)
        # print(self.replay_buffer[0][0])
        # draft_probs = self.replay_buffer[:, 0]
        # target_probs = self.replay_buffer[:, 1]
        draft_probs = torch.stack([x[0][0] for x in self.replay_buffer], dim=0)
        target_probs = torch.stack([x[1][0] for x in self.replay_buffer], dim=0)

        self.draft_model.train()

        # for param in self.draft_model.parameters():
        #     print(param.requires_grad)


        # criterion = torch.nn.CrossEntropyLoss()
        # print(draft_probs.shape, target_probs.shape)
        loss = self.soft_cross_entropy(draft_probs, target_probs)
        print("Loss grad_fn:", loss.grad_fn)
        print("Draft probs grad_fn:", draft_probs.grad_fn)
        print("Target probs grad_fn:", target_probs.grad_fn)

        loss.backward()

        self.draft_model.eval()

        # Clear the replay buffer
        self.replay_buffer = []

    def soft_cross_entropy(self, predicts, targets, padding_mask=None):
        predict_log_prob = torch.nn.functional.log_softmax(predicts, dim=-1)
        targets_prob = torch.nn.functional.softmax(targets, dim=-1)
        entropy = -targets_prob * predict_log_prob
        # expand_mask = padding_mask.unsqueeze(-1).expand_as(entropy)
        # entropy.masked_fill_(expand_mask, 0)
        # mean_entropy = entropy.sum() / (~padding_mask).sum()
        return entropy

    def translate(
        self,
        source_text: str,
        max_length: int = 128
    ) -> str:
        """Translate source text using speculative decoding."""
        # Encode source text
        encoder_inputs = self.tokenizer(
            f"translate English to German: {source_text}",
            return_tensors="pt",
            padding=True
        ).to(self.device)

        # Initialize with start token
        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], device=self.device)

        self.iteration_count = 0
        self.replay_buffer = []

        while decoder_input_ids.shape[1] < max_length:
            self.temp_buffer = []

            while decoder_input_ids.shape[1] < max_length:
                # Get draft tokens autoregressively
                # print("Encoder Inputs", encoder_inputs.input_ids.shape)
                draft_tokens, draft_probs, proposed_decoder_ids, draft_logits = self.get_draft_logits(
                    encoder_inputs.input_ids,
                    encoder_inputs.attention_mask,
                    decoder_input_ids,
                    self.gamma
                )

                draft_tokens = torch.tensor(draft_tokens, device=self.device)
                draft_probs = torch.tensor(draft_probs, device=self.device)

                if len(draft_tokens) == 0:
                    break

                # Get target probabilities in parallel
                target_probs, target_logits = self.get_target_probs(
                    encoder_inputs.input_ids,
                    encoder_inputs.attention_mask,
                    decoder_input_ids,
                    draft_tokens
                )
                # print(draft_logits.shape, target_logits.shape)

                # Verify tokens
                n_accepted = self.verify_tokens(target_probs, draft_tokens, draft_probs)

                # Accept verified tokens
                if n_accepted > 0:
                    decoder_input_ids = torch.cat([
                        decoder_input_ids,
                        draft_tokens[:n_accepted].unsqueeze(0)
                    ], dim=1)

                # # If rejection or no acceptance, sample one token from target
                # if n_accepted < len(draft_tokens):
                #     with torch.no_grad():
                #         outputs = self.target_model(
                #             input_ids=encoder_inputs.input_ids,
                #             attention_mask=encoder_inputs.attention_mask,
                #             decoder_input_ids=decoder_input_ids,
                #             return_dict=True
                #         )
                #         logits = outputs.logits[:, -1, :]
                #         probs = F.softmax(logits, dim=-1)
                #         token_id = torch.multinomial(probs, num_samples=1)
                #         decoder_input_ids = torch.cat([decoder_input_ids, token_id], dim=1)

                with torch.no_grad():
                    # print(target_logits.shape, draft_logits.shape)
                    probs = target_logits[:, -1, :] #- draft_logits[:, -1, :]
                    probs = F.softmax(probs, dim=-1)
                    token_id = torch.multinomial(probs, num_samples=1)
                    # print(probs.shape, token_id.shape)
                    decoder_input_ids = torch.cat([decoder_input_ids, token_id], dim=1)

                # Check for end of sequence
                if decoder_input_ids[0][-1].item() == self.tokenizer.eos_token_id:
                    break
                
                # TODO: Update buffer with draft and target logits of the first rejected token, verify implementation
                # rejected_tokens = draft_tokens[n_accepted]
                if n_accepted < len(draft_tokens):
                    # rejected_prob_draft = draft_logits[:, n_accepted, :]
                    # rejected_prob_target = target_logits[:, n_accepted, :]

                    self.temp_buffer.append((draft_logits[:, -1, :], target_logits[:, -1, :]))

                # Check for end of sequence
                if decoder_input_ids[0][-1].item() == self.tokenizer.eos_token_id:
                    break

            self.replay_buffer.extend(self.temp_buffer)
            self.iteration_count += 1

            if self.iteration_count % self.update_interval == 0:
                self.update_draft_model()
                self.iteration_count = 0

        # Decode translation
        translation = self.tokenizer.decode(
            decoder_input_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        return translation

In [None]:
# Initialize decoder
online_decoder = OnlineSpeculativeDecoder()

# Example translation
source_text = "In a world where technology evolves at an unprecedented pace, individuals and organizations must adapt quickly to the rapid advancements in artificial intelligence, machine learning, and automation, ensuring that ethical considerations, environmental sustainability, and equitable access to resources are prioritized to create a future that benefits all of humanity."

# Time the translation
start_time = time.time()
translation = online_decoder.translate(source_text)
end_time = time.time()

print(f"Source: {source_text}")
print(f"Translation: {translation}\n")
print(f"Time taken: {end_time - start_time:.2f} seconds")

# Online V2

In [None]:
class OnlineSpeculativeDecoder:
    def __init__(
        self,
        target_model_name = "google-t5/t5-large",
        draft_model_name = "google-t5/t5-small",
        device = "cuda" if torch.cuda.is_available() else "cpu",
        gamma = 4,
        update_interval=2,  # Update draft model after every `update_interval` iterations
        buffer_size_threshold=2,  # Buffer size threshold for updates
        time_threshold=2,  # Time threshold (in seconds) for updates
    ):
        self.device = device
        self.gamma = gamma
        self.update_interval = update_interval
        self.buffer_size_threshold = buffer_size_threshold
        self.time_threshold = time_threshold
        self.last_update_time = time.time()  # Track last update time

        self.tokenizer = T5Tokenizer.from_pretrained(target_model_name)
        self.target_model = T5ForConditionalGeneration.from_pretrained(target_model_name).to(device)
        self.draft_model = T5ForConditionalGeneration.from_pretrained(draft_model_name).to(device)

        self.target_model.eval()
        self.draft_model.eval()

        # Buffers for storing token proposals and updates
        self.replay_buffer = []
        self.temp_buffer = []  # Temporary buffer for a single request

        # Counter for iteration tracking
        self.iteration_count = 0

    def get_draft_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        gamma: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get draft logits for gamma tokens"""
        draft_tokens = []
        draft_probs = []
        current_decoder_ids = decoder_input_ids

        # Generate gamma tokens from the draft model
        for _ in range(gamma):
            with torch.no_grad():
                outputs = self.draft_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=current_decoder_ids,
                    return_dict=True
                )
                logits = outputs.logits[:, -1, :]  # Get logits for last position
                probs = F.softmax(logits, dim=-1)

                # Sample token
                token_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
                prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)

                draft_tokens.append(token_id.item())
                draft_probs.append(prob.item())

                # Update decoder inputs for next iteration
                current_decoder_ids = torch.cat(
                    [current_decoder_ids, token_id.view(1, 1)],
                    dim=1
                )

                if token_id.item() == self.tokenizer.eos_token_id:
                    break

        return draft_tokens, draft_probs, current_decoder_ids, outputs.logits

    def get_target_probs(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        draft_tokens: torch.Tensor
    ) -> torch.Tensor:
        """Get target probabilities for the draft tokens in parallel."""
        with torch.no_grad():
            # Add draft tokens to decoder input
            full_decoder_ids = torch.cat([decoder_input_ids, draft_tokens.unsqueeze(0)], dim=1)

            outputs = self.target_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=full_decoder_ids,
                return_dict=True
            )

            # Get probabilities for positions before each draft token
            logits = outputs.logits[:, -(len(draft_tokens) + 1):-1, :]
            target_probs = F.softmax(logits, dim=-1)

            return target_probs.squeeze(0), outputs.logits
        
    def get_logits(self, model, input_ids, attention_mask, decoder_input_ids):
        return model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
        ).logits

    def verify_tokens(
        self,
        target_probs: torch.Tensor,
        draft_tokens: torch.Tensor,
        draft_probs: torch.Tensor,
    ) -> int:
        """Determine number of accepted tokens"""
        # Get target probabilities for the draft tokens
        target_probs_draft_tokens = target_probs.gather(
            -1,
            draft_tokens.unsqueeze(-1)
        ).squeeze(-1)

        # Calculate acceptance ratios
        acceptance_ratios = target_probs_draft_tokens / draft_probs

        # Sample uniform random numbers
        random_nums = torch.rand_like(acceptance_ratios)

        # Find number of accepted tokens
        # Accept if random number < min(1, target_prob / draft_prob)
        accepts = random_nums < torch.minimum(
            torch.ones_like(acceptance_ratios),
            acceptance_ratios
        )

        # Find first rejection
        try:
            n_accepted = torch.where(~accepts)[0][0].item()
        except:
            n_accepted = len(accepts)

        return n_accepted

        # accepted_tokens = []
        # for i in range(len(draft_tokens)):
        #     if target_probs[i] / draft_probs[i] > torch.rand(1).item():
        #         accepted_tokens.append(draft_tokens[i])
        #     else:
        #         break # Stop if token is not accepted

        # return len(accepted_tokens)
    
    # TODO: verify this, might need to do some window size thing
    # def update_draft_model(self):
    #     """Update draft model with the replay buffer."""
    #     if len(self.replay_buffer) == 0:
    #         return

    #     # Get draft tokens, draft and target probabilities from the replay buffer
    #     draft_tokens = torch.tensor([x[0] for x in self.replay_buffer], device=self.device)
    #     target_logits = torch.tensor([x[1] for x in self.replay_buffer], device=self.device)

    #     encoder_inputs = s
    #     output = self.draft_model(
    #         input_ids=encoder_inputs.input_ids,
    #         attention_mask=encoder_inputs.attention_mask,
    #         decoder_input_ids=decoder_input_ids,
    #         return_dict=True
    #     )

    def pad_to_2d(self, tensor_list, pad_token_id, max_len=None):
        if not isinstance(tensor_list[0], torch.Tensor):
            tensor_list = [torch.tensor(t).reshape(1, -1) for t in tensor_list]
        if max_len is None:
            max_len = max([t.shape[-1] for t in tensor_list])
        assert max_len > 0

        # Pad each tensor to the max length and stack them to form a 2D tensor
        result = torch.cat(
            [
                torch.nn.functional.pad(
                    tensor, (0, max_len - tensor.shape[-1]),
                    value=pad_token_id
                )
                for tensor in tensor_list
            ],
            dim=0
        )
        return result
        

    def soft_cross_entropy(self, predicts, targets, padding_mask=None):
        predict_log_prob = torch.nn.functional.log_softmax(predicts, dim=-1)
        targets_prob = torch.nn.functional.softmax(targets, dim=-1)
        entropy = -targets_prob * predict_log_prob
        expand_mask = padding_mask.unsqueeze(-1).expand_as(entropy)
        entropy.masked_fill_(expand_mask, 0)
        mean_entropy = entropy.sum() / (~padding_mask).sum()
        return mean_entropy

    def translate_dataset(
        self,
        sentences: List[str],
        max_length: int = 128
    ) -> List[str]:
        """Translate dataset using online speculative decoding."""

        self.iteration_count = 0
        self.replay_buffer = []

        translated_data = []

        for source_text in sentences:
            # Encode source text
            encoder_inputs = self.tokenizer(
                f"translate English to German: {source_text}",
                return_tensors="pt",
                padding=True
            ).to(self.device)

            # Initialize with start token
            decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], device=self.device)
            self.temp_buffer = []

            while decoder_input_ids.shape[1] < max_length:

                # Check for end of sequence
                if decoder_input_ids[0][-1].item() == self.tokenizer.eos_token_id:
                    break

                # Get draft tokens autoregressively
                # print("Encoder Inputs", encoder_inputs.input_ids.shape)
                draft_tokens, draft_probs, proposed_decoder_ids, draft_logits = self.get_draft_logits(
                    encoder_inputs.input_ids,
                    encoder_inputs.attention_mask,
                    decoder_input_ids,
                    self.gamma
                )

                draft_tokens = torch.tensor(draft_tokens, device=self.device)
                draft_probs = torch.tensor(draft_probs, device=self.device)

                if len(draft_tokens) == 0:
                    break

                # Get target probabilities in parallel
                target_probs, target_logits = self.get_target_probs(
                    encoder_inputs.input_ids,
                    encoder_inputs.attention_mask,
                    decoder_input_ids,
                    draft_tokens
                )
                # print(draft_logits.shape, target_logits.shape)

                # Verify tokens
                n_accepted = self.verify_tokens(target_probs, draft_tokens, draft_probs)

                # Accept verified tokens
                if n_accepted > 0:
                    decoder_input_ids = torch.cat([
                        decoder_input_ids,
                        draft_tokens[:n_accepted].unsqueeze(0)
                    ], dim=1)

                with torch.no_grad():
                    # print(target_logits.shape, draft_logits.shape)
                    probs = target_logits[:, -1, :] #- draft_logits[:, -1, :]
                    probs = F.softmax(probs, dim=-1)
                    token_id = torch.multinomial(probs, num_samples=1)
                    # print(probs.shape, token_id.shape)
                    decoder_input_ids = torch.cat([decoder_input_ids, token_id], dim=1)
                
                # rejected_tokens = draft_tokens[n_accepted]
                if n_accepted < len(draft_tokens):

                    self.temp_buffer.append((encoder_inputs, decoder_input_ids, target_logits, n_accepted))

            self.replay_buffer.extend(self.temp_buffer)
            self.iteration_count += 1

            if self.iteration_count % self.update_interval == 0:
                # self.update_draft_model()
                self.draft_model.train()
                
                # finetune over collected tokens and logits
                encoder_input_ids = self.pad_to_2d([x[0].input_ids for x in self.replay_buffer], 0)
                encoder_attention_mask = torch.stack([x[0].attention_mask[0] for x in self.replay_buffer], dim=0)
                decoder_input_ids = self.pad_to_2d([x[1] for x in self.replay_buffer], 0, 512)

                print(encoder_input_ids.shape, encoder_attention_mask.shape, decoder_input_ids.shape)

                target_logits = [x[2] for x in self.replay_buffer]
                for i in range(len(target_logits)):
                    temp = torch.zeros(1, 32128, device=self.device).repeat(512 - target_logits[i].shape[1], 1).unsqueeze(0)
                    target_logits[i] = torch.cat([target_logits[i], temp], dim=1)

                n_accepted_tokens = [x[3] for x in self.replay_buffer]

                # CUDA out of memory error
                draft_logits = self.get_logits(self.draft_model, encoder_input_ids, encoder_attention_mask, decoder_input_ids).float()

                # need to get loss only using the wrong tokens
                # TODO: check if we need to ignore the pad tokens in the mask
                mask = torch.ones_like(decoder_input_ids, dtype=torch.bool)
                for i in range(len(n_accepted_tokens)):
                    mask[i, n_accepted_tokens[i]:] = False
                loss = self.soft_cross_entropy(draft_logits, target_logits, mask)
                loss.backward()

                self.draft_model.eval()
                self.replay_buffer = []
                self.iteration_count = 0

            # Decode translation
            translation = self.tokenizer.decode(
                decoder_input_ids[0],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True
            )
            translated_data.append(translation)
        return translated_data

In [None]:
# Initialize decoder
online_decoder = OnlineSpeculativeDecoder()

# Example translation
source_text = "In a world where technology evolves at an unprecedented pace, individuals and organizations must adapt quickly to the rapid advancements in artificial intelligence, machine learning, and automation, ensuring that ethical considerations, environmental sustainability, and equitable access to resources are prioritized to create a future that benefits all of humanity."
sents = [source_text] * 10
# Time the translation
start_time = time.time()
translation = online_decoder.translate_dataset(sents)
end_time = time.time()

for i, sent in enumerate(sents):
    print(f"Source: {sent}")
    print(f"Translation: {translation[i]}\n")
print(f"Time taken: {end_time - start_time:.2f} seconds")