# BitNetTransformer

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn
from torchinfo import summary
from bitnet import BitNetTransformer

from app.utils import clear_memory

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

### Bertimbau Tokenizer & Embeddings

In [7]:
from transformers import pipeline
from transformers import AutoTokenizer, AutoModel

huggingface_model = "neuralmind/bert-base-portuguese-cased"
tokenizer = AutoTokenizer.from_pretrained(huggingface_model)
model = AutoModel.from_pretrained(huggingface_model)
feature_extractor = pipeline('feature-extraction', model=model, tokenizer=tokenizer)
embeddings = model.get_input_embeddings()

In [8]:
print(tokenizer.vocab_size)
print(embeddings.embedding_dim)

29794
768


Comparando diferentes maneiras de transformar texto para embbedings

In [9]:
sentence = 'O advogado apresentou recurso para o juíz'

# a. contextualized embeddings c/ feature_extractor: texto tokenizado automaticamente, passa para embedding, encoder e attention layers antes
sentence_embeddings1 = feature_extractor(sentence)

with torch.inference_mode():
    sentence_tokenized = tokenizer(sentence, return_tensors="pt")
    sentence_input_ids = sentence_tokenized["input_ids"]
    sentence_attention_mask = sentence_tokenized["attention_mask"]

    # b. non-contextualized embeddings com embedding layer do bertimbau: usa embbeding layer 
    # antes de passar na rede transformer, ou seja, sem passar pelo encoder e attention layers
    sentence_embeddings2 = embeddings(sentence_input_ids)

    # c. a mesma coisa da primeira abordagem só que mais explicita, sem as facilidades do pipeline
    sentence_embeddings3 = model(input_ids=sentence_input_ids, 
                                attention_mask=sentence_attention_mask
                                ).last_hidden_state
    

In [10]:
print("Batch sizes:")
print(len(sentence_embeddings1))
print(len(sentence_embeddings2))
print(len(sentence_embeddings3))

print("Sequence Lengths:")
print(len(sentence_embeddings1[0]))
print(len(sentence_embeddings2[0]))
print(len(sentence_embeddings3[0]))

print("BERT Hidden Size:")
print(len(sentence_embeddings1[0][0]))
print(len(sentence_embeddings2[0][0]))
print(len(sentence_embeddings3[0][0]))

print("Embeddings Types:")
print(type(sentence_embeddings1[0][0]))
print(type(sentence_embeddings2[0][0]))
print(type(sentence_embeddings3[0][0]))

print("Embeddings Values:")
print(torch.tensor(sentence_embeddings1[0][0][:5]))
print(sentence_embeddings2[0][0][:5])
print(sentence_embeddings3[0][0][:5])


Batch sizes:
1
1
1
Sequence Lengths:
10
10
10
BERT Hidden Size:
768
768
768
Embeddings Types:
<class 'list'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
Embeddings Values:
tensor([ 0.6715, -0.3474, -0.1421, -0.0567,  0.8103])
tensor([-0.0004,  0.0014,  0.0251,  0.0272, -0.0022])
tensor([ 0.6715, -0.3474, -0.1421, -0.0567,  0.8103])


Vamos testar os modelos na sentença 'Olá Mundo!', mas antes temos que preparar o input corretamente para os modelos.

In [11]:
texto = 'Olá Mundo!'

Usando tamanho do vocabulário igual ao do tokenizer ou o tamanho dos embeddings?

In [12]:
print(tokenizer.vocab_size)
print(embeddings.embedding_dim)

29794
768


In [13]:
VOCAB_SIZE = 29794
D_MODEL = 512

Convertendo o texto p/ token_ids (entradas das 2 redes)

In [14]:
input_ids = tokenizer(texto, 
                      return_tensors="pt"
                      ).input_ids

input_ids = input_ids.to(device)
print(input_ids.shape)
print(input_ids[0,:5])

torch.Size([1, 6])
tensor([  101,  1651, 22303,  3327,   106], device='cuda:0')


### Wikipedia Dataset and Dataloader

In [15]:
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from datasets import load_dataset

In [16]:
# Load the dataset
train_dataset = load_dataset("Luciano/lener_br_text_to_lm", split="train[:10%]")  # Use 100% for a smaller dataset
test_dataset = load_dataset("Luciano/lener_br_text_to_lm", split="test[:10%]")  # Use 100% for a smaller dataset

print("Número de documentos no dataset de treino:", len(train_dataset))
print("Número de documentos no dataset de teste:", len(test_dataset))

Analisando quantidade de tokens por sequência

In [None]:
from collections import Counter
import pandas as pd

counter = Counter()

def input_ids_extractor(examples):
    tokens = tokenizer(examples['text'], truncation=False, return_tensors="pt")
    return {"input_ids": tokens["input_ids"]}

df = pd.DataFrame(train_dataset)
n_tokens = []
for x in train_dataset.map(input_ids_extractor, remove_columns=train_dataset.column_names)["input_ids"]:
    n_tokens.append(len(x[0]))
    counter.update(x[0])
    
df["n_tokens"] = n_tokens
df

Unnamed: 0,text,n_tokens
0,Seria o mesmo que dizer que o trabalhador tem ...,83
1,O autor sustenta que a lei é formal e material...,226
2,Esse juízo decorre do fato de que o exame de c...,73
3,"Apesar , de o próprio responsável apresentar c...",24
4,"Quando de sua assunção à direção do STM , esta...",67
...,...,...
8311,"No ponto , convém salientar que o Supremo Trib...",154
8312,"Em relação ao subitem 9.2.1 , o GAP/BR informa...",53
8313,4 .,4
8314,O agravante limitou-se a reprisar os argumento...,42


In [None]:
df["n_tokens"].describe()

count    8316.000000
mean       48.716450
std        53.066798
min         2.000000
25%        15.000000
50%        35.000000
75%        64.000000
max      1009.000000
Name: n_tokens, dtype: float64

Listando tokens mais frequentes

In [None]:
counter.most_common(10)

[(117, 18875),
 (119, 17256),
 (125, 10438),
 (101, 8316),
 (102, 8316),
 (123, 6859),
 (171, 5852),
 (180, 5018),
 (118, 4854),
 (146, 4847)]

In [14]:
def preprocess_function(examples):
    tokens = tokenizer(examples['text'], truncation=True, padding="max_length", max_length=D_MODEL, return_tensors="pt")
    return {
        "input_ids": tokens["input_ids"],
        "attention_mask": tokens["attention_mask"]
    }

train_tokenized_dataset = train_dataset.map(preprocess_function, 
                                            batched=True, 
                                            remove_columns=train_dataset.column_names)
test_tokenized_dataset = test_dataset.map(preprocess_function, 
                                          batched=True, 
                                          remove_columns=train_dataset.column_names)

In [None]:
BATCH_SIZE = 8

In [15]:
import os

NUM_WORKERS = os.cpu_count()

# TODO verificar se realmente precisa disso
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Create DataLoader
train_dataloader = DataLoader(train_tokenized_dataset, 
                              batch_size=BATCH_SIZE, 
                              shuffle=True, 
                              collate_fn=data_collator
                              num_workers=NUM_WORKERS)

test_dataloader = DataLoader(test_tokenized_dataset, 
                             batch_size=BATCH_SIZE, 
                             shuffle=True, 
                             collate_fn=data_collator,
                             num_workers=NUM_WORKERS)

In [16]:

print("Número de batches no conjunto de treinamento:", len(train_dataloader))
print("Número de batches no conjunto de testes:", len(test_dataloader))

Número de batches no conjunto de treinamento: 8316
Número de batches no conjunto de testes: 2079


In [17]:
sample_batch = next(iter(train_dataloader))
print(sample_batch.keys())
print(len(sample_batch["input_ids"][0]))
print(len(sample_batch["attention_mask"][0]))

dict_keys(['input_ids', 'attention_mask'])
768
768


### bitnet.BitnetTransformer

Come será visto no resultado do summary, o BitNetTransformet já tem uma camada de embeddings incorporada na arquitetura do modelo.

In [18]:
bitnet = BitNetTransformer(
    num_tokens=VOCAB_SIZE,  # Number of unique tokens in the input
    dim=D_MODEL,  # Dimension of the input and output embeddings
    depth=6,  # Number of transformer layers
    # heads=8,  # Number of attention heads
    # ff_mult=4,  # Multiplier for the hidden dimension in the feed-forward network
).to(device)

summary(model= bitnet,
        col_names=["num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                       Param #              Trainable
BitNetTransformer (BitNetTransformer)         --                   True
├─Embedding (emb)                             22,881,792           True
├─Transformer (transformer)                   --                   True
│    └─ModuleList (layers)                    --                   True
│    │    └─BitMGQA (0)                       1,477,632            True
│    │    └─BitMGQA (1)                       1,477,632            True
│    │    └─BitMGQA (2)                       1,477,632            True
│    │    └─BitMGQA (3)                       1,477,632            True
│    │    └─BitMGQA (4)                       1,477,632            True
│    │    └─BitMGQA (5)                       1,477,632            True
│    └─ModuleList (ffn_layers)                --                   True
│    │    └─BitFeedForward (0)                4,728,576            True
│    │    └─BitFeedForward (1)                4,728,576    

In [19]:
clear_memory()

In [20]:

bitnet.eval()
with torch.inference_mode():
    bitnet_logits = bitnet(input_ids)

print(bitnet_logits.shape)
print(bitnet_logits[0,:5])

torch.Size([1, 6, 29794])
tensor([[-0.6329, -0.3154, -0.3709,  ...,  0.3420, -0.8325,  1.3835],
        [-0.8646,  0.9436, -0.8437,  ..., -0.0316,  0.7270, -0.6726],
        [ 0.0547, -0.7013, -1.6044,  ...,  0.8291,  0.0767,  0.8498],
        [-0.6529, -0.3598,  0.0294,  ...,  0.3700,  0.9106, -0.1380],
        [-0.2384, -0.4411,  0.1621,  ..., -0.1336, -0.2379, -0.4866]],
       device='cuda:0')


### torch.nn.Transformer

In [21]:
# estou usando a implementaçao da bitnet pq no torch '2.2.0+cu121' não tem implementado ainda
from bitnet.bit_transformer import RMSNorm

class BaseTransformerModel(nn.Module):
    def __init__(
        self,
        dim: int,
        depth: int,
        num_tokens: int,
        heads=8,
        ff_mult=4,
    ):
        super().__init__()
        self.emb = nn.Embedding(num_tokens, dim)
        
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=dim, nhead=heads),
            num_layers=1
        )

        self.to_logits = nn.Sequential(
            RMSNorm(dim),
            nn.Linear(dim, num_tokens)
        )

    def forward(self, x):
        # embeddings sem positional encoding somados
        x = self.emb(x)
        x = self.decoder(tgt=x, memory=x)
        return self.to_logits(x)

In [22]:
baseline = BaseTransformerModel(
    num_tokens=VOCAB_SIZE,  # Number of unique tokens in the input
    dim=D_MODEL,  # Dimension of the input and output embeddings
    depth=6,  # Number of transformer layers
    heads=8,  # Number of attention heads
    # ff_mult=4,  # Multiplier for the hidden dimension in the feed-forward network
    ).to(device)

# c/ parâmetro input_size aloca 200MB de memória e não desaloca depois
summary(model= baseline,
        #input_size=(src_size, tgt_size),
        #col_names=["input_size", "output_size", "num_params", "trainable"],
        col_names=["num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

Layer (type (var_name))                                                Param #              Trainable
BaseTransformerModel (BaseTransformerModel)                            --                   True
├─Embedding (emb)                                                      22,881,792           True
├─TransformerDecoder (decoder)                                         --                   True
│    └─ModuleList (layers)                                             --                   True
│    │    └─TransformerDecoderLayer (0)                                7,877,888            True
├─Sequential (to_logits)                                               --                   True
│    └─RMSNorm (0)                                                     768                  True
│    └─Linear (1)                                                      22,911,586           True
Total params: 53,672,034
Trainable params: 53,672,034
Non-trainable params: 0

In [23]:
clear_memory()

In [24]:
baseline.eval()
with torch.inference_mode():
    base_logits = baseline(input_ids)

print(base_logits.shape)
print(base_logits[0,:5])

torch.Size([1, 6, 29794])
tensor([[ 0.4504,  0.6340, -0.4628,  ...,  0.4732,  0.2900,  0.2842],
        [-0.1217, -0.6543, -0.0230,  ..., -0.1015, -0.2650,  0.7520],
        [-0.1499,  0.5043,  0.7887,  ..., -0.4949, -1.0004, -0.7903],
        [ 0.0235,  0.2744, -0.2596,  ..., -0.1396, -0.4983,  0.3388],
        [ 0.4763,  0.4011,  0.1525,  ..., -0.7573, -0.1147,  0.2858]],
       device='cuda:0')


In [25]:
clear_memory()

### Training Loop

In [26]:

# Set up optimizer and learning rate scheduler

num_epochs = 1
num_batches = len(train_dataloader)
num_training_steps = num_epochs * len(train_dataloader)

def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          criteria: torch.nn.functional.cross_entropy,
          optimizer: torch.optim.Optimizer,
          epochs: int,
          device: str = device):
    
    results = {"train_loss": [], "test_loss": []}

    for epoch in range(epochs):

        epoch_train_loss = 0
        epoch_test_loss = 0

        # Training loop
        model.train()
        for batch, data in enumerate(train_dataloader):
            optimizer.zero_grad()
            # pode pegar as attention masks também
            input_ids = data["input_ids"].to(device)

            # Forward pass
            outputs = model(input_ids)  

            # Calculate the loss
            loss = criteria(outputs.view(-1, VOCAB_SIZE), input_ids.view(-1))
            epoch_train_loss += loss.item()
            
            # Backward pass
            loss.backward()

            # Update weights
            optimizer.step()  

            if batch % 100 == 0:
                # printando número de batches
                print(f"Batch: {batch}/{num_batches}, Loss: {loss.item():.4f}")

        # Test loop
        model.eval()
        with torch.inference_mode():
            for test_batch, test_data in enumerate(test_dataloader):
                test_input_ids = test_data["input_ids"].to(device)

                test_outputs = model(test_input_ids)
                test_loss = criteria(test_outputs.view(-1, VOCAB_SIZE), test_input_ids.view(-1))
                epoch_test_loss += test_loss.item()
                
                if batch % 100 == 0:
                    # printando número de batches de teste
                    print(f"Batch: {test_batch}/{len(test_dataloader)}, Test Loss: {test_loss.item():.4f}")

        # printando loss após o final de uma época de treinamento        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

        results["train_loss"].append(epoch_train_loss)
        results["test_loss"].append(epoch_test_loss)

    return results
    
    

Batch 0/8316, Loss: 10.5015
Batch 100/8316, Loss: 4.7975
Batch 200/8316, Loss: 0.9008
Batch 300/8316, Loss: 0.2838
Batch 400/8316, Loss: 0.1509
Batch 500/8316, Loss: 0.0516
Batch 600/8316, Loss: 0.5582
Batch 700/8316, Loss: 0.7851
Batch 800/8316, Loss: 0.3335
Batch 900/8316, Loss: 0.1326
Batch 1000/8316, Loss: 0.7498
Batch 1100/8316, Loss: 0.1131
Batch 1200/8316, Loss: 0.0802
Batch 1300/8316, Loss: 0.0719
Batch 1400/8316, Loss: 0.0393
Batch 1500/8316, Loss: 0.0505
Batch 1600/8316, Loss: 0.3757
Batch 1700/8316, Loss: 0.2915
Batch 1800/8316, Loss: 0.4189
Batch 1900/8316, Loss: 0.0825
Batch 2000/8316, Loss: 0.3687
Batch 2100/8316, Loss: 0.2147
Batch 2200/8316, Loss: 0.0406
Batch 2300/8316, Loss: 0.1786
Batch 2400/8316, Loss: 0.0612
Batch 2500/8316, Loss: 0.1522
Batch 2600/8316, Loss: 0.0802
Batch 2700/8316, Loss: 0.1176
Batch 2800/8316, Loss: 0.0330


KeyboardInterrupt: 

In [None]:
EPOCHS = 1
DEFAULT_LOSS = torch.nn.functional.cross_entropy

In [None]:

optimizer = torch.optim.AdamW(baseline.parameters(), lr=5e-5)

baseline_results = train(model = baseline,
                         train_dataloader=train_dataloader,
                         test_dataloader=test_dataloader,
                         criteria=DEFAULT_LOSS,
                         optimizer=optimizer,
                         epochs = EPOCHS,
                         device = device)

In [None]:
optimizer = torch.optim.AdamW(bitnet.parameters(), lr=5e-5)

bitnet_results = train(model = bitnet,
                       train_dataloader = train_dataloader,
                       test_dataloader = test_dataloader,
                       criteria = DEFAULT_LOSS,
                       optimizer = optimizer,
                       epochs = EPOCHS,
                       device = device)

In [27]:
# Save the model
torch.save(bitnet, "./models/bitnet_transformer.pt")

### Geração de Texto

In [28]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer

def generate_text(model, tokenizer, prompt, max_length=50, temperature=1.0):
    """
    Generates text using an autoregressive approach with a Transformer model.
    
    Args:
        model (nn.Module): The autoregressive Transformer model.
        tokenizer (AutoTokenizer): The tokenizer for encoding and decoding text.
        prompt (str): The initial text prompt to start the generation.
        max_length (int): Maximum length of the generated text.
        temperature (float): Temperature value for sampling (higher = more randomness).
    
    Returns:
        str: The generated text.
    """
    # Encode the input prompt into token IDs
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids

    # Move input_ids to the same device as the model (CPU/GPU)
    input_ids = input_ids.to(next(model.parameters()).device)

    # Initialize a list to store generated tokens
    generated_ids = input_ids.clone()

    # Set the model to evaluation mode
    model.eval()

    # Use a loop to generate tokens one by one
    with torch.no_grad():
        for _ in range(max_length):
            # Get the model's logits for the next token
            output = model(generated_ids)

            logits = output[:, -1, :]  # Take the logits of the last token

            # Apply temperature scaling to logits if temperature is specified
            logits = logits / temperature

            # Use softmax to get probabilities and sample the next token id
            probs = torch.nn.functional.softmax(logits, dim=-1)
            next_token_id = torch.multinomial(probs, num_samples=1)

            # Append the predicted token id to the generated_ids
            generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

            # If the model generates the EOS (end-of-sequence) token, stop early
            if next_token_id.item() == tokenizer.eos_token_id:
                break

    # Decode the generated token IDs back into a string
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    return generated_text


In [29]:
# Example prompt and text generation
prompt = "Tinha uma pedra no meio do caminho"

bitnet_generated_text = generate_text(model=bitnet, 
                               tokenizer=tokenizer, 
                               prompt=prompt, 
                               max_length=20, 
                               temperature=0.8)


baseline_generated_text = generate_text(model=baseline, 
                               tokenizer=tokenizer, 
                               prompt=prompt, 
                               max_length=20, 
                               temperature=0.8)


print(bitnet_generated_text)
print(baseline_generated_text)

Tinha uma pedra no meio do caminho
Tinha uma pedra no meio do caminho Gent袋 divindades razões Daí branc Superliga Sinfônica obs gravadora周 sódio começou vídeos contratar Foram nestas defendem큔 juntou
