In [1]:
import torch
import torch.nn.functional as F
from transformers import T5ForConditionalGeneration, T5Tokenizer
from typing import List, Tuple, Optional
import time

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class SpeculativeDecoder:
    def __init__(
        self,
        target_model_name = "google-t5/t5-large",
        draft_model_name = "google-t5/t5-small",
        device = "cuda" if torch.cuda.is_available() else "cpu",
        gamma = 4
    ):
        self.device = device
        self.gamma = gamma

        self.tokenizer = T5Tokenizer.from_pretrained(target_model_name)
        self.target_model = T5ForConditionalGeneration.from_pretrained(target_model_name).to(device)
        self.draft_model = T5ForConditionalGeneration.from_pretrained(draft_model_name).to(device)

        self.target_model.eval()
        self.draft_model.eval()

    def get_draft_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        gamma: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get draft logits for gamma tokens"""
        draft_tokens = []
        draft_probs = []
        current_decoder_ids = decoder_input_ids

        # Generate gamma tokens from the draft model
        for _ in range(gamma):
            with torch.no_grad():
                outputs = self.draft_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=current_decoder_ids,
                    return_dict=True
                )
                logits = outputs.logits[:, -1, :]  # Get logits for last position
                probs = F.softmax(logits, dim=-1)

                # Sample token
                token_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
                prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)

                draft_tokens.append(token_id.item())
                draft_probs.append(prob.item())

                # Update decoder inputs for next iteration
                current_decoder_ids = torch.cat(
                    [current_decoder_ids, token_id.view(1, 1)],
                    dim=1
                )

                if token_id.item() == self.tokenizer.eos_token_id:
                    break

        return draft_tokens, draft_probs, current_decoder_ids, logits

    def get_target_probs(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        draft_tokens: torch.Tensor
    ) -> torch.Tensor:
        """Get target probabilities for the draft tokens in parallel."""
        with torch.no_grad():
            # Add draft tokens to decoder input
            full_decoder_ids = torch.cat([decoder_input_ids, draft_tokens.unsqueeze(0)], dim=1)

            outputs = self.target_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=full_decoder_ids,
                return_dict=True
            )

            # Get probabilities for positions before each draft token
            logits = outputs.logits[:, -(len(draft_tokens) + 1):-1, :]
            target_probs = F.softmax(logits, dim=-1)

            return target_probs.squeeze(0), outputs.logits[:, -1, :]

    def verify_tokens(
        self,
        target_probs: torch.Tensor,
        draft_tokens: torch.Tensor,
        draft_probs: torch.Tensor,
    ) -> int:
        """Determine number of accepted tokens"""
        # Get target probabilities for the draft tokens
        target_probs_draft_tokens = target_probs.gather(
            -1,
            draft_tokens.unsqueeze(-1)
        ).squeeze(-1)

        # Calculate acceptance ratios
        acceptance_ratios = target_probs_draft_tokens / draft_probs

        # Sample uniform random numbers
        random_nums = torch.rand_like(acceptance_ratios)

        # Find number of accepted tokens
        # Accept if random number < min(1, target_prob / draft_prob)
        accepts = random_nums < torch.minimum(
            torch.ones_like(acceptance_ratios),
            acceptance_ratios
        )

        # Find first rejection
        try:
            n_accepted = torch.where(~accepts)[0][0].item()
        except:
            n_accepted = len(accepts)

        return n_accepted

    def translate(
        self,
        source_text: str,
        max_length: int = 128
    ) -> str:
        """Translate source text using speculative decoding."""
        # Encode source text
        encoder_inputs = self.tokenizer(
            f"translate English to German: {source_text}",
            return_tensors="pt",
            padding=True
        ).to(self.device)

        # Initialize with start token
        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], device=self.device)

        while decoder_input_ids.shape[1] < max_length:
            # Get draft tokens autoregressively
            draft_tokens, draft_probs, proposed_decoder_ids, draft_logits = self.get_draft_logits(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids,
                self.gamma
            )

            draft_tokens = torch.tensor(draft_tokens, device=self.device)
            draft_probs = torch.tensor(draft_probs, device=self.device)

            if len(draft_tokens) == 0:
                break

            # Get target probabilities in parallel
            target_probs, target_logits = self.get_target_probs(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids,
                draft_tokens
            )

            # Verify tokens
            # print(target_probs.shape, draft_probs.shape)
            n_accepted = self.verify_tokens(target_probs, draft_tokens, draft_probs)

            # Accept verified tokens
            if n_accepted > 0:
                decoder_input_ids = torch.cat([
                    decoder_input_ids,
                    draft_tokens[:n_accepted].unsqueeze(0)
                ], dim=1)

            # If rejection or no acceptance, sample one token from target
            # if n_accepted < len(draft_tokens):
                # with torch.no_grad():
                #     outputs = self.target_model(
                #         input_ids=encoder_inputs.input_ids,
                #         attention_mask=encoder_inputs.attention_mask,
                #         decoder_input_ids=decoder_input_ids,
                #         return_dict=True
                #     )
                #     logits = outputs.logits[:, -1, :]
                #     probs = F.softmax(logits, dim=-1)
                #     token_id = torch.multinomial(probs, num_samples=1)
                #     decoder_input_ids = torch.cat([decoder_input_ids, token_id], dim=1)
                # print(target_probs.shape, draft_probs.shape)
                
            with torch.no_grad():
                # print(target_logits.shape, draft_logits.shape)
                probs = target_logits #- draft_logits
                probs = F.softmax(probs, dim=-1)
                token_id = torch.multinomial(probs, num_samples=1)
                # print(probs.shape, token_id.shape)
                decoder_input_ids = torch.cat([decoder_input_ids, token_id], dim=1)

            # Check for end of sequence
            if decoder_input_ids[0][-1].item() == self.tokenizer.eos_token_id:
                break

        # Decode translation
        translation = self.tokenizer.decode(
            decoder_input_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        return translation

In [5]:
# Initialize decoder
decoder = SpeculativeDecoder()

# Example translation
source_text = "In a world where technology evolves at an unprecedented pace, individuals and organizations must adapt quickly to the rapid advancements in artificial intelligence, machine learning, and automation, ensuring that ethical considerations, environmental sustainability, and equitable access to resources are prioritized to create a future that benefits all of humanity."

# Time the translation
start_time = time.time()
translation = decoder.translate(source_text)
end_time = time.time()

print(f"Source: {source_text}")
print(f"Translation: {translation}\n")
print(f"Time taken: {end_time - start_time:.2f} seconds")

Source: In a world where technology evolves at an unprecedented pace, individuals and organizations must adapt quickly to the rapid advancements in artificial intelligence, machine learning, and automation, ensuring that ethical considerations, environmental sustainability, and equitable access to resources are prioritized to create a future that benefits all of humanity.
Translation: In einer Welt, in der sich die Technik in einem beispiellosen Tempo müssen sich Menschen und Organisationen schnell den rasanten Entwicklungen im Bereich künstlicher Intelligenz, unds und Automatisierung ein anpassen,eischer, um dass Überlegungen um die den Menschen jeweiligen Lebensstandarddenblicksetzte und an  Ressourcen gerechter an

Time taken: 3.98 seconds


In [8]:
class NormalDecoder:
    def __init__(
        self,
        model_name: str = "google-t5/t5-large",
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        self.device = device

        # Initialize tokenizer and model
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
        self.model.eval()

    def get_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor
    ) -> torch.Tensor:
        """Get logits from model for the last token."""
        with torch.no_grad():
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                return_dict=True
            )
            return outputs.logits[:, -1, :]

    def sample_token(self, logits: torch.Tensor, temperature: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        """Sample a token from logits using temperature sampling."""
        if temperature == 0:
            # Greedy sampling
            token_id = torch.argmax(logits, dim=-1)
            prob = torch.ones_like(token_id, dtype=torch.float)
        else:
            # Temperature sampling
            probs = F.softmax(logits / temperature, dim=-1)
            token_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
            prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)
        return token_id, prob

    def translate(
        self,
        source_text: str,
        max_length: int = 128,
        temperature: float = 0.7
    ) -> str:
        """Translate source text using the normal T5 model."""
        # Encode source text
        encoder_inputs = self.tokenizer(
            f"translate English to German: {source_text}",
            return_tensors="pt",
            padding=True
        ).to(self.device)

        # Initialize decoder input with start token
        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], device=self.device)

        while decoder_input_ids.shape[1] < max_length:
            # Generate logits for the next token
            logits = self.get_logits(
                encoder_inputs.input_ids,
                encoder_inputs.attention_mask,
                decoder_input_ids
            )

            # Sample a token
            token_id, _ = self.sample_token(logits, temperature)

            # Add token to the decoder input
            decoder_input_ids = torch.cat(
                [decoder_input_ids, token_id.view(1, 1)],
                dim=1
            )

            # Break if end token is generated
            if token_id.item() == self.tokenizer.eos_token_id:
                break

        # Decode and return translation
        translation = self.tokenizer.decode(
            decoder_input_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        return translation


In [9]:
# Initialize decoder
base_decoder = NormalDecoder()

# Example translation
source_text = "In a world where technology evolves at an unprecedented pace, individuals and organizations must adapt quickly to the rapid advancements in artificial intelligence, machine learning, and automation, ensuring that ethical considerations, environmental sustainability, and equitable access to resources are prioritized to create a future that benefits all of humanity."

# Time the translation
start_time = time.time()
translation = base_decoder.translate(source_text)
end_time = time.time()

print(f"Source: {source_text}")
print(f"Translation: {translation}\n")
print(f"Time taken: {end_time - start_time:.2f} seconds")

Source: In a world where technology evolves at an unprecedented pace, individuals and organizations must adapt quickly to the rapid advancements in artificial intelligence, machine learning, and automation, ensuring that ethical considerations, environmental sustainability, and equitable access to resources are prioritized to create a future that benefits all of humanity.
Translation: In einer Welt, in der die Technologie in beispielloser Geschwindigkeit eine Entwicklung durchläuft, müssen sich Menschen und Organisationen schnell auf die rapiden Fortschritte bei der künstlichen Intelligenz, beim maschinellen Lernen und bei der Automatisierung einstellen, indem ethische Überlegungen, ökologische Nachhaltigkeit und gleichberechtigter Zugang zu Ressourcen Priorität erhalten, um eine Zukunft zu schaffen, von der die gesamte Menschheit profitiert.

Time taken: 5.54 seconds


In [None]:
from scipy.spatial.distance import jensenshannon as JSD

class OnlineSpeculativeDecoder:
    def __init__(
        self,
        target_model_name = "google-t5/t5-large",
        draft_model_name = "google-t5/t5-small",
        device = "cuda" if torch.cuda.is_available() else "cpu",
        gamma = 4,
        update_interval=2,  # Update draft model after every `update_interval` iterations
        buffer_size_threshold=2,  # Buffer size threshold for updates
        time_threshold=2,  # Time threshold (in seconds) for updates
    ):
        self.device = device
        self.gamma = gamma
        self.update_interval = update_interval
        self.buffer_size_threshold = buffer_size_threshold
        self.time_threshold = time_threshold
        self.last_update_time = time.time()  # Track last update time

        self.tokenizer = T5Tokenizer.from_pretrained(target_model_name)
        self.target_model = T5ForConditionalGeneration.from_pretrained(target_model_name).to(device)
        self.draft_model = T5ForConditionalGeneration.from_pretrained(draft_model_name).to(device)

        self.target_model.eval()
        self.draft_model.eval()

        # Buffers for storing token proposals and updates
        self.replay_buffer = []
        self.temp_buffer = []  # Temporary buffer for a single request

        # Counter for iteration tracking
        self.iteration_count = 0

    def get_draft_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        gamma: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get draft logits for gamma tokens"""
        draft_tokens = []
        draft_probs = []
        current_decoder_ids = decoder_input_ids

        # Generate gamma tokens from the draft model
        for _ in range(gamma):
            with torch.no_grad():
                outputs = self.draft_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=current_decoder_ids,
                    return_dict=True
                )
                logits = outputs.logits[:, -1, :]  # Get logits for last position
                probs = F.softmax(logits, dim=-1)

                # Sample token
                token_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
                prob = probs.gather(-1, token_id.unsqueeze(-1)).squeeze(-1)

                draft_tokens.append(token_id.item())
                draft_probs.append(prob.item())

                # Update decoder inputs for next iteration
                current_decoder_ids = torch.cat(
                    [current_decoder_ids, token_id.view(1, 1)],
                    dim=1
                )

                if token_id.item() == self.tokenizer.eos_token_id:
                    break

        return draft_tokens, draft_probs, current_decoder_ids, outputs.logits

    def get_target_probs(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor,
        draft_tokens: torch.Tensor
    ) -> torch.Tensor:
        """Get target probabilities for the draft tokens in parallel."""
        with torch.no_grad():
            # Add draft tokens to decoder input
            full_decoder_ids = torch.cat([decoder_input_ids, draft_tokens.unsqueeze(0)], dim=1)

            outputs = self.target_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=full_decoder_ids,
                return_dict=True
            )

            # Get probabilities for positions before each draft token
            logits = outputs.logits[:, -(len(draft_tokens) + 1):-1, :]
            target_probs = F.softmax(logits, dim=-1)

            return target_probs.squeeze(0), outputs.logits

    def verify_tokens(
        self,
        target_probs: torch.Tensor,
        draft_tokens: torch.Tensor,
        draft_probs: torch.Tensor,
    ) -> int:
        """Determine number of accepted tokens"""
        # Get target probabilities for the draft tokens
        target_probs_draft_tokens = target_probs.gather(
            -1,
            draft_tokens.unsqueeze(-1)
        ).squeeze(-1)

        # Calculate acceptance ratios
        acceptance_ratios = target_probs_draft_tokens / draft_probs

        # Sample uniform random numbers
        random_nums = torch.rand_like(acceptance_ratios)

        # Find number of accepted tokens
        # Accept if random number < min(1, target_prob / draft_prob)
        accepts = random_nums < torch.minimum(
            torch.ones_like(acceptance_ratios),
            acceptance_ratios
        )

        # Find first rejection
        try:
            n_accepted = torch.where(~accepts)[0][0].item()
        except:
            n_accepted = len(accepts)

        return n_accepted

        # accepted_tokens = []
        # for i in range(len(draft_tokens)):
        #     if target_probs[i] / draft_probs[i] > torch.rand(1).item():
        #         accepted_tokens.append(draft_tokens[i])
        #     else:
        #         break # Stop if token is not accepted

        # return len(accepted_tokens)
    
    # TODO: verify this, might need to do some window size thing
    def update_draft_model(self):
        """Update draft model with the replay buffer."""
        if len(self.replay_buffer) == 0:
            return

        # Get draft tokens, draft and target probabilities from the replay buffer
        # draft_tokens = torch.tensor([x[0] for x in self.replay_buffer], device=self.device)
        # print(self.replay_buffer[0][0])
        # draft_probs = self.replay_buffer[:, 0]
        # target_probs = self.replay_buffer[:, 1]
        draft_probs = torch.stack([x[0][0] for x in self.replay_buffer], dim=0)
        target_probs = torch.stack([x[1][0] for x in self.replay_buffer], dim=0)

        # Calculate loss and update draft model
        print("Updating")
        # loss = JSD between draft and target probs
        loss = JSD(
            F.softmax(draft_probs.cpu(), dim=-1),
            F.softmax(target_probs.cpu(), dim=-1)
        )

        self.draft_model.zero_grad()
        
        # TODO: figure out how tf to finetune
        loss.backward()
        self.draft_model.optimizer.step()

        # Reset the replay buffer
        self.replay_buffer = []

    def translate(
        self,
        source_text: str,
        max_length: int = 128
    ) -> str:
        """Translate source text using speculative decoding."""
        # Encode source text
        encoder_inputs = self.tokenizer(
            f"translate English to German: {source_text}",
            return_tensors="pt",
            padding=True
        ).to(self.device)

        # Initialize with start token
        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], device=self.device)

        self.iteration_count = 0
        self.replay_buffer = []

        while decoder_input_ids.shape[1] < max_length:
            self.temp_buffer = []

            while decoder_input_ids.shape[1] < max_length:
                # Get draft tokens autoregressively
                draft_tokens, draft_probs, proposed_decoder_ids, draft_logits = self.get_draft_logits(
                    encoder_inputs.input_ids,
                    encoder_inputs.attention_mask,
                    decoder_input_ids,
                    self.gamma
                )

                draft_tokens = torch.tensor(draft_tokens, device=self.device)
                draft_probs = torch.tensor(draft_probs, device=self.device)

                if len(draft_tokens) == 0:
                    break

                # Get target probabilities in parallel
                target_probs, target_logits = self.get_target_probs(
                    encoder_inputs.input_ids,
                    encoder_inputs.attention_mask,
                    decoder_input_ids,
                    draft_tokens
                )

                # Verify tokens
                n_accepted = self.verify_tokens(target_probs, draft_tokens, draft_probs)

                # Accept verified tokens
                if n_accepted > 0:
                    decoder_input_ids = torch.cat([
                        decoder_input_ids,
                        draft_tokens[:n_accepted].unsqueeze(0)
                    ], dim=1)

                # # If rejection or no acceptance, sample one token from target
                # if n_accepted < len(draft_tokens):
                #     with torch.no_grad():
                #         outputs = self.target_model(
                #             input_ids=encoder_inputs.input_ids,
                #             attention_mask=encoder_inputs.attention_mask,
                #             decoder_input_ids=decoder_input_ids,
                #             return_dict=True
                #         )
                #         logits = outputs.logits[:, -1, :]
                #         probs = F.softmax(logits, dim=-1)
                #         token_id = torch.multinomial(probs, num_samples=1)
                #         decoder_input_ids = torch.cat([decoder_input_ids, token_id], dim=1)

                with torch.no_grad():
                    # print(target_logits.shape, draft_logits.shape)
                    probs = target_logits[:, -1, :] #- draft_logits[:, -1, :]
                    probs = F.softmax(probs, dim=-1)
                    token_id = torch.multinomial(probs, num_samples=1)
                    # print(probs.shape, token_id.shape)
                    decoder_input_ids = torch.cat([decoder_input_ids, token_id], dim=1)

                # Check for end of sequence
                if decoder_input_ids[0][-1].item() == self.tokenizer.eos_token_id:
                    break
                
                # TODO: Update buffer with draft and target logits of the first rejected token, verify implementation
                # rejected_tokens = draft_tokens[n_accepted]
                if n_accepted < len(draft_tokens):
                    rejected_prob_draft = draft_logits[:, n_accepted, :]
                    rejected_prob_target = target_logits[:, n_accepted, :]

                    self.temp_buffer.append((rejected_prob_draft, rejected_prob_target))

                # Check for end of sequence
                if decoder_input_ids[0][-1].item() == self.tokenizer.eos_token_id:
                    break

            self.replay_buffer.extend(self.temp_buffer)
            self.iteration_count += 1

            if self.iteration_count % self.update_interval == 0:
                self.update_draft_model()
                self.iteration_count = 0

        # Decode translation
        translation = self.tokenizer.decode(
            decoder_input_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        return translation

In [9]:
# Initialize decoder
online_decoder = OnlineSpeculativeDecoder()

# Example translation
source_text = "In a world where technology evolves at an unprecedented pace, individuals and organizations must adapt quickly to the rapid advancements in artificial intelligence, machine learning, and automation, ensuring that ethical considerations, environmental sustainability, and equitable access to resources are prioritized to create a future that benefits all of humanity."

# Time the translation
start_time = time.time()
translation = online_decoder.translate(source_text)
end_time = time.time()

print(f"Source: {source_text}")
print(f"Translation: {translation}\n")
print(f"Time taken: {end_time - start_time:.2f} seconds")

Updating


AttributeError: 'numpy.ndarray' object has no attribute 'backward'