In [1]:
# clone the BLT repository
!git clone https://github.com/sathishkumar67/Byte-Latent-Transformer.git
# move the files to the current directory
!mv /kaggle/working/Byte-Latent-Transformer/* /kaggle/working/
# upgrade pip
!pip install --upgrade pip
# install latest version pytorch
# install the required packages
# !pip install -r requirements.txt

Cloning into 'Byte-Latent-Transformer'...
remote: Enumerating objects: 124, done.[K
remote: Counting objects: 100% (124/124), done.[K
remote: Compressing objects: 100% (89/89), done.[K
remote: Total 124 (delta 68), reused 80 (delta 31), pack-reused 0 (from 0)[K
Receiving objects: 100% (124/124), 32.94 KiB | 4.71 MiB/s, done.
Resolving deltas: 100% (68/68), done.
Collecting pip
  Downloading pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.2-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m35.9 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.2


In [6]:
from __future__ import annotations
from dataclasses import dataclass
import inspect
import torch
import torch.nn as nn
import torch.nn.functional as F
from BLT.norms import RMSNorm
from BLT.mlp import MLPwithSwiGLU
from BLT.attention import MultiHeadLatentAttentionWithGQAFused
from typing import Optional, Tuple



@dataclass
class EntropyConfig:
    # Attention hyperparameters
	hidden_size: int = 512
	num_heads: int = 8
	n_kv_heads: Optional[int] = None
	kv_lora_rank: int = 512
	qk_rope_head_dim: int = 64
	v_head_dim: int = 128
	qk_nope_head_dim: int = 128
	max_position_embeddings: int = 2048
	rope_base: int = 10000
	attn_dropout: float = 0.0
	attn_bias: bool = False
	use_cache: bool = False
	is_causal: bool = True

	# MLP hyperparameters
	mlp_hidden_dim: Optional[int] = None  # If None, will be set to 4 * hidden_size
	mlp_dropout: float = 0.0
	mlp_bias: bool = False

	# RMSNorm hyperparameters
	rmsnorm_eps: float = 1e-8

	# RotaryPositionEmbedding hyperparameters
	rotary_max_position_embeddings: int = 2048
	rotary_base: int = 10000

In [None]:
class EntropyBlock(nn.Module):
    """
    EntropyBlock: Transformer block combining MultiHeadLatentAttentionWithGQAFused, RMSNorm, and MLPwithSwiGLU.

    This block consists of:
      - Pre-attention RMSNorm normalization
      - Multi-head latent attention with grouped query and fused QKV projection
      - Residual connection after attention
      - Post-attention RMSNorm normalization
      - Feed-forward MLP with SwiGLU activation
      - Residual connection after MLP

    Args:
        config (EntropyConfig): Configuration dataclass containing all hyperparameters for attention, MLP, and normalization.

    Attributes:
        attention (MultiHeadLatentAttentionWithGQAFused): Multi-head latent attention module.
        mlp (MLPwithSwiGLU): Feed-forward network with SwiGLU activation.
        rmsnorm_1 (RMSNorm): RMSNorm layer before attention.
        rmsnorm_2 (RMSNorm): RMSNorm layer before MLP.
        config (EntropyConfig): Configuration object.
    """

    def __init__(self, config: EntropyConfig) -> None:
        """
        Initializes the EntropyBlock.

        Sets up attention, MLP, and normalization layers using the provided configuration.
        """
        super().__init__()
        self.config = config

        # Multi-head latent attention with grouped query and fused QKV projection
        self.attention = MultiHeadLatentAttentionWithGQAFused(
            hidden_size=config.hidden_size,
            num_heads=config.num_heads,
            n_kv_heads=config.n_kv_heads,
            kv_lora_rank=config.kv_lora_rank,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            qk_nope_head_dim=config.qk_nope_head_dim,
            max_position_embeddings=config.max_position_embeddings,
            rope_base=config.rope_base,
            attn_dropout=config.attn_dropout,
            attn_bias=config.attn_bias
        )

        # Feed-forward network with SwiGLU activation
        self.mlp = MLPwithSwiGLU(
            dim=config.hidden_size,
            hidden_dim=config.mlp_hidden_dim,
            mlp_dropout=config.mlp_dropout,
            mlp_bias=config.mlp_bias
        )

        # RMSNorm layers for pre-attention and pre-MLP normalization
        self.rmsnorm_1 = RMSNorm(
            dim=config.hidden_size,
            eps=config.rmsnorm_eps
        )
        self.rmsnorm_2 = RMSNorm(
            dim=config.hidden_size,
            eps=config.rmsnorm_eps
        )

    def forward(
        self, 
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
    ) -> torch.Tensor:
        """
        Forward pass for EntropyBlock.

        Applies RMSNorm, attention, and MLP with residual connections.

        Args:
            hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size].
            attention_mask (Optional[torch.Tensor]): Optional attention mask for attention module.
            past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Optional cached key/value states for attention.

        Returns:
            torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size].
        """
        # Pre-attention normalization and attention block with residual connection
        hidden_states = hidden_states + self.attention(
            self.rmsnorm_1(hidden_states),
            attention_mask=attention_mask,
            past_key_value=past_key_value
        )

        # Pre-MLP normalization and MLP block with residual connection
        hidden_states = hidden_states + self.mlp(self.rmsnorm_2(hidden_states))

        return hidden_states

In [None]:
class EntropyModel(nn.Module):
    def __init__(self, config: EntropyConfig):