# Create the Final Dataset

**INPUT**: "./data/0cleanDataset.csv"

**OUTPUT**: "./data/1finalDataset.csv"

This notebook will take the clean and randomized dataset and calculate stadistics for each player (ELO rankings, BreakPoints, etc.).

First, of course, we need to import some libraries.

In [1]:
%matplotlib inline
import pandas as pd
from tqdm import tqdm
pd.set_option('display.max_columns', None)

In [2]:
import random
import numpy as np
import torch
def seed_everything(seed: int = 41):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # si multi-GPU

    # Pour forcer la reproductibilité sur CUDA (moins perf mais stable)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# Exemple :
seed_everything(41)


In [3]:
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
import json

class History:
    def __init__(self, outdir="runs/curves", hparams=None):
        """
        hparams: dict optionnel avec les hyperparamètres 
                 (ex: {"lr":4e-4,"weight_decay":1e-5,"memory_dim":128,...})
        """
        self.outdir = Path(outdir)
        self.outdir.mkdir(parents=True, exist_ok=True)
        self.rows = []
        self.prec_at_rows = []
        self.detail_rows = []
        self.hparams = hparams if hparams is not None else {}

    def log_epoch(self, epoch,
                  train_loss, train_ap, train_prec,
                  val_loss,   val_ap,   val_prec,
                  prec_at=None):
        self.rows.append({
            "epoch": epoch,
            "train_loss": float(train_loss),
            "train_ap": float(train_ap),
            "train_prec@0.5": float(train_prec),
            "val_loss": float(val_loss),
            "val_ap": float(val_ap),
            "val_prec@0.5": float(val_prec),
        })
        if prec_at is not None:
            self.prec_at_rows.append(
                {"epoch": epoch, **{f"@{k}": float(v) for k, v in prec_at.items()}}
            )

    def save_tables(self):
        pd.DataFrame(self.rows).to_csv(self.outdir / "metrics_history.csv", index=False)
        if self.prec_at_rows:
            pd.DataFrame(self.prec_at_rows).to_csv(self.outdir / "precision_at_history.csv", index=False)
        if self.detail_rows:
            pd.DataFrame(self.detail_rows).to_csv(self.outdir / "predicted_dates.csv", index=False)
        # 💾 Sauvegarde aussi les hyperparamètres dans un JSON
        if self.hparams:
            with open(self.outdir / "hparams.json","w") as f:
                json.dump(self.hparams,f,indent=2)

    def _plot_and_save(self, x, y, ylabel, fname):
        plt.figure()
        plt.plot(x, y)
        plt.xlabel("epoch")
        plt.ylabel(ylabel)
        plt.grid(True, linestyle="--", linewidth=0.5)
        plt.tight_layout()
        plt.savefig(self.outdir / fname, dpi=200)
        plt.close()

    def save_plots(self):
        df = pd.DataFrame(self.rows)
        x = df["epoch"].values
        self._plot_and_save(x, df["train_loss"].values, "train_loss", "curve_train_loss.png")
        self._plot_and_save(x, df["val_loss"].values,   "val_loss",   "curve_val_loss.png")
        self._plot_and_save(x, df["train_ap"].values,   "train_AP",   "curve_train_ap.png")
        self._plot_and_save(x, df["val_ap"].values,     "val_AP",     "curve_val_ap.png")
        self._plot_and_save(x, df["train_prec@0.5"].values, "train_precision@0.5", "curve_train_prec.png")
        self._plot_and_save(x, df["val_prec@0.5"].values,   "val_precision@0.5",   "curve_val_prec.png")

    def save_all(self):
        self.save_tables()
        self.save_plots()


In [4]:
clean_data = pd.read_csv("./data/0cleanDataset.csv")

In [5]:
clean_data 

Unnamed: 0,tourney_id,tourney_name,surface,draw_size,tourney_level,tourney_date,match_num,p1_id,p1_seed,p1_entry,p1_name,p1_hand,p1_ht,p1_ioc,p1_age,p2_id,p2_seed,p2_entry,p2_name,p2_hand,p2_ht,p2_ioc,p2_age,score,best_of,round,minutes,p1_ace,p1_df,p1_svpt,p1_1stIn,p1_1stWon,p1_2ndWon,p1_SvGms,p1_bpSaved,p1_bpFaced,p2_ace,p2_df,p2_svpt,p2_1stIn,p2_1stWon,p2_2ndWon,p2_SvGms,p2_bpSaved,p2_bpFaced,p1_rank,p1_rank_points,p2_rank,p2_rank_points,RESULT
0,1991-301,Auckland,Hard,32,A,19910107,1,101142,1.0,,Emilio Sanchez,R,180.0,ESP,25.6,101746,,,Renzo Furlan,R,175.0,ITA,20.6,6-4 6-1,3,R32,63.0,1.0,0.0,53.0,37.0,30.0,7.0,9.0,5.0,6.0,3.0,0.0,46.0,30.0,17.0,7.0,8.0,2.0,6.0,9.0,1487.0,78.0,459.0,1
1,1991-301,Auckland,Hard,32,A,19910107,2,101613,,Q,Malivai Washington,R,180.0,USA,21.5,100587,,WC,Steve Guy,R,188.0,NZL,31.8,6-3 6-2,3,R32,72.0,5.0,1.0,56.0,25.0,17.0,20.0,9.0,1.0,2.0,4.0,7.0,56.0,30.0,22.0,6.0,8.0,7.0,11.0,94.0,371.0,220.0,114.0,1
2,1991-301,Auckland,Hard,32,A,19910107,3,101601,,WC,Brett Steven,R,185.0,NZL,21.6,101179,,,Jean Philippe Fleurian,R,185.0,FRA,25.3,2-6 6-1 6-2,3,R32,101.0,1.0,3.0,68.0,43.0,24.0,14.0,11.0,4.0,8.0,2.0,4.0,80.0,55.0,35.0,16.0,12.0,2.0,4.0,212.0,116.0,77.0,468.0,0
3,1991-301,Auckland,Hard,32,A,19910107,4,101117,,,Eric Jelen,R,180.0,GER,25.8,101332,8.0,,Gilad Bloom,L,173.0,ISR,23.8,6-3 1-6 6-4,3,R32,108.0,0.0,1.0,82.0,55.0,35.0,14.0,13.0,6.0,10.0,3.0,2.0,96.0,61.0,38.0,15.0,13.0,8.0,12.0,65.0,502.0,72.0,483.0,1
4,1991-301,Auckland,Hard,32,A,19910107,5,101901,,Q,Chuck Adams,R,185.0,USA,19.7,101735,3.0,,Richard Fromberg,R,196.0,AUS,20.6,6-3 6-4,3,R32,65.0,4.0,4.0,65.0,46.0,34.0,12.0,10.0,2.0,2.0,1.0,3.0,49.0,25.0,21.0,12.0,9.0,4.0,6.0,190.0,142.0,28.0,876.0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95370,2024-M-DC-2024-WG2-PO-TOG-INA-01,Davis Cup WG2 PO: TOG vs INA,Hard,4,D,20240203,4,133933,,,Thomas Yaka Kofi Setodji,R,183.0,TOG,28.2,207134,,,Fitriadi M Rifqi,R,175.0,INA,25.0,6-2 6-3,3,RR,87.0,2.0,1.0,50.0,35.0,19.0,5.0,8.0,5.0,9.0,1.0,2.0,52.0,32.0,25.0,11.0,9.0,2.0,3.0,819.0,24.0,569.0,64.0,0
95371,2024-M-DC-2024-WG2-PO-TUN-CRC-01,Davis Cup WG2 PO: TUN vs CRC,Hard,4,D,20240202,1,132374,,,Jesse Flores,R,188.0,CRC,28.8,121411,,,Moez Echargui,R,178.0,TUN,31.0,6-2 6-3,3,RR,65.0,3.0,2.0,49.0,30.0,17.0,9.0,8.0,2.0,5.0,4.0,2.0,56.0,35.0,30.0,11.0,9.0,4.0,4.0,900.0,18.0,279.0,205.0,0
95372,2024-M-DC-2024-WG2-PO-URU-MDA-01,Davis Cup WG2 PO: URU vs MDA,Clay,4,D,20240203,1,208364,,,Franco Roncadelli,L,185.0,URU,23.9,209943,,,Ilya Snitari,R,188.0,MDA,21.8,4-6 6-1 6-1,3,RR,137.0,1.0,2.0,74.0,50.0,31.0,13.0,12.0,5.0,9.0,0.0,5.0,92.0,52.0,24.0,18.0,12.0,8.0,16.0,616.0,55.0,740.0,34.0,1
95373,2024-M-DC-2024-WG2-PO-URU-MDA-01,Davis Cup WG2 PO: URU vs MDA,Clay,4,D,20240203,4,208364,,,Franco Roncadelli,L,185.0,URU,23.9,105430,,,Radu Albot,R,175.0,MDA,34.2,6-3 6-1,3,RR,95.0,1.0,1.0,82.0,66.0,33.0,6.0,8.0,6.0,11.0,2.0,1.0,50.0,30.0,19.0,14.0,8.0,3.0,4.0,616.0,55.0,136.0,489.0,0


In [6]:
clean_data.columns

Index(['tourney_id', 'tourney_name', 'surface', 'draw_size', 'tourney_level',
       'tourney_date', 'match_num', 'p1_id', 'p1_seed', 'p1_entry', 'p1_name',
       'p1_hand', 'p1_ht', 'p1_ioc', 'p1_age', 'p2_id', 'p2_seed', 'p2_entry',
       'p2_name', 'p2_hand', 'p2_ht', 'p2_ioc', 'p2_age', 'score', 'best_of',
       'round', 'minutes', 'p1_ace', 'p1_df', 'p1_svpt', 'p1_1stIn',
       'p1_1stWon', 'p1_2ndWon', 'p1_SvGms', 'p1_bpSaved', 'p1_bpFaced',
       'p2_ace', 'p2_df', 'p2_svpt', 'p2_1stIn', 'p2_1stWon', 'p2_2ndWon',
       'p2_SvGms', 'p2_bpSaved', 'p2_bpFaced', 'p1_rank', 'p1_rank_points',
       'p2_rank', 'p2_rank_points', 'RESULT'],
      dtype='object')

I want to use the same functions/API for calculating the stadistics of players as when we predict. Therefore, this will be slighly more complicated. 

If you want to see an easier way, my last code (DataAnalysis.ipynb) showed an easier way to calculate the stats (for example all_data_filtered["winner_rank_points"] - all_data_filtered["loser_rank_points"]).

In [7]:
from utils.updateStats import getStats, updateStats, createStats

final_dataset = []
prev_stats = createStats()

# Iterate through each row in clean_data
for index, row in tqdm(clean_data.iterrows(), total=len(clean_data)):
    player1 = {
        "ID": row["p1_id"],
        "ATP_POINTS": row["p1_rank_points"],
        "ATP_RANK": row["p1_rank"],
        "AGE": row["p1_age"],
        "HEIGHT": row["p1_ht"],
    }

    player2 = {
        "ID": row["p2_id"],
        "ATP_POINTS": row["p2_rank_points"],
        "ATP_RANK": row["p2_rank"],
        "AGE": row["p2_age"],
        "HEIGHT": row["p2_ht"],
    }

    match = {
        "BEST_OF": row["best_of"],
        "DRAW_SIZE": row["draw_size"],
        "SURFACE": row["surface"],
        "TOURNEY_LEVEL": row["tourney_level"],
        "DATE" : row["tourney_date"]


    }

    ########## GET STATS ##########
    # Call getStatsPlayers function
    output = getStats(player1, player2, match, prev_stats)

    # Append sorted stats to final dataset
    match_data = dict(sorted(output.items()))
    match_data["p1_id"] = row["p1_id"]
    match_data["p2_id"] = row["p2_id"]
    match_data["RESULT"] = row.RESULT
    match_data["SCORE"] = row.score
    
    final_dataset.append(match_data)

    ########## UPDATE STATS ##########
    prev_stats = updateStats(row, prev_stats)


# Convert final dataset to DataFrame
final_dataset = pd.DataFrame(final_dataset)

100%|██████████| 95375/95375 [00:49<00:00, 1919.03it/s]


In [8]:
final_dataset = final_dataset
final_dataset

Unnamed: 0,AGE_DIFF,ATP_POINTS_DIFF,ATP_RANK_DIFF,BEST_OF,DRAW_SIZE,ELO_DIFF,ELO_GRAD_LAST_100_DIFF,ELO_GRAD_LAST_10_DIFF,ELO_GRAD_LAST_200_DIFF,ELO_GRAD_LAST_25_DIFF,ELO_GRAD_LAST_3_DIFF,ELO_GRAD_LAST_50_DIFF,ELO_GRAD_LAST_5_DIFF,ELO_SURFACE_DIFF,H2H_DIFF,H2H_SURFACE_DIFF,HEIGHT_DIFF,N_GAMES_DIFF,P_1ST_IN_LAST_100_DIFF,P_1ST_IN_LAST_10_DIFF,P_1ST_IN_LAST_200_DIFF,P_1ST_IN_LAST_25_DIFF,P_1ST_IN_LAST_3_DIFF,P_1ST_IN_LAST_50_DIFF,P_1ST_IN_LAST_5_DIFF,P_1ST_WON_LAST_100_DIFF,P_1ST_WON_LAST_10_DIFF,P_1ST_WON_LAST_200_DIFF,P_1ST_WON_LAST_25_DIFF,P_1ST_WON_LAST_3_DIFF,P_1ST_WON_LAST_50_DIFF,P_1ST_WON_LAST_5_DIFF,P_2ND_WON_LAST_100_DIFF,P_2ND_WON_LAST_10_DIFF,P_2ND_WON_LAST_200_DIFF,P_2ND_WON_LAST_25_DIFF,P_2ND_WON_LAST_3_DIFF,P_2ND_WON_LAST_50_DIFF,P_2ND_WON_LAST_5_DIFF,P_ACE_LAST_100_DIFF,P_ACE_LAST_10_DIFF,P_ACE_LAST_200_DIFF,P_ACE_LAST_25_DIFF,P_ACE_LAST_3_DIFF,P_ACE_LAST_50_DIFF,P_ACE_LAST_5_DIFF,P_BP_SAVED_LAST_100_DIFF,P_BP_SAVED_LAST_10_DIFF,P_BP_SAVED_LAST_200_DIFF,P_BP_SAVED_LAST_25_DIFF,P_BP_SAVED_LAST_3_DIFF,P_BP_SAVED_LAST_50_DIFF,P_BP_SAVED_LAST_5_DIFF,P_DF_LAST_100_DIFF,P_DF_LAST_10_DIFF,P_DF_LAST_200_DIFF,P_DF_LAST_25_DIFF,P_DF_LAST_3_DIFF,P_DF_LAST_50_DIFF,P_DF_LAST_5_DIFF,TOURNEY_LEVEL,WIN_LAST_100_DIFF,WIN_LAST_10_DIFF,WIN_LAST_200_DIFF,WIN_LAST_25_DIFF,WIN_LAST_3_DIFF,WIN_LAST_50_DIFF,WIN_LAST_5_DIFF,date,p1_elo,p2_elo,p1_id,p2_id,RESULT,SCORE
0,5.0,1028.0,-69.0,3,32,0.000000,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,0.000000,0,0,5.0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,A,0,0,0,0,0,0,0,19910107,1500.000000,1500.000000,101142,101746,1,6-4 6-1
1,-10.3,257.0,-126.0,3,32,0.000000,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,0.000000,0,0,-8.0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,A,0,0,0,0,0,0,0,19910107,1500.000000,1500.000000,101613,100587,1,6-3 6-2
2,-3.7,-352.0,135.0,3,32,0.000000,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,0.000000,0,0,0.0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,A,0,0,0,0,0,0,0,19910107,1500.000000,1500.000000,101601,101179,0,2-6 6-1 6-2
3,2.0,19.0,-7.0,3,32,0.000000,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,0.000000,0,0,7.0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,A,0,0,0,0,0,0,0,19910107,1500.000000,1500.000000,101117,101332,1,6-3 1-6 6-4
4,-0.9,-734.0,162.0,3,32,0.000000,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,0.000000,0,0,-11.0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,A,0,0,0,0,0,0,0,19910107,1500.000000,1500.000000,101901,101735,1,6-3 6-4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95370,3.2,-40.0,250.0,3,4,0.000000,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,0.000000,0,0,8.0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,D,0,0,0,0,0,0,0,20240203,1500.000000,1500.000000,133933,207134,0,6-2 6-3
95371,-2.2,-187.0,621.0,3,4,11.415089,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,0.000000,0,0,10.0,-1,-53.820988,-53.820988,-53.820988,-53.820988,-53.820988,-53.820988,-53.820988,-69.954545,-69.954545,-69.954545,-69.954545,-69.954545,-69.954545,-69.954545,-40.040541,-40.040541,-40.040541,-40.040541,-40.040541,-40.040541,-40.040541,-3.203704,-3.203704,-3.203704,-3.203704,-3.203704,-3.203704,-3.203704,-49.500000,-49.500000,-49.500000,-49.500000,-49.500000,-49.500000,-49.500000,-3.203704,-3.203704,-3.203704,-3.203704,-3.203704,-3.203704,-3.203704,D,0,0,0,0,0,0,0,20240202,1500.000000,1488.584911,132374,121411,0,6-2 6-3
95372,2.1,21.0,-124.0,3,4,21.484266,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,22.190451,0,0,-3.0,2,72.147059,72.147059,72.147059,72.147059,72.147059,72.147059,72.147059,63.450893,63.450893,63.450893,63.450893,63.450893,63.450893,63.450893,47.714286,47.714286,47.714286,47.714286,47.714286,47.714286,47.714286,0.713235,0.713235,0.713235,0.713235,0.713235,0.713235,0.713235,70.333333,70.333333,70.333333,70.333333,70.333333,70.333333,70.333333,2.477941,2.477941,2.477941,2.477941,2.477941,2.477941,2.477941,D,0,0,0,0,0,0,0,20240203,1521.484266,1500.000000,208364,209943,1,4-6 6-1 6-1
95373,-10.3,-434.0,480.0,3,4,21.029457,0.0,0.0,0.0,0.0,2.305865e-16,0.0,0.0,58.946102,0,0,10.0,-226,8.322480,6.413177,7.644898,7.970157,7.202387,7.201900,8.811394,-2.903634,-3.002937,-2.913438,-1.397046,-0.017734,-3.707832,-0.920725,2.842961,0.235478,0.350060,2.964308,-0.613553,1.651779,-0.863322,-2.280146,-2.584554,-2.241952,-1.703725,-0.876012,-2.251091,-0.513701,10.982150,14.552512,11.355782,9.835787,-5.698653,8.524708,6.877104,-1.190927,-1.091163,-0.910355,-0.749971,-1.297126,-0.634542,-1.540796,D,0,0,0,0,2,0,0,20240203,1532.743171,1511.713713,208364,105430,0,6-3 6-1


In [9]:
import re
SET_RE = re.compile(r'(\d+)\s*[-–—]\s*(\d+)(?:\s*\(([^)]*)\))?')
def parse_score_str(s: str):
    """'6-4 7-6(7)' -> [(6,4,7), ...]. Tolère () vides, 10-8, RET, etc."""
    if not isinstance(s, str):
        return []
    s = s.strip().replace('\xa0', ' ')  # NBSP -> espace
    out = []
    for a, b, tb in SET_RE.findall(s):
        # tb peut être '', '5', '10-8'... on prend le 1er entier s'il existe
        tb_val = None
        if tb:
            m = re.search(r'\d+', tb)
            if m:
                tb_val = int(m.group())
        out.append((int(a), int(b), tb_val))
    return out
def summarize_sets(sets, flip=False):
    if flip:
        sets = [(b, a, tb) for (a, b, tb) in sets]

    n = len(sets)
    if n == 0:
        return dict(n_sets=0, sets_src=0, sets_dst=0,
                    games_src_total=0, games_dst_total=0,
                    games_diff_total=0, sets_diff=0,
                    bagels_src=0, bagels_dst=0,
                    tiebreaks=0, super_tb_flag=0,
                    super_tb_src_pts=0, super_tb_dst_pts=0,
                    straight_sets=0, closeness=0.5, valid=0)

    # Détecter un super TB en dernier "set" (points >=10, écart >=2)
    is_super_tb = (sets[-1][0] >= 10 or sets[-1][1] >= 10) and abs(sets[-1][0] - sets[-1][1]) >= 2

    # Séparer les "vrais" sets des points de super TB
    normal_sets = sets[:-1] if is_super_tb else sets
    super_tb_src_pts = sets[-1][0] if is_super_tb else 0
    super_tb_dst_pts = sets[-1][1] if is_super_tb else 0

    # Comptes sur les sets "normaux"
    sets_src_norm = sum(1 for a, b, _ in normal_sets if a > b)
    sets_dst_norm = len(normal_sets) - sets_src_norm

    games_src_total = sum(a for a, _, _ in normal_sets)   # on exclut le super TB
    games_dst_total = sum(b for _, b, _ in normal_sets)
    games_diff_total = games_src_total - games_dst_total

    # Ajouter le "set" décisif si super TB
    sets_src = sets_src_norm + (1 if is_super_tb and super_tb_src_pts > super_tb_dst_pts else 0)
    sets_dst = sets_dst_norm + (1 if is_super_tb and super_tb_dst_pts > super_tb_src_pts else 0)

    sets_diff = sets_src - sets_dst
    bagels_src = sum(1 for a, b, _ in normal_sets if a == 6 and b == 0)
    bagels_dst = sum(1 for a, b, _ in normal_sets if b == 6 and a == 0)

    # Tiebreaks "classiques" (7-6/6-7 ou parenthèses)
    tiebreaks = sum(1 for a, b, tb in normal_sets if (max(a, b) == 7 and abs(a - b) == 1) or (tb is not None))

    total_games = games_src_total + games_dst_total
    closeness = 1.0 - (abs(games_diff_total) / total_games) if total_games > 0 else 0.5

    return dict(
        n_sets=n,
        sets_src=sets_src, sets_dst=sets_dst,
        games_src_total=games_src_total, games_dst_total=games_dst_total,
        games_diff_total=games_diff_total, sets_diff=sets_diff,
        bagels_src=bagels_src, bagels_dst=bagels_dst,
        tiebreaks=tiebreaks,
        super_tb_flag=int(is_super_tb),
        super_tb_src_pts=super_tb_src_pts,
        super_tb_dst_pts=super_tb_dst_pts,
        straight_sets=int(sets_dst == 0 and len(normal_sets) >= 2),  # "2 sets secs" hors super TB
        closeness=closeness,
        valid=1
    )
def make_score_features(df: pd.DataFrame,
                        score_col: str = "SCORE",
                        result_col: str | None = "RESULT",
                        assume_left_is_winner: bool = True,
                        prefix: str = "score_") -> pd.DataFrame:
    """
    Crée un DataFrame de features à partir de df[score_col] et le renvoie.
    - result_col: 1 si src gagne, 0 sinon (utilisé pour 'flip' si le score est gagnant-à-gauche).
    - assume_left_is_winner: mettre False si le texte SCORE n'est PAS orienté gagnant-à-gauche.
    """
    sets_list = df[score_col].map(parse_score_str)
    if result_col is not None and assume_left_is_winner:
        flip_series = (df[result_col].astype('int32') == 0)  # flip si src a perdu
    else:
        flip_series = pd.Series(False, index=df.index)

    feats = [summarize_sets(s, flip=f) for s, f in zip(sets_list, flip_series)]
    feats_df = pd.DataFrame.from_records(feats, index=df.index).add_prefix(prefix)
    # dtypes compacts
    float_cols = feats_df.select_dtypes(include=['float64']).columns
    feats_df[float_cols] = feats_df[float_cols].astype('float32')
    int_cols = feats_df.select_dtypes(include=['int64']).columns
    feats_df[int_cols] = feats_df[int_cols].astype('int16')
    return feats_df


In [10]:
score_feats = make_score_features(
    final_dataset,
    score_col="SCORE",
    result_col="RESULT",              # ta colonne binaire 0/1
    assume_left_is_winner=True,       # True si SCORE est gagnant-à-gauche
    prefix="score_"
)

# 2) Merge au DataFrame d'origine
final_dataset = final_dataset.join(score_feats)
final_dataset

Unnamed: 0,AGE_DIFF,ATP_POINTS_DIFF,ATP_RANK_DIFF,BEST_OF,DRAW_SIZE,ELO_DIFF,ELO_GRAD_LAST_100_DIFF,ELO_GRAD_LAST_10_DIFF,ELO_GRAD_LAST_200_DIFF,ELO_GRAD_LAST_25_DIFF,ELO_GRAD_LAST_3_DIFF,ELO_GRAD_LAST_50_DIFF,ELO_GRAD_LAST_5_DIFF,ELO_SURFACE_DIFF,H2H_DIFF,H2H_SURFACE_DIFF,HEIGHT_DIFF,N_GAMES_DIFF,P_1ST_IN_LAST_100_DIFF,P_1ST_IN_LAST_10_DIFF,P_1ST_IN_LAST_200_DIFF,P_1ST_IN_LAST_25_DIFF,P_1ST_IN_LAST_3_DIFF,P_1ST_IN_LAST_50_DIFF,P_1ST_IN_LAST_5_DIFF,P_1ST_WON_LAST_100_DIFF,P_1ST_WON_LAST_10_DIFF,P_1ST_WON_LAST_200_DIFF,P_1ST_WON_LAST_25_DIFF,P_1ST_WON_LAST_3_DIFF,P_1ST_WON_LAST_50_DIFF,P_1ST_WON_LAST_5_DIFF,P_2ND_WON_LAST_100_DIFF,P_2ND_WON_LAST_10_DIFF,P_2ND_WON_LAST_200_DIFF,P_2ND_WON_LAST_25_DIFF,P_2ND_WON_LAST_3_DIFF,P_2ND_WON_LAST_50_DIFF,P_2ND_WON_LAST_5_DIFF,P_ACE_LAST_100_DIFF,P_ACE_LAST_10_DIFF,P_ACE_LAST_200_DIFF,P_ACE_LAST_25_DIFF,P_ACE_LAST_3_DIFF,P_ACE_LAST_50_DIFF,P_ACE_LAST_5_DIFF,P_BP_SAVED_LAST_100_DIFF,P_BP_SAVED_LAST_10_DIFF,P_BP_SAVED_LAST_200_DIFF,P_BP_SAVED_LAST_25_DIFF,P_BP_SAVED_LAST_3_DIFF,P_BP_SAVED_LAST_50_DIFF,P_BP_SAVED_LAST_5_DIFF,P_DF_LAST_100_DIFF,P_DF_LAST_10_DIFF,P_DF_LAST_200_DIFF,P_DF_LAST_25_DIFF,P_DF_LAST_3_DIFF,P_DF_LAST_50_DIFF,P_DF_LAST_5_DIFF,TOURNEY_LEVEL,WIN_LAST_100_DIFF,WIN_LAST_10_DIFF,WIN_LAST_200_DIFF,WIN_LAST_25_DIFF,WIN_LAST_3_DIFF,WIN_LAST_50_DIFF,WIN_LAST_5_DIFF,date,p1_elo,p2_elo,p1_id,p2_id,RESULT,SCORE,score_n_sets,score_sets_src,score_sets_dst,score_games_src_total,score_games_dst_total,score_games_diff_total,score_sets_diff,score_bagels_src,score_bagels_dst,score_tiebreaks,score_super_tb_flag,score_super_tb_src_pts,score_super_tb_dst_pts,score_straight_sets,score_closeness,score_valid
0,5.0,1028.0,-69.0,3,32,0.000000,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,0.000000,0,0,5.0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,A,0,0,0,0,0,0,0,19910107,1500.000000,1500.000000,101142,101746,1,6-4 6-1,2,2,0,12,5,7,2,0,0,0,0,0,0,1,0.588235,1
1,-10.3,257.0,-126.0,3,32,0.000000,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,0.000000,0,0,-8.0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,A,0,0,0,0,0,0,0,19910107,1500.000000,1500.000000,101613,100587,1,6-3 6-2,2,2,0,12,5,7,2,0,0,0,0,0,0,1,0.588235,1
2,-3.7,-352.0,135.0,3,32,0.000000,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,0.000000,0,0,0.0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,A,0,0,0,0,0,0,0,19910107,1500.000000,1500.000000,101601,101179,0,2-6 6-1 6-2,3,1,2,9,14,-5,-1,0,0,0,0,0,0,0,0.782609,1
3,2.0,19.0,-7.0,3,32,0.000000,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,0.000000,0,0,7.0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,A,0,0,0,0,0,0,0,19910107,1500.000000,1500.000000,101117,101332,1,6-3 1-6 6-4,3,2,1,13,13,0,1,0,0,0,0,0,0,0,1.000000,1
4,-0.9,-734.0,162.0,3,32,0.000000,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,0.000000,0,0,-11.0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,A,0,0,0,0,0,0,0,19910107,1500.000000,1500.000000,101901,101735,1,6-3 6-4,2,2,0,12,7,5,2,0,0,0,0,0,0,1,0.736842,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95370,3.2,-40.0,250.0,3,4,0.000000,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,0.000000,0,0,8.0,0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,D,0,0,0,0,0,0,0,20240203,1500.000000,1500.000000,133933,207134,0,6-2 6-3,2,0,2,5,12,-7,-2,0,0,0,0,0,0,0,0.588235,1
95371,-2.2,-187.0,621.0,3,4,11.415089,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,0.000000,0,0,10.0,-1,-53.820988,-53.820988,-53.820988,-53.820988,-53.820988,-53.820988,-53.820988,-69.954545,-69.954545,-69.954545,-69.954545,-69.954545,-69.954545,-69.954545,-40.040541,-40.040541,-40.040541,-40.040541,-40.040541,-40.040541,-40.040541,-3.203704,-3.203704,-3.203704,-3.203704,-3.203704,-3.203704,-3.203704,-49.500000,-49.500000,-49.500000,-49.500000,-49.500000,-49.500000,-49.500000,-3.203704,-3.203704,-3.203704,-3.203704,-3.203704,-3.203704,-3.203704,D,0,0,0,0,0,0,0,20240202,1500.000000,1488.584911,132374,121411,0,6-2 6-3,2,0,2,5,12,-7,-2,0,0,0,0,0,0,0,0.588235,1
95372,2.1,21.0,-124.0,3,4,21.484266,0.0,0.0,0.0,0.0,0.000000e+00,0.0,0.0,22.190451,0,0,-3.0,2,72.147059,72.147059,72.147059,72.147059,72.147059,72.147059,72.147059,63.450893,63.450893,63.450893,63.450893,63.450893,63.450893,63.450893,47.714286,47.714286,47.714286,47.714286,47.714286,47.714286,47.714286,0.713235,0.713235,0.713235,0.713235,0.713235,0.713235,0.713235,70.333333,70.333333,70.333333,70.333333,70.333333,70.333333,70.333333,2.477941,2.477941,2.477941,2.477941,2.477941,2.477941,2.477941,D,0,0,0,0,0,0,0,20240203,1521.484266,1500.000000,208364,209943,1,4-6 6-1 6-1,3,2,1,16,8,8,1,0,0,0,0,0,0,0,0.666667,1
95373,-10.3,-434.0,480.0,3,4,21.029457,0.0,0.0,0.0,0.0,2.305865e-16,0.0,0.0,58.946102,0,0,10.0,-226,8.322480,6.413177,7.644898,7.970157,7.202387,7.201900,8.811394,-2.903634,-3.002937,-2.913438,-1.397046,-0.017734,-3.707832,-0.920725,2.842961,0.235478,0.350060,2.964308,-0.613553,1.651779,-0.863322,-2.280146,-2.584554,-2.241952,-1.703725,-0.876012,-2.251091,-0.513701,10.982150,14.552512,11.355782,9.835787,-5.698653,8.524708,6.877104,-1.190927,-1.091163,-0.910355,-0.749971,-1.297126,-0.634542,-1.540796,D,0,0,0,0,2,0,0,20240203,1532.743171,1511.713713,208364,105430,0,6-3 6-1,2,0,2,4,12,-8,-2,0,0,0,0,0,0,0,0.500000,1


In [11]:



final_dataset['date2'] = pd.to_datetime(final_dataset['date'], format='%Y%m%d')
final_dataset['t_days'] = (final_dataset['date2'] - final_dataset['date2'].min()).dt.days  # petit nombre
final_dataset['t_days'] 

0            0
1            0
2            0
3            0
4            0
         ...  
95370    12080
95371    12079
95372    12080
95373    12080
95374    12079
Name: t_days, Length: 95375, dtype: int64

In [12]:
final_dataset.columns

Index(['AGE_DIFF', 'ATP_POINTS_DIFF', 'ATP_RANK_DIFF', 'BEST_OF', 'DRAW_SIZE',
       'ELO_DIFF', 'ELO_GRAD_LAST_100_DIFF', 'ELO_GRAD_LAST_10_DIFF',
       'ELO_GRAD_LAST_200_DIFF', 'ELO_GRAD_LAST_25_DIFF',
       'ELO_GRAD_LAST_3_DIFF', 'ELO_GRAD_LAST_50_DIFF', 'ELO_GRAD_LAST_5_DIFF',
       'ELO_SURFACE_DIFF', 'H2H_DIFF', 'H2H_SURFACE_DIFF', 'HEIGHT_DIFF',
       'N_GAMES_DIFF', 'P_1ST_IN_LAST_100_DIFF', 'P_1ST_IN_LAST_10_DIFF',
       'P_1ST_IN_LAST_200_DIFF', 'P_1ST_IN_LAST_25_DIFF',
       'P_1ST_IN_LAST_3_DIFF', 'P_1ST_IN_LAST_50_DIFF', 'P_1ST_IN_LAST_5_DIFF',
       'P_1ST_WON_LAST_100_DIFF', 'P_1ST_WON_LAST_10_DIFF',
       'P_1ST_WON_LAST_200_DIFF', 'P_1ST_WON_LAST_25_DIFF',
       'P_1ST_WON_LAST_3_DIFF', 'P_1ST_WON_LAST_50_DIFF',
       'P_1ST_WON_LAST_5_DIFF', 'P_2ND_WON_LAST_100_DIFF',
       'P_2ND_WON_LAST_10_DIFF', 'P_2ND_WON_LAST_200_DIFF',
       'P_2ND_WON_LAST_25_DIFF', 'P_2ND_WON_LAST_3_DIFF',
       'P_2ND_WON_LAST_50_DIFF', 'P_2ND_WON_LAST_5_DIFF',
       'P_ACE_

In [13]:
features_cols = [
    'AGE_DIFF', 'ATP_POINTS_DIFF', 'ATP_RANK_DIFF', 'BEST_OF', 'DRAW_SIZE',
       'ELO_DIFF', 'ELO_GRAD_LAST_100_DIFF', 'ELO_GRAD_LAST_10_DIFF',
       'ELO_GRAD_LAST_200_DIFF', 'ELO_GRAD_LAST_25_DIFF',
       'ELO_GRAD_LAST_3_DIFF', 'ELO_GRAD_LAST_50_DIFF', 'ELO_GRAD_LAST_5_DIFF',
       'ELO_SURFACE_DIFF', 'H2H_DIFF', 'H2H_SURFACE_DIFF', 'HEIGHT_DIFF',
       'N_GAMES_DIFF', 'P_1ST_IN_LAST_100_DIFF', 'P_1ST_IN_LAST_10_DIFF',
       'P_1ST_IN_LAST_200_DIFF', 'P_1ST_IN_LAST_25_DIFF',
       'P_1ST_IN_LAST_3_DIFF', 'P_1ST_IN_LAST_50_DIFF', 'P_1ST_IN_LAST_5_DIFF',
       'P_1ST_WON_LAST_100_DIFF', 'P_1ST_WON_LAST_10_DIFF',
       'P_1ST_WON_LAST_200_DIFF', 'P_1ST_WON_LAST_25_DIFF',
       'P_1ST_WON_LAST_3_DIFF', 'P_1ST_WON_LAST_50_DIFF',
       'P_1ST_WON_LAST_5_DIFF', 'P_2ND_WON_LAST_100_DIFF',
       'P_2ND_WON_LAST_10_DIFF', 'P_2ND_WON_LAST_200_DIFF',
       'P_2ND_WON_LAST_25_DIFF', 'P_2ND_WON_LAST_3_DIFF',
       'P_2ND_WON_LAST_50_DIFF', 'P_2ND_WON_LAST_5_DIFF',
       'P_ACE_LAST_100_DIFF', 'P_ACE_LAST_10_DIFF', 'P_ACE_LAST_200_DIFF',
       'P_ACE_LAST_25_DIFF', 'P_ACE_LAST_3_DIFF', 'P_ACE_LAST_50_DIFF',
       'P_ACE_LAST_5_DIFF', 'P_BP_SAVED_LAST_100_DIFF',
       'P_BP_SAVED_LAST_10_DIFF', 'P_BP_SAVED_LAST_200_DIFF',
       'P_BP_SAVED_LAST_25_DIFF', 'P_BP_SAVED_LAST_3_DIFF',
       'P_BP_SAVED_LAST_50_DIFF', 'P_BP_SAVED_LAST_5_DIFF',
       'P_DF_LAST_100_DIFF', 'P_DF_LAST_10_DIFF', 'P_DF_LAST_200_DIFF',
       'P_DF_LAST_25_DIFF', 'P_DF_LAST_3_DIFF', 'P_DF_LAST_50_DIFF',
       'P_DF_LAST_5_DIFF', 'WIN_LAST_100_DIFF',
       'WIN_LAST_10_DIFF', 'WIN_LAST_200_DIFF', 'WIN_LAST_25_DIFF',
       'WIN_LAST_3_DIFF', 'WIN_LAST_50_DIFF', 'WIN_LAST_5_DIFF'
]
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
final_dataset[features_cols] = scaler.fit_transform(final_dataset[features_cols])


In [14]:
final_dataset.to_csv("./data/1finalDataset.csv", index=False)

In [15]:
date = 20240101
train_df = final_dataset[final_dataset["date"] < date].copy()
test_df  = final_dataset[final_dataset["date"] >= date].copy()
print(len(train_df))
print(len(test_df))

92429
2946


In [16]:
import torch
from torch_geometric.data import TemporalData
from torch_geometric.loader import TemporalDataLoader
def build_temporal_data(df_subset):
    return TemporalData(
        src = torch.tensor(df_subset["p1_id"].values, dtype=torch.long),
        dst = torch.tensor(df_subset["p2_id"].values, dtype=torch.long),
        t   = torch.tensor(df_subset["t_days"].values, dtype=torch.long),
        msg = torch.tensor(df_subset[features_cols].values, dtype=torch.float),
        y   = torch.tensor(df_subset["RESULT"].values, dtype=torch.float),
        closeness = torch.tensor(df_subset["score_closeness"].values,dtype=torch.float)
    )

In [17]:
full_data = build_temporal_data(final_dataset)
train_data = build_temporal_data(train_df)
test_data  = build_temporal_data(test_df)

for data in (full_data,train_data, test_data):
    data.src = data.src.long()    # src en ints longs
    data.dst = data.dst.long()    # dst en ints longs
    data.t   = data.t.long()     # timestamps en floats
    data.msg = data.msg.float()   # features en floats
    data.y   = data.y.float()     # labels en floats
    data.closeness = data.closeness.float()
# 4) Déplacer t et msg sur GPU
device = "cuda"
for data in (full_data,train_data, test_data):
    data.t   = data.t.to(device)
    data.msg = data.msg.to(device)
# 6. DataLoaders pour entraînement
train_loader = TemporalDataLoader(train_data, batch_size=32, neg_sampling_ratio=0)
test_loader  = TemporalDataLoader(test_data, batch_size=32, neg_sampling_ratio=0)


In [18]:

for batch in train_loader:
    print("=== Nouveau batch ===")
    print("src:", batch.src)
    print("dst:", batch.dst)
    print("t:", batch.t)
    print("msg:", batch.msg)
    print("y:", batch.y)
    print("Closeness:", batch.closeness)
    break

=== Nouveau batch ===
src: tensor([101142, 101613, 101601, 101117, 101901, 101377, 101409, 101407, 101481,
        101441, 101421, 101120, 101233, 100752, 101312, 101230, 101613, 101117,
        101377, 101409, 101767, 101120, 101233, 101119, 101142, 101234, 101179,
        101120, 101179, 100656, 101511, 101703])
dst: tensor([101746, 100587, 101179, 101332, 101735, 101439, 100772, 100954, 101532,
        101767, 101205, 101123, 101274, 101234, 101119, 102000, 101142, 101179,
        101901, 100954, 101481, 101205, 101234, 101230, 101179, 101119, 100954,
        101119, 101120, 100923, 101196, 101073])
t: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
msg: tensor([[ 0.9601,  0.5384, -0.5188,  ..., -0.0022,  0.0043, -0.0023],
        [-1.9789,  0.1350, -0.9501,  ..., -0.0022,  0.0043, -0.0023],
        [-0.7111, -0.1837,  1.0249,  ..., -0.0022,  0.0043, -0.0023],
        ...,
        [ 0.6143,  1.0605, -0

In [22]:
import pandas as pd

# Récupère toutes les valeurs uniques
all_ids = pd.unique(final_dataset[["p1_id", "p2_id"]].values.ravel())

# Crée un mapping {id_original -> id_compact}
id_map = {old_id: new_id for new_id, old_id in enumerate(all_ids)}

# Applique le mapping
final_dataset["p1_id"] = final_dataset["p1_id"].map(id_map)
final_dataset["p2_id"] = final_dataset["p2_id"].map(id_map)

# Nombre de noeuds réels
num_nodes = len(all_ids)
print("Nombre de joueurs:", num_nodes)


Nombre de joueurs: 1932


In [None]:
import matplotlib.pyplot as plt
from itertools import product

from torch_geometric.nn import TGNMemory
from torch_geometric.nn.models.tgn import (
    LastAggregator,
    LastNeighborLoader,
    IdentityMessage
)
from tgn.model import MultiLayerTimeAwareGNN,MessageMLP,WinPredictorMLP,WinPredictor,SmallWinPredictor
from tgn.utils import train,evaluate,compute_alpha,train_debug,train_debug2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Paramètres
memory_dim = 128
time_dim   = 32
embedding_dim = 128
in_channels = 128
hidden_channels = 32

num_layers = 2
heads = 4
dropout= 0.4
learning_rates = [1e-3, 4e-4]
weight_decays = [1e-4, 5e-4]
hidden_variants = [[512, 64], [256, 64]]

# Générer toutes les combinaisons
grid = list(product(learning_rates, weight_decays, hidden_variants))
run_id = 0
for lr, wd, hidden in grid:
    run_id += 1
    print(f"\n=== RUN {run_id} | lr={lr} | wd={wd} | hidden={hidden} ===")

    msg_dim = full_data.msg.size(-1)

    memory = TGNMemory(
        num_nodes=num_nodes,
        raw_msg_dim=msg_dim,
        memory_dim=memory_dim,
        time_dim=time_dim,
        message_module=MessageMLP(msg_dim, memory_dim, time_dim,2*memory_dim),
        aggregator_module=LastAggregator(),
    ).to(device)

    gnn = MultiLayerTimeAwareGNN(in_channels,memory_dim,hidden_channels, 
                                 embedding_dim, msg_dim, memory.time_enc,
                                 num_layers,heads,dropout).to(device)
    
    win_pred = SmallWinPredictor(
        embed_dim=embedding_dim,
        match_dim=msg_dim,
        hidden = hidden 
    ).to(device)

    total_params = 0
    for model in [memory, gnn, win_pred]:
        model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"{model.__class__.__name__} params: {model_params:,}")
        total_params += model_params

    print(f"Total parameters: {total_params:,}")



    optimizer = torch.optim.AdamW(
        list(memory.parameters()) + list(gnn.parameters()) + list(win_pred.parameters()),
        lr=lr,weight_decay=wd
    )
    criterion = torch.nn.BCEWithLogitsLoss()

    # === Loaders ===


    train_loader_ngh = LastNeighborLoader(num_nodes=num_nodes, size=25, device=device)
    eval_loader_ngh  = LastNeighborLoader(num_nodes=num_nodes, size=25, device=device)

    assoc = torch.empty(num_nodes, dtype=torch.long, device=device)



    threshold = [0.6,0.65,0.7,0.75,0.8]
    num_epochs = 200

    import random

    train_variants = [
        (train_loader, full_data, train_data),

    ]
    best_val_ap = 0
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.9)
    hparams = {
    "learning_rate": lr,
    "weight_decay": wd,
    "memory_dim": memory_dim,
    "time_dim": time_dim,
    "embedding_dim": embedding_dim,
    "in_channels": in_channels,
    "hidden_channels": hidden_channels,
    "num_layers": num_layers,
    "heads": heads,
    "dropout": dropout,
    "hidden": hidden
    }


    history = History(outdir="runs/curves_run7",hparams=hparams)

    train_losses, train_aps, train_prec = [], [], []
    val_losses,   val_aps,   val_prec  = [], [], []

    for epoch in range(1, num_epochs + 1):
        alpha = compute_alpha(epoch, num_epochs)

        loader, full, train_data_split = random.choice(train_variants)

        loss, ap, prec = train_debug2(
            loader, memory, gnn, win_pred, full, train_loader_ngh, eval_loader_ngh,
            optimizer, device, assoc, train_data_split, alpha
        )

        train_losses.append(loss)
        train_aps.append(ap)
        train_prec.append(prec)

        val_ap, val_loss, prec_v, prec_at, well_dates, bad_dates = evaluate(
            test_loader, memory, gnn, win_pred, full_data, eval_loader_ngh,
            assoc, device, threshold, alpha
        )

        val_losses.append(val_loss)
        val_aps.append(val_ap)
        val_prec.append(prec_v)

        # --- LOG + SAUVEGARDE INCRÉMENTALE  ---
        history.log_epoch(
            epoch=epoch,
            train_loss=loss, train_ap=ap, train_prec=prec,
            val_loss=val_loss, val_ap=val_ap, val_prec=prec_v,
            prec_at=prec_at
        )
        
        history.save_tables()
        history.save_plots()
        torch.save({
        "memory_state": memory.state_dict(),
        "gnn_state": gnn.state_dict(),
        "win_pred_state": win_pred.state_dict(),
        "optimizer_state": optimizer.state_dict()
    }, f"models/run7/epoch{epoch}.pth")

    
    history.save_all()
    print("Courbes et CSV sauvegardés dans", history.outdir.resolve())


LR = 0.0004
212083
TGNMemory params: 464,192
MultiLayerTimeAwareGNN params: 174,528
SmallWinPredictor params: 190,849
Total parameters: 829,569


  return disable_fn(*args, **kwargs)
Training: 100%|██████████| 2889/2889 [00:56<00:00, 51.24batch/s]
Evaluating: 100%|██████████| 93/93 [00:00<00:00, 126.63batch/s]
Training: 100%|██████████| 2889/2889 [00:49<00:00, 58.52batch/s]
Evaluating: 100%|██████████| 93/93 [00:00<00:00, 135.34batch/s]
Training: 100%|██████████| 2889/2889 [00:59<00:00, 48.42batch/s]
Evaluating: 100%|██████████| 93/93 [00:00<00:00, 104.17batch/s]
Training: 100%|██████████| 2889/2889 [00:50<00:00, 57.36batch/s]
Evaluating: 100%|██████████| 93/93 [00:00<00:00, 136.43batch/s]
Training: 100%|██████████| 2889/2889 [00:53<00:00, 53.94batch/s]
Evaluating: 100%|██████████| 93/93 [00:00<00:00, 101.98batch/s]
Training: 100%|██████████| 2889/2889 [00:51<00:00, 55.68batch/s]
Evaluating: 100%|██████████| 93/93 [00:00<00:00, 143.93batch/s]
Training: 100%|██████████| 2889/2889 [00:51<00:00, 55.87batch/s]
Evaluating: 100%|██████████| 93/93 [00:00<00:00, 100.16batch/s]
Training: 100%|██████████| 2889/2889 [00:59<00:00, 48.18batc

Courbes et CSV sauvegardés dans /home/romain/tensorflow_project/tennis/random-forest-tennis/random-forest-tennis/runs/curves_run7


Most of the logic is in utils, because I wanted a common API with which to call getStats and updateStats. That way there is no possible data leakage. 

In addition, by doing this, when we predict a player. We can call getStats to get the stats from both players, making it easier to predict future matches.

Let's export this to 1finalDataset.csv:

In [20]:
final_dataset.to_csv("./data/1finalDataset.csv", index=False)

I also tried to do this with getStats from scratch (but it was kinda dumb, since I was recalculating ELO scores from scratch every time). So, that's why it's commented out.