## Import des librairies

In [1]:
import pandas as pd 
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
from pickle import dump 
import lightgbm as lgb

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

## Chargement des données

In [2]:
data = pd.read_csv("../data/processed/data_preprocessed.csv")
data_with_encoded_cat_columns = pd.read_csv("../data/processed/data_with_encoded_categories_preprocessed.csv")

In [3]:
data.head()

Unnamed: 0,p2_age,p2_hand,p2_ht,p2_id,p2_ioc,p2_rank_points,match_num,round,surface,tourney_level,tourney_name,p1_age,p1_hand,p1_ht,p1_id,p1_ioc,p1_rank_points,p1_won,tourney_month,p1_is_seed_player,p2_is_seed_player,p1_new_rank,p2_new_rank
0,0.90462,L,0.204632,104269,ESP,-0.237747,-1.558052,R64,Hard,M,Cincinnati Masters,0.576739,R,0.2156288,104542,FRA,0.098019,0,0.912765,-0.715611,-0.70978,Top 30,Top 30-100
1,0.728708,R,-0.546553,104571,CYP,-0.203132,0.749218,R32,Grass,A,Stuttgart,-0.375267,R,-4.275467e-15,105526,GER,-0.425967,1,0.241991,-0.715611,-0.70978,Top 30-100,Top 30-100
2,-0.367358,L,-0.546553,105385,USA,-0.378871,-0.586125,R16,Hard,G,US Open,0.599352,R,-0.5365192,104527,SUI,2.23873,1,0.912765,1.397407,-0.70978,Top 30,Top 30-100
3,0.219015,L,-0.246079,104745,ESP,2.291827,-1.363667,QF,Clay,A,Rio de Janeiro,0.305383,R,-0.987808,104655,URU,-0.006251,0,-1.099557,1.397407,1.408888,Top 30,Top 30
4,-0.062895,R,0.505105,104997,NED,-0.471533,-1.473537,R32,Hard,A,Zagreb,0.581262,R,-0.23566,104433,CAN,-0.596592,0,-1.099557,-0.715611,-0.70978,Under 100,Top 30-100


In [4]:
data_with_encoded_cat_columns.head()

Unnamed: 0,p2_age,p2_hand,p2_ht,p2_id,p2_ioc,p2_rank_points,match_num,round,surface,tourney_level,tourney_name,p1_age,p1_hand,p1_ht,p1_id,p1_ioc,p1_rank_points,p1_won,tourney_month,p1_is_seed_player,p2_is_seed_player,p1_new_rank,p2_new_rank
0,0.90462,0,0.204632,50,25,-0.237747,-1.558052,6,3,4,21,0.576739,1,0.2156288,79,27,0.098019,0,0.912765,-0.715611,-0.70978,0,1
1,0.728708,1,-0.546553,80,18,-0.203132,0.749218,5,2,0,359,-0.375267,1,-4.275467e-15,241,30,-0.425967,1,0.241991,-0.715611,-0.70978,1,1
2,-0.367358,0,-0.546553,213,83,-0.378871,-0.586125,4,3,3,363,0.599352,1,-0.5365192,76,72,2.23873,1,0.912765,1.397407,-0.70978,0,1
3,0.219015,0,-0.246079,107,25,2.291827,-1.363667,2,1,0,347,0.305383,1,-0.987808,99,80,-0.006251,0,-1.099557,1.397407,1.408888,0,0
4,-0.062895,1,0.505105,143,56,-0.471533,-1.473537,5,3,0,372,0.581262,1,-0.23566,67,11,-0.596592,0,-1.099557,-0.715611,-0.70978,2,1


In [5]:
data.dtypes

p2_age               float64
p2_hand               object
p2_ht                float64
p2_id                  int64
p2_ioc                object
p2_rank_points       float64
match_num            float64
round                 object
surface               object
tourney_level         object
tourney_name          object
p1_age               float64
p1_hand               object
p1_ht                float64
p1_id                  int64
p1_ioc                object
p1_rank_points       float64
p1_won                 int64
tourney_month        float64
p1_is_seed_player    float64
p2_is_seed_player    float64
p1_new_rank           object
p2_new_rank           object
dtype: object

In [6]:
def preprocess(data):
    data[["p1_id", "p2_id"]] = data[["p1_id", "p2_id"]].astype('category')
    object_columns = data.select_dtypes(['object']).columns
    data[object_columns] = data[object_columns].astype('category')
    return data

In [7]:
data = preprocess(data)
data_with_encoded_cat_columns = preprocess(data_with_encoded_cat_columns)

In [8]:
X = data.drop(columns=['p1_won'])
y = data['p1_won']

In [9]:
categorical_columns = X.select_dtypes(['category']).columns

In [10]:
categorical_columns

Index(['p2_hand', 'p2_id', 'p2_ioc', 'round', 'surface', 'tourney_level',
       'tourney_name', 'p1_hand', 'p1_id', 'p1_ioc', 'p1_new_rank',
       'p2_new_rank'],
      dtype='object')

In [11]:
categorical_indexes = []
for element in categorical_columns:
    categorical_indexes.append(X.columns.to_list().index(element))

## Entraînement des modèles

In [12]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

In [14]:
catboost_model = CatBoostClassifier(iterations=1500,
                           learning_rate=0.01,
                           eval_metric="Precision",
                           depth=2)

In [15]:
catboost_model.fit(X_train, y_train, categorical_indexes)

0:	learn: 0.7062197	total: 61.4ms	remaining: 1m 32s
1:	learn: 0.7002935	total: 66.5ms	remaining: 49.8s
2:	learn: 0.6945848	total: 70.4ms	remaining: 35.1s
3:	learn: 0.5802222	total: 73.4ms	remaining: 27.4s
4:	learn: 0.6904003	total: 76.2ms	remaining: 22.8s
5:	learn: 0.6936808	total: 79.2ms	remaining: 19.7s
6:	learn: 0.6939793	total: 82.3ms	remaining: 17.6s
7:	learn: 0.5819621	total: 85.1ms	remaining: 15.9s
8:	learn: 0.5819621	total: 87.4ms	remaining: 14.5s
9:	learn: 0.5822981	total: 90ms	remaining: 13.4s
10:	learn: 0.5823374	total: 92.6ms	remaining: 12.5s
11:	learn: 0.6768881	total: 95.5ms	remaining: 11.8s
12:	learn: 0.6767045	total: 98.1ms	remaining: 11.2s
13:	learn: 0.6768881	total: 101ms	remaining: 10.7s
14:	learn: 0.6768881	total: 103ms	remaining: 10.2s
15:	learn: 0.6769886	total: 106ms	remaining: 9.83s
16:	learn: 0.6796117	total: 109ms	remaining: 9.51s
17:	learn: 0.6800689	total: 112ms	remaining: 9.23s
18:	learn: 0.6793044	total: 115ms	remaining: 8.97s
19:	learn: 0.6796006	total: 1

186:	learn: 0.7018152	total: 642ms	remaining: 4.51s
187:	learn: 0.7020192	total: 645ms	remaining: 4.5s
188:	learn: 0.7017723	total: 648ms	remaining: 4.5s
189:	learn: 0.7015685	total: 651ms	remaining: 4.49s
190:	learn: 0.7018736	total: 654ms	remaining: 4.48s
191:	learn: 0.7021190	total: 657ms	remaining: 4.47s
192:	learn: 0.7017544	total: 660ms	remaining: 4.47s
193:	learn: 0.7014895	total: 663ms	remaining: 4.46s
194:	learn: 0.7027193	total: 666ms	remaining: 4.46s
195:	learn: 0.7035897	total: 669ms	remaining: 4.45s
196:	learn: 0.7026197	total: 671ms	remaining: 4.44s
197:	learn: 0.7026585	total: 675ms	remaining: 4.44s
198:	learn: 0.7036885	total: 678ms	remaining: 4.43s
199:	learn: 0.7033777	total: 681ms	remaining: 4.43s
200:	learn: 0.7036885	total: 684ms	remaining: 4.42s
201:	learn: 0.7035444	total: 687ms	remaining: 4.42s
202:	learn: 0.7042427	total: 690ms	remaining: 4.41s
203:	learn: 0.7033102	total: 693ms	remaining: 4.4s
204:	learn: 0.7030452	total: 696ms	remaining: 4.4s
205:	learn: 0.70

357:	learn: 0.7101743	total: 1.21s	remaining: 3.84s
358:	learn: 0.7109850	total: 1.21s	remaining: 3.84s
359:	learn: 0.7104677	total: 1.21s	remaining: 3.83s
360:	learn: 0.7109581	total: 1.21s	remaining: 3.83s
361:	learn: 0.7102388	total: 1.22s	remaining: 3.82s
362:	learn: 0.7102388	total: 1.22s	remaining: 3.82s
363:	learn: 0.7101537	total: 1.22s	remaining: 3.82s
364:	learn: 0.7099777	total: 1.23s	remaining: 3.82s
365:	learn: 0.7102652	total: 1.23s	remaining: 3.81s
366:	learn: 0.7101801	total: 1.24s	remaining: 3.81s
367:	learn: 0.7102388	total: 1.24s	remaining: 3.81s
368:	learn: 0.7109897	total: 1.24s	remaining: 3.81s
369:	learn: 0.7109043	total: 1.25s	remaining: 3.81s
370:	learn: 0.7106435	total: 1.25s	remaining: 3.8s
371:	learn: 0.7106753	total: 1.25s	remaining: 3.79s
372:	learn: 0.7105583	total: 1.25s	remaining: 3.79s
373:	learn: 0.7107605	total: 1.26s	remaining: 3.79s
374:	learn: 0.7104731	total: 1.26s	remaining: 3.78s
375:	learn: 0.7102123	total: 1.26s	remaining: 3.78s
376:	learn: 0

552:	learn: 0.7182813	total: 1.78s	remaining: 3.05s
553:	learn: 0.7180474	total: 1.79s	remaining: 3.05s
554:	learn: 0.7186613	total: 1.79s	remaining: 3.05s
555:	learn: 0.7185155	total: 1.79s	remaining: 3.04s
556:	learn: 0.7186929	total: 1.79s	remaining: 3.04s
557:	learn: 0.7194478	total: 1.8s	remaining: 3.04s
558:	learn: 0.7190418	total: 1.8s	remaining: 3.03s
559:	learn: 0.7193660	total: 1.8s	remaining: 3.03s
560:	learn: 0.7198618	total: 1.81s	remaining: 3.02s
561:	learn: 0.7202683	total: 1.81s	remaining: 3.02s
562:	learn: 0.7201788	total: 1.81s	remaining: 3.02s
563:	learn: 0.7201220	total: 1.81s	remaining: 3.01s
564:	learn: 0.7204716	total: 1.82s	remaining: 3.01s
565:	learn: 0.7204148	total: 1.82s	remaining: 3s
566:	learn: 0.7202357	total: 1.82s	remaining: 3s
567:	learn: 0.7203820	total: 1.83s	remaining: 3s
568:	learn: 0.7206749	total: 1.83s	remaining: 2.99s
569:	learn: 0.7200651	total: 1.83s	remaining: 2.99s
570:	learn: 0.7210248	total: 1.83s	remaining: 2.98s
571:	learn: 0.7209350	to

743:	learn: 0.7270867	total: 2.37s	remaining: 2.4s
744:	learn: 0.7268263	total: 2.37s	remaining: 2.4s
745:	learn: 0.7269380	total: 2.37s	remaining: 2.4s
746:	learn: 0.7269939	total: 2.38s	remaining: 2.4s
747:	learn: 0.7269562	total: 2.38s	remaining: 2.39s
748:	learn: 0.7269750	total: 2.38s	remaining: 2.39s
749:	learn: 0.7266025	total: 2.39s	remaining: 2.39s
750:	learn: 0.7264905	total: 2.39s	remaining: 2.38s
751:	learn: 0.7268632	total: 2.39s	remaining: 2.38s
752:	learn: 0.7266585	total: 2.4s	remaining: 2.38s
753:	learn: 0.7267144	total: 2.4s	remaining: 2.37s
754:	learn: 0.7267144	total: 2.4s	remaining: 2.37s
755:	learn: 0.7263050	total: 2.41s	remaining: 2.37s
756:	learn: 0.7266025	total: 2.41s	remaining: 2.36s
757:	learn: 0.7263610	total: 2.41s	remaining: 2.36s
758:	learn: 0.7262124	total: 2.42s	remaining: 2.36s
759:	learn: 0.7261003	total: 2.42s	remaining: 2.35s
760:	learn: 0.7262684	total: 2.42s	remaining: 2.35s
761:	learn: 0.7258592	total: 2.42s	remaining: 2.35s
762:	learn: 0.72598

927:	learn: 0.7289490	total: 2.95s	remaining: 1.82s
928:	learn: 0.7289490	total: 2.95s	remaining: 1.81s
929:	learn: 0.7290217	total: 2.96s	remaining: 1.81s
930:	learn: 0.7288725	total: 2.96s	remaining: 1.81s
931:	learn: 0.7289834	total: 2.96s	remaining: 1.8s
932:	learn: 0.7290771	total: 2.96s	remaining: 1.8s
933:	learn: 0.7291326	total: 2.97s	remaining: 1.8s
934:	learn: 0.7289834	total: 2.97s	remaining: 1.79s
935:	learn: 0.7290389	total: 2.97s	remaining: 1.79s
936:	learn: 0.7291880	total: 2.98s	remaining: 1.79s
937:	learn: 0.7290771	total: 2.98s	remaining: 1.79s
938:	learn: 0.7291880	total: 2.98s	remaining: 1.78s
939:	learn: 0.7293925	total: 2.99s	remaining: 1.78s
940:	learn: 0.7293372	total: 2.99s	remaining: 1.78s
941:	learn: 0.7293372	total: 2.99s	remaining: 1.77s
942:	learn: 0.7293372	total: 3s	remaining: 1.77s
943:	learn: 0.7291880	total: 3s	remaining: 1.77s
944:	learn: 0.7292818	total: 3s	remaining: 1.76s
945:	learn: 0.7294310	total: 3.01s	remaining: 1.76s
946:	learn: 0.7291326	to

1106:	learn: 0.7303601	total: 3.53s	remaining: 1.25s
1107:	learn: 0.7303601	total: 3.54s	remaining: 1.25s
1108:	learn: 0.7302658	total: 3.54s	remaining: 1.25s
1109:	learn: 0.7303210	total: 3.54s	remaining: 1.24s
1110:	learn: 0.7302658	total: 3.55s	remaining: 1.24s
1111:	learn: 0.7301165	total: 3.55s	remaining: 1.24s
1112:	learn: 0.7301165	total: 3.55s	remaining: 1.24s
1113:	learn: 0.7300225	total: 3.56s	remaining: 1.23s
1114:	learn: 0.7297242	total: 3.56s	remaining: 1.23s
1115:	learn: 0.7297794	total: 3.56s	remaining: 1.23s
1116:	learn: 0.7297794	total: 3.57s	remaining: 1.22s
1117:	learn: 0.7300776	total: 3.57s	remaining: 1.22s
1118:	learn: 0.7296304	total: 3.57s	remaining: 1.22s
1119:	learn: 0.7297794	total: 3.57s	remaining: 1.21s
1120:	learn: 0.7297242	total: 3.58s	remaining: 1.21s
1121:	learn: 0.7297242	total: 3.58s	remaining: 1.21s
1122:	learn: 0.7296137	total: 3.58s	remaining: 1.2s
1123:	learn: 0.7296690	total: 3.59s	remaining: 1.2s
1124:	learn: 0.7296137	total: 3.59s	remaining: 1

1283:	learn: 0.7314344	total: 4.12s	remaining: 692ms
1284:	learn: 0.7316924	total: 4.12s	remaining: 689ms
1285:	learn: 0.7316924	total: 4.12s	remaining: 686ms
1286:	learn: 0.7318413	total: 4.13s	remaining: 683ms
1287:	learn: 0.7317867	total: 4.13s	remaining: 680ms
1288:	learn: 0.7318413	total: 4.13s	remaining: 676ms
1289:	learn: 0.7316924	total: 4.13s	remaining: 673ms
1290:	learn: 0.7318265	total: 4.14s	remaining: 670ms
1291:	learn: 0.7318265	total: 4.14s	remaining: 667ms
1292:	learn: 0.7318265	total: 4.15s	remaining: 664ms
1293:	learn: 0.7316378	total: 4.15s	remaining: 661ms
1294:	learn: 0.7317321	total: 4.15s	remaining: 658ms
1295:	learn: 0.7316775	total: 4.16s	remaining: 654ms
1296:	learn: 0.7316775	total: 4.16s	remaining: 651ms
1297:	learn: 0.7318265	total: 4.16s	remaining: 648ms
1298:	learn: 0.7318265	total: 4.17s	remaining: 645ms
1299:	learn: 0.7316775	total: 4.17s	remaining: 641ms
1300:	learn: 0.7316775	total: 4.17s	remaining: 638ms
1301:	learn: 0.7318811	total: 4.18s	remaining:

1452:	learn: 0.7326954	total: 4.7s	remaining: 152ms
1453:	learn: 0.7326410	total: 4.7s	remaining: 149ms
1454:	learn: 0.7326954	total: 4.71s	remaining: 146ms
1455:	learn: 0.7326954	total: 4.71s	remaining: 142ms
1456:	learn: 0.7326954	total: 4.71s	remaining: 139ms
1457:	learn: 0.7323972	total: 4.72s	remaining: 136ms
1458:	learn: 0.7323428	total: 4.72s	remaining: 133ms
1459:	learn: 0.7322883	total: 4.72s	remaining: 129ms
1460:	learn: 0.7324374	total: 4.73s	remaining: 126ms
1461:	learn: 0.7321937	total: 4.73s	remaining: 123ms
1462:	learn: 0.7324374	total: 4.74s	remaining: 120ms
1463:	learn: 0.7321392	total: 4.74s	remaining: 117ms
1464:	learn: 0.7321392	total: 4.74s	remaining: 113ms
1465:	learn: 0.7325463	total: 4.75s	remaining: 110ms
1466:	learn: 0.7323972	total: 4.75s	remaining: 107ms
1467:	learn: 0.7323428	total: 4.75s	remaining: 104ms
1468:	learn: 0.7326410	total: 4.76s	remaining: 100ms
1469:	learn: 0.7322482	total: 4.76s	remaining: 97.2ms
1470:	learn: 0.7322482	total: 4.76s	remaining: 

<catboost.core.CatBoostClassifier at 0x7f9aa853f9d0>

In [26]:
lgb_model = lgb.LGBMClassifier(categorical_feature=categorical_indexes,
                              learning_rate=0.1,
                              eval_metric="Precision",
                              num_iterations=2000)

In [27]:
lgb_model.fit(X_train, y_train)

Please use categorical_feature argument of the Dataset constructor to pass this parameter.




LGBMClassifier(categorical_feature=[1, 3, 4, 7, 8, 9, 10, 12, 14, 15, 20, 21],
               eval_metric='Precision', num_iterations=2000)

In [28]:
y_pred_catboost = catboost_model.predict(X_test)

In [29]:
y_pred_lgb = lgb_model.predict(X_test)

## Evaluation, Comparaison et Validation des modèles

In [30]:
print(classification_report(y_test, y_pred_catboost))

              precision    recall  f1-score   support

           0       0.71      0.73      0.72      1209
           1       0.74      0.71      0.72      1253

    accuracy                           0.72      2462
   macro avg       0.72      0.72      0.72      2462
weighted avg       0.72      0.72      0.72      2462



In [31]:
print(classification_report(y_test, y_pred_lgb))

              precision    recall  f1-score   support

           0       0.66      0.68      0.67      1209
           1       0.68      0.66      0.67      1253

    accuracy                           0.67      2462
   macro avg       0.67      0.67      0.67      2462
weighted avg       0.67      0.67      0.67      2462



In [32]:
catboost_model.save_model('../models/catboost_model.pkl')

In [33]:
lgb_model.booster_.save_model('../models/lgbm_model.pkl')

<lightgbm.basic.Booster at 0x7f9a4832c0a0>

In [34]:
pd.DataFrame({'columns': X.columns, 'feature_importance': catboost_model.feature_importances_}).sort_values(by="feature_importance", ascending=False)

Unnamed: 0,columns,feature_importance
10,tourney_name,20.158569
14,p1_id,18.670931
3,p2_id,16.917043
5,p2_rank_points,12.795306
16,p1_rank_points,11.86314
20,p1_new_rank,5.451546
21,p2_new_rank,4.641654
19,p2_is_seed_player,3.708885
18,p1_is_seed_player,2.443946
4,p2_ioc,0.933132


In [35]:
pd.DataFrame({'columns': X.columns, 'feature_importance': lgb_model.feature_importances_}).sort_values(by="feature_importance", ascending=False)

Unnamed: 0,columns,feature_importance
16,p1_rank_points,8186
5,p2_rank_points,7661
6,match_num,7449
11,p1_age,7215
0,p2_age,7183
14,p1_id,4653
3,p2_id,4613
10,tourney_name,3963
17,tourney_month,2116
13,p1_ht,1555
