<a href="https://colab.research.google.com/github/ryyhan/RandomCodes/blob/main/HybridEmbeddingsModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from transformers import T5Tokenizer, T5EncoderModel
import numpy as np
from typing import List, Union

class HybridLongT5Embeddings:
    def __init__(self, model_name: str = "google/long-t5-local-base", max_length: int = 4096):
        """
        Initialize the hybrid LongT5 embeddings model.

        Args:
            model_name (str): Pretrained LongT5 model name
            max_length (int): Maximum sequence length
        """
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5EncoderModel.from_pretrained(model_name)
        self.max_length = max_length
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def _mean_pooling(self, model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """
        Perform mean pooling on token embeddings.

        Args:
            model_output: Model encoder outputs
            attention_mask: Attention mask for tokens

        Returns:
            Pooled embeddings
        """
        token_embeddings = model_output
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def _attention_weighted_pooling(self, model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """
        Perform attention-weighted pooling on token embeddings.

        Args:
            model_output: Model encoder outputs
            attention_mask: Attention mask for tokens

        Returns:
            Attention-weighted embeddings
        """
        token_embeddings = model_output
        attention_weights = torch.softmax(token_embeddings.norm(dim=-1), dim=-1)
        weighted_embeddings = token_embeddings * attention_weights.unsqueeze(-1)
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(weighted_embeddings.size()).float()
        return torch.sum(weighted_embeddings * input_mask_expanded, 1)

    def generate_embeddings(self, texts: Union[str, List[str]], pooling: str = "hybrid") -> np.ndarray:
        """
        Generate hybrid embeddings for input texts.

        Args:
            texts: Single text string or list of texts
            pooling: Pooling strategy ("mean", "attention", "hybrid")

        Returns:
            Embeddings as numpy array
        """
        if isinstance(texts, str):
            texts = [texts]

        # Tokenize input
        encoded_input = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        ).to(self.device)

        # Generate embeddings
        with torch.no_grad():
            outputs = self.model(**encoded_input)
            hidden_states = outputs.last_hidden_state

        # Apply pooling
        if pooling == "mean":
            embeddings = self._mean_pooling(hidden_states, encoded_input.attention_mask)
        elif pooling == "attention":
            embeddings = self._attention_weighted_pooling(hidden_states, encoded_input.attention_mask)
        elif pooling == "hybrid":
            mean_pooled = self._mean_pooling(hidden_states, encoded_input.attention_mask)
            attn_pooled = self._attention_weighted_pooling(hidden_states, encoded_input.attention_mask)
            embeddings = (mean_pooled + attn_pooled) / 2
        else:
            raise ValueError("Pooling must be 'mean', 'attention', or 'hybrid'")

        return embeddings.cpu().numpy()

    def batch_generate(self, texts: List[str], batch_size: int = 32, pooling: str = "hybrid") -> np.ndarray:
        """
        Generate embeddings for large datasets in batches.

        Args:
            texts: List of input texts
            batch_size: Number of texts per batch
            pooling: Pooling strategy

        Returns:
            Embeddings as numpy array
        """
        all_embeddings = []

        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            batch_embeddings = self.generate_embeddings(batch_texts, pooling=pooling)
            all_embeddings.append(batch_embeddings)

        return np.vstack(all_embeddings)

# Example usage
if __name__ == "__main__":
    # Initialize the model
    embedder = HybridLongT5Embeddings()

    # Single text example
    text = "This is a sample text for generating embeddings using LongT5."
    embeddings = embedder.generate_embeddings(text)
    print(f"Single text embedding shape: {embeddings.shape}")

    # Multiple texts example
    texts = [
        "First sample text for embedding generation.",
        "Second sample text with different content.",
        "Third text to demonstrate batch processing."
    ]

    # Generate embeddings with different pooling strategies
    mean_embeddings = embedder.batch_generate(texts, pooling="mean")
    attn_embeddings = embedder.batch_generate(texts, pooling="attention")
    hybrid_embeddings = embedder.batch_generate(texts, pooling="hybrid")

    print(f"Mean pooling embeddings shape: {mean_embeddings.shape}")
    print(f"Attention pooling embeddings shape: {attn_embeddings.shape}")
    print(f"Hybrid embeddings shape: {hybrid_embeddings.shape}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/811 [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
You are using a model of type longt5 to instantiate a model of type t5. This is not supported for all configurations of models and can yield errors.


model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

Some weights of T5EncoderModel were not initialized from the model checkpoint at google/long-t5-local-base and are newly initialized: ['encoder.block.0.layer.0.SelfAttention.k.weight', 'encoder.block.0.layer.0.SelfAttention.o.weight', 'encoder.block.0.layer.0.SelfAttention.q.weight', 'encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'encoder.block.0.layer.0.SelfAttention.v.weight', 'encoder.block.1.layer.0.SelfAttention.k.weight', 'encoder.block.1.layer.0.SelfAttention.o.weight', 'encoder.block.1.layer.0.SelfAttention.q.weight', 'encoder.block.1.layer.0.SelfAttention.v.weight', 'encoder.block.10.layer.0.SelfAttention.k.weight', 'encoder.block.10.layer.0.SelfAttention.o.weight', 'encoder.block.10.layer.0.SelfAttention.q.weight', 'encoder.block.10.layer.0.SelfAttention.v.weight', 'encoder.block.11.layer.0.SelfAttention.k.weight', 'encoder.block.11.layer.0.SelfAttention.o.weight', 'encoder.block.11.layer.0.SelfAttention.q.weight', 'encoder.block.11.layer.0.SelfAttent

Single text embedding shape: (1, 768)
Mean pooling embeddings shape: (3, 768)
Attention pooling embeddings shape: (3, 768)
Hybrid embeddings shape: (3, 768)


In [2]:
print(mean_embeddings)

[[-0.02130399  0.04053642  0.05826725 ...  0.06866038 -0.04904569
   0.06126637]
 [-0.04293535 -0.00552416  0.00314978 ...  0.0454332  -0.070364
  -0.02484155]
 [-0.04646979  0.02111381  0.05731771 ...  0.0811644  -0.01929428
   0.08375346]]
