In [1]:
#!pip install datasets --upgrade #huggingface_hub --quiet

In [2]:
import os
from datasets import load_dataset, DatasetDict
from huggingface_hub import notebook_login
from transformers import AutoTokenizer, PreTrainedTokenizerFast
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import math  

In [3]:
# Load a tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [4]:
eos = tokenizer.eos_token

In [5]:
if tokenizer.mask_token is None:
    tokenizer.add_special_tokens({'mask_token': '[MASK]'})

In [6]:
def mytokenizer(dataset_split, tokenizer, chunk_size=1024):
    '''megatext = ""
    for item in tqdm(dataset_split, desc="Concatenating texts"):
        megatext += item['text'] + tokenizer.eos_token

    # Tokenize all at once (no special tokens)
    tokenized = tokenizer(megatext, return_tensors=None, add_special_tokens=False)
    input_ids = tokenized['input_ids']'''

    # Collect all token IDs into a flat list
    input_ids = []

    for item in tqdm(dataset_split, desc="Tokenizing and concatenating"):
        # Tokenize individual text without adding special tokens
        tokenized = tokenizer(item["text"], add_special_tokens=False)["input_ids"]

        # Append the tokenized text followed by the EOS token
        input_ids.extend(tokenized + [tokenizer.eos_token_id])

    # Chunk and pad
    chunks = []
    for i in range(0, len(input_ids), chunk_size):
        chunk = input_ids[i:i+chunk_size]
        if len(chunk) < chunk_size:
            chunk += [0] * (chunk_size - len(chunk))
        chunks.append(chunk)  # Just list of input_ids now

    return torch.tensor(chunks, dtype=torch.long)  # shape: (num_chunks, chunk_size)

In [7]:
tiny_wikipedia = load_dataset('joaomonteirof/biss_tiny_wikipedia')

In [8]:
tokenized_tiny_wikipedia = {
    split: mytokenizer(tiny_wikipedia[split], tokenizer)
    for split in tiny_wikipedia
}

Tokenizing and concatenating:   0%|          | 0/10000 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1342 > 1024). Running this sequence through the model will result in indexing errors
Tokenizing and concatenating: 100%|██████████| 10000/10000 [00:09<00:00, 1084.86it/s]
Tokenizing and concatenating: 100%|██████████| 1000/1000 [00:00<00:00, 1028.64it/s]
Tokenizing and concatenating: 100%|██████████| 1000/1000 [00:00<00:00, 1052.86it/s]


In [9]:
class BasicSelfAttn(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, X, mask=None):
        batch_size, seq_len, _ = X.shape

        Q = self.q_proj(X)
        K = self.k_proj(X)
        V = self.v_proj(X)

        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        attention_weights = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)

        if mask is not None:
            attention_weights = attention_weights.masked_fill(mask == 0, float('-inf'))

        attention_weights = F.softmax(attention_weights, dim=-1)
        output = torch.matmul(attention_weights, V)

        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        return self.out_proj(output)

In [10]:
def generate_attention_mask(seq_len, causal=False, device=None):
    if causal:
        mask = torch.tril(torch.ones(seq_len, seq_len, device=device)).unsqueeze(0).unsqueeze(0)
    else:
        mask = torch.ones(1, 1, seq_len, seq_len, device=device)
    return mask  # shape: (1, 1, seq_len, seq_len)

In [11]:
def mask_input_tokens(inputs, tokenizer, mlm_probability=0.15):
    labels = inputs.clone()
    probability_matrix = torch.full(labels.shape, mlm_probability)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    # 80% -> [MASK]
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% -> random token
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(
        len(tokenizer), labels.shape, dtype=torch.long, device=inputs.device
    )
    inputs[indices_random] = random_words[indices_random]

    # 10% remain unchanged (implicitly)

    return inputs, labels


In [12]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, is_decoder=False):
        super().__init__()
        self.self_attn = BasicSelfAttn(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
        self.norm3 = nn.LayerNorm(embed_dim)
        self.is_decoder = is_decoder
        if is_decoder:
            self.cross_attn = BasicSelfAttn(embed_dim, num_heads)

    def forward(self, X, context=None):
        seq_len = X.size(1)
        device = X.device
        mask = generate_attention_mask(seq_len, causal=self.is_decoder, device=device)

        X = X + self.self_attn(self.norm1(X), mask=mask)

        if self.is_decoder and context is not None:
            # Cross-attention for decoder
            context_len = context.size(1)
            cross_mask = generate_attention_mask(context_len, causal=False, device=device)
            X = X + self.cross_attn(self.norm2(X), mask=cross_mask)

        X = X + self.ff(self.norm3(X))
        return X

In [13]:
class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, embed_dim, num_heads):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, is_decoder=False)
            for _ in range(num_layers)
        ])

    def forward(self, X):
        for layer in self.layers:
            X = layer(X)
        return X

class TransformerDecoder(nn.Module):
    def __init__(self, num_layers, embed_dim, num_heads):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, is_decoder=True)
            for _ in range(num_layers)
        ])

    def forward(self, X):
        for layer in self.layers:
            X = layer(X)
        return X


In [14]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=10000):
        super().__init__()
        self.embed_dim = embed_dim

        position = torch.arange(0, max_len).unsqueeze(1)           # shape: (max_len, 1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-torch.log(torch.tensor(10000.0)) / embed_dim))

        pe = torch.zeros(max_len, embed_dim)
        pe[:, 0::2] = torch.sin(position * div_term)  # even index
        pe[:, 1::2] = torch.cos(position * div_term)  # odd index

        pe = pe.unsqueeze(0)  # shape: (1, max_len, embed_dim)
        self.register_buffer('pe', pe)

    def forward(self, X):
        # X shape: (batch_size, seq_len, embed_dim)
        seq_len = X.size(1)
        return X + self.pe[:, :seq_len, :]

In [15]:
class Trainer:
    def __init__(self, model, tokenizer, dataloader, optimizer, device, task):
        self.model = model
        self.tokenizer = tokenizer
        self.dataloader = dataloader
        self.optimizer = optimizer
        self.device = device
        self.task = task  # 'next_token' or 'mlm'
        self.criterion = nn.CrossEntropyLoss(ignore_index=-100)

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        for batch in self.dataloader:
            batch = batch.to(self.device)  # shape: (B, T)

            if self.task == 'next_token':
                inputs = batch[:, :-1]
                targets = batch[:, 1:]
            elif self.task == 'mlm':
                inputs, targets = mask_input_tokens(batch.clone(), self.tokenizer)
            else:
                raise ValueError("Invalid task type")

            outputs = self.model(inputs)
            logits = outputs.view(-1, outputs.size(-1))
            targets = targets.reshape(-1)

            loss = self.criterion(logits, targets)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

        return total_loss / len(self.dataloader)

    def evaluate(self, dataloader):
        self.model.eval()
        total_loss = 0
        total_correct = 0
        total_count = 0

        with torch.no_grad():
            for batch in dataloader:
                batch = batch.to(self.device)

                if self.task == 'next_token':
                    inputs  = batch[:, :-1]
                    targets = batch[:, 1:]
                else:  # 'mlm'
                    inputs, targets = mask_input_tokens(batch.clone(), self.tokenizer)

                outputs = self.model(inputs)
                logits  = outputs.view(-1, outputs.size(-1))
                targets = targets.reshape(-1)

                loss = self.criterion(logits, targets)
                total_loss += loss.item()

                # Compute accuracy: only where targets != -100
                preds = logits.argmax(dim=-1)
                mask = (targets != -100)
                correct = (preds[mask] == targets[mask]).sum().item()
                count = mask.sum().item()
                total_correct += correct
                total_count += count

        avg_loss = total_loss / len(dataloader)
        #perplexity = math.exp(avg_loss)
        accuracy = total_correct / total_count if total_count > 0 else 0.0
        return avg_loss, accuracy

In [16]:
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, is_decoder=False):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoding = SinusoidalPositionalEncoding(embed_dim)
        if is_decoder:
            self.transformer = TransformerDecoder(num_layers, embed_dim, num_heads)
        else:
            self.transformer = TransformerEncoder(num_layers, embed_dim, num_heads)
        self.output_projection = nn.Linear(embed_dim, vocab_size)
        
    def forward(self, input_ids, context_ids=None):
        x = self.token_embedding(input_ids)           # (batch_size, seq_len, embed_dim)
        x = self.pos_encoding(x)
        if context_ids is not None:
            ctxt = self.token_embedding(context_ids)
            ctxt = self.pos_encoding(ctxt)
            x = self.transformer(x, ctxt)
        else:
            x = self.transformer(x)
        logits = self.output_projection(x)            # (batch_size, seq_len, vocab_size)
        return logits


In [17]:
class ChunkDataset(Dataset):
    def __init__(self, chunks_tensor: torch.Tensor):
        """
        chunks_tensor: torch.LongTensor of shape (num_chunks, chunk_size)
        """
        self.chunks = chunks_tensor

    def __len__(self):
        return self.chunks.size(0)

    def __getitem__(self, idx):
        return self.chunks[idx]


In [18]:
# 2. Create DataLoaders for train / validation / test splits
# Assume `tokenized_tiny_wikipedia` is already in memory as a dict:
# {
#   "train":    torch.LongTensor of shape (num_train_chunks, chunk_size),
#   "validation": torch.LongTensor of shape (num_val_chunks,   chunk_size),
#   "test":     torch.LongTensor of shape (num_test_chunks,  chunk_size)
# }

train_dataset = ChunkDataset(tokenized_tiny_wikipedia["train"])
val_dataset   = ChunkDataset(tokenized_tiny_wikipedia["validation"])
test_dataset  = ChunkDataset(tokenized_tiny_wikipedia["test"])

BATCH_SIZE = 8

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False)

In [19]:
# 3. Instantiate two TransformerModel variants:
#    - One for next-token (decoder-only, causal)
#    - One for MLM (encoder-only, bidirectional)

import torch.nn as nn

VOCAB_SIZE  = len(tokenizer)
EMBED_DIM   = 512
NUM_HEADS   = 8
NUM_LAYERS  = 6

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 3a. Next-token model (decoder-only)
model_nt = TransformerModel(
    vocab_size   = VOCAB_SIZE,
    embed_dim    = EMBED_DIM,
    num_heads    = NUM_HEADS,
    num_layers   = NUM_LAYERS,
    is_decoder   = True
).to(device)

# 3b. MLM model (encoder-only)
model_mlm = TransformerModel(
    vocab_size   = VOCAB_SIZE,
    embed_dim    = EMBED_DIM,
    num_heads    = NUM_HEADS,
    num_layers   = NUM_LAYERS,
    is_decoder   = False
).to(device)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [20]:
# 4. Create one optimizer per model
optimizer_nt  = torch.optim.Adam(model_nt.parameters(),  lr=5e-5)
optimizer_mlm = torch.optim.Adam(model_mlm.parameters(), lr=5e-5)

In [21]:
NUM_EPOCHS = 5

In [22]:
# --- Next-Token Training ---
trainer_nt = Trainer(
    model     = model_nt,
    tokenizer = tokenizer,
    dataloader= train_loader,
    optimizer = optimizer_nt,
    device    = device,
    task      = "next_token"
)

print("[Next-Token prediction]")

for epoch in range(NUM_EPOCHS):
    train_loss = trainer_nt.train_epoch()
    val_loss, val_acc = trainer_nt.evaluate(val_loader)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} — train loss: {train_loss:.4f}, "
          f"validation loss: {val_loss:.4f}, validation accuracy: {val_acc:.2f}")

# Evaluate final test perplexity
test_loss_nt, test_acc_nt = trainer_nt.evaluate(test_loader)
print(f"Test loss: {test_loss_nt:.4f}, test accuracy: {test_acc_nt:.2f}")


[Next-Token prediction]


Epoch 1/5 — train loss: 7.2834, validation loss: 6.8241, validation accuracy: 0.14
Epoch 2/5 — train loss: 6.4749, validation loss: 6.4510, validation accuracy: 0.16
Epoch 3/5 — train loss: 6.1325, validation loss: 6.2431, validation accuracy: 0.17
Epoch 4/5 — train loss: 5.9005, validation loss: 6.0946, validation accuracy: 0.18
Epoch 5/5 — train loss: 5.7158, validation loss: 5.9867, validation accuracy: 0.19
Test loss: 5.9705, test accuracy: 0.19


In [23]:
# --- MLM Training ---
trainer_mlm = Trainer(
    model     = model_mlm,
    tokenizer = tokenizer,
    dataloader= train_loader,
    optimizer = optimizer_mlm,
    device    = device,
    task      = "mlm"
)

print("[Masked Language Model]")

for epoch in range(NUM_EPOCHS):
    train_loss = trainer_mlm.train_epoch()
    val_loss, val_acc = trainer_mlm.evaluate(val_loader)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} — train loss: {train_loss:.4f}, "
          f"validation loss: {val_loss:.4f}, validation accuracy: {val_acc:.4f}")

# Evaluate final test MLM accuracy
test_loss_mlm, test_acc_mlm = trainer_mlm.evaluate(test_loader)
print(f"Test loss: {test_loss_mlm:.4f}, test accuracy: {test_acc_mlm:.4f}")

[Masked Language Model]
Epoch 1/5 — train loss: 7.7403, validation loss: 7.5098, validation accuracy: 0.0976
Epoch 2/5 — train loss: 7.3718, validation loss: 7.3330, validation accuracy: 0.1133
Epoch 3/5 — train loss: 7.2110, validation loss: 7.2171, validation accuracy: 0.1221
Epoch 4/5 — train loss: 7.0914, validation loss: 7.1399, validation accuracy: 0.1244
Epoch 5/5 — train loss: 6.9620, validation loss: 7.0372, validation accuracy: 0.1348
Test loss: 7.0061, test accuracy: 0.1362
