# Entrenamiento de Modelo de Clasificación de Toxicidad
Este notebook cubre el entrenamiento de un modelo de clasificación de texto (toxicity vs no-toxicity) usando un dataset sintético generado a partir de intervenciones parlamentarias.

In [None]:
# 📦 Cargar librerías necesarias
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from transformers import DataCollatorWithPadding
import torch
from datasets import Dataset

In [None]:
# 📂 Cargar el dataset
df = pd.read_csv("../data/processed/dataset_diverso_10000.csv")
df = df.dropna()
df['label'] = df['label'].apply(lambda x: 1 if x.lower() == 'toxico' else 0)
df.head()

In [None]:
# 🔀 Dividir el dataset
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)

In [None]:
# 🧠 Preparar modelo y tokenizer
modelo_base = "bert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained(modelo_base)

def tokenize_function(example):
    return tokenizer(example["text"], truncation=True)

tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_test = test_dataset.map(tokenize_function, batched=True)

In [None]:
# En lugar de usar Trainer y TrainingArguments:
from transformers import BertForSequenceClassification, AdamW

model = BertForSequenceClassification.from_pretrained("bert-base-multilingual-cased", num_labels=2)
optimizer = AdamW(model.parameters(), lr=2e-5)

# Luego tú controlas todo el entrenamiento con un loop:
for epoch in range(3):
    model.train()
    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()


In [None]:
# 🚀 Entrenar
trainer.train()

In [None]:
# 📊 Evaluar
preds = trainer.predict(tokenized_test)
y_pred = preds.predictions.argmax(axis=1)
y_true = preds.label_ids
print(classification_report(y_true, y_pred))