### Nous allons ici faire quelques applications d'augmentation des données textuelles en suivant un tutoriel sur medium 

In [None]:
!pip install nlpaug
!pip install sacremoses

In [1]:
import nlpaug.augmenter.word as naw 


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## Un exemple de texte qu'on veut reformuler 
text = "The quick brown fox jumps over a lazy dog"

In [4]:
import os, nltk

# 1) Dossier local et inscriptible pour les données NLTK
NLTK_DIR = os.path.expanduser("~/nltk_data")
os.makedirs(NLTK_DIR, exist_ok=True)
if NLTK_DIR not in nltk.data.path:
    nltk.data.path.append(NLTK_DIR)

# 2) Téléchargements nécessaires
# - wordnet + omw-1.4 : pour les synonymes
# - averaged_perceptron_tagger_eng : POS tagger (NLTK 3.8+)
# - averaged_perceptron_tagger : par compatibilité descendante (certaines libs le demandent encore)
# - punkt : tokenisation de base (utile selon les tokenizers)
for pkg in ["wordnet", "omw-1.4", "averaged_perceptron_tagger_eng",
            "averaged_perceptron_tagger", "punkt"]:
    try:
        nltk.download(pkg, download_dir=NLTK_DIR, quiet=True)
    except Exception as e:
        print(f"NLTK download failed for {pkg}: {e}")


#### Synonym Replacement 

In [5]:
syn_aug = naw.synonym.SynonymAug(aug_src="wordnet")
synonym_text = syn_aug.augment(text)
print("Synonym Text: ", synonym_text)

Synonym Text:  ['The quick robert brown fox bound over a lazy wienerwurst']


#### Random Substitution

In [6]:
sub_aug = naw.random.RandomWordAug(action='substitute')
substituted_text = sub_aug.augment(text)
print("Substituted Text: ", substituted_text)

Substituted Text:  ['_ quick brown _ jumps over a lazy _']


### Random Deletion

In [7]:
del_aug = naw.random.RandomWordAug(action='delete')
deletion_text = del_aug.augment(text)
print("Deletion Text: ", deletion_text)

Deletion Text:  ['The jumps over a lazy dog']


### Random Swap

In [8]:
swap_aug = naw.random.RandomWordAug(action='swap')
swap_text = swap_aug.augment(text)
print("Swap Text: ", swap_text)

Swap Text:  ['Quick the brown jumps fox over lazy a dog']


### Back Translation

Translate original text to other language (like french) and convert back to english language

In [9]:
back_trans_aug = naw.back_translation.BackTranslationAug()
back_trans_text = back_trans_aug.augment(text)
print("Back Translated Text: ", back_trans_text)

Back Translated Text:  ['The speedy brown fox jumps over a lazy dog']


### Nous allons appliquer la Rétrotraduction pour former notre premier jeu de données augmenté 
Nous allons appliquer cela sur les données de texte brute ensuite on fera encore le nettoyage, nous allons appliquer l'augmentation uniquement pour les articles de type VS qui est sous représenté

In [23]:
import os, nltk

# Dossier local pour les données NLTK (avec droits d’écriture)
NLTK_DIR = os.path.expanduser("~/nltk_data")
os.makedirs(NLTK_DIR, exist_ok=True)
if NLTK_DIR not in nltk.data.path:
    nltk.data.path.append(NLTK_DIR)

# Paquets requis (NLTK 3.8+)
for pkg in [
    "punkt",                         # tokeniseur historique
    "punkt_tab",                     # depuis NLTK 3.8
    "stopwords",
    "wordnet", "omw-1.4",
    "averaged_perceptron_tagger_eng",
    "averaged_perceptron_tagger",    # compat descente
]:
    try:
        nltk.download(pkg, download_dir=NLTK_DIR, quiet=True)
    except Exception as e:
        print(f"NLTK download failed for {pkg}: {e}")


In [None]:
import pandas as pd
import nlpaug.augmenter.word as naw
from typing import Optional, List
import torch, re
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from string import punctuation
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"


# ==== Init NLTK stuff ====
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
custom_stopwords = {'et', 'al'}

# ==== Utils ====
def coerce_to_str(x) -> str:
    if isinstance(x, list):
        x = " ".join(map(str, x))
    elif pd.isna(x):
        x = ""
    return str(x).strip()

def safe_augment(aug, text: str, n: int = 1) -> List[str]:
    """Retourne une liste de n paraphrases (peut être < n si le modèle échoue)."""
    if not text:
        return []
    try:
        out = aug.augment(text, n=n)  # peut renvoyer str ou list[str]
        if isinstance(out, str):
            out = [out]
        return [t.strip() for t in out if isinstance(t, str) and t.strip()]
    except Exception:
        return []

def nettoyer_texte_tokens(texte: str) -> List[str]:
    tokens = word_tokenize(texte)
    tokens_nettoyes = []
    for token in tokens:
        token = token.lower()
        token = re.sub(r'\s+', '', token)
        token = re.sub(r'[^a-zàâçéèêëîïôûùüÿñæœ]', '', token)
        if token and token not in stop_words and token not in punctuation and token not in custom_stopwords:
            token = lemmatizer.lemmatize(token)
            tokens_nettoyes.append(token)
    return tokens_nettoyes

def tokens_to_text(tokens: List[str]) -> str:
    return " ".join(tokens)

# ==== Device & models ====
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device choisi :", device)

from_model, to_model = (
    ("facebook/wmt19-en-de", "facebook/wmt19-de-en") if device == "cuda"
    else ("Helsinki-NLP/opus-mt-en-de", "Helsinki-NLP/opus-mt-de-en")
)

back_trans_aug = naw.BackTranslationAug(
    from_model_name=from_model,
    to_model_name=to_model,
    device=device,
    batch_size=32 if device == "cuda" else 8,
    max_length=256
)

# ==== Load data ====
df = pd.read_csv("./data/data_final_phase2_private.csv")
df["text"] = df["text"].apply(coerce_to_str)

# (Re)crée text_clean / token_clean si manquants
if "token_clean" not in df.columns:
    df["token_clean"] = df["text"].apply(nettoyer_texte_tokens)
else:
    mask = df["token_clean"].isna()
    df.loc[mask, "token_clean"] = df.loc[mask, "text"].apply(nettoyer_texte_tokens)

if "text_clean" not in df.columns:
    df["text_clean"] = df["token_clean"].apply(tokens_to_text)
else:
    mask = df["text_clean"].isna()
    df.loc[mask, "text_clean"] = df.loc[mask, "token_clean"].apply(tokens_to_text)

# ==== Filtrer uniquement la classe VS ====
# (Robuste si tes labels sont 'VS'/'NVS' en str; adapte si 1/0)
mask_vs = df["type_article"].astype(str).str.upper().eq("VS")
df_vs = df[mask_vs].copy()

# === Paramètre: nombre d'augmentations par article VS ===
n_aug_per_sample = 1   # mets 2, 3… pour plus de paraphrases par texte VS

aug_rows = []
for _, row in df_vs.iterrows():
    raw = row["text"]
    aug_texts = safe_augment(back_trans_aug, raw[:2000], n=n_aug_per_sample)
    for aug_text in aug_texts:
        tokens_bt = nettoyer_texte_tokens(aug_text)
        aug_rows.append({
            "text_src": raw,
            "text_final": aug_text,
            "source": "bt",
            "token_clean": tokens_bt,
            "text_clean": tokens_to_text(tokens_bt),
            "type_article": row["type_article"],
            "thematique": row.get("thematique", "")
        })

aug_df = pd.DataFrame(aug_rows)

# ==== Bloc original (toutes les classes) ====
train_original = df.copy()
train_original["text_src"] = train_original["text"]
train_original["text_final"] = train_original["text"]
train_original["source"] = "orig"

# ==== Harmonisation & concat ====
cols = ["text_final", "text_clean", "token_clean", "source", "type_article", "thematique", "text_src"]

def ensure_list(x):
    if isinstance(x, list):
        return x
    if isinstance(x, str):
        return x.split()
    return []

train_original = train_original.reindex(columns=cols, fill_value="")
train_original["token_clean"] = train_original["token_clean"].apply(ensure_list)
train_original["text_clean"] = train_original["text_clean"].astype(str)

train_aug = aug_df.reindex(columns=cols, fill_value="")

train_data = pd.concat([train_original, train_aug], ignore_index=True)

print("Nb lignes original :", len(df))
print("Nb VS augmentées   :", len(aug_df))
print("Taille finale       :", train_data.shape)


#### Essayons ici de rattraper le nombre d'article de NVS 

In [None]:
import pandas as pd
import nlpaug.augmenter.word as naw
from typing import Optional, List
import torch, re
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from string import punctuation
import os
import math

os.environ["TOKENIZERS_PARALLELISM"] = "false"


# ==== Init NLTK stuff ====
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
custom_stopwords = {'et', 'al'}

# ==== Utils ====
def coerce_to_str(x) -> str:
    if isinstance(x, list):
        x = " ".join(map(str, x))
    elif pd.isna(x):
        x = ""
    return str(x).strip()

def safe_augment(aug, text: str, n: int = 1) -> List[str]:
    """Retourne une liste de n paraphrases (peut être < n si le modèle échoue)."""
    if not text:
        return []
    try:
        out = aug.augment(text, n=n)  # peut renvoyer str ou list[str]
        if isinstance(out, str):
            out = [out]
        return [t.strip() for t in out if isinstance(t, str) and t.strip()]
    except Exception:
        return []

def nettoyer_texte_tokens(texte: str) -> List[str]:
    tokens = word_tokenize(texte)
    tokens_nettoyes = []
    for token in tokens:
        token = token.lower()
        token = re.sub(r'\s+', '', token)
        token = re.sub(r'[^a-zàâçéèêëîïôûùüÿñæœ]', '', token)
        if token and token not in stop_words and token not in punctuation and token not in custom_stopwords:
            token = lemmatizer.lemmatize(token)
            tokens_nettoyes.append(token)
    return tokens_nettoyes

def tokens_to_text(tokens: List[str]) -> str:
    return " ".join(tokens)

# ==== Device & models ====
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device choisi :", device)

from_model, to_model = (
    ("facebook/wmt19-en-de", "facebook/wmt19-de-en") if device == "cuda"
    else ("Helsinki-NLP/opus-mt-en-de", "Helsinki-NLP/opus-mt-de-en")
)

back_trans_aug = naw.BackTranslationAug(
    from_model_name=from_model,
    to_model_name=to_model,
    device=device,
    batch_size=32 if device == "cuda" else 8,
    max_length=256
)

# ==== Load data ====
df = pd.read_csv("./data/data_final_phase2_private.csv")
df["text"] = df["text"].apply(coerce_to_str)

# (Re)crée text_clean / token_clean si manquants
if "token_clean" not in df.columns:
    df["token_clean"] = df["text"].apply(nettoyer_texte_tokens)
else:
    mask = df["token_clean"].isna()
    df.loc[mask, "token_clean"] = df.loc[mask, "text"].apply(nettoyer_texte_tokens)

if "text_clean" not in df.columns:
    df["text_clean"] = df["token_clean"].apply(tokens_to_text)
else:
    mask = df["text_clean"].isna()
    df.loc[mask, "text_clean"] = df.loc[mask, "token_clean"].apply(tokens_to_text)

# ==== Filtrer uniquement la classe VS ====

mask_vs = df["type_article"].astype(str).str.upper().eq("VS")
df_vs = df[mask_vs].copy()



# --- Comptage ---
count_vs  = (df["type_article"].astype(str).str.upper() == "VS").sum()
count_nvs = (df["type_article"].astype(str).str.upper() == "NVS").sum()
print("NVS:", count_nvs, " VS:", count_vs)

target = count_nvs
needed = max(0, target - count_vs)
if needed == 0:
    print("Pas besoin d'augmentation: VS est déjà ≥ NVS.")

# --- Génération ---
aug_rows = []
if needed > 0:
    factor = math.ceil(target / count_vs)  # nb total de versions par article (original compris)
    print(f"Chaque article VS doit produire à peu près {factor} versions (dont l'original).")

    for _, row in df_vs.iterrows():
        raw = row["text"]
        aug_texts = safe_augment(back_trans_aug, raw[:2000], n=max(1, factor-1))
        # construire les lignes
        for aug_text in aug_texts:
            tokens_bt = nettoyer_texte_tokens(aug_text)
            aug_rows.append({
                "text_src": raw,
                "text_final": aug_text,
                "source": "bt",
                "token_clean": tokens_bt,
                "text_clean": tokens_to_text(tokens_bt),
                "type_article": row["type_article"],
                "thematique": row.get("thematique", "")
            })

# DataFrame des augmentées
aug_df = pd.DataFrame(aug_rows)

# --- Déduplication robuste ---
# 1) enlever les lignes vides/NaN
aug_df = aug_df[aug_df["text_final"].astype(str).str.strip().ne("")].dropna(subset=["text_final"])

# 2) dédupliquer par paraphrase (et étiquette) pour éviter redites exactes
aug_df = aug_df.drop_duplicates(subset=["text_final", "type_article"], keep="first")

# 3) si trop de lignes, échantillonner pour viser exactement "needed"
if len(aug_df) > needed:
    # Utilisons random_sate pour controler l'aléa 
    aug_df = aug_df.sample(n=needed, random_state=42).reset_index(drop=True)
elif len(aug_df) < needed:
    print(f"Seulement {len(aug_df)} paraphrases uniques générées, < needed={needed}.")
  

print("Articles VS générés (uniques) :", len(aug_df))
print("Total VS (original + aug) :", count_vs + len(aug_df))
print("Total NVS :", count_nvs)

# --- Bloc original & concat finale ---
train_original = df.copy()
train_original["text_src"]   = train_original["text"]
train_original["text_final"] = train_original["text"]
train_original["source"]     = "orig"

cols = ["text_final", "text_clean", "token_clean", "source", "type_article", "thematique", "text_src"]

def ensure_list(x):
    if isinstance(x, list): return x
    if isinstance(x, str):  return x.split()
    return []

train_original = train_original.reindex(columns=cols, fill_value="")
train_original["token_clean"] = train_original["token_clean"].apply(ensure_list)
train_original["text_clean"]  = train_original["text_clean"].astype(str)

train_aug = aug_df.reindex(columns=cols, fill_value="")

train_data = pd.concat([train_original, train_aug], ignore_index=True)
print("Nb lignes original :", len(df))
print("Nb VS augmentées   :", len(aug_df))
print("Taille finale       :", train_data.shape)


Device choisi : cuda
NVS: 2241  VS: 249
Chaque article VS doit produire à peu près 9 versions (dont l'original).


In [35]:
train_data

Unnamed: 0,text_final,text_clean,token_clean,source,type_article,thematique,text_src
0,Microbial Community Composition Associated wit...,microbial community composition associated pot...,"[microbial, community, composition, associated...",orig,VS,SV,Microbial Community Composition Associated wit...
1,Plant Pathogenic and Endophytic Colletotrichum...,plant pathogenic endophytic colletotrichum fru...,"[plant, pathogenic, endophytic, colletotrichum...",orig,VS,SV,Plant Pathogenic and Endophytic Colletotrichum...
2,Lethal Bronzing: What you should know about th...,lethal bronzing know disease turn palm tree br...,"[lethal, bronzing, know, disease, turn, palm, ...",orig,VS,SV,Lethal Bronzing: What you should know about th...
3,Leaffooted Bug Damage in Almond Orchards Leaff...,leaffooted bug damage almond orchard leaffoote...,"[leaffooted, bug, damage, almond, orchard, lea...",orig,VS,SV,Leaffooted Bug Damage in Almond Orchards Leaff...
4,Kebbi govt battles mysterious disease affectin...,kebbi govt battle mysterious disease affecting...,"[kebbi, govt, battle, mysterious, disease, aff...",orig,VS,SV,Kebbi govt battles mysterious disease affectin...
...,...,...,...,...,...,...,...
2734,Mystery Seed Packages Appearing Once Again in ...,mystery seed package appearing alabama mystery...,"[mystery, seed, package, appearing, alabama, m...",bt,VS,SV,Mystery Seed Packages Appearing Once Again in ...
2735,ACES: Mystery seed packages appeared again in ...,ace mystery seed package appeared alabama ace ...,"[ace, mystery, seed, package, appeared, alabam...",bt,VS,SV,ACES: Mystery seed packages appearing once aga...
2736,"Farmers blame the plague: 150,000 per bag Farm...",farmer blame plague per bag farmer blame plagu...,"[farmer, blame, plague, per, bag, farmer, blam...",bt,VS,SV,Farmers Blame Unknown Pest As Pepper Hits ₦150...
2737,Sharp drop in yield due to mysterious fungal i...,sharp drop yield due mysterious fungal infecti...,"[sharp, drop, yield, due, mysterious, fungal, ...",bt,VS,SV,Sharp decline in yield as mysterious fungal in...


In [36]:
train_data["text_final"].iloc[2735] == train_data["text_src"].iloc[2735]

False

In [37]:
train_data["type_article"].value_counts()

type_article
NVS    2241
VS      498
Name: count, dtype: int64

In [38]:
df["type_article"].value_counts()

type_article
NVS    2241
VS      249
Name: count, dtype: int64

#### Les textes augmentées par la rétro traduction ont une source bt 

In [39]:
## Sauvegardons les données augmentées dans un fichier csv 

train_data.to_csv("./data/data_augmented_back_traduction.csv", index= False)

###  Nous allons reprendre la classification avec le fine-tuning de SBERT que nous avons fait à la phase 3 

In [40]:
data = pd.read_csv("./data/data_augmented_back_traduction.csv")

In [42]:
# =========================
# 0) Imports
# =========================
import os, json, joblib
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay,
    precision_recall_curve, average_precision_score,
    f1_score, precision_score, recall_score, accuracy_score
)

from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    Trainer, TrainingArguments, TrainerCallback
)

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

# GPU / précision (Ada -> bf16 recommandé)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
use_fp16 = torch.cuda.is_available() and not use_bf16

try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass


# =========================
# 1) Données + indices sauvegardés
# =========================
idx_train = idx_train_all
idx_val   = idx_val_all
idx_test  = idx_test_all

X_all = data["text_final"].astype(str).to_numpy() # Nous lançons l'entrainnement sur les données brutes 
y_all = data["type_article"].to_numpy()

# Encodage labels
label_encoder = LabelEncoder()
y_enc = label_encoder.fit_transform(y_all)
num_classes = len(label_encoder.classes_)
print("Classes:", list(label_encoder.classes_))

# Splits
X_train, X_val, X_test = X_all[idx_train], X_all[idx_val], X_all[idx_test]
y_train, y_val, y_test = y_enc[idx_train], y_enc[idx_val], y_enc[idx_test]

# =========================
# 2) Tokenizer & modèle
# =========================
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_classes,
    torch_dtype=(torch.bfloat16 if use_bf16 else None),  # bf16
).to(device)


# =========================
# 3) Dataset PyTorch (robuste)
# =========================
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=256):
        if isinstance(texts, np.ndarray):
            texts = texts.tolist()
        self.texts = [
            "" if (t is None or (isinstance(t, float) and np.isnan(t))) else str(t)
            for t in texts
        ]
        self.encodings = tokenizer(self.texts, truncation=True, padding=True, max_length=max_length)
        self.labels = labels.tolist() if isinstance(labels, np.ndarray) else list(labels)
        assert len(self.texts) == len(self.labels)

    def __len__(self): 
        return len(self.labels)

    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

train_ds = TextDataset(X_train, y_train, tokenizer)
val_ds   = TextDataset(X_val,   y_val,   tokenizer)
test_ds  = TextDataset(X_test,  y_test,  tokenizer)

# =========================
# 4) Metrics + Callback TrainEval
# =========================
def compute_metrics(eval_pred):
    # Compat: EvalPrediction (HF récent) ou tuple (anciens)
    if hasattr(eval_pred, "predictions"):
        logits = eval_pred.predictions
        labels = eval_pred.label_ids
    else:
        logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1_macro": f1_score(labels, preds, average="macro"),
        "precision_macro": precision_score(labels, preds, average="macro", zero_division=0),
        "recall_macro": recall_score(labels, preds, average="macro", zero_division=0),
    }

from transformers import TrainerCallback
import math, os

class ValEvalAndEarlyStopCallback(TrainerCallback):
    """
    - Évalue sur le set de validation à la fin de chaque epoch
    - Loggue les métriques 'eval_*' dans log_history
    - Early stopping custom basé sur training_args.metric_for_best_model et greater_is_better
    - Sauvegarde le meilleur modèle dans output_dir/best_checkpoint
    """
    def __init__(self, trainer, patience=3):
        self.trainer = trainer
        self.patience = patience
        self.best_metric = None
        self.best_epoch = None
        self.bad_epochs = 0
        self.best_dir = os.path.join(str(trainer.args.output_dir), "best_checkpoint")
        os.makedirs(self.best_dir, exist_ok=True)

        # lit la config d'arrêt
        self.metric_name = trainer.args.metric_for_best_model or "eval_loss"
        self.greater_is_better = bool(trainer.args.greater_is_better)

    def _is_better(self, current, best):
        if best is None:
            return True
        return (current > best) if self.greater_is_better else (current < best)

    def on_epoch_end(self, args, state, control, **kwargs):
        # 1) Évaluer VAL
        try:
            metrics = self.trainer.evaluate(
                eval_dataset=self.trainer.eval_dataset,
                metric_key_prefix="eval"
            )
        except TypeError:
            metrics = self.trainer.evaluate(eval_dataset=self.trainer.eval_dataset)

        # 2) Ajouter l'epoch et logger explicitement
        if state.epoch is not None:
            metrics["epoch"] = float(state.epoch)
        self.trainer.log({k: float(v) for k, v in metrics.items()
                          if isinstance(v, (int, float, np.floating))})

        # 3) Early stopping custom
        current = metrics.get(self.metric_name, None)
        if current is None or (isinstance(current, float) and math.isnan(current)):
            # rien à faire si la métrique n'est pas là
            return control

        if self._is_better(current, self.best_metric):
            # Amélioration -> reset patience + save best
            self.best_metric = current
            self.best_epoch = int(round(state.epoch)) if state.epoch is not None else None
            self.bad_epochs = 0
            # Sauvegarder le meilleur modèle
            self.trainer.save_model(self.best_dir)
            if hasattr(self.trainer, "tokenizer") and self.trainer.tokenizer is not None:
                self.trainer.tokenizer.save_pretrained(self.best_dir)
            control.should_save = True
        else:
            # Pas d'amélioration
            self.bad_epochs += 1
            if self.bad_epochs >= self.patience:
                # Stopper l'entraînement
                control.should_training_stop = True

        return control
    

# Pour avoir aussi les retours sur le jeux d'entrainement
class TrainEvalCallback(TrainerCallback):
    def __init__(self, trainer, train_dataset):
        self.trainer = trainer
        self.train_dataset = train_dataset

    def on_epoch_end(self, args, state, control, **kwargs):
        # Évalue sur TRAIN et logue train_*
        try:
            metrics = self.trainer.evaluate(
                eval_dataset=self.train_dataset,
                metric_key_prefix="train"  # => train_loss, train_accuracy, ...
            )
        except TypeError:
            metrics = self.trainer.evaluate(eval_dataset=self.train_dataset)
            # Rétro-compat: préfixer manuellement si besoin
            metrics = {f"train_{k}": v for k, v in metrics.items()}

        if state.epoch is not None:
            metrics["epoch"] = float(state.epoch)

        # pousse dans log_history (même avec report_to='none')
        self.trainer.log({k: float(v) for k, v in metrics.items() if isinstance(v, (int, float, np.floating))})
        return control


# =========================
# 5) TrainingArguments + Trainer + EarlyStopping
# =========================
training_args = TrainingArguments(
    output_dir="./results_SBERT",
    num_train_epochs=20,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,

    # on laisse les callbacks gérer eval/save/log par epoch
    logging_dir="./logs_SBERT",
    report_to="none",
    save_total_limit=3,
    seed=SEED,

    # GPU-friendly
    dataloader_pin_memory=True,
    dataloader_num_workers=4,       # qu'on peut ajuster
    gradient_checkpointing=True,    # réduit la VRAM
    bf16=use_bf16,                  # Ada -> True
    fp16=use_fp16,                  # fallback si pas de bf16
    gradient_accumulation_steps=1,  

    # IMPORTANT: on ne dépend pas du best interne du Trainer
    load_best_model_at_end=False,

    logging_steps=100,
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer, 
    callbacks=[]  # <= pas de EarlyStoppingCallback natif
)

# Callback VAL + early stop custom
val_es_cb = ValEvalAndEarlyStopCallback(trainer, patience=3)
trainer.add_callback(val_es_cb)

# Callback TRAIN pour avoir train_* dans les courbes
train_cb = TrainEvalCallback(trainer, train_ds)
trainer.add_callback(train_cb)



# =========================
# 6) Entraînement
# =========================
print(f"Device: {device} | CUDA={torch.cuda.is_available()} | bf16={use_bf16} | fp16={use_fp16}")
trainer.train()

# =========================
# 7) Historique pour tracés (robuste aux clés manquantes)
# =========================
log_hist = trainer.state.log_history
hist_df = pd.DataFrame(log_hist)

# Crée toutes les colonnes attendues si absentes pour éviter KeyError
needed_cols = ["epoch","eval_loss","eval_accuracy","eval_f1_macro",
          "train_loss","train_accuracy","train_f1_macro",
          "train_precision_macro","train_recall_macro"]
for c in needed_cols:
    if c not in hist_df.columns:
        hist_df[c] = np.nan

# Ne garder que les lignes avec epoch défini et forcer epoch en int
hist_df = hist_df[pd.to_numeric(hist_df["epoch"], errors="coerce").notna()].copy()
hist_df["epoch"] = hist_df["epoch"].astype(float).round().astype(int)

# Convertir toutes les métriques numériques en float (coerce -> NaN si non num)
metric_cols = [c for c in needed_cols if c != "epoch"]
for c in metric_cols:
    hist_df[c] = pd.to_numeric(hist_df[c], errors="coerce")

# Fonction utilitaire: dernière valeur non-NaN dans un groupe
def last_not_nan(s: pd.Series):
    s = s.dropna()
    return s.iloc[-1] if len(s) else np.nan

# Agrégation par epoch: on prend la DERNIÈRE valeur non-NaN de chaque métrique
curves = (
    hist_df
    .sort_values(["epoch"])  # important: on veut la "dernière" dans l'ordre
    .groupby("epoch", as_index=False)
    .agg({col: last_not_nan for col in metric_cols})
    .sort_values("epoch")
    .reset_index(drop=True)
)

print("\nAperçu des courbes (par epoch):\n", curves.head(10))
# ============================================================================

# Exemple de tracés "safe"
def _safe_plot(x, y, label):
    s = pd.Series(y)
    if s.notna().any():
        plt.plot(x, y, marker="o", label=label)

plt.figure(figsize=(7,5))
_safe_plot(curves["epoch"], curves["train_loss"], "Train Loss")
_safe_plot(curves["epoch"], curves["eval_loss"],  "Val Loss")
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Courbe d'apprentissage (Loss)")
plt.legend(); plt.grid(True); plt.tight_layout(); plt.show()

plt.figure(figsize=(7,5))
_safe_plot(curves["epoch"], curves["train_accuracy"], "Train Accuracy")
_safe_plot(curves["epoch"], curves["eval_accuracy"],  "Val Accuracy")
plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.title("Accuracy — Train vs Val")
plt.legend(); plt.grid(True); plt.tight_layout(); plt.show()

plt.figure(figsize=(7,5))
_safe_plot(curves["epoch"], curves["train_f1_macro"], "Train F1 (macro)")
_safe_plot(curves["epoch"], curves["eval_f1_macro"],  "Val F1 (macro)")
plt.xlabel("Epoch"); plt.ylabel("F1 (macro)"); plt.title("F1 macro — Train vs Val")
plt.legend(); plt.grid(True); plt.tight_layout(); plt.show()

# =========================
# 9) Évaluation finale sur TEST
# =========================
# Recharger le meilleur modèle pour la phase TEST
best_dir = os.path.join(training_args.output_dir, "best_checkpoint")
if os.path.isdir(best_dir):
    model = AutoModelForSequenceClassification.from_pretrained(best_dir).to(model.device)
    trainer.model = model

preds_test = trainer.predict(test_ds)
y_pred = np.argmax(preds_test.predictions, axis=1)
y_scores = torch.softmax(torch.tensor(preds_test.predictions), dim=1).numpy()

print("\n=== Rapport de classification (TEST) ===")
print(classification_report(y_test, y_pred, target_names=label_encoder.classes_))

cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_encoder.classes_)
disp.plot(cmap="Blues")
plt.xticks(rotation=45)
plt.title("Matrice de confusion — TEST")
plt.tight_layout()
plt.show()

# Courbe PR pour "VS" si présent
classes = list(label_encoder.classes_)
if "VS" in classes:
    vs_idx = classes.index("VS")
    y_true_bin = (y_test == vs_idx).astype(int)
    y_prob_vs = y_scores[:, vs_idx]
    precision, recall, _ = precision_recall_curve(y_true_bin, y_prob_vs)
    auc_pr = average_precision_score(y_true_bin, y_prob_vs)
    plt.figure(figsize=(7,5))
    plt.plot(recall, precision, label=f"AUC-PR = {auc_pr:.3f}")
    plt.xlabel("Recall"); plt.ylabel("Precision")
    plt.title("Precision–Recall (classe VS) — TEST")
    plt.legend(); plt.grid(True); plt.tight_layout(); plt.show()
else:
    print("La classe 'VS' n'est pas présente dans les labels. PR-curve sautée.")

# =========================
# 10) Sauvegardes
# =========================
save_dir = "results_SBERT"
os.makedirs(save_dir, exist_ok=True)

trainer.model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
joblib.dump(label_encoder, os.path.join(save_dir, "label_encoder.joblib"))

curves.to_csv(os.path.join(save_dir, "learning_curves.csv"), index=False)
with open(os.path.join(save_dir, "log_history.json"), "w", encoding="utf-8") as f:
    json.dump(log_hist, f, ensure_ascii=False, indent=2)

print(f"\nBest model saved to: {best_dir} (metric {training_args.metric_for_best_model} "
      f"with greater_is_better={training_args.greater_is_better})")


Classes: ['NVS', 'VS']


IndexError: index 3715 is out of bounds for axis 0 with size 2739