In [None]:
# September 2025
# Train multilingual classification neural network
# Violeta Berdejo-Espinola

# pytorch dataset
# pythorch model
# pytorch training loop

In [None]:
%pip install torch transformers sys matplotlib numpy sklearn pytorch-ignite

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
import pickle
import numpy as np

from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, recall_score, precision_score, f1_score

import matplotlib.pyplot as plt
from collections import Counter
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

import sys

print('system version:', sys.version)
print('pytorch version:', torch.__version__)
print('numpy version:', np.version)

import platform; 

print(f'mac processor: {platform.mac_ver()}')
print(f'mps is available: {torch.backends.mps.is_built()}')
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"using device: {device}")

# function to load data

In [None]:
# file paths

filepath_x = "../data/for_analysis/eng_x.pickle"

filepath_y = "../data/for_analysis/eng_y.pickle"

In [None]:
def load_data(path_x, path_y):
    """
    Args:
        path_x (list[str]): Input texts.
        path_y (list[str]): Input labels.
        
    Returns:
        lists
    """
    
    with open(path_x,'rb') as x_file:
        texts = pickle.load(x_file)
    with open (path_y, 'rb') as y_file:
        labels = pickle.load(y_file)
        
    return texts[:500], labels[:500]
        
texts, labels = load_data("../data/for_analysis/eng_x.pickle", "../data/for_analysis/eng_y.pickle")

print(labels)

In [None]:
# # function to load multilingual data

# file paths

# filepaths_x = [
#     "../data/for_analysis/eng_x.pickle",
#     # "../data/for_analysis/jap_x.pickle"
#     # "../data/for_analysis/spa_x.pickle"
# ]

# filepaths_y = [
#     "../data/for_analysis/eng_y.pickle",
#     # "../data/for_analysis/jap_y.pickle"
#     # "../data/for_analysis/spa_y.pickle"
# ]

# # def load_data(filepaths_x, filepaths_y):
# #     """Load and combine multilingual data"""
# #     all_texts = []
# #     all_labels = []
    
# #     for fp_x, fp_y in zip(filepaths_x, filepaths_y):
# #         with open(fp_x, 'rb') as f:
# #             texts = pickle.load(f)
# #         with open(fp_y, 'rb') as f:
# #             labels = pickle.load(f)
            
# #         all_texts.extend(texts)
# #         all_labels.extend(labels)
    
# #     return all_texts[:40], all_labels[:40]


# # texts, labels = load_data(filepaths_x, filepaths_y)

# # print(labels)

# encode data

In [None]:
# load hf tokenizer and model

tokenizer = AutoTokenizer.from_pretrained(
    "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"       # max sequence length == 128 ~ 80 words
)                                                                       # mapping to 768 dimensional vector space
hf_model = AutoModel.from_pretrained(                                   # using a hf Tokenizer and Model function we can extend the max length
    "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
hf_model = hf_model.to(device)

# function to encode text

def encode_texts(texts, tokenizer, model, device, max_length=512):
    """
    Encode a list of texts into fixed-size embeddings using mean pooling.
    
    Args:
        texts (list[str]): Input texts.
        tokenizer: Hugging Face tokenizer.
        model: Multilingual model (encoder).
        device (torch.device): "mpu".
        max_length (int): Maximum sequence length.
        
    Returns:
        torch.Tensor of shape (batch_size, hidden_size)
    """
    # Tokenize
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    ).to(device)

    # Forward pass (no gradients, eval mode)
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        
    token_embeddings = outputs.last_hidden_state  # (batch, seq_len, hidden_size)

    # Mean pooling (ignore padding)
    attention_mask = inputs["attention_mask"].unsqueeze(-1).float()  # (batch, seq_len, 1)
    sum_embeddings = (token_embeddings * attention_mask).sum(dim=1)  # (batch, hidden_size)
    sum_mask = attention_mask.sum(dim=1).clamp(min=1e-9)             # (batch, 1)
    embeddings = sum_embeddings / sum_mask                           # (batch, hidden_size)

    return embeddings

# dataset class

In [None]:
# create datasets class

class TextDataset(Dataset):

    def __init__(self, texts, labels, tokenizer, model, device, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.model = model
        self.device = device
        self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        
        # create embedding for single text
        embedding = encode_texts([text], self.tokenizer, self.model, self.device, self.max_length)
        embedding = embedding.squeeze(0)  # remove batch dimension
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        
        return embedding, label

# data imbalance

In [None]:
# dealing with imbalance in loss function

def compute_loss_weights(labels):
    """Compute class weights for weighted loss function"""
    
    unique_labels = np.unique(labels)
    class_weights = compute_class_weight(
        'balanced',
        classes=unique_labels,
        y=labels
    )
    
    return torch.FloatTensor(class_weights).to(device)

In [None]:
# analyze and visualize class distribution

def analyze_class_distribution(labels):
   
    counter = Counter(labels)
    print("Class distribution:")
    
    for class_id, count in sorted(counter.items()):
        
        print(f"Class {class_id}: {count} samples ({count/len(labels)*100:.2f}%)")
    
    # Calculate imbalance ratio
    counts = list(counter.values())
    imbalance_ratio = max(counts) / min(counts)
    
    print(f"Imbalance ratio: {imbalance_ratio:.2f}:1")
    
    return counter

In [None]:
# dataloader stratified batch sampler

from torch.utils.data import Sampler
from collections import defaultdict

class StratifiedBatchSampler(Sampler):
    """
    Sampler that creates stratified mini-batches where each batch 
    maintains approximately the same class distribution as the dataset.
    """
    
    def __init__(self, labels, batch_size, shuffle=True):
        """
        Args:
            labels (list or array): Class labels for all samples
            batch_size (int): Size of mini-batches
            shuffle (bool): Whether to shuffle within each class
        """
        self.labels = np.array(labels)
        self.batch_size = batch_size
        self.shuffle = shuffle
        
        # Group indices by class
        self.class_indices = defaultdict(list)
        for idx, label in enumerate(self.labels):
            self.class_indices[label].append(idx)
        
        # Convert to lists
        self.class_indices = {k: np.array(v) for k, v in self.class_indices.items()}
        self.classes = list(self.class_indices.keys())
        
        # Calculate samples per class per batch
        self.class_counts = {cls: len(indices) for cls, indices in self.class_indices.items()}
        total_samples = sum(self.class_counts.values())
        
        # Proportional representation in each batch
        self.samples_per_class = {
            cls: max(1, int(batch_size * count / total_samples))
            for cls, count in self.class_counts.items()
        }
        
        # Adjust to ensure batch_size is maintained
        diff = batch_size - sum(self.samples_per_class.values())
        if diff != 0:
            # Add/remove from largest class
            largest_class = max(self.class_counts, key=self.class_counts.get)
            self.samples_per_class[largest_class] += diff
        
        print(f"Samples per class per batch: {self.samples_per_class}")
        
    def __iter__(self):
        # Shuffle indices within each class if required
        if self.shuffle:
            class_indices = {
                cls: np.random.permutation(indices).tolist()
                for cls, indices in self.class_indices.items()
            }
        else:
            class_indices = {cls: indices.tolist() for cls, indices in self.class_indices.items()}
        
        # Create batches
        batches = []
        
        # Calculate number of batches
        min_batches = min(
            len(indices) // self.samples_per_class[cls]
            for cls, indices in class_indices.items()
        )
        
        for batch_idx in range(min_batches):
            batch = []
            for cls in self.classes:
                start_idx = batch_idx * self.samples_per_class[cls]
                end_idx = start_idx + self.samples_per_class[cls]
                batch.extend(class_indices[cls][start_idx:end_idx])
            
            # Shuffle within batch
            if self.shuffle:
                np.random.shuffle(batch)
            
            batches.append(batch)
        
        # Handle remaining samples
        remaining = []
        for cls in self.classes:
            start_idx = min_batches * self.samples_per_class[cls]
            remaining.extend(class_indices[cls][start_idx:])
        
        # Create final partial batch if we have remaining samples
        if len(remaining) >= self.batch_size // 2:  # Only if we have substantial samples
            if self.shuffle:
                np.random.shuffle(remaining)
            batches.append(remaining[:self.batch_size])
        
        # Shuffle batch order
        if self.shuffle:
            np.random.shuffle(batches)
        
        for batch in batches:
            yield batch
    
    def __len__(self):
        min_batches = min(
            len(indices) // self.samples_per_class[cls]
            for cls, indices in self.class_indices.items()
        )
        return min_batches

In [None]:
# # weighted random sampler --> use instead of mini batches

# def weighted_sampler(labels):
#     """Assign a weight inversely proportional to its class frequency"""
    
#     # 1. Create mapping of each class to a weight equal to 1 / count
#     class_counts = Counter(labels)
#     class_weights = {cls: 1.0/count for cls, count in class_counts.items()}
#     print(f"Class weights: {class_weights}")
    
#     # 2. Replace class with that weight.
#     sample_weights = [class_weights[label] for label in labels]
#     print(len(sample_weights))
    
#     # 3. Sampler object
#     sampler = WeightedRandomSampler(
#         weights=sample_weights,
#         num_samples=len(sample_weights),
#         replacement=True # samples can be picked multiple times per epoch
#     )
    
#     return sampler

# classifier class

In [None]:
# create classifier class

class MLPClassifier(nn.Module):
    
    def __init__(self, input_dim=None, hidden_dims=[512, 512, 512], num_classes=2, dropout=0.2): 
        super(MLPClassifier, self).__init__()                                               
        """
        Multi-layer perceptron for text classification.
        
        Args:
        input_dim (feature vector): per sample and outputs
        hidden_dim (list): hidden state
        num_classes (int): classification classes
        dropout (float): fraction of neurons to drop during training at each training step
        
        Returns: num_classes logits per sample
        """
        
        layers = []
        prev_dim = input_dim
        
        # Hidden layers
        for hidden_dim in hidden_dims:
            layers.extend([
                
                # a single fully connected layer -> performs the linear transformation on input data: 768 to 512
                nn.Linear(prev_dim, hidden_dim), # it applies the operation: y = xW^T + b
                                                 
                # nn.BatchNorm1d(hidden_dim), # normalize inputs  - if dropout is used normalization is not 
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim
        
        # Output layer
        layers.append(nn.Linear(prev_dim, num_classes))
        
        # container that organizes multiple layers into a pipeline
        self.classifier = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.classifier(x)

# train eval function

In [None]:
# train function

def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    for embeddings, labels in tqdm(dataloader, desc="Training"):
        embeddings, labels = embeddings.to(device), labels.to(device)
        
        # Forward pass (no gradients)
        optimizer.zero_grad()
        outputs = model(embeddings)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Backward pass (update gradients)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Collect predictions and labels
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    accuracy_tr = 100 * correct / total
    
    # Calculate performance metrics
    recall_tr = recall_score(all_labels, all_preds)
    f1_tr = f1_score(all_labels, all_preds)
    precision_tr = precision_score(all_labels, all_preds)
    
    return (all_preds, all_labels, 
            avg_loss, recall_tr, precision_tr, f1_tr, accuracy_tr)

# eval function 

def validate_epoch(model, dataloader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for embeddings, labels in tqdm(dataloader, desc="Validation"):
            embeddings, labels = embeddings.to(device), labels.to(device)
            
            # 1. Forward pass
            outputs = model(embeddings)
            
            # 2. Calculate loss
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # 3. Collect predictions and labels
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    accuracy_val = 100 * correct / total
    
    # 5. Calculate metrics for each class
    recall_val = recall_score(all_labels, all_preds)
    f1_val = f1_score(all_labels, all_preds)
    precision_val = precision_score(all_labels, all_preds)
    
    return (all_preds, all_labels, 
            avg_loss, recall_val, precision_val, f1_val, accuracy_val)

# load data, split, create datastes and dataloaders

In [None]:
# load data

texts, labels = load_data(filepath_x, filepath_y)

print(f"Total documents: {len(texts)}")

class_distribution = analyze_class_distribution(labels)

weights = compute_loss_weights(labels)

# split data       %% run it with different seeds %%

seed = 42

X_train, X_temp, y_train, y_temp = train_test_split(
    texts, labels, test_size=0.3, random_state=seed, stratify=labels
)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.2, random_state=seed, stratify=y_temp
)

print(f"Train: {len(X_train)}, \
      Val: {len(X_val)}, \
      Test: {len(X_test)}")
    
# datasets

train_dataset = TextDataset(X_train, y_train, tokenizer, hf_model, device)
val_dataset = TextDataset(X_val, y_val, tokenizer, hf_model, device)
test_dataset = TextDataset(X_test, y_test, tokenizer, hf_model, device)

# weighted sampler

# train_sampler, class_weights = weighted_sampler(y_train)

batch_size = 32

train_batch_sampler = StratifiedBatchSampler(
    labels=y_train,
    batch_size=batch_size,
    shuffle=True
)

# dataloaders

train_loader = DataLoader(train_dataset, batch_sampler=train_batch_sampler, num_workers=0)  # batch_sampler
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# get embedding dimension

sample_embedding, _ = train_dataset[0]
embedding_dim = sample_embedding.shape[0]

print(f"Embedding dimension: {embedding_dim}")

# training loop

In [None]:
# training loop

def train_loop():
    
    # Initialize model
    model = MLPClassifier(input_dim=embedding_dim).to(device)
    
    # Initialize loss function
    criterion = nn.CrossEntropyLoss(weight=weights)
    
    num_epochs = 25                       
    patience = 5  # stop if no improvement after 5 epochs
    patience_counter = 0
    best_val_f1 = 0 
    
    # Initialize optimizer
    optimizer = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.05)
    total_steps = len(train_loader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    train_preds, train_labels = [],[]
    val_preds, val_labels = [],[]
    train_losses, val_losses = [],[]
    train_accs, val_accs = [],[]
    train_recall, val_recall = [],[]
    train_f1, val_f1 = [],[]
    train_precision, val_precision = [],[]
    
    print("\nTraining...")
    
    for epoch in range(num_epochs):
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # 1. Train
        (train_pred, train_label, train_loss, tr_recall, tr_precision, tr_f1, train_acc) = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # 2. Validate
        (val_pred, val_label, val_loss, v_recall, v_f1, v_precision, val_acc) = validate_epoch(model, val_loader, criterion, device)
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Train Recall: {tr_recall:.4f}")
        print(f"Train F1: {tr_f1:.4f}")
        print(f"Train Precision: {tr_precision:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"Val Recall: {v_recall:.4f}")
        print(f"Val F1: {v_f1:.4f}")
        print(f"Val Precision: {v_precision:.4f}")
        
        # 3. Update learning rate
        scheduler.step()
        
        # 4. Save metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        train_recall.append(tr_recall)
        val_recall.append(v_recall)
        train_f1.append(tr_f1)
        val_f1.append(v_f1)
        train_precision.append(tr_precision)
        val_precision.append(v_precision)
        train_preds.append(train_pred)
        val_preds.append(val_pred)

        train_labels.append(train_label)
        val_labels.append(val_label)
        
        # 5. Save best model based on F1 score and reset patience counter
        if v_f1 > best_val_f1:
            best_val_f1 = v_f1
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"New best validation F1: {best_val_f1:.4f}")
            patience_counter = 0
            
        else:
            
            patience_counter += 1
            print(f"No improvement. Early stopping counter: {patience_counter}/{patience}")
            
            if patience_counter >= patience:
                print("Early stopping activated")
                break
    
    # 6. Load best model
    model.load_state_dict(torch.load('best_model.pth'))
    print("\nEvaluating on test set...")
    
    # 7. Test
    (test_labels, test_preds, test_loss, test_acc,
     test_recall, test_f1, test_precision) = validate_epoch(model, test_loader, criterion, device)
    
    print(f"\nTest Results:")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.2f}%")
    print(f"Test Recall: {test_recall:.4f}")
    print(f"Test F1: {test_f1:.4f}")
    print(f"Test Precision: {test_precision:.4f}")
    
    print("\nClassification Report:")
    print(classification_report(test_labels, test_preds, labels=[0,1], target_names=['Class 0', 'Class 1'], zero_division=0))
    
    cm = confusion_matrix(test_labels, test_preds)
    print("\nConfusion Matrix Test Set:")
    print(cm)
    
    return model, hf_model, train_loader, val_loader, test_loader, train_losses, val_losses, train_acc, val_acc, test_labels, test_preds

In [None]:
# seed and run
torch.manual_seed(42)
trained_model, hf_model, train_loader, val_loader, test_loader, train_losses, val_losses, train_acc, val_acc, test_labels, test_preds = train_loop()