In [11]:
import torch
import torch.nn.functional as F
from transformers import T5ForConditionalGeneration, T5Tokenizer
from typing import List, Tuple, Optional
import time

In [12]:
class SpeculativeDecoder:
    def __init__(
        self,
        target_model_name = "t5-large",
        draft_model_name = "t5-small",
        device = "cuda" if torch.cuda.is_available() else "cpu",
        gamma = 4
    ):
        self.device = device
        self.gamma = gamma

        self.tokenizer = T5Tokenizer.from_pretrained(target_model_name)
        self.target_model = T5ForConditionalGeneration.from_pretrained(target_model_name).to(device)
        self.draft_model = T5ForConditionalGeneration.from_pretrained(draft_model_name).to(device)

        self.target_model.eval()
        self.draft_model.eval()

    def get_draft_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        gamma: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get draft logits for gamma tokens"""
        draft_tokens = []
        draft_probs = []
        current_decoder_ids = decoder_input_ids

        # Generate gamma tokens from the draft model
        for _ in range(gamma):
            with torch.no_grad():
                outputs = self.draft_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=current_decoder_ids,
                    return_dict=True
                )
                logits = outputs.logits[:, -1, :]  # Get logits for last position
                probs = F.softmax(logits, dim=-1)

                # Sample token
                token_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
                prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)

                draft_tokens.append(token_id.item())
                draft_probs.append(prob.item())

                # Update decoder inputs for next iteration
                current_decoder_ids = torch.cat(
                    [current_decoder_ids, token_id.view(1, 1)],
                    dim=1
                )

                if token_id.item() == self.tokenizer.eos_token_id:
                    break

        return draft_tokens, draft_probs, current_decoder_ids

    def get_target_probs(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        draft_tokens: torch.Tensor
    ) -> torch.Tensor:
        """Get target probabilities for the draft tokens in parallel."""
        with torch.no_grad():
            # Add draft tokens to decoder input
            full_decoder_ids = torch.cat([decoder_input_ids, draft_tokens.unsqueeze(0)], dim=1)

            outputs = self.target_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=full_decoder_ids,
                return_dict=True
            )

            # Get probabilities for positions before each draft token
            logits = outputs.logits[:, -(len(draft_tokens) + 1):-1, :]
            target_probs = F.softmax(logits, dim=-1)

            return target_probs.squeeze(0)

    def verify_tokens(
        self,
        target_probs: torch.Tensor,
        draft_tokens: torch.Tensor,
        draft_probs: torch.Tensor,
    ) -> int:
        """Determine number of accepted tokens"""
        # Get target probabilities for the draft tokens
        target_probs_draft_tokens = target_probs.gather(
            -1,
            draft_tokens.unsqueeze(-1)
        ).squeeze(-1)

        # Calculate acceptance ratios
        acceptance_ratios = target_probs_draft_tokens / draft_probs

        # Sample uniform random numbers
        random_nums = torch.rand_like(acceptance_ratios)

        # Find number of accepted tokens
        # Accept if random number < min(1, target_prob / draft_prob)
        accepts = random_nums < torch.minimum(
            torch.ones_like(acceptance_ratios),
            acceptance_ratios
        )

        # Find first rejection
        try:
            n_accepted = torch.where(~accepts)[0][0].item()
        except:
            n_accepted = len(accepts)

        return n_accepted

    def translate(
        self,
        source_text: str,
        max_length: int = 128
    ) -> str:
        """Translate source text using speculative decoding."""
        # Encode source text
        encoder_inputs = self.tokenizer(
            f"translate English to German: {source_text}",
            return_tensors="pt",
            padding=True
        ).to(self.device)

        # Initialize with start token
        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], device=self.device)

        while decoder_input_ids.shape[1] < max_length:
            # Get draft tokens autoregressively
            draft_tokens, draft_probs, proposed_decoder_ids = self.get_draft_logits(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids,
                self.gamma
            )

            draft_tokens = torch.tensor(draft_tokens, device=self.device)
            draft_probs = torch.tensor(draft_probs, device=self.device)

            if len(draft_tokens) == 0:
                break

            # Get target probabilities in parallel
            target_probs = self.get_target_probs(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids,
                draft_tokens
            )

            # Verify tokens
            n_accepted = self.verify_tokens(target_probs, draft_tokens, draft_probs)

            # Accept verified tokens
            if n_accepted > 0:
                decoder_input_ids = torch.cat([
                    decoder_input_ids,
                    draft_tokens[:n_accepted].unsqueeze(0)
                ], dim=1)

            # If rejection or no acceptance, sample one token from target
            if n_accepted < len(draft_tokens):
                with torch.no_grad():
                    outputs = self.target_model(
                        input_ids=encoder_inputs.input_ids,
                        attention_mask=encoder_inputs.attention_mask,
                        decoder_input_ids=decoder_input_ids,
                        return_dict=True
                    )
                    logits = outputs.logits[:, -1, :]
                    probs = F.softmax(logits, dim=-1)
                    token_id = torch.multinomial(probs, num_samples=1)
                    decoder_input_ids = torch.cat([decoder_input_ids, token_id], dim=1)

            # Check for end of sequence
            if decoder_input_ids[0][-1].item() == self.tokenizer.eos_token_id:
                break

        # Decode translation
        translation = self.tokenizer.decode(
            decoder_input_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        return translation

In [13]:
# Initialize decoder
decoder = SpeculativeDecoder()

# Example translation
source_text = "In a world where technology evolves at an unprecedented pace, individuals and organizations must adapt quickly to the rapid advancements in artificial intelligence, machine learning, and automation, ensuring that ethical considerations, environmental sustainability, and equitable access to resources are prioritized to create a future that benefits all of humanity."

# Time the translation
start_time = time.time()
translation = decoder.translate(source_text)
end_time = time.time()

print(f"Source: {source_text}")
print(f"Translation: {translation}")
print(f"Time taken: {end_time - start_time:.2f} seconds")

Source: In a world where technology evolves at an unprecedented pace, individuals and organizations must adapt quickly to the rapid advancements in artificial intelligence, machine learning, and automation, ensuring that ethical considerations, environmental sustainability, and equitable access to resources are prioritized to create a future that benefits all of humanity.
Translation: In einer Welt, in der die Technik in nie da gewesenem Tempo weiter wächst, müssen sich die Einzelnen und Organisationen rasch an die raschen Entwicklungen in den Bereichen künstlicher Intelligenz, Maschinelle Lernen und Automatisierung wenden, um zu sichern, dass ethische Erwägungen, ökologische Nachhaltigkeit und gerechter Zugang zu Ressourcen Priorität haben, um damals eine Zukunft zu schaffen, von der die Menschheit insgesamt profitieren kann.
Time taken: 4.80 seconds


In [14]:
class NormalDecoder:
    def __init__(
        self,
        model_name: str = "t5-large",
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        self.device = device

        # Initialize tokenizer and model
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
        self.model.eval()

    def get_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor
    ) -> torch.Tensor:
        """Get logits from model for the last token."""
        with torch.no_grad():
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                return_dict=True
            )
            return outputs.logits[:, -1, :]

    def sample_token(self, logits: torch.Tensor, temperature: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        """Sample a token from logits using temperature sampling."""
        if temperature == 0:
            # Greedy sampling
            token_id = torch.argmax(logits, dim=-1)
            prob = torch.ones_like(token_id, dtype=torch.float)
        else:
            # Temperature sampling
            probs = F.softmax(logits / temperature, dim=-1)
            token_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
            prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)
        return token_id, prob

    def translate(
        self,
        source_text: str,
        max_length: int = 128,
        temperature: float = 0.7
    ) -> str:
        """Translate source text using the normal T5 model."""
        # Encode source text
        encoder_inputs = self.tokenizer(
            f"translate English to German: {source_text}",
            return_tensors="pt",
            padding=True
        ).to(self.device)

        # Initialize decoder input with start token
        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], device=self.device)

        while decoder_input_ids.shape[1] < max_length:
            # Generate logits for the next token
            logits = self.get_logits(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids
            )

            # Sample a token
            token_id, _ = self.sample_token(logits, temperature)

            # Add token to the decoder input
            decoder_input_ids = torch.cat(
                [decoder_input_ids, token_id.view(1, 1)],
                dim=1
            )

            # Break if end token is generated
            if token_id.item() == self.tokenizer.eos_token_id:
                break

        # Decode and return translation
        translation = self.tokenizer.decode(
            decoder_input_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        return translation


In [15]:
# Initialize decoder
base_decoder = NormalDecoder()

# Example translation
source_text = "In a world where technology evolves at an unprecedented pace, individuals and organizations must adapt quickly to the rapid advancements in artificial intelligence, machine learning, and automation, ensuring that ethical considerations, environmental sustainability, and equitable access to resources are prioritized to create a future that benefits all of humanity."

# Time the translation
start_time = time.time()
translation = base_decoder.translate(source_text)
end_time = time.time()

print(f"Source: {source_text}")
print(f"Translation: {translation}")
print(f"Time taken: {end_time - start_time:.2f} seconds")

Source: In a world where technology evolves at an unprecedented pace, individuals and organizations must adapt quickly to the rapid advancements in artificial intelligence, machine learning, and automation, ensuring that ethical considerations, environmental sustainability, and equitable access to resources are prioritized to create a future that benefits all of humanity.
Translation: In einer Welt, in der sich die Technologie mit einem beispiellosen Tempo weiterentwickelt, müssen sich Individuen und Organisationen rasch an die schnellen Fortschritte in den Bereichen künstliche Intelligenz, maschinelles Lernen und Automatisierung anpassen, um sicherzustellen, dass ethische Überlegungen, Umweltverträglichkeit und gerechter Zugang zu Ressourcen Vorrang haben, um eine Zukunft zu schaffen, von der die gesamte Menschheit profitiert
Time taken: 7.74 seconds
