In [1]:
# clone the BLT repository
!git clone https://github.com/sathishkumar67/Byte-Latent-Transformer.git
# move the files to the current directory
!mv /kaggle/working/Byte-Latent-Transformer/* /kaggle/working/
# upgrade pip
!pip install --upgrade pip
# install latest version pytorch
!pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126
# install the required packages
!pip install -r requirements.txt

Cloning into 'Byte-Latent-Transformer'...
remote: Enumerating objects: 191, done.[K
remote: Counting objects: 100% (191/191), done.[K
remote: Compressing objects: 100% (134/134), done.[K
remote: Total 191 (delta 109), reused 127 (delta 52), pack-reused 0 (from 0)[K
Receiving objects: 100% (191/191), 80.03 KiB | 5.72 MiB/s, done.
Resolving deltas: 100% (109/109), done.
Collecting pip
  Downloading pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.2-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m31.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.2
Looking in indexes: https://download.pytorch.org/whl/cu126
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading https://download.pytorc

In [2]:
import os
import torch
import torch.nn as nn
import numpy as np
import lightning as L
from lightning.pytorch import Trainer
from BLT.entropy import EntropyModel, EntropyConfig
from BLT.dataset import TokenDataset
from BLT.utils import clear_directory
from huggingface_hub import hf_hub_download

In [3]:
# clear_directory(os.getcwd())

In [4]:
# download the checkpoint for the model
hf_hub_download(repo_id="pt-sk/BLT_Entropy_Checkpoints",
                filename="entropy_ckpt_4.ckpt",
                repo_type="model",
                local_dir="/kaggle/working/")

# download the tokenized text
hf_hub_download(repo_id="pt-sk/Text_Bytes_Tokens",
                filename="wikipedia_512_pretraining/tokenized_text5.npy",
                repo_type="dataset",
                local_dir="/kaggle/working/")

# load the tokenized text
tokens = np.load("/kaggle/working/wikipedia_512_pretraining/tokenized_text5.npy", allow_pickle=True)

entropy_ckpt_4.ckpt:   0%|          | 0.00/447M [00:00<?, ?B/s]

wikipedia_512_pretraining/tokenized_text(…):   0%|          | 0.00/146M [00:00<?, ?B/s]

In [5]:
# Initialize model and config
config = EntropyConfig()
model = EntropyModel(config)

# count the number of parameters in the model
num_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters in the model: {num_params/1e6}M")

def configure_optimizer(model: nn.Module) -> torch.optim.Optimizer:
        # start with all of the candidate parameters (that require grad)
        param_dict = {pn: p for pn, p in model.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for _, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for _, p in param_dict.items() if p.dim() < 2]

        # Create AdamW optimizer and use the fused version if available
        return torch.optim.AdamW([{'params': decay_params, 'weight_decay': 0.1},
                                {'params': nodecay_params, 'weight_decay': 0.0}],
                                lr=0.0001,
                                betas=(0.9, 0.999),
                                eps=1e-8,
                                fused=True)

# initialize optimizer
optimizer = configure_optimizer(model)

Number of parameters in the model: 35.800064M


In [6]:
# EntropyWrapper: PyTorch Lightning wrapper for the EntropyModel
class EntropyWrapper(L.LightningModule):
    def __init__(self, config: EntropyConfig, model: EntropyModel) -> None:
        super().__init__()
        self.config = config
        self.model = model
        self.optimizer = self.configure_optimizers()

    def training_step(self, batch, batch_idx):
        self.model.train()
        optimizer = self.optimizers()
        optimizer.zero_grad()
        
        inputs, targets = batch
        _, loss = self.model(inputs, targets)
        self.log("Train_Loss", loss, prog_bar=True)

        return loss
    
    def configure_optimizers(self):
        optimizer = configure_optimizer(self.model)
        return optimizer

In [7]:
# Create dataset and dataloader
dataset = TokenDataset(block_size=4096, input_ids=tokens)
dataloader = torch.utils.data.DataLoader(dataset, 
                                        batch_size=6, 
                                        shuffle=True,
                                        pin_memory=True,
                                        pin_memory_device='cuda',
                                        num_workers=os.cpu_count(),
                                        prefetch_factor=2)

# Initialize model wrapper
model_wrapper = EntropyWrapper.load_from_checkpoint("/kaggle/working/entropy_ckpt_4.ckpt", config=config, model=model)

In [8]:
# Initialize trainer
trainer = Trainer(max_epochs=1,
                  accelerator="cuda",
                  accumulate_grad_batches=8,
                  gradient_clip_val=1.0,
                  devices=1)

# Train the model
trainer.fit(model_wrapper, dataloader)

INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
2025-09-17 10:02:03.642662: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1758103323.829785      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758103323.886570      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name  | Type         | Params | M

Training: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
