In [42]:
import os
import time
import shutil
import random
import re
from typing import Tuple
from argparse import Namespace

import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer
from sklearn.metrics import accuracy_score
from tqdm import tqdm

# Set seeds
seed = 1111
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True  # Ensure reproducibility


In [74]:
# Configuration 
tokenizer = BertTokenizer.from_pretrained('dccuchile/bert-base-spanish-wwm-uncased')
args = Namespace(
    emb_size=200,
    num_layers=5,
    n_heads=5,
    head_size=50,
    vocab_size=tokenizer.vocab_size,
    max_seq_len=30,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    batch_size=16,
    lr=3e-3,
    num_epochs=100,
    patience=10,
    lr_patience=10,
    lr_factor=0.5,
    savedir='model'
)
os.makedirs(args.savedir, exist_ok=True)


### Data 

In [75]:
# Load and preprocess data
def load_data(file_path):
    try:
        return pd.read_csv(file_path, sep='\r\n', engine='python', header=None).loc[:, 0].values.tolist()
    except Exception as e:
        print(f"Error loading data: {e}")
        return []

def preprocess_tweet(tweet):
    tweet = re.sub(r'http\S+', '', tweet)
    tweet = re.sub(r'@\S+', '', tweet)
    tweet = re.sub(r'#\S+', '', tweet)
    tweet = tweet.lower()
    tweet = re.sub(r'\W', ' ', tweet)
    tweet = re.sub(r'\s+', ' ', tweet).strip()
    return tweet

X_train = [preprocess_tweet(tweet) for tweet in load_data('./data_mex20/mex20_train.txt')]
X_val = [preprocess_tweet(tweet) for tweet in load_data('./data_mex20/mex20_val.txt')]
y_train = np.array(load_data('./data_mex20/mex20_train_labels.txt')).reshape(-1)
y_val = np.array(load_data('./data_mex20/mex20_val_labels.txt')).reshape(-1)


In [76]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.encodings = tokenizer(texts, add_special_tokens=True, return_tensors='pt',
                                   truncation=True, max_length=max_length, padding='max_length',
                                   return_attention_mask=True)
        self.labels = torch.tensor(labels, dtype=torch.float)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)


# Create datasets    
train_dataset = TextDataset(X_train, y_train, tokenizer, args.max_seq_len)
val_dataset = TextDataset(X_val, y_val, tokenizer, args.max_seq_len)

# Create dataloaders 
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size)


### Model 

In [77]:
class Attention(nn.Module):
    def __init__(self, emb_size, head_size):
        super().__init__()
        self.key = nn.Linear(emb_size, head_size, bias=False)
        self.query = nn.Linear(emb_size, head_size, bias=False)
        self.value = nn.Linear(emb_size, head_size, bias=False)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x, mask=None):
        k, q, v = self.key(x), self.query(x), self.value(x)
        scores = torch.matmul(q, k.transpose(-2, -1)) * (k.shape[-1] ** -0.5)
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1) == 0, float('-inf'))
        attn = self.dropout(F.softmax(scores, dim=-1))
        return torch.matmul(attn, v)

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, head_size, emb_size):
        super().__init__()
        self.heads = nn.ModuleList([Attention(emb_size, head_size) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads * head_size, emb_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x, mask=None):
        x = torch.cat([h(x, mask) for h in self.heads], dim=-1)
        return self.dropout(self.proj(x))

class FeedForward(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_size, 4 * emb_size),
            nn.ReLU(),
            nn.Linear(4 * emb_size, emb_size),
            nn.Dropout(0.2)
        )

    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self, emb_size, n_heads, head_size):
        super().__init__()
        self.mha = MultiHeadAttention(n_heads, head_size, emb_size)
        self.ln1 = nn.LayerNorm(emb_size)
        self.ln2 = nn.LayerNorm(emb_size)
        self.ff = FeedForward(emb_size)

    def forward(self, x, mask=None):
        x = x + self.mha(self.ln1(x), mask)
        return x + self.ff(self.ln2(x))

class Transformer(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.emb = nn.Embedding(args.vocab_size, args.emb_size)
        self.pos = nn.Embedding(args.max_seq_len, args.emb_size)
        self.blocks = nn.ModuleList([TransformerBlock(args.emb_size, args.n_heads, args.head_size) for _ in range(args.num_layers)])
        self.ln_f = nn.LayerNorm(args.emb_size)
        self.lm_head = nn.Linear(args.emb_size * args.max_seq_len, 1)
        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, mask=None):
        B, T = idx.shape
        x = self.emb(idx) + self.pos(torch.arange(T, device=self.args.device))
        for block in self.blocks:
            x = block(x, mask)
        x = self.ln_f(x)
        x = x.view(B, -1)
        return self.lm_head(x)



In [49]:
# item = next(iter(train_loader))
# ids = item['input_ids']
# mask = item['attention_mask']
# labels = item['labels']
# item

In [78]:
def get_preds(raw_logit):
    return torch.sigmoid(raw_logit)

def model_eval(model, data, device, loss_fn):
    model.eval()
    val_loss, correct = 0, 0
    with torch.no_grad():
        for item in data:
            ids, mask, labels = item['input_ids'].to(device), item['attention_mask'].to(device), item['labels'].to(device)
            outputs = get_preds(model(ids, mask))
            loss = loss_fn(outputs.view(-1), labels)
            val_loss += loss.item()
            preds = (outputs.view(-1) > 0.5).float()
            correct += (preds == labels).sum().item()
    val_loss /= len(data.dataset)
    accuracy = correct / len(data.dataset)
    return val_loss, accuracy

def save_checkpoint(state, is_best, checkpoint_path, filename="checkpoint.pt"):
    filename = os.path.join(checkpoint_path, filename)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, os.path.join(checkpoint_path, "model_best.pt"))

model = Transformer(args).to(args.device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=args.lr_patience, factor=args.lr_factor)

start_time = time.time()
best_metric = 0
n_no_improve = 0
train_loss_history, train_metric_history = [], []
val_loss_history, val_metric_history = [], []

for epoch in range(args.num_epochs):
    model.train()
    train_loss_epoch, correct = 0, 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.num_epochs}")
    for item in loop:
        ids, mask, labels = item['input_ids'].to(args.device), item['attention_mask'].to(args.device), item['labels'].to(args.device)
        optimizer.zero_grad()
        outputs = get_preds(model(ids, mask))
        loss = criterion(outputs.view(-1), labels)
        loss.backward()
        optimizer.step()
        train_loss_epoch += loss.item()
        preds = (outputs.view(-1) > 0.5).float()
        correct += (preds == labels).sum().item()
        loop.set_postfix(train_loss=train_loss_epoch/len(train_loader), train_accuracy=correct/len(train_loader.dataset))

    train_loss = train_loss_epoch / len(train_loader)
    train_accuracy = correct / len(train_loader.dataset)
    train_loss_history.append(train_loss)
    train_metric_history.append(train_accuracy)

    val_loss, val_accuracy = model_eval(model, val_loader, args.device, criterion)
    val_loss_history.append(val_loss)
    val_metric_history.append(val_accuracy)

    scheduler.step(val_loss)

    is_improvement = val_accuracy > best_metric
    if is_improvement:
        best_metric = val_accuracy
        n_no_improve = 0
    else:
        n_no_improve += 1

    save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 
                     'scheduler': scheduler.state_dict(), 'best_metric': best_metric}, is_improvement, args.savedir)

    if n_no_improve >= args.patience:
        print("No improvement. Breaking out of loop.")
        break

    print(f'Epoch [{epoch+1}/{args.num_epochs}], Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}, Time: {time.time() - start_time:.2f}s')

print(f"Total Training Time: {time.time() - start_time:.2f} seconds")


Epoch 1/100: 100%|██████████| 330/330 [01:46<00:00,  3.10it/s, train_accuracy=0.698, train_loss=0.726]   


Epoch [1/100], Train Loss: 0.7263, Train Acc: 0.6980, Val Loss: 0.0413, Val Acc: 0.7240, Time: 108.96s


Epoch 2/100: 100%|██████████| 330/330 [02:29<00:00,  2.21it/s, train_accuracy=0.816, train_loss=0.456]  


Epoch [2/100], Train Loss: 0.4563, Train Acc: 0.8158, Val Loss: 0.0364, Val Acc: 0.7530, Time: 260.67s


Epoch 3/100: 100%|██████████| 330/330 [02:47<00:00,  1.96it/s, train_accuracy=0.873, train_loss=0.306]  


Epoch [3/100], Train Loss: 0.3064, Train Acc: 0.8727, Val Loss: 0.0275, Val Acc: 0.8058, Time: 431.11s


Epoch 4/100: 100%|██████████| 330/330 [02:45<00:00,  1.99it/s, train_accuracy=0.907, train_loss=0.238]   


Epoch [4/100], Train Loss: 0.2379, Train Acc: 0.9072, Val Loss: 0.0290, Val Acc: 0.8058, Time: 599.27s


Epoch 5/100: 100%|██████████| 330/330 [02:56<00:00,  1.87it/s, train_accuracy=0.923, train_loss=0.183]   


Epoch [5/100], Train Loss: 0.1830, Train Acc: 0.9233, Val Loss: 0.0361, Val Acc: 0.7853, Time: 778.49s


Epoch 6/100: 100%|██████████| 330/330 [02:42<00:00,  2.03it/s, train_accuracy=0.948, train_loss=0.134]   


Epoch [6/100], Train Loss: 0.1342, Train Acc: 0.9483, Val Loss: 0.0420, Val Acc: 0.7888, Time: 943.43s


Epoch 7/100: 100%|██████████| 330/330 [03:04<00:00,  1.79it/s, train_accuracy=0.951, train_loss=0.116]   


Epoch [7/100], Train Loss: 0.1165, Train Acc: 0.9509, Val Loss: 0.0425, Val Acc: 0.7922, Time: 1130.13s


Epoch 8/100: 100%|██████████| 330/330 [02:43<00:00,  2.02it/s, train_accuracy=0.959, train_loss=0.109]   


Epoch [8/100], Train Loss: 0.1092, Train Acc: 0.9591, Val Loss: 0.0437, Val Acc: 0.7888, Time: 1295.65s


Epoch 9/100: 100%|██████████| 330/330 [02:45<00:00,  1.99it/s, train_accuracy=0.968, train_loss=0.0865]  


Epoch [9/100], Train Loss: 0.0865, Train Acc: 0.9676, Val Loss: 0.0487, Val Acc: 0.7530, Time: 1463.66s


Epoch 10/100: 100%|██████████| 330/330 [02:37<00:00,  2.10it/s, train_accuracy=0.973, train_loss=0.0714]  


Epoch [10/100], Train Loss: 0.0714, Train Acc: 0.9735, Val Loss: 0.0570, Val Acc: 0.7649, Time: 1623.57s


Epoch 11/100: 100%|██████████| 330/330 [02:39<00:00,  2.07it/s, train_accuracy=0.978, train_loss=0.059]   


Epoch [11/100], Train Loss: 0.0590, Train Acc: 0.9782, Val Loss: 0.0602, Val Acc: 0.7615, Time: 1785.20s


Epoch 12/100: 100%|██████████| 330/330 [02:31<00:00,  2.18it/s, train_accuracy=0.973, train_loss=0.079]   


Epoch [12/100], Train Loss: 0.0790, Train Acc: 0.9729, Val Loss: 0.0538, Val Acc: 0.7683, Time: 1939.99s


Epoch 13/100: 100%|██████████| 330/330 [25:12<00:00,  4.58s/it, train_accuracy=0.974, train_loss=0.0681]    


No improvement. Breaking out of loop.
Total Training Time: 3454.96 seconds
