In [1]:
import os
if not os.path.exists("/scratch/sarthak"):
    os.makedirs("/scratch/sarthak")

os.environ["HF_HOME"] = "/scratch/sarthak/"

In [2]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Tuple, Optional
import time
import numpy as np
import datasets
from tqdm import tqdm

In [3]:
en_gr_dataset = datasets.load_dataset('wmt16', 'de-en', split='validation')

# Speculative Model

In [4]:
class SpeculativeDecoder:
    def __init__(
        self,
        target_model_name = "meta-llama/Llama-3.2-3B",
        draft_model_name = "meta-llama/Llama-3.2-1B",
        gamma = 4,
        temperature = 1.0
    ):
        self.gamma = gamma

        self.temperature = temperature

        self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.target_model = AutoModelForCausalLM.from_pretrained(target_model_name, device_map='auto')

        self.draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, device_map='auto')

        self.target_model.eval()
        self.draft_model.eval()

    def get_draft_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        gamma: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get draft logits for gamma tokens"""
        draft_tokens = []
        draft_probs = []
        current_decoder_ids = decoder_input_ids

        # Generate gamma tokens from the draft model
        for _ in range(gamma):
            with torch.no_grad():
                outputs = self.draft_model(
                    input_ids=current_decoder_ids,
                    attention_mask=attention_mask,
                    # decoder_input_ids=current_decoder_ids,
                    return_dict=True,
                )
                logits = outputs.logits[:, -1, :]  # Get logits for last position
                probs = logits
                # probs = F.softmax(logits, dim=-1)

                token_id = torch.argmax(probs, dim=-1)
                prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)

                draft_tokens.append(token_id.item())
                draft_probs.append(prob.item())

                # Update decoder inputs for next iteration
                current_decoder_ids = torch.cat(
                    [current_decoder_ids, token_id.view(1, 1)],
                    dim=1
                )

                if token_id.item() == self.tokenizer.eos_token_id:
                    break

        return draft_tokens, draft_probs, outputs.logits.squeeze(0)

    def get_target_probs(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        draft_tokens: torch.Tensor,
    ) -> torch.Tensor:
        """Get target probabilities for the draft tokens in parallel."""
        with torch.no_grad():
            # Add draft tokens to decoder input
            full_decoder_ids = torch.cat([decoder_input_ids, draft_tokens.unsqueeze(0)], dim=1)
            # print(input_ids_batched.shape, attention_mask_batched.shape, padded_decoder_ids.shape)
            # decoder_mask = torch.triu(
            #     torch.ones((full_decoder_ids.shape[1], full_decoder_ids.shape[1] + 1))
            # )
            # decoder_mask = decoder_mask[-(len(draft_tokens) + 1):, :-1]
            # decoder_mask = 1 - decoder_mask
            
            # the shapes that we want to see are:
            # torch.Size([11, 12]) torch.Size([1, 12])
            # torch.Size([11, 12]) torch.Size([11, 12])
            # torch.Size([11, 32128])
            # torch.Size([11, 32128]) torch.Size([11, 32128])

            # What im getting
            # torch.Size([1, 12]) torch.Size([1, 12])
            # torch.Size([1, 12]) torch.Size([1, 12])
            # torch.Size([1, 11, 32128])


            decoder_mask = torch.ones(full_decoder_ids.shape[1])
            decoder_mask = decoder_mask.unsqueeze(0)


            # print(decoder_mask.shape, full_decoder_ids.shape)

            # conver to a batched input
            # input_ids = input_ids.repeat(len(draft_tokens) + 1, 1)
            # attention_mask = attention_mask.repeat(len(draft_tokens) + 1, 1)
            # full_decoder_ids = full_decoder_ids.repeat(len(draft_tokens) + 1, 1)

            # print(decoder_mask.shape, full_decoder_ids.shape)


            # outputs = self.target_model(
            #     input_ids=input_ids,
            #     attention_mask=attention_mask,
            #     decoder_input_ids=full_decoder_ids,
            #     decoder_attention_mask=decoder_mask,
            #     return_dict=True
            # )
            outputs = self.target_model(
                input_ids=full_decoder_ids,
                attention_mask=decoder_mask,
                return_dict=True
            )

            
            # dim_0_indices = torch.arange(len(draft_tokens) + 1)
            # dim_1_indices = torch.arange(len(draft_tokens) + 1) + full_decoder_ids.shape[1] - 1 - len(draft_tokens)
            # logits = outputs.logits[dim_0_indices, dim_1_indices, :]
            logits = outputs.logits[:, -(len(draft_tokens) + 1):, :]
            logits = logits.squeeze(0)

            # print(logits.shape)
            
            # Get probabilities for positions before each draft token
            # logits = outputs.logits[:, -(len(draft_tokens) + 1):-1, :]
            target_probs = F.softmax(logits, dim=-1)

            # print(target_probs.shape, target_probs.squeeze(0).shape)
            

            return target_probs.squeeze(0), logits

    def verify_tokens(
        self,
        target_probs: torch.Tensor,
        draft_tokens: torch.Tensor,
        draft_probs: torch.Tensor,
    ) -> int:
        """Determine number of accepted tokens"""
        # Get target probabilities for the draft tokens
        # get the probabilities of the tokens at the indices of the draft tokens
        target_probs_draft_tokens = torch.gather(target_probs, 1, draft_tokens.unsqueeze(0))
        # Calculate acceptance ratios
        acceptance_ratios = target_probs_draft_tokens / draft_probs.clamp(min=1e-10)

        # Sample uniform random numbers 
        random_nums = torch.rand_like(acceptance_ratios)
        acceptance_mask = random_nums <= acceptance_ratios

        num_accepted = (acceptance_mask.cumsum(dim=-1) == torch.arange(1, len(acceptance_ratios) + 1)).sum().item()

        return num_accepted

    def translate(
        self,
        source_text: str,
        max_length: int = 128
    ) -> str:
        """Generate from source text using speculative decoding."""
        # Encode source text
        encoder_inputs = self.tokenizer(
            source_text,
            return_tensors="pt",
            padding=True
        )

        # Initialize with start token
        decoder_input_ids = torch.tensor([[id for id in encoder_inputs.input_ids[0][:5]]])

        # output = self.target_model(
        #     input_ids=encoder_inputs.input_ids,
        #     attention_mask=encoder_inputs.attention_mask,
        #     decoder_input_ids=decoder_input_ids,
        #     return_dict=True
        # )

        # probs = output.logits[:, -1, :]
                    
        # probs = F.softmax(probs / (self.temperature + 1e-13), dim=-1)
        # token_id = torch.multinomial(probs, num_samples=1)

        # decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id, token_id.item()]])

        total_tokens = 0
        accepted_tokens = 0

        while decoder_input_ids.shape[1] < max_length:
            # Get draft tokens autoregressively
            draft_tokens, draft_probs, draft_logits = self.get_draft_logits(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids,
                self.gamma
            )

            draft_tokens = torch.tensor(draft_tokens)
            draft_probs = torch.tensor(draft_probs)

            # softmax the draft probs
            draft_probs = F.softmax(draft_probs / (self.temperature + 1e-13), dim=-1)

            if len(draft_tokens) == 0:
                raise ValueError("Draft tokens not generated.")

            # Get target probabilities in parallel
            # start = time.time()
            target_probs, target_logits = self.get_target_probs(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids,
                draft_tokens
            )
            # target probs are the logits but with softmax applied

            # Verify tokens
            n_accepted = self.verify_tokens(target_probs, draft_tokens, draft_probs)
            # Accept verified tokens
            if n_accepted > 0:
                decoder_input_ids = torch.cat([
                    decoder_input_ids,
                    draft_tokens[:n_accepted].unsqueeze(0)
                ], dim=1)
                
            with torch.no_grad():
                # n_rejected = self.gamma - n_accepted
                n_rejected = len(draft_tokens) - n_accepted 
                total_tokens += len(draft_tokens)
                accepted_tokens += n_accepted

                if n_rejected > 0:
                    probs = target_logits[-n_rejected, :] #- draft_logits[1-n_rejected, :]
                else:
                    probs = target_logits[-1, :]
                    
                probs = F.softmax(probs / (self.temperature + 1e-13), dim=-1)
                # probs /= max(self.temperature, 1e-13)
                # probs_max = torch.where(probs > 0, probs, torch.zeros_like(probs))
                # probs_max_sum = torch.sum(probs_max)
                # probs = probs_max / max(probs_max_sum, 1e-13)
                
                token_id = torch.multinomial(probs, num_samples=1).unsqueeze(0)

                decoder_input_ids = torch.cat([decoder_input_ids, token_id], dim=1)

            # Check for end of sequence
            if decoder_input_ids[0][-1].item() == self.tokenizer.eos_token_id:
                break

            # # or if a full stop is generated
            # if decoder_input_ids[0][-1].item() == self.tokenizer.convert_tokens_to_ids('.'):
            #     break

        # Decode translation
        translation = self.tokenizer.decode(
            decoder_input_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        perc_accepted = accepted_tokens / total_tokens * 100
        return translation, perc_accepted

In [5]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [6]:
# SpeculativeDecoder(target_model_name="google-t5/t5-3b", draft_model_name="google-t5/t5-small")
# SpeculativeDecoder(target_model_name="google-t5/t5-large", draft_model_name="google-t5/t5-base")
# decoder = SpeculativeDecoder()
# decoder.translate("Hi, how are you?", max_length=128)

In [7]:
# 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

# Only Large Model

In [8]:
class NormalDecoder:
    def __init__(
        self,
        model_name: str = "meta-llama/Llama-3.2-3B",
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        self.device = device

        # Initialize tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
        self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto')

        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model.eval()

    def get_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor
    ) -> torch.Tensor:
        """Get logits from model for the last token."""
        with torch.no_grad():
            outputs = self.model(
                input_ids=decoder_input_ids,
                attention_mask=attention_mask,
                # decoder_input_ids=decoder_input_ids,
                return_dict=True
            )
            return outputs.logits[:, -1, :]

    def sample_token(self, logits: torch.Tensor, temperature: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        """Sample a token from logits using temperature sampling."""
        if temperature == 0:
            # Greedy sampling
            token_id = torch.argmax(logits, dim=-1)
            prob = torch.ones_like(token_id, dtype=torch.float)
        else:
            # Temperature sampling
            probs = F.softmax(logits / temperature, dim=-1)
            token_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
            prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)
        return token_id, prob

    def translate(
        self,
        source_text: str,
        max_length: int = 128,
        temperature: float = 0.7
    ) -> str:
        """Translate source text using the normal T5 model."""
        # Encode source text
        encoder_inputs = self.tokenizer(
            source_text,
            return_tensors="pt",
            padding=True
        ).to(self.device)

        # Initialize decoder input with start token
        # decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], device=self.device)
        decoder_input_ids = torch.tensor([[id for id in encoder_inputs.input_ids[0][:5]]]).to(self.device)

        while decoder_input_ids.shape[1] < max_length:
            # Generate logits for the next token
            logits = self.get_logits(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids
            )

            # Sample a token
            token_id, _ = self.sample_token(logits, temperature)

            # Add token to the decoder input
            decoder_input_ids = torch.cat(
                [decoder_input_ids, token_id.view(1, 1)],
                dim=1
            )

            # Break if end token is generated
            if token_id.item() == self.tokenizer.eos_token_id:
                break

        # Decode and return translation
        translation = self.tokenizer.decode(
            decoder_input_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        return translation


In [9]:
decoder = SpeculativeDecoder()
begin = time.time()
res, perc = decoder.translate("Hi, how are you?", max_length=128)
spec_time = time.time() - begin

print(f"Speculative Generation: {res}")
print(f"Accepted Tokens: {perc:.2f}%")
print(f"Speculative Decoding Time: {spec_time:.2f}s")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Speculative Generation: Hi, how are you? I am Ozan, a Turkish guy who loves to travel and meet new people. I am a student and i currently live in Istanbul. It is my dream to visit all countries in the world meet new people and experience new cultures. I am flexible and open minded. YOu can contact me skype:anatolozan91 or:ozan.anatolozan@live.com.:
Accepted Tokens: 72.22%
Speculative Decoding Time: 8.43s


In [10]:
normal_decoder = NormalDecoder()
begin = time.time()
res_normal = normal_decoder.translate("Hi, how are you?", max_length=128)
normal_time = time.time() - begin

print(f"Normal Generation: {res_normal}")
print(f"Normal Decoding Time: {normal_time:.2f}s")
print(f"Speedup: {normal_time / spec_time:.2f}x")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Normal Generation: Hi, how are you? It's been a while since I posted here. I've had a busy summer. I've been cooking all the time, so I've been a bit busy. I've been taking loads of photos, but it's hard to find time to write. Now that I've got a bit of free time, I'm going to share with you some of the dishes I've been cooking. I'm going to start with this one. It's a simple dish, but it's really good. I've been making it for years, and it's always a hit. I've been cooking it for my
Normal Decoding Time: 10.69s
Speedup: 1.27x


# Experiments

## Varying gamma values for T5-3b(target) + T5-small(draft)

In [7]:
def vary_gamma(gamma=3):

    speculative_decoder = SpeculativeDecoder(gamma=gamma)
    normal_decoder = NormalDecoder()

    spec_total_time = 0
    normal_total_time = 0
    total_iters = 0
    total_pc = 0

    f = open("res/t53b-spec-gamma-{}.txt".format(gamma), "w")

    for i in tqdm(en_gr_dataset['translation'][:30]):
        source_text = i['en']
        target_text = i['de']
        
        # Time the translation
        start_time = time.time()
        spec_translation, pc = speculative_decoder.translate(source_text)
        end_time = time.time()

        spec_time = end_time - start_time

        # spec_total_time += spec_time
        
        start_time = time.time()
        normal_translation = normal_decoder.translate(source_text)
        end_time = time.time()

        normal_time = end_time - start_time

        # normal_total_time += normal_time

        # print(f"Source: {source_text}")
        # print(f"Normal Translation: {normal_translation}")
        # print(f"Time taken: {normal_time:.2f} seconds")
        # print(f"Speculative Translation: {spec_translation}")
        # print(f"Time taken: {spec_time:.2f} seconds")
        # print(f"Percentage tokens accepted: {pc:.2f}%")
        # print(f"Target: {target_text}")
        f.write(f"Source: {source_text}\n")
        f.write(f"Normal Translation: {normal_translation}\n")
        f.write(f"Time taken: {normal_time:.2f} seconds\n")
        f.write(f"Speculative Translation: {spec_translation}\n")
        f.write(f"Time taken: {spec_time:.2f} seconds\n")
        f.write(f"Percentage tokens accepted: {pc:.2f}%\n")
        f.write(f"Target: {target_text}\n")
        f.write("\n")

        spec_total_time += spec_time
        normal_total_time += normal_time
        total_pc += pc
        total_iters += 1

    print(f"\nAverage time taken for normal decoding: {normal_total_time / total_iters:.2f} seconds")
    print(f"Average time taken for speculative decoding: {spec_total_time / total_iters:.2f} seconds")
    print(f"Average percentage of tokens accepted: {total_pc / total_iters:.2f}%")
    print(f"Average speedup over {total_iters} iterations: {normal_total_time / spec_total_time:.2f}x")

    f.write(f"\nAverage time taken for normal decoding: {normal_total_time / total_iters:.2f} seconds\n")
    f.write(f"Average time taken for speculative decoding: {spec_total_time / total_iters:.2f} seconds\n")
    f.write(f"Average percentage of tokens accepted: {total_pc / total_iters:.2f}%\n")
    f.write(f"Average speedup over {total_iters} iterations: {normal_total_time / spec_total_time:.2f}x\n")
    f.write("\n")
    f.close()

In [8]:
vary_gamma(3)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
100%|██████████| 30/30 [02:46<00:00,  5.54s/it]


Average time taken for normal decoding: 3.17 seconds
Average time taken for speculative decoding: 2.37 seconds
Average percentage of tokens accepted: 78.73%
Average speedup over 30 iterations: 1.34x





In [8]:
vary_gamma(5)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
100%|██████████| 30/30 [02:31<00:00,  5.04s/it]


Average time taken for normal decoding: 3.20 seconds
Average time taken for speculative decoding: 1.84 seconds
Average percentage of tokens accepted: 78.10%
Average speedup over 30 iterations: 1.74x





In [9]:
vary_gamma(7)

100%|██████████| 30/30 [02:30<00:00,  5.03s/it]


Average time taken for normal decoding: 3.22 seconds
Average time taken for speculative decoding: 1.81 seconds
Average percentage of tokens accepted: 65.90%
Average speedup over 30 iterations: 1.78x





In [10]:
vary_gamma(9)

100%|██████████| 30/30 [02:52<00:00,  5.76s/it]


Average time taken for normal decoding: 3.59 seconds
Average time taken for speculative decoding: 2.17 seconds
Average percentage of tokens accepted: 59.11%
Average speedup over 30 iterations: 1.66x





## Varying models values for optimal gamma value

In [8]:
def vary_models(target_model_name = "google/t5-3b", draft_model_name = "google/t5-small", gamma=7, verbose=False):

    speculative_decoder = SpeculativeDecoder(target_model_name=target_model_name, draft_model_name=draft_model_name, gamma=gamma)
    normal_decoder = NormalDecoder(model_name=target_model_name)

    spec_total_time = 0
    normal_total_time = 0
    total_iters = 0
    total_pc = 0

    target_name = target_model_name.split("/")[1]
    draft_name = draft_model_name.split("/")[1]

    f = open(f"res/spec-{target_name}-{draft_name}.txt", "w")

    for i in tqdm(en_gr_dataset['translation'][:30]):
        source_text = i['en']
        target_text = i['de']
        
        # Time the translation
        start_time = time.time()
        spec_translation, pc = speculative_decoder.translate(source_text)
        end_time = time.time()

        spec_time = end_time - start_time

        # spec_total_time += spec_time
        
        start_time = time.time()
        normal_translation = normal_decoder.translate(source_text)
        end_time = time.time()

        normal_time = end_time - start_time

        # normal_total_time += normal_time
        if verbose:
            print(f"Source: {source_text}")
            print(f"Normal Translation: {normal_translation}")
            print(f"Time taken: {normal_time:.2f} seconds")
            print(f"Speculative Translation: {spec_translation}")
            print(f"Time taken: {spec_time:.2f} seconds")
            print(f"Percentage tokens accepted: {pc:.2f}%")
            print(f"Target: {target_text}")
        f.write(f"Source: {source_text}\n")
        f.write(f"Normal Translation: {normal_translation}\n")
        f.write(f"Time taken: {normal_time:.2f} seconds\n")
        f.write(f"Speculative Translation: {spec_translation}\n")
        f.write(f"Time taken: {spec_time:.2f} seconds\n")
        f.write(f"Percentage tokens accepted: {pc:.2f}%\n")
        f.write(f"Target: {target_text}\n")
        f.write("\n")

        spec_total_time += spec_time
        normal_total_time += normal_time
        total_pc += pc
        total_iters += 1

    print(f"\nAverage time taken for normal decoding: {normal_total_time / total_iters:.2f} seconds")
    print(f"Average time taken for speculative decoding: {spec_total_time / total_iters:.2f} seconds")
    print(f"Average percentage of tokens accepted: {total_pc / total_iters:.2f}%")
    print(f"Average speedup over {total_iters} iterations: {normal_total_time / spec_total_time:.2f}x")

    f.write(f"\nAverage time taken for normal decoding: {normal_total_time / total_iters:.2f} seconds\n")
    f.write(f"Average time taken for speculative decoding: {spec_total_time / total_iters:.2f} seconds\n")
    f.write(f"Average percentage of tokens accepted: {total_pc / total_iters:.2f}%\n")
    f.write(f"Average speedup over {total_iters} iterations: {normal_total_time / spec_total_time:.2f}x\n")
    f.write("\n")
    f.close()

In [10]:
vary_models("google-t5/t5-3b", "google-t5/t5-small")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
100%|██████████| 30/30 [02:29<00:00,  4.99s/it]


Average time taken for normal decoding: 3.14 seconds
Average time taken for speculative decoding: 1.85 seconds
Average percentage of tokens accepted: 71.06%
Average speedup over 30 iterations: 1.69x





In [13]:
vary_models("google-t5/t5-3b", "google-t5/t5-base", gamma=9)

100%|██████████| 30/30 [03:06<00:00,  6.23s/it]


Average time taken for normal decoding: 3.46 seconds
Average time taken for speculative decoding: 2.76 seconds
Average percentage of tokens accepted: 62.88%
Average speedup over 30 iterations: 1.25x





In [None]:
vary_models("google-t5/t5-3b", "google-t5/t5-large", gamma=7, verbose=True)

In [9]:
vary_models("google-t5/t5-large", "google-t5/t5-small")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
100%|██████████| 30/30 [02:26<00:00,  4.87s/it]


Average time taken for normal decoding: 3.04 seconds
Average time taken for speculative decoding: 1.83 seconds
Average percentage of tokens accepted: 70.80%
Average speedup over 30 iterations: 1.66x





In [12]:
vary_models("google-t5/t5-large", "google-t5/t5-base", gamma=9)

100%|██████████| 30/30 [03:05<00:00,  6.19s/it]


Average time taken for normal decoding: 3.06 seconds
Average time taken for speculative decoding: 3.13 seconds
Average percentage of tokens accepted: 65.75%
Average speedup over 30 iterations: 0.98x





In [11]:
vary_models("google-t5/t5-base", "google-t5/t5-small")

100%|██████████| 30/30 [01:35<00:00,  3.17s/it]


Average time taken for normal decoding: 1.60 seconds
Average time taken for speculative decoding: 1.57 seconds
Average percentage of tokens accepted: 71.93%
Average speedup over 30 iterations: 1.02x





## Varying temperature values

In [7]:
def vary_temp(target_model_name = "google-t5/t5-3b", draft_model_name = "google-t5/t5-small", temperature=0.7):

    speculative_decoder = SpeculativeDecoder(target_model_name=target_model_name, draft_model_name=draft_model_name, gamma=7, temperature=temperature)
    normal_decoder = NormalDecoder(model_name=target_model_name)

    spec_total_time = 0
    normal_total_time = 0
    total_iters = 0
    total_pc = 0

    f = open(f"res/spec-temp-{temperature}.txt", "w")

    for i in tqdm(en_gr_dataset['translation'][:30]):
        source_text = i['en']
        target_text = i['de']
        
        # Time the translation
        start_time = time.time()
        spec_translation, pc = speculative_decoder.translate(source_text)
        end_time = time.time()

        spec_time = end_time - start_time

        # spec_total_time += spec_time
        
        start_time = time.time()
        normal_translation = normal_decoder.translate(source_text)
        end_time = time.time()

        normal_time = end_time - start_time

        # normal_total_time += normal_time

        # print(f"Source: {source_text}")
        # print(f"Normal Translation: {normal_translation}")
        # print(f"Time taken: {normal_time:.2f} seconds")
        # print(f"Speculative Translation: {spec_translation}")
        # print(f"Time taken: {spec_time:.2f} seconds")
        # print(f"Percentage tokens accepted: {pc:.2f}%")
        # print(f"Target: {target_text}")
        f.write(f"Source: {source_text}\n")
        f.write(f"Normal Translation: {normal_translation}\n")
        f.write(f"Time taken: {normal_time:.2f} seconds\n")
        f.write(f"Speculative Translation: {spec_translation}\n")
        f.write(f"Time taken: {spec_time:.2f} seconds\n")
        f.write(f"Percentage tokens accepted: {pc:.2f}%\n")
        f.write(f"Target: {target_text}\n")
        f.write("\n")

        spec_total_time += spec_time
        normal_total_time += normal_time
        total_pc += pc
        total_iters += 1

    print(f"\nAverage time taken for normal decoding: {normal_total_time / total_iters:.2f} seconds")
    print(f"Average time taken for speculative decoding: {spec_total_time / total_iters:.2f} seconds")
    print(f"Average percentage of tokens accepted: {total_pc / total_iters:.2f}%")
    print(f"Average speedup over {total_iters} iterations: {normal_total_time / spec_total_time:.2f}x")

    f.write(f"\nAverage time taken for normal decoding: {normal_total_time / total_iters:.2f} seconds\n")
    f.write(f"Average time taken for speculative decoding: {spec_total_time / total_iters:.2f} seconds\n")
    f.write(f"Average percentage of tokens accepted: {total_pc / total_iters:.2f}%\n")
    f.write(f"Average speedup over {total_iters} iterations: {normal_total_time / spec_total_time:.2f}x\n")
    f.write("\n")
    f.close()

In [11]:
vary_temp(temperature=0.0)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
100%|██████████| 30/30 [03:22<00:00,  6.76s/it]


Average time taken for normal decoding: 3.19 seconds
Average time taken for speculative decoding: 3.57 seconds
Average percentage of tokens accepted: 17.77%
Average speedup over 30 iterations: 0.90x





In [8]:
vary_temp(temperature=0.3)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
100%|██████████| 30/30 [02:53<00:00,  5.79s/it]


Average time taken for normal decoding: 3.18 seconds
Average time taken for speculative decoding: 2.60 seconds
Average percentage of tokens accepted: 35.19%
Average speedup over 30 iterations: 1.22x





In [8]:
vary_temp(temperature=0.7)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
100%|██████████| 30/30 [02:40<00:00,  5.37s/it]


Average time taken for normal decoding: 3.41 seconds
Average time taken for speculative decoding: 1.96 seconds
Average percentage of tokens accepted: 64.49%
Average speedup over 30 iterations: 1.74x





In [8]:
vary_temp(temperature=1.0)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
100%|██████████| 30/30 [03:05<00:00,  6.19s/it]


Average time taken for normal decoding: 3.88 seconds
Average time taken for speculative decoding: 2.31 seconds
Average percentage of tokens accepted: 67.20%
Average speedup over 30 iterations: 1.68x



