## Import des librairies

In [1]:
import pandas as pd 
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
from pickle import dump 
import lightgbm as lgb

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

## Chargement des données

In [2]:
data = pd.read_csv("../data/processed/data_preprocessed.csv")

In [3]:
data.head()

Unnamed: 0,p2_age,p2_hand,p2_ht,p2_id,p2_ioc,p2_rank_points,match_num,round,surface,tourney_level,tourney_name,p1_age,p1_hand,p1_ht,p1_id,p1_ioc,p1_rank_points,p1_won,tourney_month,p1_is_seed_player,p2_is_seed_player,p1_new_rank,p2_new_rank
0,0.90462,L,0.204632,104269,ESP,-0.237747,-1.558052,R64,Hard,M,Cincinnati Masters,0.576739,R,0.2156288,104542,FRA,0.098019,0,0.912765,-0.715611,-0.70978,Top 30,Top 30-100
1,0.728708,R,-0.546553,104571,CYP,-0.203132,0.749218,R32,Grass,A,Stuttgart,-0.375267,R,-4.275467e-15,105526,GER,-0.425967,1,0.241991,-0.715611,-0.70978,Top 30-100,Top 30-100
2,-0.367358,L,-0.546553,105385,USA,-0.378871,-0.586125,R16,Hard,G,US Open,0.599352,R,-0.5365192,104527,SUI,2.23873,1,0.912765,1.397407,-0.70978,Top 30,Top 30-100
3,0.219015,L,-0.246079,104745,ESP,2.291827,-1.363667,QF,Clay,A,Rio de Janeiro,0.305383,R,-0.987808,104655,URU,-0.006251,0,-1.099557,1.397407,1.408888,Top 30,Top 30
4,-0.062895,R,0.505105,104997,NED,-0.471533,-1.473537,R32,Hard,A,Zagreb,0.581262,R,-0.23566,104433,CAN,-0.596592,0,-1.099557,-0.715611,-0.70978,Under 100,Top 30-100


In [4]:
data.dtypes

p2_age               float64
p2_hand               object
p2_ht                float64
p2_id                  int64
p2_ioc                object
p2_rank_points       float64
match_num            float64
round                 object
surface               object
tourney_level         object
tourney_name          object
p1_age               float64
p1_hand               object
p1_ht                float64
p1_id                  int64
p1_ioc                object
p1_rank_points       float64
p1_won                 int64
tourney_month        float64
p1_is_seed_player    float64
p2_is_seed_player    float64
p1_new_rank           object
p2_new_rank           object
dtype: object

In [5]:
data[["p1_id", "p2_id"]] = data[["p1_id", "p2_id"]].astype('category')
object_columns = data.select_dtypes(['object']).columns
data[object_columns] = data[object_columns].astype('category')

In [6]:
X = data.drop(columns=['p1_won'])
y = data['p1_won']

In [7]:
categorical_columns = X.select_dtypes(['category']).columns

In [8]:
categorical_columns

Index(['p2_hand', 'p2_id', 'p2_ioc', 'round', 'surface', 'tourney_level',
       'tourney_name', 'p1_hand', 'p1_id', 'p1_ioc', 'p1_new_rank',
       'p2_new_rank'],
      dtype='object')

In [9]:
categorical_indexes = []
for element in categorical_columns:
    categorical_indexes.append(X.columns.to_list().index(element))

## Entraînement des modèles

In [10]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

In [11]:
catboost_model = CatBoostClassifier(iterations=1500,
                           learning_rate=0.01,
                           depth=2)

In [12]:
catboost_model.fit(X_train, y_train, categorical_indexes)

0:	learn: 0.6923388	total: 70.3ms	remaining: 1m 45s
1:	learn: 0.6915426	total: 76.1ms	remaining: 57s
2:	learn: 0.6907667	total: 80.2ms	remaining: 40s
3:	learn: 0.6899982	total: 83ms	remaining: 31s
4:	learn: 0.6892528	total: 85.9ms	remaining: 25.7s
5:	learn: 0.6885068	total: 88.8ms	remaining: 22.1s
6:	learn: 0.6876964	total: 91.6ms	remaining: 19.5s
7:	learn: 0.6869961	total: 93.9ms	remaining: 17.5s
8:	learn: 0.6862968	total: 96.6ms	remaining: 16s
9:	learn: 0.6855810	total: 99.8ms	remaining: 14.9s
10:	learn: 0.6848348	total: 103ms	remaining: 13.9s
11:	learn: 0.6841841	total: 106ms	remaining: 13.2s
12:	learn: 0.6835470	total: 109ms	remaining: 12.5s
13:	learn: 0.6828666	total: 112ms	remaining: 11.9s
14:	learn: 0.6821823	total: 115ms	remaining: 11.3s
15:	learn: 0.6815260	total: 117ms	remaining: 10.9s
16:	learn: 0.6809373	total: 121ms	remaining: 10.5s
17:	learn: 0.6803191	total: 124ms	remaining: 10.2s
18:	learn: 0.6797270	total: 126ms	remaining: 9.83s
19:	learn: 0.6790756	total: 129ms	remain

197:	learn: 0.6117926	total: 653ms	remaining: 4.29s
198:	learn: 0.6115237	total: 656ms	remaining: 4.29s
199:	learn: 0.6113203	total: 659ms	remaining: 4.29s
200:	learn: 0.6110499	total: 662ms	remaining: 4.28s
201:	learn: 0.6108747	total: 666ms	remaining: 4.28s
202:	learn: 0.6105108	total: 669ms	remaining: 4.27s
203:	learn: 0.6102402	total: 672ms	remaining: 4.27s
204:	learn: 0.6100640	total: 675ms	remaining: 4.26s
205:	learn: 0.6097073	total: 678ms	remaining: 4.26s
206:	learn: 0.6093565	total: 681ms	remaining: 4.25s
207:	learn: 0.6091758	total: 683ms	remaining: 4.25s
208:	learn: 0.6089978	total: 686ms	remaining: 4.24s
209:	learn: 0.6088171	total: 690ms	remaining: 4.24s
210:	learn: 0.6084805	total: 693ms	remaining: 4.23s
211:	learn: 0.6083053	total: 696ms	remaining: 4.23s
212:	learn: 0.6080339	total: 699ms	remaining: 4.22s
213:	learn: 0.6077045	total: 702ms	remaining: 4.22s
214:	learn: 0.6074473	total: 705ms	remaining: 4.21s
215:	learn: 0.6071253	total: 708ms	remaining: 4.21s
216:	learn: 

387:	learn: 0.5805401	total: 1.24s	remaining: 3.54s
388:	learn: 0.5803991	total: 1.24s	remaining: 3.54s
389:	learn: 0.5802509	total: 1.24s	remaining: 3.53s
390:	learn: 0.5801196	total: 1.24s	remaining: 3.53s
391:	learn: 0.5800627	total: 1.25s	remaining: 3.53s
392:	learn: 0.5799698	total: 1.25s	remaining: 3.52s
393:	learn: 0.5799217	total: 1.25s	remaining: 3.52s
394:	learn: 0.5798027	total: 1.26s	remaining: 3.52s
395:	learn: 0.5796404	total: 1.26s	remaining: 3.52s
396:	learn: 0.5795792	total: 1.26s	remaining: 3.51s
397:	learn: 0.5795284	total: 1.27s	remaining: 3.51s
398:	learn: 0.5794533	total: 1.27s	remaining: 3.51s
399:	learn: 0.5793082	total: 1.27s	remaining: 3.5s
400:	learn: 0.5792226	total: 1.28s	remaining: 3.5s
401:	learn: 0.5790677	total: 1.28s	remaining: 3.5s
402:	learn: 0.5789469	total: 1.28s	remaining: 3.5s
403:	learn: 0.5787845	total: 1.29s	remaining: 3.5s
404:	learn: 0.5786062	total: 1.29s	remaining: 3.49s
405:	learn: 0.5785536	total: 1.29s	remaining: 3.49s
406:	learn: 0.578

582:	learn: 0.5632289	total: 1.82s	remaining: 2.86s
583:	learn: 0.5631427	total: 1.82s	remaining: 2.86s
584:	learn: 0.5630515	total: 1.82s	remaining: 2.85s
585:	learn: 0.5629737	total: 1.83s	remaining: 2.85s
586:	learn: 0.5628815	total: 1.83s	remaining: 2.85s
587:	learn: 0.5628347	total: 1.83s	remaining: 2.85s
588:	learn: 0.5627768	total: 1.84s	remaining: 2.84s
589:	learn: 0.5626986	total: 1.84s	remaining: 2.84s
590:	learn: 0.5626089	total: 1.84s	remaining: 2.84s
591:	learn: 0.5625890	total: 1.85s	remaining: 2.83s
592:	learn: 0.5625197	total: 1.85s	remaining: 2.83s
593:	learn: 0.5624535	total: 1.85s	remaining: 2.83s
594:	learn: 0.5624177	total: 1.86s	remaining: 2.83s
595:	learn: 0.5623391	total: 1.86s	remaining: 2.82s
596:	learn: 0.5622968	total: 1.86s	remaining: 2.82s
597:	learn: 0.5622205	total: 1.87s	remaining: 2.82s
598:	learn: 0.5621562	total: 1.87s	remaining: 2.81s
599:	learn: 0.5620682	total: 1.87s	remaining: 2.81s
600:	learn: 0.5619721	total: 1.88s	remaining: 2.81s
601:	learn: 

765:	learn: 0.5547005	total: 2.4s	remaining: 2.3s
766:	learn: 0.5546908	total: 2.4s	remaining: 2.3s
767:	learn: 0.5546577	total: 2.41s	remaining: 2.3s
768:	learn: 0.5546269	total: 2.41s	remaining: 2.29s
769:	learn: 0.5545932	total: 2.42s	remaining: 2.29s
770:	learn: 0.5545415	total: 2.42s	remaining: 2.29s
771:	learn: 0.5545118	total: 2.42s	remaining: 2.28s
772:	learn: 0.5545062	total: 2.42s	remaining: 2.28s
773:	learn: 0.5544528	total: 2.43s	remaining: 2.28s
774:	learn: 0.5543897	total: 2.43s	remaining: 2.27s
775:	learn: 0.5543438	total: 2.43s	remaining: 2.27s
776:	learn: 0.5543212	total: 2.44s	remaining: 2.27s
777:	learn: 0.5542837	total: 2.44s	remaining: 2.26s
778:	learn: 0.5542384	total: 2.44s	remaining: 2.26s
779:	learn: 0.5541816	total: 2.44s	remaining: 2.26s
780:	learn: 0.5541564	total: 2.45s	remaining: 2.25s
781:	learn: 0.5541107	total: 2.45s	remaining: 2.25s
782:	learn: 0.5541031	total: 2.46s	remaining: 2.25s
783:	learn: 0.5540503	total: 2.46s	remaining: 2.25s
784:	learn: 0.554

940:	learn: 0.5491551	total: 2.99s	remaining: 1.77s
941:	learn: 0.5491256	total: 2.99s	remaining: 1.77s
942:	learn: 0.5490831	total: 2.99s	remaining: 1.77s
943:	learn: 0.5490547	total: 3s	remaining: 1.76s
944:	learn: 0.5490477	total: 3s	remaining: 1.76s
945:	learn: 0.5490192	total: 3s	remaining: 1.76s
946:	learn: 0.5489894	total: 3.01s	remaining: 1.76s
947:	learn: 0.5489800	total: 3.01s	remaining: 1.75s
948:	learn: 0.5489427	total: 3.02s	remaining: 1.75s
949:	learn: 0.5489189	total: 3.02s	remaining: 1.75s
950:	learn: 0.5488921	total: 3.02s	remaining: 1.74s
951:	learn: 0.5488512	total: 3.02s	remaining: 1.74s
952:	learn: 0.5488287	total: 3.03s	remaining: 1.74s
953:	learn: 0.5487951	total: 3.03s	remaining: 1.74s
954:	learn: 0.5487461	total: 3.04s	remaining: 1.73s
955:	learn: 0.5486983	total: 3.04s	remaining: 1.73s
956:	learn: 0.5486638	total: 3.04s	remaining: 1.73s
957:	learn: 0.5486209	total: 3.04s	remaining: 1.72s
958:	learn: 0.5486091	total: 3.05s	remaining: 1.72s
959:	learn: 0.5486004

1111:	learn: 0.5446830	total: 3.57s	remaining: 1.24s
1112:	learn: 0.5446673	total: 3.57s	remaining: 1.24s
1113:	learn: 0.5446552	total: 3.57s	remaining: 1.24s
1114:	learn: 0.5446286	total: 3.58s	remaining: 1.24s
1115:	learn: 0.5446068	total: 3.58s	remaining: 1.23s
1116:	learn: 0.5445926	total: 3.58s	remaining: 1.23s
1117:	learn: 0.5445647	total: 3.59s	remaining: 1.23s
1118:	learn: 0.5445451	total: 3.59s	remaining: 1.22s
1119:	learn: 0.5445309	total: 3.59s	remaining: 1.22s
1120:	learn: 0.5445167	total: 3.6s	remaining: 1.22s
1121:	learn: 0.5444898	total: 3.6s	remaining: 1.21s
1122:	learn: 0.5444840	total: 3.6s	remaining: 1.21s
1123:	learn: 0.5444504	total: 3.61s	remaining: 1.21s
1124:	learn: 0.5444138	total: 3.61s	remaining: 1.2s
1125:	learn: 0.5443836	total: 3.61s	remaining: 1.2s
1126:	learn: 0.5443606	total: 3.61s	remaining: 1.2s
1127:	learn: 0.5443454	total: 3.62s	remaining: 1.19s
1128:	learn: 0.5443185	total: 3.62s	remaining: 1.19s
1129:	learn: 0.5442941	total: 3.62s	remaining: 1.19s

1268:	learn: 0.5418010	total: 4.15s	remaining: 756ms
1269:	learn: 0.5417880	total: 4.16s	remaining: 753ms
1270:	learn: 0.5417726	total: 4.16s	remaining: 750ms
1271:	learn: 0.5417667	total: 4.17s	remaining: 747ms
1272:	learn: 0.5417527	total: 4.17s	remaining: 744ms
1273:	learn: 0.5417196	total: 4.17s	remaining: 741ms
1274:	learn: 0.5416971	total: 4.18s	remaining: 738ms
1275:	learn: 0.5416846	total: 4.18s	remaining: 734ms
1276:	learn: 0.5416788	total: 4.19s	remaining: 731ms
1277:	learn: 0.5416607	total: 4.19s	remaining: 728ms
1278:	learn: 0.5416368	total: 4.19s	remaining: 725ms
1279:	learn: 0.5416285	total: 4.2s	remaining: 721ms
1280:	learn: 0.5416206	total: 4.2s	remaining: 718ms
1281:	learn: 0.5416080	total: 4.2s	remaining: 715ms
1282:	learn: 0.5415934	total: 4.21s	remaining: 712ms
1283:	learn: 0.5415800	total: 4.21s	remaining: 709ms
1284:	learn: 0.5415675	total: 4.21s	remaining: 705ms
1285:	learn: 0.5415463	total: 4.22s	remaining: 702ms
1286:	learn: 0.5415360	total: 4.22s	remaining: 69

1428:	learn: 0.5395187	total: 4.74s	remaining: 235ms
1429:	learn: 0.5395132	total: 4.74s	remaining: 232ms
1430:	learn: 0.5394985	total: 4.75s	remaining: 229ms
1431:	learn: 0.5394905	total: 4.75s	remaining: 226ms
1432:	learn: 0.5394853	total: 4.75s	remaining: 222ms
1433:	learn: 0.5394795	total: 4.76s	remaining: 219ms
1434:	learn: 0.5394570	total: 4.76s	remaining: 216ms
1435:	learn: 0.5394451	total: 4.76s	remaining: 212ms
1436:	learn: 0.5394228	total: 4.77s	remaining: 209ms
1437:	learn: 0.5394223	total: 4.77s	remaining: 206ms
1438:	learn: 0.5393979	total: 4.78s	remaining: 202ms
1439:	learn: 0.5393895	total: 4.78s	remaining: 199ms
1440:	learn: 0.5393846	total: 4.78s	remaining: 196ms
1441:	learn: 0.5393709	total: 4.79s	remaining: 193ms
1442:	learn: 0.5393616	total: 4.79s	remaining: 189ms
1443:	learn: 0.5393608	total: 4.79s	remaining: 186ms
1444:	learn: 0.5393557	total: 4.8s	remaining: 183ms
1445:	learn: 0.5393445	total: 4.8s	remaining: 179ms
1446:	learn: 0.5393233	total: 4.81s	remaining: 1

<catboost.core.CatBoostClassifier at 0x7fb3a8e6acd0>

In [13]:
lgb_model = lgb.LGBMClassifier(categorical_feature='auto',
                              learning_rate=0.1,
                              num_iterations=2000)

In [14]:
lgb_model.fit(X_train, y_train)

Please use categorical_feature argument of the Dataset constructor to pass this parameter.


LGBMClassifier(categorical_feature='auto', num_iterations=2000)

In [15]:
y_pred_catboost = catboost_model.predict(X_test)

In [16]:
y_pred_lgb = lgb_model.predict(X_test)

## Evaluation, Comparaison et Validation des modèles

In [17]:
print(classification_report(y_test, y_pred_catboost))

              precision    recall  f1-score   support

           0       0.71      0.73      0.72      1209
           1       0.74      0.71      0.72      1253

    accuracy                           0.72      2462
   macro avg       0.72      0.72      0.72      2462
weighted avg       0.72      0.72      0.72      2462



In [18]:
print(classification_report(y_test, y_pred_lgb))

              precision    recall  f1-score   support

           0       0.66      0.68      0.67      1209
           1       0.68      0.66      0.67      1253

    accuracy                           0.67      2462
   macro avg       0.67      0.67      0.67      2462
weighted avg       0.67      0.67      0.67      2462



In [19]:
catboost_model.save_model('../models/catboost_model.pkl')

In [20]:
lgb_model.booster_.save_model('../models/lgbm_model.pkl')

<lightgbm.basic.Booster at 0x7fb3a8e86c10>

In [21]:
pd.DataFrame({'columns': X.columns, 'feature_importance': catboost_model.feature_importances_}).sort_values(by="feature_importance", ascending=False)

Unnamed: 0,columns,feature_importance
10,tourney_name,20.158569
14,p1_id,18.670931
3,p2_id,16.917043
5,p2_rank_points,12.795306
16,p1_rank_points,11.86314
20,p1_new_rank,5.451546
21,p2_new_rank,4.641654
19,p2_is_seed_player,3.708885
18,p1_is_seed_player,2.443946
4,p2_ioc,0.933132


In [22]:
pd.DataFrame({'columns': X.columns, 'feature_importance': lgb_model.feature_importances_}).sort_values(by="feature_importance", ascending=False)

Unnamed: 0,columns,feature_importance
16,p1_rank_points,8186
5,p2_rank_points,7661
6,match_num,7449
11,p1_age,7215
0,p2_age,7183
14,p1_id,4653
3,p2_id,4613
10,tourney_name,3963
17,tourney_month,2116
13,p1_ht,1555
