## Setup

In [None]:
import torch
import torch.nn as nn
from transformers import (
    PretrainedConfig,
    PreTrainedModel,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
    EarlyStoppingCallback
)
from datasets import load_dataset
from transformers import AutoTokenizer
import numpy as np
import evaluate

# Create an HF-compatible PyTorch Model

Read ["Customizing models"](https://huggingface.co/docs/transformers/custom_models) or search online for more info on HF custom models.

Remember this time we are not using CNNs/UNets, but RNNs with text.

In [None]:
# @title Define the config class


class RNNConfig(PretrainedConfig):
    model_type = "custom_raw_rnn"

    def __init__(
        self,
        vocab_size=30522,  # Default matching DistilBERT (change to match your tokenizer if needed)
        embedding_dim=64,
        hidden_dim=128,    # Size of the "Memory" vector h_t // Era 128
        n_classes=6,       # AG News has 4 classes (again, change if using different dataset)
        dropout=0.2,
        **kwargs
    ):
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_classes = n_classes
        self.dropout = dropout
        super().__init__(**kwargs)

In [None]:
# @title Define a recurrent layer and text classification model


# To separate the complexity from the sequence classification task from the complexity of
# using recursion in our network, let's divide the recursion in one class (`VanillaRNNLayer`)
# and the text classification model in another (`TextClassificationRawRNN`)


# The "layer" we will use inside our model. This class handles recursion so we don't need to wory
# about it in the model class
class VanillaRNNLayer(nn.Module):
    """
    A manual implementation of an Elman RNN.
    We avoid nn.RNN to show you the internal loop.
    """
    def __init__(self, input_size, hidden_size, dropout):
        super().__init__()
        self.hidden_size = hidden_size

        # We need two linear transformations:
        # 1. Input-to-Hidden (Processes the current word x_t)
        self.i2h = nn.Linear(input_size, hidden_size)

        # 2. Hidden-to-Hidden (Processes the previous memory h_{t-1})
        self.h2h =  nn.Linear(hidden_size, hidden_size)
        # nn.Sequential(
        #     nn.Linear(hidden_size, hidden_size),
        #     nn.Dropout(dropout)
        # )

        # The activation function (Tanh is standard for vanilla RNNs)
        self.activation = nn.Tanh()

    def forward(self, x, attention_mask=None):
        # x shape: [Batch_Size, Seq_Len, Embedding_Dim]
        batch_size, seq_len, _ = x.size()

        # Initialize hidden state h_0 with zeros (the "placeholder" when we don't have a previous state)
        # QUESTION: Why not initialize with random noise like autoencoder(decoder)/GAN(generator)/diffusion inferences?
        # Shape: [Batch_Size, Hidden_Size]
        h_t = torch.zeros(batch_size, self.hidden_size).to(x.device)

        # List to store hidden states at every time step
        hidden_states_history = []

        # --- THE TEACHING MOMENT: The Time Loop ---
        for t in range(seq_len):
            # 1. Get current input step
            x_t = x[:, t, :]

            # 2. Calculate the update: New Memory = Tanh(Input + Old Memory)
            # Note: The Linear layers handle the Wx + b math
            i2h_out = self.i2h(x_t)
            h2h_out = self.h2h(h_t)

            next_h_t = self.activation(i2h_out + h2h_out)
            if attention_mask is not None:
                mask_t = attention_mask[:, t].unsqueeze(1).type_as(next_h_t)
                h_t = mask_t * next_h_t + (1.0 - mask_t) * h_t
            else:
                h_t = next_h_t

            # 3. Save state
            hidden_states_history.append(h_t)

        # Stack the history to return shape [Batch_Size, Seq_Len, Hidden_Size]
        # This mimics the output format of standard PyTorch recurrent layers
        output = torch.stack(hidden_states_history, dim=1)

        return output


# The "model" we will actually use. This class handles the specifics of the text classification task
class TextClassificationRawRNN(PreTrainedModel):
    config_class = RNNConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # Layer 1: Embeddings - Converts Token IDs (integers) to Vectors
        # nn.Embedding is basically a one-hot encodding followed by a nn.Linear layer, but done more efficiently
        # Learn more about this layer: https://discuss.pytorch.org/t/how-does-nn-embedding-work/88518
        self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=0)

        # Layer 2: Our Custom "Raw" RNN - Does the actual recursive part of recursive neural networks.
        # Later, we will be able to replace this layer with some nn.Module layers that also handle
        # recursion under the hood, so we won't have to manually deal with it every time.
        self.rnn_block = VanillaRNNLayer(
            input_size=config.embedding_dim,
            hidden_size=config.hidden_dim,
            dropout=config.dropout
        )

        self.dropout = nn.Dropout(config.dropout)

        # Layer 3: Classifier Head - Project hidden state into a y_hat prediction
        self.classifier = nn.Linear(config.hidden_dim, config.n_classes)

        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # HF compatibility workarounds necessary in newer `transformers` versions
        self.all_tied_weights_keys = []
        self._tied_weights_keys = []
        self.post_init()

    # More HF compatibility workarounds needed in recent `transformers` versions
    @property
    def dummy_inputs(self):
        # Helps HF infer input shapes if needed
        return {"input_ids": torch.tensor([[0, 1]]), "attention_mask": torch.tensor([[1, 1]])}

    def _check_and_adjust_experts_implementation(self, experts_implementation):
        # This bypasses a check for model source file. Without this method, you get an error during model instantiation
        return experts_implementation

    # Forward method as usual
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        # 1. Get Embeddings
        x = self.embedding(input_ids) # [Batch, Seq, Emb]

        # 2. Run our Custom RNN Loop
        rnn_output = self.rnn_block(x, attention_mask=attention_mask) # [Batch, Seq, Hidden]

        # 3. Smart Pooling (Handling Padding)
        # We need to pick the hidden state corresponding to the last REAL token, not padding.
        if attention_mask is not None:
            # -1 because lengths are 1-based, indices are 0-based
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = input_ids.shape[0]
            # This is necessary because the batch has fixed sequence size, but sequences have
            # variable size. That way, we pad sequences will padding tokens to reach the desired
            # max length, and then only look at the last token corresponding to an actual word.


            # Select the correct time-step for each item in the batch
            last_hidden_states = rnn_output[torch.arange(batch_size, device=x.device), sequence_lengths]
        else:
            # Fallback if no mask provided (rare in HF)
            last_hidden_states = rnn_output[:, -1, :]

        last_hidden_states = self.dropout(last_hidden_states)

        # 4. Project to Classes (classification head projecting a hidden state into class labels)
        logits = self.classifier(last_hidden_states)

        # 5. Calculate Loss (as HF requires)
        loss = None
        if labels is not None:
            loss = self.loss_fn(logits, labels)

        return {"loss": loss, "logits": logits}

# Setup Training

In [None]:
# @title Define Data Pipeline (Emotion CSV)

from pathlib import Path

# Use DistilBERT tokenizer just for the vocab mapping
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Resolve project root whether notebook runs from workspace root or from notebooks/ folder
project_root = Path.cwd()
if not (project_root / "data").exists() and (project_root.parent / "data").exists():
    project_root = project_root.parent

train_csv = project_root / "data" / "kaggle_emotion_classification" / "train_large.csv"
if not train_csv.exists():
    raise FileNotFoundError(f"Arquivo n√£o encontrado: {train_csv}")

# Load labeled data and create train/validation split
dataset = load_dataset("csv", data_files={"train": str(train_csv)})["train"]
dataset = dataset.train_test_split(test_size=0.2, seed=42)

# Guarantee labels are contiguous [0..N-1] for CrossEntropyLoss
label_values = sorted(set(dataset["train"]["label"]))
label2id = {label: idx for idx, label in enumerate(label_values)}
id2label = {idx: str(label) for label, idx in label2id.items()}

if label_values != list(range(len(label_values))):
    dataset = dataset.map(lambda x: {"label": label2id[x["label"]]})

num_labels = len(label_values)
print(f"Labels detectados: {label_values} -> total de classes: {num_labels}")

# Preprocessing
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding=False, max_length=128)

# Remove non-model columns (like raw text/id) to avoid collator issues
columns_to_remove = [col for col in dataset["train"].column_names if col != "label"]
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=columns_to_remove)

In [None]:
# @title Training Setup

# Setup model (number of classes inferred from dataset)
config = RNNConfig(
    vocab_size=tokenizer.vocab_size,
    embedding_dim=64,
    hidden_dim=128,
    n_classes=num_labels,
    dropout=0.2
)
model = TextClassificationRawRNN(config)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model moved to: {device}")

# Setup metrics
accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

# Setup training loop and args
training_args = TrainingArguments(
    output_dir="./raw_rnn_agnews",
    learning_rate=5e-4,
    per_device_train_batch_size=128, # Era 64
    num_train_epochs=10,             # Keep it short for the demo
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    lr_scheduler_type="linear",
    warmup_ratio=0.1,
    weight_decay=0.01,
    max_grad_norm=1.0,
    logging_steps=50,
    remove_unused_columns=True,
    report_to="none"
 )

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    # callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

# Quick sanity check to catch label-range/device errors before full training
# sample_batch = data_collator([tokenized_datasets["train"][0], tokenized_datasets["train"][1]])
# sample_batch = {k: v.to(device) for k, v in sample_batch.items()}
# with torch.no_grad():
#     _ = model(**sample_batch)
# print("Sanity check OK: forward pass com batch real funcionou.")

In [None]:
# @title Check device (CPU or GPU)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
else:
    print("Running on CPU")

# Check where the model will be placed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nModel will run on: {device}")

In [None]:
# @title Start training

trainer.train()

In [None]:
# @title Save trained model

save_directory = "./raw_rnn_agnews/final_checkpoint"

# 1. Save the model weights and config
trainer.save_model(save_directory)

# 2. Save the tokenizer (crucial for ensuring the vocab mapping stays consistent)...
#    if we weren't just using a pretrained tokenizer downloaded off of HF...
#    Well, if you decide to train your own, now you know how
tokenizer.save_pretrained(save_directory)

# Run Inferences and Calculate Metrics

In [None]:
# @title Load the model checkpoint

# If you are still in the same session that you trained the model in, this is unecessary. But this cell
# enables you to load a checkpoint saved days/weeks ago and still get the same inferences.

# Notice we must use our custom class `TextClassificationRawRNN`, not AutoModel
loaded_model = TextClassificationRawRNN.from_pretrained(save_directory)
loaded_tokenizer = AutoTokenizer.from_pretrained(save_directory)

In [None]:
# @title Get test set metrics

# We create a temporary Trainer just for evaluation (it handles batching/collating for us)
# You can reuse the one from training, if still in the same session
eval_trainer = Trainer(
    model=loaded_model,
    args=TrainingArguments(output_dir="./eval_output", report_to="none"), # dummy args
    eval_dataset=tokenized_datasets["test"],
    processing_class=loaded_tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

metrics = eval_trainer.evaluate()
print(f"\nTest Set Accuracy: {metrics['eval_accuracy']:.2%}")
print(f"Test Set Loss:     {metrics['eval_loss']:.4f}")

In [None]:
# @title Run inference on new data


# AG News Labels: 0=World, 1=Sports, 2=Business, 3=Sci/Tech
# Kaggle Emotion Classification Labels: 0=Sadness, 1=Joy, 2=Love, 3=Anger, 4=Fear, 5=Surprise
# label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
label_map = {0: "Sadness", 1: "Joy", 2: "Love", 3: "Anger", 4: "Fear", 5: "Surprise"}

def predict_news_topic(text):
    # Prepare input (tokenize)
    inputs = loaded_tokenizer(text, return_tensors="pt", truncation=True, padding=False)

    # Move inputs to same device as model
    inputs = {k: v.to(loaded_model.device) for k, v in inputs.items()}

    # Forward pass
    with torch.no_grad():
        outputs = loaded_model(**inputs)
        logits = outputs["logits"]
        predicted_class_id = torch.argmax(logits, dim=-1).item()

    return label_map[predicted_class_id]

# --- Demo ---
print("\nü§ñ INFERENCE DEMO:")
sample_text = "Nvidia shares jumped today after announcing a new AI chip architecture."
prediction = predict_news_topic(sample_text)

print(f"üìù Text: '{sample_text}'")
print(f"üè∑Ô∏è Predicted Topic: {prediction}")