In [1]:
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5EncoderModel, AutoModelForCausalLM
import torch

### Model Architecture and configuration

In [2]:
model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained('google/byt5-base')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

tokenizer = AutoTokenizer.from_pretrained('google/byt5-base')



In [3]:
print(model)

T5ForConditionalGeneration(
  (shared): Embedding(384, 1536)
  (encoder): T5Stack(
    (embed_tokens): Embedding(384, 1536)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1536, out_features=768, bias=False)
              (k): Linear(in_features=1536, out_features=768, bias=False)
              (v): Linear(in_features=1536, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=1536, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=1536, out_features=3968, bias=False)
              (wi_1): Linear(in_features=1536, out_features=3968, bias=False)
              (

In [4]:
print(model.config)

T5Config {
  "_name_or_path": "google/byt5-base",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 3968,
  "d_kv": 64,
  "d_model": 1536,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "gradient_checkpointing": false,
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "num_decoder_layers": 6,
  "num_heads": 12,
  "num_layers": 18,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "tokenizer_class": "ByT5Tokenizer",
  "transformers_version": "4.37.2",
  "use_cache": true,
  "vocab_size": 384
}


In [5]:
print(tokenizer)

ByT5Tokenizer(name_or_path='google/byt5-base', vocab_size=256, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '<extra_id

In [6]:
print(model.encoder)

T5Stack(
  (embed_tokens): Embedding(384, 1536)
  (block): ModuleList(
    (0): T5Block(
      (layer): ModuleList(
        (0): T5LayerSelfAttention(
          (SelfAttention): T5Attention(
            (q): Linear(in_features=1536, out_features=768, bias=False)
            (k): Linear(in_features=1536, out_features=768, bias=False)
            (v): Linear(in_features=1536, out_features=768, bias=False)
            (o): Linear(in_features=768, out_features=1536, bias=False)
            (relative_attention_bias): Embedding(32, 12)
          )
          (layer_norm): T5LayerNorm()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (1): T5LayerFF(
          (DenseReluDense): T5DenseGatedActDense(
            (wi_0): Linear(in_features=1536, out_features=3968, bias=False)
            (wi_1): Linear(in_features=1536, out_features=3968, bias=False)
            (wo): Linear(in_features=3968, out_features=1536, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)


In [7]:
model_encoder = T5EncoderModel.from_pretrained('google/byt5-base')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_encoder.to(device)

T5EncoderModel(
  (shared): Embedding(384, 1536)
  (encoder): T5Stack(
    (embed_tokens): Embedding(384, 1536)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1536, out_features=768, bias=False)
              (k): Linear(in_features=1536, out_features=768, bias=False)
              (v): Linear(in_features=1536, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=1536, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=1536, out_features=3968, bias=False)
              (wi_1): Linear(in_features=1536, out_features=3968, bias=False)
              (wo): Linear(

In [8]:
print(model_encoder)

T5EncoderModel(
  (shared): Embedding(384, 1536)
  (encoder): T5Stack(
    (embed_tokens): Embedding(384, 1536)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1536, out_features=768, bias=False)
              (k): Linear(in_features=1536, out_features=768, bias=False)
              (v): Linear(in_features=1536, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=1536, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=1536, out_features=3968, bias=False)
              (wi_1): Linear(in_features=1536, out_features=3968, bias=False)
              (wo): Linear(

In [9]:
print(model_encoder.config)

T5Config {
  "_name_or_path": "google/byt5-base",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 3968,
  "d_kv": 64,
  "d_model": 1536,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "gradient_checkpointing": false,
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "num_decoder_layers": 6,
  "num_heads": 12,
  "num_layers": 18,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "tokenizer_class": "ByT5Tokenizer",
  "transformers_version": "4.37.2",
  "use_cache": true,
  "vocab_size": 384
}


### Check forward pass and hidden states

In [10]:
# first option
input_ids_first = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + 3  # add 3 for special tokens

# second option
from transformers import AutoTokenizer

model_inputs = tokenizer(["Life is like a box of chocolates.", "Today is Monday."], padding="longest", return_tensors="pt")
input_ids_second = model_inputs.input_ids

In [11]:
print('input_ids_first:', input_ids_first, 'shape:', input_ids_first.shape)
print('input_ids_second:', input_ids_second, 'shape:', input_ids_second.shape)

input_ids_first: tensor([[ 79, 108, 105, 104,  35, 108, 118,  35, 111, 108, 110, 104,  35, 100,
          35, 101, 114, 123,  35, 114, 105,  35, 102, 107, 114, 102, 114, 111,
         100, 119, 104, 118,  49]]) shape: torch.Size([1, 33])
input_ids_second: tensor([[ 79, 108, 105, 104,  35, 108, 118,  35, 111, 108, 110, 104,  35, 100,
          35, 101, 114, 123,  35, 114, 105,  35, 102, 107, 114, 102, 114, 111,
         100, 119, 104, 118,  49,   1],
        [ 87, 114, 103, 100, 124,  35, 108, 118,  35,  80, 114, 113, 103, 100,
         124,  49,   1,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0]]) shape: torch.Size([2, 34])


In [13]:
# inference
input_ids = tokenizer("summarize: this dog is brown, friendly and nice", return_tensors="pt").input_ids  # Batch size 1
input_ids = input_ids.to(device)
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))



dogs are brown, fr


In [14]:
# Input sequence of characters
input_text = "This is a sequence of characters."

# Tokenize the input text
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

print('input_ids:', input_ids)
print('input_ids shape:', input_ids.shape)

# Pass the input through the encoder to get the hidden states
with torch.no_grad():  # No need to calculate gradients
    encoder_outputs = model.encoder(input_ids=input_ids)

print('encoder_outputs:', encoder_outputs)

# Extract the last hidden state (shape: [batch_size, sequence_length, hidden_size])
last_hidden_state = encoder_outputs.last_hidden_state

# last_hidden_state can now be used as input to another model
print("Last hidden state:", last_hidden_state)
print("Last hidden state shape:", last_hidden_state.shape)

input_ids: tensor([[ 87, 107, 108, 118,  35, 108, 118,  35, 100,  35, 118, 104, 116, 120,
         104, 113, 102, 104,  35, 114, 105,  35, 102, 107, 100, 117, 100, 102,
         119, 104, 117, 118,  49,   1]], device='cuda:0')
input_ids shape: torch.Size([1, 34])
encoder_outputs: BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[-3.1714e-02, -3.4912e-02,  2.7887e-02,  ..., -2.5939e-02,
          -8.8362e-04, -1.3063e-02],
         [-7.7431e-02,  4.4391e-02, -3.4005e-02,  ..., -6.0181e-03,
          -8.0622e-04, -4.7700e-03],
         [-5.6287e-02,  4.2060e-02, -1.3474e-02,  ..., -6.8940e-03,
          -8.7906e-04,  5.3880e-02],
         ...,
         [ 2.3268e-03, -4.1163e-03,  4.8101e-04,  ..., -1.0362e-02,
          -6.8147e-04, -1.0306e-02],
         [-1.3536e-03, -1.2030e-02, -1.6738e-02,  ...,  4.5562e-02,
          -7.5500e-04,  2.7465e-02],
         [-2.2792e-04, -2.1033e-03, -4.3789e-03,  ...,  3.7608e-04,
          -5.9590e-05, -3.9681e-03]]], device='cuda:

In [17]:
# Pass the input through the encoder to get the hidden states
with torch.no_grad():  # No need to calculate gradients
    encoder_outputs_encoder_model = model_encoder(input_ids=input_ids)

print('encoder_outputs:', encoder_outputs_encoder_model)
print('encoder outputs keys:', ' '.join(encoder_outputs_encoder_model.keys()))

# Extract the last hidden state (shape: [batch_size, sequence_length, hidden_size])
last_hidden_state_encoder_model = encoder_outputs_encoder_model.last_hidden_state

# last_hidden_state can now be used as input to another model
print("Last hidden state:", last_hidden_state_encoder_model)
print("Last hidden state shape:", last_hidden_state_encoder_model.shape)

encoder_outputs: BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[-3.1714e-02, -3.4912e-02,  2.7887e-02,  ..., -2.5939e-02,
          -8.8362e-04, -1.3063e-02],
         [-7.7431e-02,  4.4391e-02, -3.4005e-02,  ..., -6.0181e-03,
          -8.0622e-04, -4.7700e-03],
         [-5.6287e-02,  4.2060e-02, -1.3474e-02,  ..., -6.8940e-03,
          -8.7906e-04,  5.3880e-02],
         ...,
         [ 2.3268e-03, -4.1163e-03,  4.8101e-04,  ..., -1.0362e-02,
          -6.8147e-04, -1.0306e-02],
         [-1.3536e-03, -1.2030e-02, -1.6738e-02,  ...,  4.5562e-02,
          -7.5500e-04,  2.7465e-02],
         [-2.2792e-04, -2.1033e-03, -4.3789e-03,  ...,  3.7608e-04,
          -5.9590e-05, -3.9681e-03]]], device='cuda:0'), past_key_values=None, hidden_states=None, attentions=None, cross_attentions=None)
encoder outputs keys: last_hidden_state
Last hidden state: tensor([[[-3.1714e-02, -3.4912e-02,  2.7887e-02,  ..., -2.5939e-02,
          -8.8362e-04, -1.3063e-02],
         [-7.

In [18]:
# check if the hidden states are the same
torch.allclose(last_hidden_state, last_hidden_state_encoder_model, atol=1e-6)

True

## Use encoder + LM head for token classification

In [7]:
import torch
import torch.nn as nn
from transformers import T5EncoderModel, AutoTokenizer
from torch.utils.data import DataLoader, Dataset

In [8]:
class TextDataset(Dataset):
    def __init__(self, file_path, tokenizer, block_size):
        """
        file_path: Path to the text file (e.g., 'train.txt')
        tokenizer: The ByT5 tokenizer
        block_size: The length of the input sequences
        """
        self.tokenizer = tokenizer
        self.block_size = block_size

        with open(file_path, 'r', encoding="utf8") as f:
            text = f.read()

        tokens = self.tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.squeeze(0)
        self.data = tokens

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        # Input sequence
        x = self.data[idx:idx + self.block_size]
        # Target is the next character in the sequence
        y = self.data[idx + self.block_size]
        return x, y

In [9]:
class ByT5ForNextCharPrediction(nn.Module):
    def __init__(self):
        super(ByT5ForNextCharPrediction, self).__init__()
        self.encoder = T5EncoderModel.from_pretrained('google/byt5-base')
        self.head = nn.Linear(self.encoder.config.d_model, self.encoder.config.vocab_size)  # Linear layer on top

    def forward(self, input_ids, attention_mask=None):
        encoder_outputs = self.encoder(input_ids, attention_mask=attention_mask)
        sequence_output = encoder_outputs.last_hidden_state  # [batch_size, seq_len, hidden_dim]
        # We only care about the output of the last token in the sequence
        last_hidden_state = sequence_output[:, -1, :]  # [batch_size, hidden_dim]
        logits = self.head(last_hidden_state)  # [batch_size, vocab_size]
        return logits

In [10]:
def train(model_, dataloader, optimizer_, criterion_, device_, epoch):
    model_.train()
    
    total_loss = 0
    total_correct = 0
    total_predictions = 0

    for batch_idx, (input_ids, targets) in enumerate(dataloader):
        input_ids = input_ids.to(device_)
        targets = targets.to(device_)

        optimizer_.zero_grad()
        logits = model_(input_ids)

        # Compute the loss
        loss = criterion_(logits, targets)
        loss.backward()
        optimizer_.step()

        total_loss += loss.item()
        
        # Compute the accuracy
        predictions = torch.argmax(logits, dim=-1)
        correct = (predictions == targets).sum().item()
        total_correct += correct
        total_predictions += len(targets.view(-1))
        
        print(total_correct)
        print(total_predictions)
        

        # Logging
        if (batch_idx + 1) % 10 == 0:
            avg_loss = total_loss / (batch_idx + 1)
            print(f'Epoch [{epoch}], Step [{batch_idx +1}/{len(dataloader)}], Loss: {avg_loss:.4f}, Accuracy: {total_correct / total_predictions:.4f}')

    avg_epoch_loss = total_loss / len(dataloader)
    print(f'====> Epoch: {epoch} Average loss: {avg_epoch_loss:.4f}')
    return avg_epoch_loss

In [11]:
def evaluate(model_, dataloader, criterion_, device_):
    model_.eval()

    total_loss = 0
    total_correct = 0
    total_predictions = 0

    with torch.no_grad():
        for input_ids, targets in dataloader:

            # Move data to device        
            input_ids = input_ids.to(device_)
            targets = targets.to(device_)

            # Forward pass
            outputs = model_(input_ids)

            loss = criterion_(outputs.logits.view(-1, outputs.logits.size(-1)), targets.view(-1))
            total_loss += loss.item()

            # Accuracy
            _, predictions = torch.max(outputs.logits, -1)
            total_correct += (predictions == targets).sum().item()
            total_predictions += len(targets.view(-1))

    avg_loss = total_loss / len(dataloader)
    print(f'====> Evaluation Average loss: {avg_loss:.4f}, Accuracy: {total_correct / total_predictions:.4f}')
    return avg_loss

In [12]:
tokenizer = AutoTokenizer.from_pretrained('google/byt5-base')

block_size = 128
batch_size = 8
learning_rate = 0.001
epochs = 10

# Load datasets
train_dataset = TextDataset('./data/sentences/train.txt', tokenizer, block_size)
valid_dataset = TextDataset('./data/sentences/valid.txt', tokenizer, block_size)
test_dataset  = TextDataset('./data/sentences/test.txt', tokenizer, block_size)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize model
model = ByT5ForNextCharPrediction()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Training loop
best_valid_loss = float('inf')
for epoch in range(1, epochs + 1):
    print(f"\nEpoch {epoch}/{epochs}")
    train_loss = train(model, train_loader, optimizer, criterion, device, epoch)
    valid_loss = evaluate(model, valid_loader, criterion, device)
    print(f"Train Loss: {train_loss:.4f}, Validation Loss: {valid_loss:.4f}")
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'byt5_next_char_prediction.pt')

# Test the model
test_loss = evaluate(model, test_loader, criterion, device)
print(f"\nTest Loss: {test_loss:.4f}")


Epoch 1/10
0
8
0
16
1
24
1
32
1
40
1
48


KeyboardInterrupt: 