In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
import xgboost as xgb
import networkx as nx
import pandas as pd
import os
from tqdm import tqdm
from multiprocessing import Pool
import pickle

pd.set_option('display.max_columns', None)

# 1. Chargement et préparation

df = pd.read_csv("data/test2.csv", encoding="utf-8")
df[df["j1"] == df["j2"]].index
df.drop(index=df[df["j1"] == df["j2"]].index,inplace=True)
def fonction_un_nom(df):
    # Paramètres Elo
    starting_elo = 1500
    K = 32  # Facteur de sensibilité

    # Nettoyer et normaliser la colonne 'surface'
    df['surface'] = df['surface'].str.strip().str.upper()

    # Mapping pour les surfaces
    surface_mapping = {
        "DUR": 1,
        "TERRE BATTUE": 2,
        "DUR (INDOOR)": 3,
        "GAZON": 4
    }
    df['surface_encoded'] = df['surface'].map(surface_mapping)

    # Dictionnaire global d'Elo
    elo_ratings = {}
    # Dictionnaires pour l'Elo par surface : 
    # Pour chaque surface, on initialise un dictionnaire pour stocker les Elo des joueurs.
    elo_surface_ratings = {surf: {} for surf in surface_mapping.keys()}

    # Listes pour stocker les valeurs d'Elo AVANT mise à jour pour chaque match
    elo_j1 = []
    elo_j2 = []
    elo_j1_surface = []
    elo_j2_surface = []
    gain_j1 = []
    gain_j2 = []

    # Fonctions d'accès aux ratings (en cas d'absence, retourne starting_elo)
    def get_elo(player):
        return elo_ratings.get(player, starting_elo)
    
    def get_elo_surface(player, surf):
        return elo_surface_ratings[surf].get(player, starting_elo)

    # Parcours du DataFrame match par match
    for idx, row in df.iterrows():
        # Récupération des identifiants et de la surface
        player1 = row["j1"]
        player2 = row["j2"]
        winner = row["winner"]
        surf = row["surface"]  # par exemple "DUR", "TERRE BATTUE", etc.

        # ---------------------------
        # Mise à jour de l'Elo global
        # ---------------------------
        current_R1 = get_elo(player1)
        current_R2 = get_elo(player2)
        # Stocker les ratings avant mise à jour pour ce match
        elo_j1.append(current_R1)
        elo_j2.append(current_R2)
        # Calcul des scores attendus
        E1 = 1 / (1 + 10 ** ((current_R2 - current_R1) / 400))
        E2 = 1 / (1 + 10 ** ((current_R1 - current_R2) / 400))
        # Scores réels : 1 pour la victoire, 0 pour la défaite
        S1 = 1 if winner == player1 else 0
        S2 = 1 if winner == player2 else 0
        gain1 = K * (S1 - E1)
        gain2 = K * (S2 - E2)
        gain_j1.append(gain1)
        gain_j2.append(gain2)

        # Mise à jour globale
        new_R1 = current_R1 + K * (S1 - E1)
        new_R2 = current_R2 + K * (S2 - E2)
        # Actualiser le dictionnaire
        elo_ratings[player1] = new_R1
        elo_ratings[player2] = new_R2

        # ---------------------------
        # Mise à jour de l'Elo par surface
        # ---------------------------
        current_R1_surf = get_elo_surface(player1, surf)
        current_R2_surf = get_elo_surface(player2, surf)
        elo_j1_surface.append(current_R1_surf)
        elo_j2_surface.append(current_R2_surf)
        # Calcul des scores attendus pour la surface
        E1_surf = 1 / (1 + 10 ** ((current_R2_surf - current_R1_surf) / 400))
        E2_surf = 1 / (1 + 10 ** ((current_R1_surf - current_R2_surf) / 400))
        # Mise à jour par surface
        new_R1_surf = current_R1_surf + K * (S1 - E1_surf)
        new_R2_surf = current_R2_surf + K * (S2 - E2_surf)
        # Actualiser le dictionnaire pour la surface correspondante
        elo_surface_ratings[surf][player1] = new_R1_surf
        elo_surface_ratings[surf][player2] = new_R2_surf

    # Ajout des colonnes d'Elo au DataFrame
    df["elo_j1"] = elo_j1
    df["elo_j2"] = elo_j2
    df["elo_j1_surface"] = elo_j1_surface
    df["elo_j2_surface"] = elo_j2_surface
    df["gain_j1"] = gain_j1
    df["gain_j2"] = gain_j2

    return df


#############################################
# 2. Préparation des données
#############################################
df = fonction_un_nom(df)
df.drop(columns=[col for col in df.columns if "Unnamed:" in col], inplace=True)
df['date'] = pd.to_datetime(df['date'], errors='coerce')
df = df.sort_values('date').reset_index(drop=True)
df['target'] = (df['winner'] == df['j1']).astype(int)

# 2. Encodage joueurs
all_players = pd.concat([df["j1"], df["j2"]]).unique()
player_encoder = {player: idx for idx, player in enumerate(all_players)}
df["j1_enc"] = df["j1"].map(player_encoder)
df["j2_enc"] = df["j2"].map(player_encoder)

df




Unnamed: 0,href,j1,j2,time,score_j1,score_j2,date,tour,surface,Doubles_fautes_j1,%_1er_Service_j1,Jeux_de_Serv._Gagnés_j1,Doubles_fautes_j2,%_1er_Service_j2,Jeux_de_Serv._Gagnés_j2,rank1,rank2,age1,age2,point1,point2,tournament,%_1er_Service_j1_perc,%_1er_Service_j2_perc,winner,elo_j1,elo_j2,surface_encoded,tour_encoded,elo_j1_surface,elo_j2_surface,gain_j1,gain_j2,target,j1_enc,j2_enc
0,https://www.flashscore.fr/match/tennis/f5tUgyP...,Gambill Jan Michael,Kratochvil Michel,,2,0.0,2003-01-06,1/16 DE FINALE,DUR,,,,,,,901.0,70.0,26.116996,23.0,0.0,580.0,auckland,,,Gambill Jan Michael,1500.000000,1500.000000,1,1,1500.000000,1500.000000,16.000000,-16.000000,1,0,53
1,https://www.flashscore.fr/match/tennis/IyDIHaY...,Chela Juan Ignacio,Costa Albert,,2,0.0,2003-01-06,1/16 DE FINALE,DUR,,,,,,,23.0,8.0,23.000000,27.0,1240.0,2090.0,sydney,,,Chela Juan Ignacio,1500.000000,1500.000000,1,1,1500.000000,1500.000000,16.000000,-16.000000,1,1,69
2,https://www.flashscore.fr/match/tennis/d0Z6liB...,Srichaphan Paradorn,Kucera Karol,,2,0.0,2003-01-06,FINALE,DUR,,,,,,,14.0,75.0,23.000000,28.0,1701.0,528.0,chennai,,,Srichaphan Paradorn,1516.000000,1500.000000,1,7,1516.000000,1500.000000,15.263693,-15.263693,1,2,55
3,https://www.flashscore.fr/match/tennis/dp3hBcB...,Ferrero Juan Carlos,Crabb Jaymon,,2,1.0,2003-01-06,1/16 DE FINALE,DUR,,,,,,,4.0,217.0,22.000000,24.0,2740.0,144.0,sydney,,,Ferrero Juan Carlos,1500.000000,1500.000000,1,1,1500.000000,1500.000000,16.000000,-16.000000,1,3,41
4,https://www.flashscore.fr/match/tennis/zy2dAHQ...,Fish Mardy,Krajicek Richard,,2,0.0,2003-01-06,1/16 DE FINALE,DUR,,,,,,,83.0,92.0,21.000000,31.0,481.0,425.0,sydney,,,Fish Mardy,1500.000000,1500.000000,1,1,1500.000000,1500.000000,16.000000,-16.000000,1,4,52
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
73190,https://www.flashscore.fr/match/tennis/phWkfwl...,Fritz Taylor,Walton Adam,95.0,2,0.0,2025-03-26,1/8 DE FINALE,DUR,,,,,,,901.0,89.0,26.116996,25.0,0.0,666.0,miami,,,Fritz Taylor,1963.742798,1619.617605,1,4,1891.432835,1588.156022,3.878984,-3.878984,1,1585,2137
73191,https://www.flashscore.fr/match/tennis/Qqez08K...,Machac Tomas,Mensik Jakub,,-,,2025-03-26,1/8 DE FINALE,DUR,,,,,,,21.0,54.0,24.000000,19.0,2310.0,1042.0,miami,,,Mensik Jakub,1971.470888,1863.987126,1,4,1922.496290,1809.830750,-20.797722,20.797722,0,1929,2020
73192,https://www.flashscore.fr/match/tennis/hpjLmie...,Zverev Alexander,Arthur Fils,120.0,1,2.0,2025-03-26,1/8 DE FINALE,DUR,0.0,69%,(11/14),2.0,68%,(12/14),2.0,18.0,27.000000,20.0,7945.0,2480.0,miami,69.0,68.0,Arthur Fils,1987.291008,1863.996674,1,4,1969.080642,1787.065875,-21.450993,21.450993,0,1322,1993
73193,https://www.flashscore.fr/match/tennis/xIyRBlU...,Cerundolo Francisco,Dimitrov Grigor,170.0,1,2.0,2025-03-26,QUARTS DE FINALE,DUR,1.0,63%,(14/17),3.0,61%,(15/17),24.0,15.0,26.000000,33.0,1925.0,2745.0,miami,63.0,61.0,Dimitrov Grigor,1877.160675,1945.924840,1,3,1753.871053,1898.571384,-12.874004,12.874004,0,1861,587


In [2]:
import pandas as pd
import numpy as np

def random_swap_df(df):
    df = df.copy()
    
    swap_mask = np.random.rand(len(df)) > 0.5

    # Texte et encodages
    df.loc[swap_mask, ['j1', 'j2']] = df.loc[swap_mask, ['j2', 'j1']].values
    df.loc[swap_mask, ['j1_enc', 'j2_enc']] = df.loc[swap_mask, ['j2_enc', 'j1_enc']].values

    # Paires explicites de features à échanger
    feature_pairs = [
        ('elo_j1', 'elo_j2'),
        ('elo_j1_surface', 'elo_j2_surface'),
        ('rank1', 'rank2'),
        ('point1', 'point2'),
        ('age1', 'age2'),
        ('score_j1','score_j2'),
        ('gain_j1','gain_j2'),
    ]

    for f1, f2 in feature_pairs:
        df.loc[swap_mask, [f1, f2]] = df.loc[swap_mask, [f2, f1]].values

    # Inverser la target
    df.loc[swap_mask, 'target'] = 1 - df.loc[swap_mask, 'target']

    return df


df = random_swap_df(df)
df['target'].value_counts()

  df.loc[swap_mask, [f1, f2]] = df.loc[swap_mask, [f2, f1]].values


target
1    36744
0    36451
Name: count, dtype: int64

In [3]:
all_tournois = df["tournament"].unique()
tourn_encoder = {tournois: idx for idx, tournois in enumerate(all_tournois)}

df["tournament_enc"] = df["tournament"].map(tourn_encoder)
print(df)

df['diff_rank'] = df['rank1'] - df['rank2']
df['diff_elo']  = df['elo_j1'] - df['elo_j2']
df['diff_age']  = df['age1'] - df['age2']
df['diff_points'] = df['point1'] - df['point2']

df['diff_elo_surf']  = df['elo_j1_surface'] - df['elo_j2_surface']
# Vous pouvez aussi ajouter d'autres ratios ou interactions
df['ratio_rank'] = df['rank1'] / (df['rank2'] + 1e-6)  # évite la division par zéro


# On suppose que df est déjà chargé et préparé avec les colonnes de base :
# date, j1, j2, target (1 si j1 gagne, 0 sinon), etc.

# Liste des fenêtres que l'on souhaite pour calculer le rolling winrate
import pandas as pd
import numpy as np

# Liste des fenêtres des rolling winrates
windows = [3, 5, 10, 25, 50]

def compute_history_features_with_streak(df, windows):
    # Préparer des listes pour stocker les features pour chaque match
    j1_winrate_features = {f'j1_winrate_{w}m': [] for w in windows}
    j2_winrate_features = {f'j2_winrate_{w}m': [] for w in windows}
    
    j1_nb_prev = []  # nombre de matchs joués par j1 avant le match courant
    j2_nb_prev = []  # idem pour j2
    
    j1_streak = []  # streak de j1, positif en cas de win streak, négatif en cas de loss streak
    j2_streak = []  # idem pour j2

    # Dictionnaire pour stocker l'historique des résultats pour chaque joueur
    # Les résultats seront stockés sous forme d'une liste de 1 (victoire) et 0 (défaite)
    history = {}

    # Fonction auxiliaire pour calculer le streak à partir d'une liste d'historique
    def get_streak(hist):
        if not hist:
            return 0
        # On parcourt à l'envers et on compte les matchs consécutifs identiques au dernier résultat
        streak = 0
        last_result = hist[-1]
        for r in reversed(hist):
            if r == last_result:
                streak += 1
            else:
                break
        return streak if last_result == 1 else -streak

    # Parcourir le DataFrame trié par date (assurez-vous que la colonne "date" est bien de type datetime)
    for idx, row in df.iterrows():
        # Identifiants des joueurs pour le match courant
        player1 = row["j1"]
        player2 = row["j2"]

        # Récupérer l'historique existant pour chaque joueur ou initialiser une liste vide
        hist1 = history.get(player1, [])
        hist2 = history.get(player2, [])

        # Stocker le nombre de matchs précédents
        nb_prev1 = len(hist1)
        nb_prev2 = len(hist2)
        j1_nb_prev.append(nb_prev1)
        j2_nb_prev.append(nb_prev2)

        # Calculer les rolling winrates pour chaque fenêtre pour j1 et j2
        for w in windows:
            # Pour j1
            if nb_prev1 > 0:
                recent_matches = hist1[-w:]  # les w derniers matchs ou moins
                winrate1 = np.mean(recent_matches)
            else:
                winrate1 = 0.5  # valeur neutre par défaut
            j1_winrate_features[f'j1_winrate_{w}m'].append(winrate1)

            # Pour j2
            if nb_prev2 > 0:
                recent_matches = hist2[-w:]
                winrate2 = np.mean(recent_matches)
            else:
                winrate2 = 0.5
            j2_winrate_features[f'j2_winrate_{w}m'].append(winrate2)

        # Calculer le streak actuel pour chaque joueur avant le match courant
        current_streak1 = get_streak(hist1)
        current_streak2 = get_streak(hist2)
        j1_streak.append(current_streak1)
        j2_streak.append(current_streak2)
        
        # Mettre à jour l'historique : pour j1, le résultat est directement dans "target"
        result_j1 = row["target"]  # 1 si j1 gagne, 0 sinon
        # Pour j2, le résultat est complémentaire
        result_j2 = 1 - row["target"]

        # Mise à jour des historiques
        history[player1] = hist1 + [result_j1]
        history[player2] = hist2 + [result_j2]

    # Ajouter les features calculées dans le DataFrame
    for w in windows:
        df[f'j1_winrate_{w}m'] = j1_winrate_features[f'j1_winrate_{w}m']
        df[f'j2_winrate_{w}m'] = j2_winrate_features[f'j2_winrate_{w}m']

    df['j1_nb_prev_matches'] = j1_nb_prev
    df['j2_nb_prev_matches'] = j2_nb_prev
    df['j1_streak'] = j1_streak
    df['j2_streak'] = j2_streak

    return df

# Exemple d'utilisation

# Conversion de la colonne "date" en datetime et tri par date
df['date'] = pd.to_datetime(df['date'])
df = df.sort_values('date').reset_index(drop=True)

# Appliquer la fonction pour ajouter les features de forme (winrate, nb_prev_matches et streak)
df = compute_history_features_with_streak(df, windows)

# Affichage des colonnes d'intérêt pour vérifier

def compute_h2h_features(df):
    h2h_counts = {}
    h2h_winrate_j1 = []
    h2h_total_prev = []  # Liste pour le nombre de matchs précédents pour le matchup

    # Parcourir chaque match dans l'ordre chronologique
    # (Assurez-vous que df est trié par date avant d'appeler cette fonction)
    for idx, row in df.iterrows():
        p1, p2 = row["j1"], row["j2"]
        # Créer une clé unique et ordonnée pour la confrontation
        matchup = tuple(sorted([p1, p2]))
        
        # Initialisation si le matchup n'existe pas encore
        if matchup not in h2h_counts:
            h2h_counts[matchup] = {"total": 0, p1: 0, p2: 0}

        # Récupérer le nombre de matchs déjà joués entre ces deux joueurs
        total = h2h_counts[matchup]["total"]
        # Stocker ce total avant la mise à jour pour ce match
        h2h_total_prev.append(total)
        
        # Récupérer le nombre de victoires de j1 dans ces confrontations
        p1_wins = h2h_counts[matchup][p1]
        # Calculer le winrate précédent de j1 (avant ce match)
        winrate = p1_wins / total if total > 0 else np.nan
        h2h_winrate_j1.append(winrate)

        # Mise à jour après le match courant
        winner = row["winner"]
        h2h_counts[matchup]["total"] += 1
        if winner not in h2h_counts[matchup]:
            raise ValueError(f"Le nom '{winner}' du winner n'est pas dans le matchup {matchup}")
        h2h_counts[matchup][winner] += 1

    # Ajout des features dans le DataFrame
    df["h2h_winrate_j1"] = h2h_winrate_j1
    df["h2h_total_prev"] = h2h_total_prev

    return df



df = compute_h2h_features(df)
df["score_j1"].replace('-',"0",inplace=True)
df["score_j2"].replace('-',"0",inplace=True)
df["score_j1"] = pd.to_numeric(df["score_j1"], errors="coerce")
df["score_j2"] = pd.to_numeric(df["score_j2"], errors="coerce")
df["set_diff"] = df["score_j1"] - df["score_j2"]
df["h2h_winrate_j1"].fillna(0.5,inplace=True)
def best_of(df):
    best = []
    for idx, row in df.iterrows():
        s1, s2 = float(row["score_j1"]), float(row["score_j2"])
        if max(s1,s2) == 2:
            best.append(3)
        elif max(s1,s2) == 3:
            best.append(5)
        elif max(s1,s2) == 4:
            best.append(7)
        else:
            best.append(-1)
    df["best_of"] = best
    return df


df = best_of(df)
df.drop(index=df[df["best_of"] ==-1].index, inplace=True)
from sklearn.preprocessing import StandardScaler, MinMaxScaler

df = df.copy()

# Colonnes par paires
paired_cols = [
    ('age1', 'age2'),
    ('point1', 'point2'),
    ('elo_j1', 'elo_j2'),
    ('elo_j1_surface', 'elo_j2_surface'),
    ('rank1', 'rank2'),
    ('gain_j1','gain_j2'),
    ("j1_nb_prev_matches","j2_nb_prev_matches")
]

# Normalisation paire par paire
for col1, col2 in paired_cols:
    stacked = np.hstack([df[col1].values.reshape(-1, 1), df[col2].values.reshape(-1, 1)])
    scaler = StandardScaler()
    scaled = scaler.fit_transform(stacked)
    df[col1] = scaled[:, 0]
    df[col2] = scaled[:, 1]

# Autres colonnes solo à standardiser
solo_cols = ["diff_rank", "diff_points", "diff_elo", "diff_elo_surf"]
df[solo_cols] = StandardScaler().fit_transform(df[solo_cols])
print(len(df))
df




                                                    href                   j1  \
0      https://www.flashscore.fr/match/tennis/f5tUgyP...  Gambill Jan Michael   
1      https://www.flashscore.fr/match/tennis/IyDIHaY...         Costa Albert   
2      https://www.flashscore.fr/match/tennis/d0Z6liB...         Kucera Karol   
3      https://www.flashscore.fr/match/tennis/dp3hBcB...         Crabb Jaymon   
4      https://www.flashscore.fr/match/tennis/zy2dAHQ...     Krajicek Richard   
...                                                  ...                  ...   
73190  https://www.flashscore.fr/match/tennis/phWkfwl...         Fritz Taylor   
73191  https://www.flashscore.fr/match/tennis/Qqez08K...         Machac Tomas   
73192  https://www.flashscore.fr/match/tennis/hpjLmie...          Arthur Fils   
73193  https://www.flashscore.fr/match/tennis/xIyRBlU...      Dimitrov Grigor   
73194  https://www.flashscore.fr/match/tennis/COXgdWg...          Arthur Fils   

                        j2 

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  df["score_j1"].replace('-',"0",inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  df["score_j2"].replace('-',"0",inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values alw

71810


Unnamed: 0,href,j1,j2,time,score_j1,score_j2,date,tour,surface,Doubles_fautes_j1,%_1er_Service_j1,Jeux_de_Serv._Gagnés_j1,Doubles_fautes_j2,%_1er_Service_j2,Jeux_de_Serv._Gagnés_j2,rank1,rank2,age1,age2,point1,point2,tournament,%_1er_Service_j1_perc,%_1er_Service_j2_perc,winner,elo_j1,elo_j2,surface_encoded,tour_encoded,elo_j1_surface,elo_j2_surface,gain_j1,gain_j2,target,j1_enc,j2_enc,tournament_enc,diff_rank,diff_elo,diff_age,diff_points,diff_elo_surf,ratio_rank,j1_winrate_3m,j2_winrate_3m,j1_winrate_5m,j2_winrate_5m,j1_winrate_10m,j2_winrate_10m,j1_winrate_25m,j2_winrate_25m,j1_winrate_50m,j2_winrate_50m,j1_nb_prev_matches,j2_nb_prev_matches,j1_streak,j2_streak,h2h_winrate_j1,h2h_total_prev,set_diff,best_of
0,https://www.flashscore.fr/match/tennis/f5tUgyP...,Gambill Jan Michael,Kratochvil Michel,,2.0,0.0,2003-01-06,1/16 DE FINALE,DUR,,,,,,,2.662166,-0.417827,-0.008331,-0.820058,-0.679134,-0.320768,auckland,,,Gambill Jan Michael,-0.993688,-0.996527,1,1,-0.814132,-0.817801,1.077755,-1.077755,1,0,53,0,2.330427,-0.000528,3.116996,-0.279478,-0.000848,12.871428,0.500000,0.500000,0.5,0.5,0.5,0.5,0.50,0.50,0.50,0.500000,-0.939541,-0.945756,0,0,0.500,0,2.0,3
1,https://www.flashscore.fr/match/tennis/CtyZhev...,Ginepri Robby,Calleri Agustin,,2.0,0.0,2003-01-06,1/16 DE FINALE,DUR,,,,,,,-0.303085,-0.492598,-1.595599,-0.039476,-0.435856,-0.228089,auckland,,,Ginepri Robby,-0.993688,-0.996527,1,1,-0.814132,-0.817801,1.077755,-1.077755,1,20,106,0,0.140803,-0.000528,-6.000000,-0.161884,-0.000848,2.040000,0.500000,0.500000,0.5,0.5,0.5,0.5,0.50,0.50,0.50,0.500000,-0.939541,-0.945756,0,0,0.500,0,2.0,3
2,https://www.flashscore.fr/match/tennis/29lkCwd...,Enqvist Thomas,Ferreira Wayne,,0.0,2.0,2003-01-06,1/16 DE FINALE,DUR,,,,,,,-0.510912,-0.526245,0.480280,1.261495,-0.198163,-0.138499,sydney,,,Ferreira Wayne,-0.993688,-0.996527,1,1,-0.814132,-0.817801,-1.086330,1.086330,0,93,21,1,0.008694,-0.000528,-3.000000,-0.046234,-0.000848,1.121951,0.500000,0.500000,0.5,0.5,0.5,0.5,0.50,0.50,0.50,0.500000,-0.939541,-0.945756,0,0,0.500,0,-2.0,3
3,https://www.flashscore.fr/match/tennis/OIkoDJt...,Durek Raphael,Davydenko Nikolay,,0.0,2.0,2003-01-06,1/16 DE FINALE,DUR,,,,,,,2.662166,-0.451474,-0.008331,-1.340446,-0.679134,-0.285550,sydney,,,Davydenko Nikolay,-0.993688,-0.996527,1,1,-0.814132,-0.817801,-1.086330,1.086330,0,2255,22,1,2.355724,-0.000528,5.116996,-0.307176,-0.000848,14.770492,0.500000,0.500000,0.5,0.5,0.5,0.5,0.50,0.50,0.50,0.500000,-0.939541,-0.945756,0,0,0.500,0,-2.0,3
4,https://www.flashscore.fr/match/tennis/MirMecf...,Coria Guillermo,Sanchez Munoz David,,2.0,1.0,2003-01-06,1/16 DE FINALE,DUR,,,,,,,-0.514623,2.688935,-1.595599,-0.009034,-0.196301,-0.679127,auckland,,,Coria Guillermo,-0.993688,-0.996527,1,1,-0.814132,-0.817801,1.077755,-1.077755,1,23,86,0,-2.411416,-0.000528,-6.116996,0.380409,-0.000848,0.049945,0.500000,0.500000,0.5,0.5,0.5,0.5,0.50,0.50,0.50,0.500000,-0.939541,-0.945756,0,0,0.500,0,1.0,3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
73189,https://www.flashscore.fr/match/tennis/OhGugwd...,Djokovic Novak,Musetti Lorenzo,83.0,2.0,0.0,2025-03-25,1/8 DE FINALE,DUR,3.0,70%,(3/8),1.0,70%,(7/8),-0.663071,-0.619710,2.815645,-0.820058,1.716411,0.958203,miami,70.0,70.0,Djokovic Novak,2.795122,1.245263,1,4,3.462285,0.945084,0.372728,-0.372728,1,380,1872,19,-0.036279,1.492695,14.000000,0.590329,2.329908,0.312500,0.666667,0.666667,0.6,0.6,0.7,0.6,0.80,0.60,0.76,0.680000,5.240170,0.312173,2,2,0.875,8,2.0,3
73190,https://www.flashscore.fr/match/tennis/hpjLmie...,Arthur Fils,Zverev Alexander,120.0,2.0,1.0,2025-03-26,1/8 DE FINALE,DUR,0.0,69%,(11/14),2.0,68%,(12/14),-0.614826,-0.672050,-1.595599,0.220719,0.859972,4.229774,miami,69.0,68.0,Arthur Fils,1.099756,1.815579,1,4,1.100981,2.329247,1.446393,-1.446393,1,1993,1322,19,0.039613,-0.681612,-7.000000,-2.653224,-1.122418,8.999996,0.666667,0.666667,0.8,0.6,0.6,0.6,0.68,0.72,0.60,0.720000,-0.313118,2.281760,2,2,0.250,4,1.0,3
73191,https://www.flashscore.fr/match/tennis/xIyRBlU...,Dimitrov Grigor,Cerundolo Francisco,170.0,2.0,1.0,2025-03-26,QUARTS DE FINALE,DUR,1.0,63%,(14/17),3.0,61%,(15/17),-0.625959,-0.589801,1.777705,-0.039476,1.024433,0.510254,miami,63.0,61.0,Dimitrov Grigor,1.570947,1.180028,1,3,1.844871,0.885412,0.866351,-0.866351,1,587,1861,19,-0.030657,0.379329,7.000000,0.400818,0.890791,0.625000,1.000000,1.000000,0.8,0.8,0.5,0.8,0.64,0.68,0.66,0.620000,2.833911,0.116718,3,3,1.000,1,1.0,3
73192,https://www.flashscore.fr/match/tennis/phWkfwl...,Fritz Taylor,Walton Adam,95.0,2.0,0.0,2025-03-26,1/8 DE FINALE,DUR,,,,,,,2.662166,-0.346794,-0.008331,-0.299670,-0.679134,-0.267632,miami,,,Fritz Taylor,1.673423,-0.306226,1,4,1.797247,-0.226365,0.258039,-0.258039,1,1585,2137,19,2.277021,1.900437,1.116996,-0.321268,1.867935,10.123595,0.666667,0.666667,0.8,0.6,0.6,0.7,0.64,0.44,0.70,0.466667,1.436888,-0.720231,2,2,0.500,0,2.0,3


In [4]:
import torch
from torch_geometric.data import TemporalData
from torch_geometric.loader import TemporalDataLoader
cols_to_drop = [
    "surface", "href", "j1", "j2", "time", "score_j1", "score_j2", "tour",
    "Doubles_fautes_j1", "%_1er_Service_j1", "Jeux_de_Serv._Gagnés_j1",
    "Doubles_fautes_j2", "%_1er_Service_j2", "Jeux_de_Serv._Gagnés_j2",
    "tournament", "%_1er_Service_j1_perc", "%_1er_Service_j2_perc", "winner"
]

# option 1 – get a new DataFrame
df_clean = df.drop(columns=cols_to_drop)
df["timestamp"] = pd.to_datetime(df["date"]).astype(int) // 10**9
print(len(df[df["j1"]==df["j2"]]))
import pandas as pd
import numpy as np

# 1. Supposons que votre DataFrame s’appelle df et que son index
#    correspond bien à l’ordre temporel des matchs.

# 2. Construire un DataFrame “plat” listant pour chaque joueur
#    l’index de chaque match où il est j1 ou j2.
records = []
for role in ['j1_enc', 'j2_enc']:
    records.extend(zip(df[role].values, df.index.values))
pm_df = pd.DataFrame(records, columns=['player', 'match_idx'])

# 3. Pour chaque joueur, trier ses index de match et calculer
#    la différence minimale entre deux index consécutifs.
def min_idx_gap(idxs):
    idxs = np.sort(idxs)
    if len(idxs) < 2:
        return np.nan
    return np.min(np.diff(idxs))

min_gap_df = (
    pm_df
    .groupby('player')['match_idx']
    .apply(min_idx_gap)
    .reset_index(name='min_index_gap')
    .sort_values('min_index_gap')
)

# 4. Afficher le résultat
print(min_gap_df.head(10))   # par exemple, les 10 joueurs avec gap minimal le plus petit
print("Gap minimal global :", min_gap_df['min_index_gap'].min())


# 3. Définir les colonnes utilisées comme features (msg)
features_cols = [
    "rank1", "rank2", "age1", "age2", "point1", "point2",
    "elo_j1", "elo_j2", "surface_encoded", "tour_encoded",
    "elo_j1_surface", "elo_j2_surface", "tournament_enc",
    "diff_rank", "diff_elo", "diff_age", "diff_points", "diff_elo_surf",
    "ratio_rank",
    "j1_winrate_3m", "j2_winrate_3m", "j1_winrate_5m", "j2_winrate_5m",
    "j1_winrate_10m", "j2_winrate_10m", "j1_winrate_25m", "j2_winrate_25m",
    "j1_winrate_50m", "j2_winrate_50m",
    "j1_nb_prev_matches", "j2_nb_prev_matches",
    "j1_streak", "j2_streak",
    "h2h_winrate_j1", "h2h_total_prev", "best_of"
]
# 1) Définissez un mapping entre les noms “1” et les noms “2”
swap_map = {
    'rank1': 'rank2', 'rank2': 'rank1',
    "j1_enc" : "j2_enc","j2_enc" : "j1_enc",
    'age1': 'age2',   'age2': 'age1',
    'point1': 'point2','point2': 'point1',
    'elo_j1': 'elo_j2', 'elo_j2': 'elo_j1',
    'elo_j1_surface': 'elo_j2_surface','elo_j2_surface': 'elo_j1_surface',
    'j1_winrate_3m':'j2_winrate_3m','j2_winrate_3m':'j1_winrate_3m',
    'j1_winrate_5m':'j2_winrate_5m','j2_winrate_5m':'j1_winrate_5m',
    'j1_winrate_10m':'j2_winrate_10m','j2_winrate_10m':'j1_winrate_10m',
    'j1_winrate_25m':'j2_winrate_25m','j2_winrate_25m':'j1_winrate_25m',
    'j1_winrate_50m':'j2_winrate_50m','j2_winrate_50m':'j1_winrate_50m',
    'j1_nb_prev_matches':'j2_nb_prev_matches','j2_nb_prev_matches':'j1_nb_prev_matches',
    'j1_streak':'j2_streak','j2_streak':'j1_streak',
    'gain_j1':'gain_j2','gain_j2':'gain_j1',
}

# 2) Renommez vos colonnes d'après ce mapping
df_swapped = df.rename(columns=swap_map).copy()
df_swapped['target'] = 1 - df_swapped['target']
df_swapped['h2h_winrate_j1'] = 1- df_swapped['h2h_winrate_j1']
df_swapped["diff_elo"] = df_swapped["diff_elo"]*(-1)
df_swapped["diff_age"] = df_swapped["diff_age"]*(-1)
df_swapped["diff_points"] = df_swapped["diff_points"]*(-1)
df_swapped["diff_elo_surf"] = df_swapped["diff_elo_surf"]*(-1)
df_swapped["diff_rank"] = df_swapped["diff_rank"]*(-1)
df_swapped["set_diff"] = df_swapped["set_diff"]*(-1)
df_swapped["ratio_rank"] = 1/df_swapped["ratio_rank"]
df_swapped

df_aug = pd.concat([df, df_swapped], ignore_index=True)
df_aug = df_aug.sort_values('date').reset_index(drop=True)


df_aug


0
      player  min_index_gap
28        28            1.0
29        29            1.0
62        62            1.0
64        64            1.0
69        69            1.0
71        71            1.0
2100    2101            1.0
7          7            1.0
8          8            1.0
25        25            1.0
Gap minimal global : 1.0


Unnamed: 0,href,j1,j2,time,score_j1,score_j2,date,tour,surface,Doubles_fautes_j1,%_1er_Service_j1,Jeux_de_Serv._Gagnés_j1,Doubles_fautes_j2,%_1er_Service_j2,Jeux_de_Serv._Gagnés_j2,rank1,rank2,age1,age2,point1,point2,tournament,%_1er_Service_j1_perc,%_1er_Service_j2_perc,winner,elo_j1,elo_j2,surface_encoded,tour_encoded,elo_j1_surface,elo_j2_surface,gain_j1,gain_j2,target,j1_enc,j2_enc,tournament_enc,diff_rank,diff_elo,diff_age,diff_points,diff_elo_surf,ratio_rank,j1_winrate_3m,j2_winrate_3m,j1_winrate_5m,j2_winrate_5m,j1_winrate_10m,j2_winrate_10m,j1_winrate_25m,j2_winrate_25m,j1_winrate_50m,j2_winrate_50m,j1_nb_prev_matches,j2_nb_prev_matches,j1_streak,j2_streak,h2h_winrate_j1,h2h_total_prev,set_diff,best_of,timestamp
0,https://www.flashscore.fr/match/tennis/f5tUgyP...,Gambill Jan Michael,Kratochvil Michel,,2.0,0.0,2003-01-06,1/16 DE FINALE,DUR,,,,,,,2.662166,-0.417827,-0.008331,-0.820058,-0.679134,-0.320768,auckland,,,Gambill Jan Michael,-0.993688,-0.996527,1,1,-0.814132,-0.817801,1.077755,-1.077755,1,0,53,0,2.330427,-0.000528,3.116996,-0.279478,-0.000848,12.871428,0.500000,0.500000,0.5,0.5,0.5,0.5,0.50,0.50,0.500000,0.50,-0.939541,-0.945756,0,0,0.50,0,2.0,3,1041811200
1,https://www.flashscore.fr/match/tennis/29lkCwd...,Enqvist Thomas,Ferreira Wayne,,0.0,2.0,2003-01-06,1/16 DE FINALE,DUR,,,,,,,-0.526245,-0.510912,1.261495,0.480280,-0.138499,-0.198163,sydney,,,Ferreira Wayne,-0.996527,-0.993688,1,1,-0.817801,-0.814132,1.086330,-1.086330,1,21,93,1,-0.008694,0.000528,3.000000,0.046234,0.000848,0.891304,0.500000,0.500000,0.5,0.5,0.5,0.5,0.50,0.50,0.500000,0.50,-0.945756,-0.939541,0,0,0.50,0,2.0,3,1041811200
2,https://www.flashscore.fr/match/tennis/OIkoDJt...,Durek Raphael,Davydenko Nikolay,,0.0,2.0,2003-01-06,1/16 DE FINALE,DUR,,,,,,,-0.451474,2.662166,-1.340446,-0.008331,-0.285550,-0.679134,sydney,,,Davydenko Nikolay,-0.996527,-0.993688,1,1,-0.817801,-0.814132,1.086330,-1.086330,1,22,2255,1,-2.355724,0.000528,-5.116996,0.307176,0.000848,0.067703,0.500000,0.500000,0.5,0.5,0.5,0.5,0.50,0.50,0.500000,0.50,-0.945756,-0.939541,0,0,0.50,0,2.0,3,1041811200
3,https://www.flashscore.fr/match/tennis/MirMecf...,Coria Guillermo,Sanchez Munoz David,,2.0,1.0,2003-01-06,1/16 DE FINALE,DUR,,,,,,,2.688935,-0.514623,-0.009034,-1.595599,-0.679127,-0.196301,auckland,,,Coria Guillermo,-0.996527,-0.993688,1,1,-0.817801,-0.814132,-1.077755,1.077755,0,86,23,0,2.411416,0.000528,6.116996,-0.380409,0.000848,20.022222,0.500000,0.500000,0.5,0.5,0.5,0.5,0.50,0.50,0.500000,0.50,-0.945756,-0.939541,0,0,0.50,0,-1.0,3,1041811200
4,https://www.flashscore.fr/match/tennis/8lvIdwu...,Sa Andre,Carlsen Kenneth,,1.0,2.0,2003-01-06,1/16 DE FINALE,DUR,,,,,,,-0.440258,-0.440399,0.741107,-0.298175,-0.301614,-0.302425,auckland,,,Carlsen Kenneth,-0.996527,-0.993688,1,1,-0.817801,-0.814132,1.086330,-1.086330,1,24,175,0,0.002549,0.000528,4.000000,-0.000415,0.000848,0.984615,0.500000,0.500000,0.5,0.5,0.5,0.5,0.50,0.50,0.500000,0.50,-0.945756,-0.939541,0,0,0.50,0,1.0,3,1041811200
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
143615,https://www.flashscore.fr/match/tennis/hpjLmie...,Arthur Fils,Zverev Alexander,120.0,2.0,1.0,2025-03-26,1/8 DE FINALE,DUR,0.0,69%,(11/14),2.0,68%,(12/14),-0.614826,-0.672050,-1.595599,0.220719,0.859972,4.229774,miami,69.0,68.0,Arthur Fils,1.099756,1.815579,1,4,1.100981,2.329247,1.446393,-1.446393,1,1993,1322,19,0.039613,-0.681612,-7.000000,-2.653224,-1.122418,8.999996,0.666667,0.666667,0.8,0.6,0.6,0.6,0.68,0.72,0.600000,0.72,-0.313118,2.281760,2,2,0.25,4,1.0,3,1742947200
143616,https://www.flashscore.fr/match/tennis/phWkfwl...,Fritz Taylor,Walton Adam,95.0,2.0,0.0,2025-03-26,1/8 DE FINALE,DUR,,,,,,,-0.346794,2.662166,-0.299670,-0.008331,-0.267632,-0.679134,miami,,,Fritz Taylor,-0.306226,1.673423,1,4,-0.226365,1.797247,-0.258039,0.258039,0,2137,1585,19,-2.277021,-1.900437,-1.116996,0.321268,-1.867935,0.098779,0.666667,0.666667,0.6,0.8,0.7,0.6,0.44,0.64,0.466667,0.70,-0.720231,1.436888,2,2,0.50,0,-2.0,3,1742947200
143617,https://www.flashscore.fr/match/tennis/xIyRBlU...,Dimitrov Grigor,Cerundolo Francisco,170.0,2.0,1.0,2025-03-26,QUARTS DE FINALE,DUR,1.0,63%,(14/17),3.0,61%,(15/17),-0.589801,-0.625959,-0.039476,1.777705,0.510254,1.024433,miami,63.0,61.0,Dimitrov Grigor,1.180028,1.570947,1,3,0.885412,1.844871,-0.866351,0.866351,0,1861,587,19,0.030657,-0.379329,-7.000000,-0.400818,-0.890791,1.600000,1.000000,1.000000,0.8,0.8,0.8,0.5,0.68,0.64,0.620000,0.66,0.116718,2.833911,3,3,0.00,1,-1.0,3,1742947200
143618,https://www.flashscore.fr/match/tennis/COXgdWg...,Arthur Fils,Mensik Jakub,77.0,0.0,2.0,2025-03-27,QUARTS DE FINALE,DUR,3.0,68%,(6/10),2.0,66%,(8/9),-0.614826,-0.477644,-1.595599,-1.860835,0.859972,-0.035317,miami,68.0,66.0,Mensik Jakub,1.223127,1.224026,1,3,1.259031,1.401826,-1.088394,1.088394,0,1993,2020,19,-0.106549,0.003134,1.000000,0.701120,-0.124629,0.333333,1.000000,1.000000,0.8,0.8,0.7,0.6,0.68,0.68,0.620000,0.64,-0.308147,-0.589927,3,4,0.50,0,-2.0,3,1743033600


In [5]:

# 4. Split temporel
train_df = df[df["date"] < "2024-01-01"].copy()
test_df  = df[df["date"] >= "2024-01-01"].copy()

# 5. Création des TemporalData
def build_temporal_data(df_subset):
    return TemporalData(
        src = torch.tensor(df_subset["j1_enc"].values, dtype=torch.long),
        dst = torch.tensor(df_subset["j2_enc"].values, dtype=torch.long),
        t   = torch.tensor(df_subset["timestamp"].values, dtype=torch.long),
        msg = torch.tensor(df_subset[features_cols].values, dtype=torch.float),
        y   = torch.tensor(df_subset["target"].values, dtype=torch.float),
        set_diff  = torch.tensor(df_subset["set_diff"].values,  dtype=torch.float),
        elo_gain  = torch.tensor(df_subset["gain_j1"].values,  dtype=torch.float),
    )

train_data = build_temporal_data(train_df)
test_data  = build_temporal_data(test_df)

for data in (train_data, test_data):
    data.src = data.src.long()    # src en ints longs
    data.dst = data.dst.long()    # dst en ints longs
    data.t   = data.t.long()     # timestamps en floats
    data.msg = data.msg.float()   # features en floats
    data.y   = data.y.float()     # labels en floats
    data.set_diff = data.set_diff.float()
    data.elo_gain = data.elo_gain.float()
# 4) Déplacer t et msg sur GPU
device = "cuda"
for data in (train_data, test_data):
    data.t   = data.t.to(device)
    data.msg = data.msg.to(device)
# 6. DataLoaders pour entraînement
train_loader = TemporalDataLoader(train_data, batch_size=16, neg_sampling_ratio=0)
test_loader  = TemporalDataLoader(test_data, batch_size=16, neg_sampling_ratio=0)

In [6]:
for batch in train_loader:
    print(batch.src.shape, batch.msg.shape, batch.y.shape)
    break


torch.Size([16]) torch.Size([16, 36]) torch.Size([16])


In [20]:
import torch
import torch.nn as nn
from torch.nn import Linear
from sklearn.metrics import average_precision_score, roc_auc_score

from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import (
    IdentityMessage,
    LastAggregator,
    LastNeighborLoader,
)
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.data import TemporalData

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Paramètres
embedding_dim = memory_dim = time_dim = 1024
learning_rate = 0.00006


# === Modules ===

class MultiLayerTimeAwareGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 msg_dim, time_enc, num_layers=3, heads=2, dropout=0.1):
        """
        - in_channels   : dim. entrée (memory_dim)
        - hidden_channels : dim. intermédiaire de chaque tête (out_channels/2 si heads=2)
        - out_channels  : dim. de sortie finale
        - msg_dim       : dim. des features d'arête
        - time_enc      : instance du TimeEncoder de TGNMemory
        - num_layers    : nombre de blocs TransformerConv à empiler
        """
        super().__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels

        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_c  = in_channels if i == 0 else hidden_channels * heads
            out_c = out_channels // heads if i == num_layers-1 else hidden_channels
            self.convs.append(
                TransformerConv(
                    in_c, out_c, heads=heads,
                    dropout=dropout, edge_dim=edge_dim
                )
            )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, last_update, edge_index, t, msg):
        # calcule rel_t et rel_t_enc une seule fois
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)

        # propagation à travers les couches empilées
        for conv in self.convs:
            x_new = conv(x, edge_index, edge_attr)
            x = (x + self.dropout(x_new)).relu()   # skip + dropout + relu

        return x


class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.lin_src = Linear(in_channels, in_channels)
        self.lin_dst = Linear(in_channels, in_channels)
        self.lin_final = Linear(in_channels, 1)

    def forward(self, z_src, z_dst):
        h = self.lin_src(z_src) + self.lin_dst(z_dst)
        h = h.relu()
        return self.lin_final(h)

# === Instanciations ===
import torch
import torch.nn as nn

import torch.nn.functional as F

class WinPredictorMLP(nn.Module):
    def __init__(self,
                 in_channels: int,
                 hidden_channels_list: list[int] = [256, 128],
                 dropout: float = 0.3,
                 residual: bool = False):
        """
        - in_channels          : dim des embeddings z_src et z_dst
        - hidden_channels_list : liste des dims cachées, ex [256, 128]
        - dropout              : probabilité de dropout entre chaque bloc
        - residual             : si True, ajoute une skip‑connection sur chaque bloc
        """
        super().__init__()
        layers = []
        dims = [5 * in_channels] + hidden_channels_list + [1]
        self.residual = residual

        for i in range(len(dims) - 2):
            in_dim, out_dim = dims[i], dims[i+1]
            layers.append(nn.Linear(in_dim, out_dim))
            layers.append(nn.LayerNorm(out_dim))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.Dropout(dropout))
        # dernière couche sans activation

        self.net = nn.Sequential(*layers)
        self.head_win = nn.Sequential(
            nn.Linear(dims[-2], dims[-1])
        )
        # Head 2: regression du diff de points (ou de jeux)
        self.head_margin = nn.Sequential(
            nn.Linear(dims[-2], dims[-1])
        )
        self.head_elo = nn.Sequential(
            nn.Linear(dims[-2], dims[-1])
        )


    def forward(self, z_src, z_dst):
        h_diff = z_src - z_dst
        h_abs  = torch.abs(z_src - z_dst)
        h_mult = z_src * z_dst

        h = torch.cat([z_src, z_dst, h_diff, h_abs, h_mult], dim=-1)

        h = self.net(h)
        logit_win = self.head_win(h)
        margin    = self.head_margin(h)  
        elo       = self.head_elo(h) 
        return logit_win,margin,elo

        


num_nodes = max(df[["j1_enc", "j2_enc"]].max()) + 1
msg_dim = train_data.msg.size(-1)

memory = TGNMemory(
    num_nodes=num_nodes,
    raw_msg_dim=msg_dim,
    memory_dim=memory_dim,
    time_dim=time_dim,
    message_module=IdentityMessage(msg_dim, memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

gnn = MultiLayerTimeAwareGNN(memory_dim,32, embedding_dim, msg_dim, memory.time_enc,3,32,0.3).to(device)
link_pred = WinPredictorMLP(
    in_channels=embedding_dim,
    hidden_channels_list=[512,128,32],  # deux couches cachées
    dropout=0.5,
    residual=False                   # ou True si vous voulez
).to(device)


optimizer = torch.optim.AdamW(
    list(memory.parameters()) + list(gnn.parameters()) + list(link_pred.parameters()),
    lr=learning_rate,weight_decay=1e-4
)
criterion = torch.nn.BCEWithLogitsLoss()

# === Loaders ===


neighbor_loader = LastNeighborLoader(num_nodes=num_nodes, size=20, device=device)

assoc = torch.empty(num_nodes, dtype=torch.long, device=device)

# === Train ===

def train():
    memory.train()
    gnn.train()
    link_pred.train()
    memory.reset_state()
    neighbor_loader.reset_state()

    total_loss = 0
    # on wrappe train_loader avec tqdm
    for batch in tqdm(train_loader, desc="Training", unit="batch"):
        optimizer.zero_grad()
        batch = batch.to(device)

        # 1. Construire le sous‑graphe voisin
        n_id, edge_index, e_id = neighbor_loader(batch.n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        # 2. Récupérer la mémoire et l'embedding
        z, last_update = memory(n_id)

        # 3. PAS batch.t/msg ! mais train_data.t/msg
        t_e = train_data.t[e_id].to(device)
        msg_e = train_data.msg[e_id].to(device)
        z = gnn(z, last_update, edge_index, t_e, msg_e)

        # 4. Link prediction
        logit_win,margin,elo = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
     
        y_win = batch.y.view(-1,1)                    # {0,1}
        y_setdiff = batch.set_diff.view(-1,1)
        y_elo_gain = batch.elo_gain.view(-1,1)
        # 5. Loss + update mémoire
        loss_win = F.binary_cross_entropy_with_logits(logit_win, y_win)
        loss_setdiff = F.mse_loss(margin, y_setdiff)
        loss_elo = F.mse_loss(elo, y_elo_gain)
        loss = loss_win + 0.5*loss_setdiff + 0.5*loss_elo
        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)

        # 6. Backprop
        loss.backward()
        optimizer.step()
        memory.detach()
        total_loss += loss.item() * batch.num_events

    return total_loss / train_data.num_events

@torch.no_grad()

def evaluate(loader,thresholds):
    memory.eval(); gnn.eval(); link_pred.eval()
    total_loss = 0.0
    total_events = 0
    aps = []
    all_preds = []
    all_trues = []

    neighbor_loader.reset_state()
    for batch in loader:
        batch = batch.to(device)
        n_id, edge_index, e_id = neighbor_loader(batch.n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        z, last_update = memory(n_id)
        t_e   = test_data.t[e_id].to(device)
        msg_e = test_data.msg[e_id].to(device)
        z = gnn(z, last_update, edge_index, t_e, msg_e)
        
        logit_win,margin,elo = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
        y_win = batch.y.view(-1,1)                    # {0,1}
        y_setdiff = batch.set_diff.view(-1,1)
        y_elo_gain = batch.elo_gain.view(-1,1)
        # 5. Loss + update mémoire
        loss_win = F.binary_cross_entropy_with_logits(logit_win, y_win)
        loss_setdiff = F.mse_loss(margin, y_setdiff)
        loss_elo = F.mse_loss(elo, y_elo_gain)
        loss = loss_win + 0.5*loss_setdiff + 0.5*loss_elo
        total_loss += loss.item() * batch.num_events
        total_events += batch.num_events

        y_pred = logit_win.sigmoid().cpu()
        y_true = batch.y.view(-1,1).cpu()
        aps.append(average_precision_score(y_true, y_pred))
        
        all_preds.append(y_pred)
        all_trues.append(y_true)
        
        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)

    avg_loss = total_loss / total_events
    avg_ap   = float(torch.tensor(aps).mean())
    all_preds = torch.cat(all_preds).numpy()
    all_trues = torch.cat(all_trues).numpy()

    # calcul précision conditionnelle
    prec_at = {}
    for thr in thresholds:
        mask = all_preds > thr
        if mask.sum() > 0:
            prec = all_trues[mask].sum() / mask.sum()
        else:
            prec = float('nan')
        prec_at[f'Prec@{thr}'] = prec
    return avg_ap, avg_loss,prec_at

# === Entraînement ===
train_losses = []
train_aps  = []
val_losses = []
val_metrics = []
val_tresh_06 =  []
val_tresh_065 =  []
val_tresh_07 =  []
val_tresh_075 =  []
val_tresh_08 =  []
threshold = [0.6,0.65,0.7,0.75,0.8]
num_epochs = 50
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
for epoch in range(1, num_epochs + 1):
    loss = train()
    train_losses.append(loss)
    val_ap, val_loss,prec_at = evaluate(test_loader,threshold)
    val_metrics.append(val_ap)
    val_losses.append(val_loss)
    val_tresh_06.append(prec_at[f'Prec@{threshold[0]}'])
    val_tresh_065.append(prec_at[f'Prec@{threshold[1]}'])
    val_tresh_07.append(prec_at[f'Prec@{threshold[2]}'])
    val_tresh_075.append(prec_at[f'Prec@{threshold[3]}'])
    val_tresh_08.append(prec_at[f'Prec@{threshold[4]}'])
    print(f"[Epoch {epoch:02d}] Loss: {loss:.4f} | Test loss : {val_loss:.4f} | Test AP: {val_ap:.4f} | Precision {threshold[0]} : {prec_at[f'Prec@{threshold[0]}']:.4f} | Precision {threshold[1]} : {prec_at[f'Prec@{threshold[1]}']:.4f} | Precision {threshold[2]} : {prec_at[f'Prec@{threshold[2]}']:.4f} | Precision {threshold[3]} : {prec_at[f'Prec@{threshold[3]}']:.4f} | Precision {threshold[4]} : {prec_at[f'Prec@{threshold[4]}']:.4f}")
    scheduler.step()


import matplotlib.pyplot as plt
epochs = list(range(1, len(train_losses) + 1))
plt.figure()
plt.plot(epochs, train_losses, label="Train Loss")
plt.plot(epochs, val_losses, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Courbe de loss")
plt.legend()
plt.tight_layout()
plt.savefig("loss.png", dpi=300)
plt.close()

# Accuracy / AP
plt.figure()
plt.plot(epochs, val_metrics,   label="Validation AP")
plt.plot(epochs, val_tresh_06,   label="Validation AP 60%")
plt.plot(epochs, val_tresh_065,   label="Validation AP 65%")
plt.plot(epochs, val_tresh_07,   label="Validation AP 70%")
plt.plot(epochs, val_tresh_075,   label="Validation AP 75%")
plt.plot(epochs, val_tresh_08,   label="Validation AP 80%")
plt.xlabel("Epoch")
plt.ylabel("Average Precision")
plt.title("Courbe d'AP")
plt.legend()
plt.tight_layout()
plt.savefig("acc.png", dpi=300)
plt.close()

  return disable_fn(*args, **kwargs)
Training: 100%|██████████| 4210/4210 [01:49<00:00, 38.61batch/s]


[Epoch 01] Loss: 2.8100 | Test loss : 2.6034 | Test AP: 0.7187 | Precision 0.6 : 0.6623 | Precision 0.65 : 0.6845 | Precision 0.7 : 0.7428 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:50<00:00, 37.99batch/s]


[Epoch 02] Loss: 2.6185 | Test loss : 2.5427 | Test AP: 0.7334 | Precision 0.6 : 0.6819 | Precision 0.65 : 0.7181 | Precision 0.7 : 0.7677 | Precision 0.75 : 0.8279 | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:50<00:00, 37.93batch/s]


[Epoch 03] Loss: 2.5892 | Test loss : 2.5387 | Test AP: 0.7328 | Precision 0.6 : 0.7144 | Precision 0.65 : 0.7388 | Precision 0.7 : 0.7574 | Precision 0.75 : 0.8136 | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:49<00:00, 38.45batch/s]


[Epoch 04] Loss: 2.5639 | Test loss : 2.5298 | Test AP: 0.7340 | Precision 0.6 : 0.7070 | Precision 0.65 : 0.7366 | Precision 0.7 : 0.7563 | Precision 0.75 : 0.7942 | Precision 0.8 : 0.9149


Training: 100%|██████████| 4210/4210 [01:51<00:00, 37.70batch/s]


[Epoch 05] Loss: 2.5538 | Test loss : 2.5231 | Test AP: 0.7363 | Precision 0.6 : 0.7008 | Precision 0.65 : 0.7393 | Precision 0.7 : 0.7636 | Precision 0.75 : 0.8023 | Precision 0.8 : 0.8654


Training: 100%|██████████| 4210/4210 [01:54<00:00, 36.85batch/s]


[Epoch 06] Loss: 2.5279 | Test loss : 2.5236 | Test AP: 0.7388 | Precision 0.6 : 0.6953 | Precision 0.65 : 0.7270 | Precision 0.7 : 0.7489 | Precision 0.75 : 0.7758 | Precision 0.8 : 0.8444


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.00batch/s]


[Epoch 07] Loss: 2.5259 | Test loss : 2.5216 | Test AP: 0.7365 | Precision 0.6 : 0.7060 | Precision 0.65 : 0.7260 | Precision 0.7 : 0.7573 | Precision 0.75 : 0.7783 | Precision 0.8 : 0.8415


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.03batch/s]


[Epoch 08] Loss: 2.5178 | Test loss : 2.5108 | Test AP: 0.7376 | Precision 0.6 : 0.7063 | Precision 0.65 : 0.7287 | Precision 0.7 : 0.7683 | Precision 0.75 : 0.7873 | Precision 0.8 : 0.8425


Training: 100%|██████████| 4210/4210 [01:52<00:00, 37.36batch/s]


[Epoch 09] Loss: 2.5088 | Test loss : 2.5211 | Test AP: 0.7362 | Precision 0.6 : 0.7063 | Precision 0.65 : 0.7286 | Precision 0.7 : 0.7589 | Precision 0.75 : 0.7715 | Precision 0.8 : 0.8254


Training: 100%|██████████| 4210/4210 [01:53<00:00, 37.20batch/s]


[Epoch 10] Loss: 2.5084 | Test loss : 2.5267 | Test AP: 0.7365 | Precision 0.6 : 0.7072 | Precision 0.65 : 0.7343 | Precision 0.7 : 0.7545 | Precision 0.75 : 0.7920 | Precision 0.8 : 0.8355


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.27batch/s]


[Epoch 11] Loss: 2.5053 | Test loss : 2.5236 | Test AP: 0.7370 | Precision 0.6 : 0.7050 | Precision 0.65 : 0.7291 | Precision 0.7 : 0.7558 | Precision 0.75 : 0.7852 | Precision 0.8 : 0.8253


Training: 100%|██████████| 4210/4210 [01:51<00:00, 37.84batch/s]


[Epoch 12] Loss: 2.4963 | Test loss : 2.5188 | Test AP: 0.7387 | Precision 0.6 : 0.7113 | Precision 0.65 : 0.7368 | Precision 0.7 : 0.7692 | Precision 0.75 : 0.7857 | Precision 0.8 : 0.8275


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.08batch/s]


[Epoch 13] Loss: 2.4959 | Test loss : 2.5180 | Test AP: 0.7351 | Precision 0.6 : 0.7049 | Precision 0.65 : 0.7295 | Precision 0.7 : 0.7537 | Precision 0.75 : 0.7802 | Precision 0.8 : 0.8081


Training: 100%|██████████| 4210/4210 [01:52<00:00, 37.29batch/s]


[Epoch 14] Loss: 2.4950 | Test loss : 2.5155 | Test AP: 0.7398 | Precision 0.6 : 0.7005 | Precision 0.65 : 0.7303 | Precision 0.7 : 0.7605 | Precision 0.75 : 0.7816 | Precision 0.8 : 0.8411


Training: 100%|██████████| 4210/4210 [01:49<00:00, 38.41batch/s]


[Epoch 15] Loss: 2.4902 | Test loss : 2.5146 | Test AP: 0.7409 | Precision 0.6 : 0.6980 | Precision 0.65 : 0.7302 | Precision 0.7 : 0.7478 | Precision 0.75 : 0.7816 | Precision 0.8 : 0.8124


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.04batch/s]


[Epoch 16] Loss: 2.4876 | Test loss : 2.5180 | Test AP: 0.7357 | Precision 0.6 : 0.6975 | Precision 0.65 : 0.7340 | Precision 0.7 : 0.7639 | Precision 0.75 : 0.7790 | Precision 0.8 : 0.8125


Training: 100%|██████████| 4210/4210 [01:48<00:00, 38.75batch/s]


[Epoch 17] Loss: 2.4789 | Test loss : 2.5214 | Test AP: 0.7389 | Precision 0.6 : 0.7003 | Precision 0.65 : 0.7269 | Precision 0.7 : 0.7453 | Precision 0.75 : 0.7723 | Precision 0.8 : 0.8012


Training: 100%|██████████| 4210/4210 [01:51<00:00, 37.92batch/s]


[Epoch 18] Loss: 2.4804 | Test loss : 2.5156 | Test AP: 0.7356 | Precision 0.6 : 0.7042 | Precision 0.65 : 0.7284 | Precision 0.7 : 0.7569 | Precision 0.75 : 0.7735 | Precision 0.8 : 0.8115


Training: 100%|██████████| 4210/4210 [01:52<00:00, 37.35batch/s]


[Epoch 19] Loss: 2.4687 | Test loss : 2.5211 | Test AP: 0.7352 | Precision 0.6 : 0.6984 | Precision 0.65 : 0.7262 | Precision 0.7 : 0.7514 | Precision 0.75 : 0.7725 | Precision 0.8 : 0.8190


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.12batch/s]


[Epoch 20] Loss: 2.4712 | Test loss : 2.5134 | Test AP: 0.7345 | Precision 0.6 : 0.7065 | Precision 0.65 : 0.7328 | Precision 0.7 : 0.7600 | Precision 0.75 : 0.7714 | Precision 0.8 : 0.8009


Training: 100%|██████████| 4210/4210 [01:51<00:00, 37.67batch/s]


[Epoch 21] Loss: 2.4711 | Test loss : 2.5189 | Test AP: 0.7360 | Precision 0.6 : 0.7006 | Precision 0.65 : 0.7284 | Precision 0.7 : 0.7551 | Precision 0.75 : 0.7692 | Precision 0.8 : 0.7951


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.16batch/s]


[Epoch 22] Loss: 2.4655 | Test loss : 2.5187 | Test AP: 0.7353 | Precision 0.6 : 0.6982 | Precision 0.65 : 0.7242 | Precision 0.7 : 0.7519 | Precision 0.75 : 0.7820 | Precision 0.8 : 0.8129


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.07batch/s]


[Epoch 23] Loss: 2.4658 | Test loss : 2.5185 | Test AP: 0.7340 | Precision 0.6 : 0.7020 | Precision 0.65 : 0.7296 | Precision 0.7 : 0.7500 | Precision 0.75 : 0.7697 | Precision 0.8 : 0.8096


Training: 100%|██████████| 4210/4210 [01:53<00:00, 37.16batch/s]


[Epoch 24] Loss: 2.4619 | Test loss : 2.5189 | Test AP: 0.7352 | Precision 0.6 : 0.7034 | Precision 0.65 : 0.7265 | Precision 0.7 : 0.7542 | Precision 0.75 : 0.7787 | Precision 0.8 : 0.8037


Training: 100%|██████████| 4210/4210 [01:48<00:00, 38.80batch/s]


[Epoch 25] Loss: 2.4605 | Test loss : 2.5137 | Test AP: 0.7367 | Precision 0.6 : 0.7009 | Precision 0.65 : 0.7283 | Precision 0.7 : 0.7580 | Precision 0.75 : 0.7760 | Precision 0.8 : 0.7984


Training: 100%|██████████| 4210/4210 [01:49<00:00, 38.41batch/s]


[Epoch 26] Loss: 2.4593 | Test loss : 2.5164 | Test AP: 0.7366 | Precision 0.6 : 0.7013 | Precision 0.65 : 0.7242 | Precision 0.7 : 0.7520 | Precision 0.75 : 0.7755 | Precision 0.8 : 0.7952


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.00batch/s]


[Epoch 27] Loss: 2.4551 | Test loss : 2.5256 | Test AP: 0.7339 | Precision 0.6 : 0.6971 | Precision 0.65 : 0.7160 | Precision 0.7 : 0.7454 | Precision 0.75 : 0.7738 | Precision 0.8 : 0.7962


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.02batch/s]


[Epoch 28] Loss: 2.4539 | Test loss : 2.5217 | Test AP: 0.7321 | Precision 0.6 : 0.7003 | Precision 0.65 : 0.7217 | Precision 0.7 : 0.7581 | Precision 0.75 : 0.7615 | Precision 0.8 : 0.8051


Training: 100%|██████████| 4210/4210 [01:56<00:00, 36.21batch/s]


[Epoch 29] Loss: 2.4612 | Test loss : 2.5179 | Test AP: 0.7357 | Precision 0.6 : 0.6995 | Precision 0.65 : 0.7236 | Precision 0.7 : 0.7449 | Precision 0.75 : 0.7711 | Precision 0.8 : 0.7996


Training: 100%|██████████| 4210/4210 [01:49<00:00, 38.31batch/s]


[Epoch 30] Loss: 2.4554 | Test loss : 2.5188 | Test AP: 0.7351 | Precision 0.6 : 0.6995 | Precision 0.65 : 0.7176 | Precision 0.7 : 0.7409 | Precision 0.75 : 0.7691 | Precision 0.8 : 0.8168


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.01batch/s]


[Epoch 31] Loss: 2.4548 | Test loss : 2.5238 | Test AP: 0.7304 | Precision 0.6 : 0.7023 | Precision 0.65 : 0.7241 | Precision 0.7 : 0.7486 | Precision 0.75 : 0.7736 | Precision 0.8 : 0.8020


Training: 100%|██████████| 4210/4210 [01:51<00:00, 37.87batch/s]


[Epoch 32] Loss: 2.4525 | Test loss : 2.5163 | Test AP: 0.7326 | Precision 0.6 : 0.7033 | Precision 0.65 : 0.7194 | Precision 0.7 : 0.7511 | Precision 0.75 : 0.7767 | Precision 0.8 : 0.8016


Training: 100%|██████████| 4210/4210 [01:48<00:00, 38.74batch/s]


[Epoch 33] Loss: 2.4510 | Test loss : 2.5228 | Test AP: 0.7346 | Precision 0.6 : 0.7005 | Precision 0.65 : 0.7181 | Precision 0.7 : 0.7500 | Precision 0.75 : 0.7651 | Precision 0.8 : 0.8031


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.22batch/s]


[Epoch 34] Loss: 2.4502 | Test loss : 2.5233 | Test AP: 0.7323 | Precision 0.6 : 0.7014 | Precision 0.65 : 0.7192 | Precision 0.7 : 0.7497 | Precision 0.75 : 0.7663 | Precision 0.8 : 0.7950


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.11batch/s]


[Epoch 35] Loss: 2.4493 | Test loss : 2.5191 | Test AP: 0.7349 | Precision 0.6 : 0.7078 | Precision 0.65 : 0.7191 | Precision 0.7 : 0.7492 | Precision 0.75 : 0.7680 | Precision 0.8 : 0.7947


Training: 100%|██████████| 4210/4210 [01:48<00:00, 38.65batch/s]


[Epoch 36] Loss: 2.4527 | Test loss : 2.5174 | Test AP: 0.7335 | Precision 0.6 : 0.7016 | Precision 0.65 : 0.7177 | Precision 0.7 : 0.7473 | Precision 0.75 : 0.7697 | Precision 0.8 : 0.8000


Training: 100%|██████████| 4210/4210 [01:54<00:00, 36.70batch/s]


[Epoch 37] Loss: 2.4522 | Test loss : 2.5292 | Test AP: 0.7337 | Precision 0.6 : 0.6907 | Precision 0.65 : 0.7185 | Precision 0.7 : 0.7457 | Precision 0.75 : 0.7760 | Precision 0.8 : 0.8008


Training: 100%|██████████| 4210/4210 [01:50<00:00, 37.97batch/s]


[Epoch 38] Loss: 2.4461 | Test loss : 2.5271 | Test AP: 0.7338 | Precision 0.6 : 0.6949 | Precision 0.65 : 0.7114 | Precision 0.7 : 0.7523 | Precision 0.75 : 0.7642 | Precision 0.8 : 0.7934


Training: 100%|██████████| 4210/4210 [01:49<00:00, 38.42batch/s]


[Epoch 39] Loss: 2.4460 | Test loss : 2.5247 | Test AP: 0.7304 | Precision 0.6 : 0.6969 | Precision 0.65 : 0.7215 | Precision 0.7 : 0.7524 | Precision 0.75 : 0.7692 | Precision 0.8 : 0.7900


Training: 100%|██████████| 4210/4210 [01:50<00:00, 37.97batch/s]


[Epoch 40] Loss: 2.4503 | Test loss : 2.5222 | Test AP: 0.7351 | Precision 0.6 : 0.6955 | Precision 0.65 : 0.7186 | Precision 0.7 : 0.7409 | Precision 0.75 : 0.7690 | Precision 0.8 : 0.7969


Training: 100%|██████████| 4210/4210 [01:49<00:00, 38.29batch/s]


[Epoch 41] Loss: 2.4491 | Test loss : 2.5207 | Test AP: 0.7342 | Precision 0.6 : 0.6961 | Precision 0.65 : 0.7192 | Precision 0.7 : 0.7472 | Precision 0.75 : 0.7645 | Precision 0.8 : 0.7985


Training: 100%|██████████| 4210/4210 [01:52<00:00, 37.38batch/s]


[Epoch 42] Loss: 2.4514 | Test loss : 2.5256 | Test AP: 0.7341 | Precision 0.6 : 0.6981 | Precision 0.65 : 0.7197 | Precision 0.7 : 0.7481 | Precision 0.75 : 0.7669 | Precision 0.8 : 0.8047


Training: 100%|██████████| 4210/4210 [01:49<00:00, 38.32batch/s]


[Epoch 43] Loss: 2.4484 | Test loss : 2.5294 | Test AP: 0.7315 | Precision 0.6 : 0.6969 | Precision 0.65 : 0.7151 | Precision 0.7 : 0.7426 | Precision 0.75 : 0.7658 | Precision 0.8 : 0.8038


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.03batch/s]


[Epoch 44] Loss: 2.4465 | Test loss : 2.5191 | Test AP: 0.7374 | Precision 0.6 : 0.7001 | Precision 0.65 : 0.7175 | Precision 0.7 : 0.7549 | Precision 0.75 : 0.7706 | Precision 0.8 : 0.8036


Training: 100%|██████████| 4210/4210 [01:50<00:00, 38.19batch/s]


[Epoch 45] Loss: 2.4503 | Test loss : 2.5238 | Test AP: 0.7330 | Precision 0.6 : 0.6987 | Precision 0.65 : 0.7190 | Precision 0.7 : 0.7398 | Precision 0.75 : 0.7639 | Precision 0.8 : 0.7988


Training: 100%|██████████| 4210/4210 [01:49<00:00, 38.31batch/s]


[Epoch 46] Loss: 2.4435 | Test loss : 2.5217 | Test AP: 0.7354 | Precision 0.6 : 0.6949 | Precision 0.65 : 0.7242 | Precision 0.7 : 0.7495 | Precision 0.75 : 0.7637 | Precision 0.8 : 0.7892


Training: 100%|██████████| 4210/4210 [01:53<00:00, 36.93batch/s]


[Epoch 47] Loss: 2.4495 | Test loss : 2.5261 | Test AP: 0.7302 | Precision 0.6 : 0.6973 | Precision 0.65 : 0.7207 | Precision 0.7 : 0.7404 | Precision 0.75 : 0.7581 | Precision 0.8 : 0.7988


Training: 100%|██████████| 4210/4210 [01:50<00:00, 37.93batch/s]


[Epoch 48] Loss: 2.4485 | Test loss : 2.5227 | Test AP: 0.7315 | Precision 0.6 : 0.6936 | Precision 0.65 : 0.7150 | Precision 0.7 : 0.7459 | Precision 0.75 : 0.7685 | Precision 0.8 : 0.8012


Training: 100%|██████████| 4210/4210 [01:49<00:00, 38.39batch/s]


[Epoch 49] Loss: 2.4486 | Test loss : 2.5184 | Test AP: 0.7344 | Precision 0.6 : 0.6942 | Precision 0.65 : 0.7175 | Precision 0.7 : 0.7524 | Precision 0.75 : 0.7690 | Precision 0.8 : 0.7962


Training: 100%|██████████| 4210/4210 [01:49<00:00, 38.28batch/s]


[Epoch 50] Loss: 2.4498 | Test loss : 2.5216 | Test AP: 0.7330 | Precision 0.6 : 0.6944 | Precision 0.65 : 0.7169 | Precision 0.7 : 0.7519 | Precision 0.75 : 0.7699 | Precision 0.8 : 0.7973


In [7]:
import torch
import torch.nn as nn
from torch.nn import Linear
from sklearn.metrics import average_precision_score, roc_auc_score

from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import (
    IdentityMessage,
    LastAggregator,
    LastNeighborLoader,
)
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.data import TemporalData

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Paramètres
embedding_dim = memory_dim = time_dim = 1024
learning_rate = 0.00006


# === Modules ===

class MultiLayerTimeAwareGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 msg_dim, time_enc, num_layers=3, heads=2, dropout=0.1):
        """
        - in_channels   : dim. entrée (memory_dim)
        - hidden_channels : dim. intermédiaire de chaque tête (out_channels/2 si heads=2)
        - out_channels  : dim. de sortie finale
        - msg_dim       : dim. des features d'arête
        - time_enc      : instance du TimeEncoder de TGNMemory
        - num_layers    : nombre de blocs TransformerConv à empiler
        """
        super().__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels

        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_c  = in_channels if i == 0 else hidden_channels * heads
            out_c = out_channels // heads if i == num_layers-1 else hidden_channels
            self.convs.append(
                TransformerConv(
                    in_c, out_c, heads=heads,
                    dropout=dropout, edge_dim=edge_dim
                )
            )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, last_update, edge_index, t, msg):
        # calcule rel_t et rel_t_enc une seule fois
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)

        # propagation à travers les couches empilées
        for conv in self.convs:
            x_new = conv(x, edge_index, edge_attr)
            x = (x + self.dropout(x_new)).relu()   # skip + dropout + relu

        return x


class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.lin_src = Linear(in_channels, in_channels)
        self.lin_dst = Linear(in_channels, in_channels)
        self.lin_final = Linear(in_channels, 1)

    def forward(self, z_src, z_dst):
        h = self.lin_src(z_src) + self.lin_dst(z_dst)
        h = h.relu()
        return self.lin_final(h)

# === Instanciations ===
import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiTaskPredictor(nn.Module):
    def __init__(self, embed_dim, hidden_dims=[512,256], dropout=0.3):
        super().__init__()
        # trunk MLP partagé
        layers = []
        dims = [2*embed_dim] + hidden_dims
        for i in range(len(dims)-1):
            layers += [
                nn.Linear(dims[i], dims[i+1]),
                nn.LayerNorm(dims[i+1]),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout),
            ]
        self.trunk = nn.Sequential(*layers)

        # tête 1 : victoire binaire
        self.head_win     = nn.Linear(hidden_dims[-1], 1)
        # tête 2 : diff de sets (régression)
        self.head_setdiff = nn.Linear(hidden_dims[-1], 1)
        # tête 3 : gain d'elo (régression)
        self.head_elogain = nn.Linear(hidden_dims[-1], 1)

    def forward(self, z_src, z_dst):
        h = torch.cat([z_src, z_dst], dim=-1)
        h = self.trunk(h)
        logit_win     = self.head_win(h).view(-1)      # (batch,)
        pred_setdiff  = self.head_setdiff(h).view(-1)  # (batch,)
        pred_elogain  = self.head_elogain(h).view(-1)  # (batch,)
        return logit_win, pred_setdiff, pred_elogain



num_nodes = max(df[["j1_enc", "j2_enc"]].max()) + 1
msg_dim = train_data.msg.size(-1)

memory = TGNMemory(
    num_nodes=num_nodes,
    raw_msg_dim=msg_dim,
    memory_dim=memory_dim,
    time_dim=time_dim,
    message_module=IdentityMessage(msg_dim, memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

gnn = MultiLayerTimeAwareGNN(memory_dim,32, embedding_dim, msg_dim, memory.time_enc,3,32,0.3).to(device)
predictor = MultiTaskPredictor(
    embed_dim=embedding_dim,
    hidden_dims=[512,256],
    dropout=0.3
).to(device)



optimizer = torch.optim.AdamW(
    list(memory.parameters()) + list(gnn.parameters()) + list(predictor.parameters()),
    lr=learning_rate,weight_decay=1e-4
)
criterion = torch.nn.BCEWithLogitsLoss()

# === Loaders ===


neighbor_loader = LastNeighborLoader(num_nodes=num_nodes, size=20, device=device)

assoc = torch.empty(num_nodes, dtype=torch.long, device=device)

# === Train ===

def train():
    memory.train()
    gnn.train()
    predictor.train()
    memory.reset_state()
    neighbor_loader.reset_state()

    total_loss = 0
    # on wrappe train_loader avec tqdm
    for batch in tqdm(train_loader, desc="Training", unit="batch"):
        optimizer.zero_grad()
        batch = batch.to(device)

        # 1. Construire le sous‑graphe voisin
        n_id, edge_index, e_id = neighbor_loader(batch.n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        # 2. Récupérer la mémoire et l'embedding
        z, last_update = memory(n_id)

        # 3. PAS batch.t/msg ! mais train_data.t/msg
        t_e = train_data.t[e_id].to(device)
        msg_e = train_data.msg[e_id].to(device)
        z = gnn(z, last_update, edge_index, t_e, msg_e)

        # 4. Link prediction
        logit_win, setdiff_pred, elogain_pred = predictor(z[assoc[batch.src]], z[assoc[batch.dst]])
        y_win       = batch.y.view(-1)                    # {0,1}
        y_setdiff   = batch.set_diff.view(-1).to(device)  # int
        y_elogain   = batch.elo_gain.view(-1).to(device)  # float

        # 5. Loss + update mémoire
        loss_win      = F.binary_cross_entropy_with_logits(logit_win, y_win)
        loss_setdiff  = F.mse_loss(setdiff_pred, y_setdiff)
        loss_elogain  = F.mse_loss(elogain_pred, y_elogain)
        loss = loss_win + 0.5*loss_setdiff + 0.5*loss_elogain
        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)

        # 6. Backprop
        loss.backward()
        optimizer.step()
        memory.detach()
        total_loss += loss.item() * batch.num_events

    return total_loss / train_data.num_events

@torch.no_grad()

def evaluate(loader,thresholds):
    memory.eval(); gnn.eval(); predictor.eval()
    total_loss = 0.0
    total_events = 0
    aps = []
    all_preds = []
    all_trues = []

    neighbor_loader.reset_state()
    for batch in loader:
        batch = batch.to(device)
        n_id, edge_index, e_id = neighbor_loader(batch.n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        z, last_update = memory(n_id)
        t_e   = test_data.t[e_id].to(device)
        msg_e = test_data.msg[e_id].to(device)
        z = gnn(z, last_update, edge_index, t_e, msg_e)

        logit_win, setdiff_pred, elogain_pred = predictor(z[assoc[batch.src]], z[assoc[batch.dst]])
        y_win       = batch.y.view(-1)                    # {0,1}
        y_setdiff   = batch.set_diff.view(-1).to(device)  # int
        y_elogain   = batch.elo_gain.view(-1).to(device)  # float

        # 5. Loss + update mémoire
        loss_win      = F.binary_cross_entropy_with_logits(logit_win, y_win)
        loss_setdiff  = F.mse_loss(setdiff_pred, y_setdiff)
        loss_elogain  = F.mse_loss(elogain_pred, y_elogain)
        loss = loss_win + 0.5*loss_setdiff + 0.5*loss_elogain
        total_loss += loss.item() * batch.num_events
        total_events += batch.num_events

        y_pred = y_win.sigmoid().cpu()
        y_true = batch.y.view(-1,1).cpu()
        aps.append(average_precision_score(y_true, y_pred))
        
        all_preds.append(y_pred)
        all_trues.append(y_true)
        
        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)

    avg_loss = total_loss / total_events
    avg_ap   = float(torch.tensor(aps).mean())
    all_preds = torch.cat(all_preds).numpy()
    all_trues = torch.cat(all_trues).numpy()

    # calcul précision conditionnelle
    prec_at = {}
    for thr in thresholds:
        mask = all_preds > thr
        if mask.sum() > 0:
            prec = all_trues[mask].sum() / mask.sum()
        else:
            prec = float('nan')
        prec_at[f'Prec@{thr}'] = prec
    return avg_ap, avg_loss,prec_at

# === Entraînement ===
train_losses = []
train_aps  = []
val_losses = []
val_metrics = []
val_tresh_05 =  []
val_tresh_055 =  []
val_tresh_06 =  []
val_tresh_065 =  []
val_tresh_07 =  []
val_tresh_075 =  []
val_tresh_08 =  []
threshold = [0.5,0.55,0.6,0.65,0.7,0.75,0.8]
num_epochs = 100
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
for epoch in range(1, num_epochs + 1):
    loss = train()
    train_losses.append(loss)
    val_ap, val_loss,prec_at = evaluate(test_loader,threshold)
    val_metrics.append(val_ap)
    val_losses.append(val_loss)
    val_tresh_05.append(prec_at[f'Prec@{threshold[0]}'])
    val_tresh_055.append(prec_at[f'Prec@{threshold[1]}'])
    val_tresh_06.append(prec_at[f'Prec@{threshold[2]}'])
    val_tresh_065.append(prec_at[f'Prec@{threshold[3]}'])
    val_tresh_07.append(prec_at[f'Prec@{threshold[4]}'])
    val_tresh_075.append(prec_at[f'Prec@{threshold[5]}'])
    val_tresh_08.append(prec_at[f'Prec@{threshold[6]}'])
    print(f"[Epoch {epoch:02d}] Loss: {loss:.4f} | Test loss : {val_loss:.4f} | Test AP: {val_ap:.4f} | Precision {threshold[0]} : {prec_at[f'Prec@{threshold[0]}']:.4f} | Precision {threshold[1]} : {prec_at[f'Prec@{threshold[1]}']:.4f} | Precision {threshold[2]} : {prec_at[f'Prec@{threshold[2]}']:.4f} | Precision {threshold[3]} : {prec_at[f'Prec@{threshold[3]}']:.4f} | Precision {threshold[4]} : {prec_at[f'Prec@{threshold[4]}']:.4f} | Precision {threshold[5]} : {prec_at[f'Prec@{threshold[5]}']:.4f} | Precision {threshold[6]} : {prec_at[f'Prec@{threshold[6]}']:.4f}")
    scheduler.step()


import matplotlib.pyplot as plt
epochs = list(range(1, len(train_losses) + 1))
plt.figure()
plt.plot(epochs, train_losses, label="Train Loss")
plt.plot(epochs, val_losses, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Courbe de loss")
plt.legend()
plt.tight_layout()
plt.savefig("loss.png", dpi=300)
plt.close()

# Accuracy / AP
plt.figure()
plt.plot(epochs, val_metrics,   label="Validation AP")
plt.plot(epochs, val_tresh_06,   label="Validation AP 60%")
plt.plot(epochs, val_tresh_065,   label="Validation AP 65%")
plt.plot(epochs, val_tresh_07,   label="Validation AP 70%")
plt.plot(epochs, val_tresh_075,   label="Validation AP 75%")
plt.plot(epochs, val_tresh_08,   label="Validation AP 80%")
plt.xlabel("Epoch")
plt.ylabel("Average Precision")
plt.title("Courbe d'AP")
plt.legend()
plt.tight_layout()
plt.savefig("acc.png", dpi=300)
plt.close()

  return disable_fn(*args, **kwargs)
Training: 100%|██████████| 4210/4210 [01:48<00:00, 38.74batch/s]


[Epoch 01] Loss: 2.6016 | Test loss : 2.5350 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:47<00:00, 39.23batch/s]


[Epoch 02] Loss: 2.5139 | Test loss : 2.5340 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:48<00:00, 38.83batch/s]


[Epoch 03] Loss: 2.5014 | Test loss : 2.5201 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:47<00:00, 39.15batch/s]


[Epoch 04] Loss: 2.4986 | Test loss : 2.5132 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:47<00:00, 39.16batch/s]


[Epoch 05] Loss: 2.4901 | Test loss : 2.5161 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:47<00:00, 39.31batch/s]


[Epoch 06] Loss: 2.4700 | Test loss : 2.5032 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:49<00:00, 38.54batch/s]


[Epoch 07] Loss: 2.4653 | Test loss : 2.5043 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:44<00:00, 40.15batch/s]


[Epoch 08] Loss: 2.4636 | Test loss : 2.5049 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:48<00:00, 38.68batch/s]


[Epoch 09] Loss: 2.4590 | Test loss : 2.5033 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:47<00:00, 39.18batch/s]


[Epoch 10] Loss: 2.4584 | Test loss : 2.5060 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:46<00:00, 39.40batch/s]


[Epoch 11] Loss: 2.4471 | Test loss : 2.5053 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:47<00:00, 39.15batch/s]


[Epoch 12] Loss: 2.4435 | Test loss : 2.5116 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:46<00:00, 39.38batch/s]


[Epoch 13] Loss: 2.4406 | Test loss : 2.5038 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:46<00:00, 39.44batch/s]


[Epoch 14] Loss: 2.4425 | Test loss : 2.4984 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:46<00:00, 39.52batch/s]


[Epoch 15] Loss: 2.4366 | Test loss : 2.5075 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:49<00:00, 38.53batch/s]


[Epoch 16] Loss: 2.4323 | Test loss : 2.5047 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:47<00:00, 39.10batch/s]


[Epoch 17] Loss: 2.4297 | Test loss : 2.5073 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:46<00:00, 39.66batch/s]


[Epoch 18] Loss: 2.4260 | Test loss : 2.5046 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:46<00:00, 39.41batch/s]


[Epoch 19] Loss: 2.4242 | Test loss : 2.5037 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:46<00:00, 39.50batch/s]


[Epoch 20] Loss: 2.4207 | Test loss : 2.5006 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:46<00:00, 39.71batch/s]


[Epoch 21] Loss: 2.4181 | Test loss : 2.5030 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training: 100%|██████████| 4210/4210 [01:48<00:00, 38.66batch/s]


[Epoch 22] Loss: 2.4155 | Test loss : 2.5105 | Test AP: 1.0000 | Precision 0.5 : 1.0000 | Precision 0.55 : 1.0000 | Precision 0.6 : 1.0000 | Precision 0.65 : 1.0000 | Precision 0.7 : 1.0000 | Precision 0.75 : nan | Precision 0.8 : nan


Training:  16%|█▌        | 657/4210 [00:18<01:39, 35.88batch/s]


KeyboardInterrupt: 