# Test - Train Model

In [2]:
# Import packages
from datasets import Dataset
import pandas as pd
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from transformers import CamembertTokenizer, CamembertForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback, CamembertConfig, CamembertModel
import torch
import torch.nn as nn

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, precision_recall_fscore_support
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


### Formulas

In [3]:
def preprocess(data):
    tokens = tokenizer(data['text'], truncation=True, padding="max_length")
    data["labels"] = [data[col] for col in lables] 
    return tokens

### Load datasets

In [4]:
# Load classified dataset
data = pd.read_excel("../classified_titles_4.xlsx") 
data['text'] = data["Title"].fillna('') + '. ' + data["Lead_posts"].fillna('')
data = data.dropna(subset=['text']).reset_index().drop(columns='index')
data.head()

Unnamed: 0,Title,Lead_posts,Link,Source,triggers,Date,Death,police,Suisse,All,text
0,Une intervention de police tourne au drame à L...,Un Sri Lankais abattu. Légitime défense invoquée.,Une intervention de police tourne au drame à L...,Letemps,"{drame: mort, intervention police: police, lég...",14 avril 2004,1,1,1,1,Une intervention de police tourne au drame à L...
1,La mort du requérant gambien en prison met en ...,"A la suite d’une erreur sur la personne, un re...",La mort du requérant gambien en prison met en ...,Letemps,"{mort: mort, cellule: police, prison: police}",2017-10-31 00:00:00,1,1,1,1,La mort du requérant gambien en prison met en ...
2,Un policier genevois tire sur des voleurs en f...,"Au centre-ville, un agent de police fait feu s...",https://www.letemps.ch/suisse/un-policier-gene...,Letemps,"{faire feu: mort, un mort: mort, agent de poli...",2000-01-12 00:00:00,1,1,1,1,Un policier genevois tire sur des voleurs en f...
3,La mort de Mike Ben Peter revient hanter la ju...,"Acquittés en première instance, les six polici...",https://www.letemps.ch/suisse/vaud/la-mort-de-...,Letemps,"{mort: mort, policiers: police, homicide: mort}",2024-07-01 00:00:00,1,1,1,1,La mort de Mike Ben Peter revient hanter la ju...
4,Une jeune femme meurt au poste de police,La personne née en 2003 a été retrouvée inanim...,Une jeune femme meurt dans un poste de police ...,TdG,"{meurt: mort, poste de police: police, inanimé...",22.02.2024,1,1,1,1,Une jeune femme meurt au poste de police. La p...


In [5]:
data = data[['text', 'Death', 'police', 'All']]

X = data['text']
y = data[[ 'Death', 'police', 'All']].values

### Split the train-test dataset

In [6]:
msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

In [7]:
for train_index, test_index in msss.split(X, y):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

In [8]:
train_df = data.iloc[train_index]
test_df = data.iloc[test_index]

train_df = train_df.drop(columns=["__index_level_0__"], errors="ignore")
test_df = test_df.drop(columns=["__index_level_0__"], errors="ignore")

train_dataset = Dataset.from_pandas(train_df.reset_index(drop=True))
test_dataset = Dataset.from_pandas(test_df.reset_index(drop=True))

### Attribute different weights

In [9]:
class CamemBERT_adjustedWeight(CamembertForSequenceClassification):
    def __init__(self, config, label_weights):
        super().__init__(config)
        self.label_weights = label_weights

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        outputs = self.roberta(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        sequence_output = outputs[0]
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss(pos_weight=self.label_weights.to(logits.device))
            loss = loss_fct(logits, labels.float())

        return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}


In [10]:
# Weight the different labels
weight_police_death = len(data[data['All'] == 1])/len(data)
weight_death = len(data[data['Death'] == 1])/len(data) / weight_police_death
weight_police = len(data[data['police'] == 1])/len(data) / weight_police_death
weight_police_death = 1

lables = [ 'Death', 'police', 'All']
label_weights = torch.tensor([weight_death, weight_police, 3*weight_police_death])

### Initialize model

In [11]:
# Initialize model
tokenizer = CamembertTokenizer.from_pretrained("camembert-base")

config = CamembertConfig.from_pretrained("camembert-base", num_labels=3, problem_type="multi_label_classification")
model = CamemBERT_adjustedWeight.from_pretrained("camembert-base", config=config, label_weights=label_weights)


Some weights of CamemBERT_adjustedWeight were not initialized from the model checkpoint at camembert-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
# Training arguments
training_args = TrainingArguments(
    output_dir="shelfens/results_frenchtitles",
    evaluation_strategy="epoch",
    num_train_epochs=10,
    save_strategy="epoch", 
    load_best_model_at_end=True, 
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    logging_dir="shelfens/logs",
    logging_steps=10,
)



In [13]:
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=2)

In [14]:
tokenized_train = train_dataset.map(preprocess, remove_columns=["text"])
tokenized_test = test_dataset.map(preprocess, remove_columns=["text"])

Map: 100%|██████████| 897/897 [00:01<00:00, 512.18 examples/s]
Map: 100%|██████████| 225/225 [00:00<00:00, 714.95 examples/s]


In [None]:
# Trainer and training 
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset= tokenized_test,
    callbacks=[early_stopping_callback],
)

: 

In [None]:
trainer.train()

In [None]:
trainer.save_model("camembert_multiple_classifier")
tokenizer.save_pretrained("camembert_multiple_classifier")

('shelfens/camembert_multiple_classifier\\tokenizer_config.json',
 'shelfens/camembert_multiple_classifier\\special_tokens_map.json',
 'shelfens/camembert_multiple_classifier\\sentencepiece.bpe.model',
 'shelfens/camembert_multiple_classifier\\added_tokens.json')

### Evaluation of model

In [None]:
# Get predictions from trainer
predictions = trainer.predict(tokenized_test)

y_true = predictions.label_ids
y_logits = predictions.predictions

y_probs = torch.sigmoid(torch.tensor(y_logits)).numpy()
y_pred = (y_probs >= 0.5).astype(int)

y_true = np.array(y_true)
y_pred = np.array(y_pred)

In [None]:
labels = ["Death", "Police", "All"]

for i, label in enumerate(labels):
    cm = confusion_matrix(y_true[:, i], y_pred[:, i])
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(cmap="Blues")
    plt.title(f"Confusion Matrix for {label}")
    plt.savefig(f"shelfens/conf_matrix_{label}.png", dpi=300)
    plt.close()

In [None]:
# Get precision, recall, F1 for each class
prec, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None)

metrics_df = pd.DataFrame({
    "Precision": prec,
    "Recall": recall,
    "F1-score": f1
}, index=["Death", "Police", "All"])

# Plot as heatmap
plt.figure(figsize=(8, 4))
sns.heatmap(metrics_df.T, annot=True, fmt=".2f", cmap="YlGnBu")
plt.title("Precision, Recall, F1 by Class")
plt.xlabel("Class")
plt.ylabel("Metric")
plt.tight_layout()
plt.savefig("shelfens/metrics_heatmap.png", dpi=300)
plt.show()