In [14]:
import numpy as np
from sklearn.datasets import load_svmlight_file

def load_data(file):
    x, y = load_svmlight_file(file)
    with open(file) as f:
        qid = np.array([ int(line.split()[-1]) for line in f ])
    order = np.argsort(qid)
    return x.toarray()[order], y[order], qid[order]

In [12]:
def split(x, y, qid, train_ratio):
    n = qid.shape[0]
    train_inds = np.sort(np.random.choice(n, int(n * train_ratio), replace=False))
    test_inds = np.sort(np.setdiff1d(range(n), train_inds))
    return (
        x[train_inds], y[train_inds], qid[train_inds],
        x[test_inds], y[test_inds], qid[test_inds],
    )

In [16]:
x, y, qid = load_data("imat2009_learning.txt")
train_x, train_y, train_qid, test_x, test_y, test_qid = split(x, y, qid, 0.8)
test_x, test_y, test_qid, val_x, val_y, val_qid = split(test_x, test_y, test_qid, 0.8)

[0.00000e+00 1.00000e+00 1.01620e-02 0.00000e+00 0.00000e+00 0.00000e+00
 6.69220e-01 0.00000e+00 0.00000e+00 9.87800e-03 6.75667e-01 1.00000e+00
 7.46552e-01 3.92200e-03 7.45100e-02 0.00000e+00 1.00000e+00 9.21569e-01
 0.00000e+00 0.00000e+00 0.00000e+00 1.00000e+00 1.00000e+00 1.00000e+00
 0.00000e+00 0.00000e+00 2.00000e-01 8.10427e-01 0.00000e+00 0.00000e+00
 6.88351e-01 0.00000e+00 7.26000e-04 0.00000e+00 0.00000e+00 0.00000e+00
 0.00000e+00 6.66670e-02 0.00000e+00 1.32760e-02 1.00000e+00 1.00000e+00
 5.08100e-03 0.00000e+00 3.92200e-03 0.00000e+00 0.00000e+00 5.05000e-04
 1.00000e+00 0.00000e+00 1.58840e-01 0.00000e+00 1.62300e-03 7.84310e-02
 0.00000e+00 5.08191e-01 2.03000e-04 0.00000e+00 2.39216e-01 0.00000e+00
 0.00000e+00 0.00000e+00 0.00000e+00 0.00000e+00 1.00000e+00 0.00000e+00
 2.04080e-02 2.60420e-02 7.08000e-04 1.00000e+00 3.63000e-04 4.98039e-01
 9.66216e-01 3.52941e-01 0.00000e+00 5.00000e-01 6.69220e-01 0.00000e+00
 1.00000e+00 1.00000e+00 0.00000e+00 0.00000e+00 0.

In [17]:
from catboost import Pool
train = Pool(data=train_x, label=train_y, group_id=train_qid)
test = Pool(data=test_x, label=test_y, group_id=test_qid)
val = Pool(data=val_x, label=val_y, group_id=val_qid)

In [20]:
from catboost import CatBoost

parameters = { 'custom_metric': ['NDCG:top=20'] }

model = CatBoost(parameters)
model.fit(train, eval_set=test)

0:	learn: 0.9309695	test: 0.9332768	best: 0.9332768 (0)	total: 23.3ms	remaining: 23.2s
1:	learn: 0.9255467	test: 0.9278003	best: 0.9278003 (1)	total: 45.8ms	remaining: 22.8s
2:	learn: 0.9200991	test: 0.9223048	best: 0.9223048 (2)	total: 71.2ms	remaining: 23.7s
3:	learn: 0.9150020	test: 0.9172600	best: 0.9172600 (3)	total: 99ms	remaining: 24.7s
4:	learn: 0.9101257	test: 0.9123483	best: 0.9123483 (4)	total: 127ms	remaining: 25.2s
5:	learn: 0.9054398	test: 0.9077070	best: 0.9077070 (5)	total: 152ms	remaining: 25.2s
6:	learn: 0.9009480	test: 0.9032350	best: 0.9032350 (6)	total: 180ms	remaining: 25.6s
7:	learn: 0.8968122	test: 0.8990285	best: 0.8990285 (7)	total: 203ms	remaining: 25.1s
8:	learn: 0.8926154	test: 0.8949543	best: 0.8949543 (8)	total: 225ms	remaining: 24.8s
9:	learn: 0.8885970	test: 0.8909099	best: 0.8909099 (9)	total: 256ms	remaining: 25.3s
10:	learn: 0.8846605	test: 0.8870549	best: 0.8870549 (10)	total: 280ms	remaining: 25.2s
11:	learn: 0.8810301	test: 0.8834032	best: 0.88340

95:	learn: 0.7756878	test: 0.7815389	best: 0.7815389 (95)	total: 2.36s	remaining: 22.2s
96:	learn: 0.7753743	test: 0.7812474	best: 0.7812474 (96)	total: 2.39s	remaining: 22.2s
97:	learn: 0.7749043	test: 0.7808387	best: 0.7808387 (97)	total: 2.41s	remaining: 22.2s
98:	learn: 0.7745595	test: 0.7805271	best: 0.7805271 (98)	total: 2.43s	remaining: 22.1s
99:	learn: 0.7742507	test: 0.7802449	best: 0.7802449 (99)	total: 2.46s	remaining: 22.1s
100:	learn: 0.7739384	test: 0.7799438	best: 0.7799438 (100)	total: 2.48s	remaining: 22.1s
101:	learn: 0.7736357	test: 0.7796711	best: 0.7796711 (101)	total: 2.5s	remaining: 22s
102:	learn: 0.7730068	test: 0.7791373	best: 0.7791373 (102)	total: 2.52s	remaining: 22s
103:	learn: 0.7727485	test: 0.7789240	best: 0.7789240 (103)	total: 2.55s	remaining: 22s
104:	learn: 0.7724529	test: 0.7786993	best: 0.7786993 (104)	total: 2.58s	remaining: 22s
105:	learn: 0.7721094	test: 0.7783874	best: 0.7783874 (105)	total: 2.6s	remaining: 21.9s
106:	learn: 0.7718272	test: 0.

194:	learn: 0.7515486	test: 0.7604898	best: 0.7604898 (194)	total: 4.75s	remaining: 19.6s
195:	learn: 0.7514444	test: 0.7603783	best: 0.7603783 (195)	total: 4.77s	remaining: 19.6s
196:	learn: 0.7513109	test: 0.7602370	best: 0.7602370 (196)	total: 4.8s	remaining: 19.6s
197:	learn: 0.7511450	test: 0.7600938	best: 0.7600938 (197)	total: 4.82s	remaining: 19.5s
198:	learn: 0.7510042	test: 0.7599808	best: 0.7599808 (198)	total: 4.85s	remaining: 19.5s
199:	learn: 0.7508368	test: 0.7598317	best: 0.7598317 (199)	total: 4.87s	remaining: 19.5s
200:	learn: 0.7506434	test: 0.7596913	best: 0.7596913 (200)	total: 4.89s	remaining: 19.5s
201:	learn: 0.7505196	test: 0.7596076	best: 0.7596076 (201)	total: 4.92s	remaining: 19.4s
202:	learn: 0.7503896	test: 0.7595026	best: 0.7595026 (202)	total: 4.94s	remaining: 19.4s
203:	learn: 0.7501918	test: 0.7593258	best: 0.7593258 (203)	total: 4.96s	remaining: 19.4s
204:	learn: 0.7500292	test: 0.7592097	best: 0.7592097 (204)	total: 4.98s	remaining: 19.3s
205:	learn:

289:	learn: 0.7382704	test: 0.7497818	best: 0.7497818 (289)	total: 7.11s	remaining: 17.4s
290:	learn: 0.7381636	test: 0.7497083	best: 0.7497083 (290)	total: 7.13s	remaining: 17.4s
291:	learn: 0.7379980	test: 0.7495677	best: 0.7495677 (291)	total: 7.16s	remaining: 17.4s
292:	learn: 0.7378979	test: 0.7494905	best: 0.7494905 (292)	total: 7.19s	remaining: 17.3s
293:	learn: 0.7378306	test: 0.7494283	best: 0.7494283 (293)	total: 7.21s	remaining: 17.3s
294:	learn: 0.7376998	test: 0.7492996	best: 0.7492996 (294)	total: 7.24s	remaining: 17.3s
295:	learn: 0.7376188	test: 0.7492272	best: 0.7492272 (295)	total: 7.27s	remaining: 17.3s
296:	learn: 0.7375101	test: 0.7491713	best: 0.7491713 (296)	total: 7.29s	remaining: 17.3s
297:	learn: 0.7373342	test: 0.7489958	best: 0.7489958 (297)	total: 7.32s	remaining: 17.2s
298:	learn: 0.7372638	test: 0.7489776	best: 0.7489776 (298)	total: 7.34s	remaining: 17.2s
299:	learn: 0.7371097	test: 0.7488629	best: 0.7488629 (299)	total: 7.37s	remaining: 17.2s
300:	learn

385:	learn: 0.7294770	test: 0.7432014	best: 0.7432014 (385)	total: 9.46s	remaining: 15.1s
386:	learn: 0.7293362	test: 0.7430859	best: 0.7430859 (386)	total: 9.49s	remaining: 15s
387:	learn: 0.7292478	test: 0.7430100	best: 0.7430100 (387)	total: 9.51s	remaining: 15s
388:	learn: 0.7291358	test: 0.7429313	best: 0.7429313 (388)	total: 9.53s	remaining: 15s
389:	learn: 0.7289968	test: 0.7428322	best: 0.7428322 (389)	total: 9.56s	remaining: 14.9s
390:	learn: 0.7289408	test: 0.7427969	best: 0.7427969 (390)	total: 9.58s	remaining: 14.9s
391:	learn: 0.7288546	test: 0.7427396	best: 0.7427396 (391)	total: 9.6s	remaining: 14.9s
392:	learn: 0.7287467	test: 0.7426677	best: 0.7426677 (392)	total: 9.62s	remaining: 14.9s
393:	learn: 0.7286256	test: 0.7425737	best: 0.7425737 (393)	total: 9.65s	remaining: 14.8s
394:	learn: 0.7285371	test: 0.7425121	best: 0.7425121 (394)	total: 9.67s	remaining: 14.8s
395:	learn: 0.7284096	test: 0.7423840	best: 0.7423840 (395)	total: 9.7s	remaining: 14.8s
396:	learn: 0.7283

482:	learn: 0.7209812	test: 0.7370396	best: 0.7370396 (482)	total: 11.8s	remaining: 12.6s
483:	learn: 0.7209381	test: 0.7370297	best: 0.7370297 (483)	total: 11.8s	remaining: 12.6s
484:	learn: 0.7208508	test: 0.7369696	best: 0.7369696 (484)	total: 11.8s	remaining: 12.6s
485:	learn: 0.7207895	test: 0.7369211	best: 0.7369211 (485)	total: 11.9s	remaining: 12.5s
486:	learn: 0.7207049	test: 0.7368730	best: 0.7368730 (486)	total: 11.9s	remaining: 12.5s
487:	learn: 0.7206281	test: 0.7368344	best: 0.7368344 (487)	total: 11.9s	remaining: 12.5s
488:	learn: 0.7205875	test: 0.7368145	best: 0.7368145 (488)	total: 11.9s	remaining: 12.5s
489:	learn: 0.7205501	test: 0.7367904	best: 0.7367904 (489)	total: 11.9s	remaining: 12.4s
490:	learn: 0.7204930	test: 0.7367649	best: 0.7367649 (490)	total: 12s	remaining: 12.4s
491:	learn: 0.7204398	test: 0.7367390	best: 0.7367390 (491)	total: 12s	remaining: 12.4s
492:	learn: 0.7203313	test: 0.7366437	best: 0.7366437 (492)	total: 12s	remaining: 12.4s
493:	learn: 0.72

581:	learn: 0.7142736	test: 0.7327777	best: 0.7327777 (581)	total: 14.2s	remaining: 10.2s
582:	learn: 0.7141894	test: 0.7327151	best: 0.7327151 (582)	total: 14.2s	remaining: 10.1s
583:	learn: 0.7141371	test: 0.7326940	best: 0.7326940 (583)	total: 14.2s	remaining: 10.1s
584:	learn: 0.7140684	test: 0.7326570	best: 0.7326570 (584)	total: 14.2s	remaining: 10.1s
585:	learn: 0.7140154	test: 0.7326500	best: 0.7326500 (585)	total: 14.3s	remaining: 10.1s
586:	learn: 0.7139681	test: 0.7326380	best: 0.7326380 (586)	total: 14.3s	remaining: 10.1s
587:	learn: 0.7138945	test: 0.7325969	best: 0.7325969 (587)	total: 14.3s	remaining: 10s
588:	learn: 0.7138371	test: 0.7325747	best: 0.7325747 (588)	total: 14.3s	remaining: 10s
589:	learn: 0.7137629	test: 0.7325202	best: 0.7325202 (589)	total: 14.4s	remaining: 9.98s
590:	learn: 0.7137250	test: 0.7324997	best: 0.7324997 (590)	total: 14.4s	remaining: 9.96s
591:	learn: 0.7136867	test: 0.7324803	best: 0.7324803 (591)	total: 14.4s	remaining: 9.93s
592:	learn: 0.

679:	learn: 0.7086441	test: 0.7295682	best: 0.7295682 (679)	total: 16.5s	remaining: 7.77s
680:	learn: 0.7085949	test: 0.7295606	best: 0.7295606 (680)	total: 16.5s	remaining: 7.75s
681:	learn: 0.7085629	test: 0.7295366	best: 0.7295366 (681)	total: 16.6s	remaining: 7.72s
682:	learn: 0.7085095	test: 0.7295076	best: 0.7295076 (682)	total: 16.6s	remaining: 7.7s
683:	learn: 0.7084619	test: 0.7294724	best: 0.7294724 (683)	total: 16.6s	remaining: 7.67s
684:	learn: 0.7083856	test: 0.7294350	best: 0.7294350 (684)	total: 16.6s	remaining: 7.65s
685:	learn: 0.7083177	test: 0.7293938	best: 0.7293938 (685)	total: 16.7s	remaining: 7.62s
686:	learn: 0.7082465	test: 0.7293428	best: 0.7293428 (686)	total: 16.7s	remaining: 7.6s
687:	learn: 0.7082022	test: 0.7293182	best: 0.7293182 (687)	total: 16.7s	remaining: 7.58s
688:	learn: 0.7081284	test: 0.7292668	best: 0.7292668 (688)	total: 16.7s	remaining: 7.55s
689:	learn: 0.7080837	test: 0.7292354	best: 0.7292354 (689)	total: 16.8s	remaining: 7.53s
690:	learn: 

774:	learn: 0.7038431	test: 0.7270238	best: 0.7270238 (774)	total: 19.1s	remaining: 5.54s
775:	learn: 0.7038058	test: 0.7270111	best: 0.7270111 (775)	total: 19.1s	remaining: 5.51s
776:	learn: 0.7037424	test: 0.7269713	best: 0.7269713 (776)	total: 19.1s	remaining: 5.49s
777:	learn: 0.7036762	test: 0.7269353	best: 0.7269353 (777)	total: 19.2s	remaining: 5.47s
778:	learn: 0.7036328	test: 0.7269197	best: 0.7269197 (778)	total: 19.2s	remaining: 5.45s
779:	learn: 0.7035879	test: 0.7269114	best: 0.7269114 (779)	total: 19.2s	remaining: 5.42s
780:	learn: 0.7035508	test: 0.7268961	best: 0.7268961 (780)	total: 19.3s	remaining: 5.4s
781:	learn: 0.7035033	test: 0.7268843	best: 0.7268843 (781)	total: 19.3s	remaining: 5.38s
782:	learn: 0.7034741	test: 0.7268618	best: 0.7268618 (782)	total: 19.3s	remaining: 5.35s
783:	learn: 0.7034411	test: 0.7268373	best: 0.7268373 (783)	total: 19.4s	remaining: 5.33s
784:	learn: 0.7034130	test: 0.7268237	best: 0.7268237 (784)	total: 19.4s	remaining: 5.31s
785:	learn:

868:	learn: 0.6996681	test: 0.7249974	best: 0.7249974 (868)	total: 21.9s	remaining: 3.3s
869:	learn: 0.6996224	test: 0.7249609	best: 0.7249609 (869)	total: 21.9s	remaining: 3.27s
870:	learn: 0.6995790	test: 0.7249463	best: 0.7249463 (870)	total: 21.9s	remaining: 3.25s
871:	learn: 0.6995469	test: 0.7249427	best: 0.7249427 (871)	total: 22s	remaining: 3.22s
872:	learn: 0.6995286	test: 0.7249353	best: 0.7249353 (872)	total: 22s	remaining: 3.2s
873:	learn: 0.6994760	test: 0.7249285	best: 0.7249285 (873)	total: 22s	remaining: 3.17s
874:	learn: 0.6994354	test: 0.7249141	best: 0.7249141 (874)	total: 22s	remaining: 3.15s
875:	learn: 0.6994120	test: 0.7249141	best: 0.7249141 (875)	total: 22.1s	remaining: 3.12s
876:	learn: 0.6993572	test: 0.7248871	best: 0.7248871 (876)	total: 22.1s	remaining: 3.1s
877:	learn: 0.6993071	test: 0.7248719	best: 0.7248719 (877)	total: 22.1s	remaining: 3.07s
878:	learn: 0.6992666	test: 0.7248457	best: 0.7248457 (878)	total: 22.2s	remaining: 3.05s
879:	learn: 0.6992433

964:	learn: 0.6955647	test: 0.7232145	best: 0.7232145 (964)	total: 24.7s	remaining: 896ms
965:	learn: 0.6955385	test: 0.7232023	best: 0.7232023 (965)	total: 24.7s	remaining: 871ms
966:	learn: 0.6954999	test: 0.7232057	best: 0.7232023 (965)	total: 24.8s	remaining: 845ms
967:	learn: 0.6954596	test: 0.7231790	best: 0.7231790 (967)	total: 24.8s	remaining: 820ms
968:	learn: 0.6954115	test: 0.7231506	best: 0.7231506 (968)	total: 24.8s	remaining: 794ms
969:	learn: 0.6953835	test: 0.7231351	best: 0.7231351 (969)	total: 24.8s	remaining: 768ms
970:	learn: 0.6953491	test: 0.7231263	best: 0.7231263 (970)	total: 24.9s	remaining: 743ms
971:	learn: 0.6953042	test: 0.7231021	best: 0.7231021 (971)	total: 24.9s	remaining: 717ms
972:	learn: 0.6952785	test: 0.7230951	best: 0.7230951 (972)	total: 24.9s	remaining: 692ms
973:	learn: 0.6952502	test: 0.7230826	best: 0.7230826 (973)	total: 25s	remaining: 666ms
974:	learn: 0.6952162	test: 0.7230559	best: 0.7230559 (974)	total: 25s	remaining: 641ms
975:	learn: 0.

<catboost.core.CatBoost at 0x7f4b09e51128>

In [None]:
parameters = { 'custom_metric': ['NDCG:top=20'], 'loss_function': 'PairLogitPairwise' }

model = CatBoost(parameters)
model.fit(train, eval_set=test)

In [None]:
parameters = { 'custom_metric': ['NDCG:top=20'], 'loss_function': 'YetiRankPairwise' }

model = CatBoost(parameters)
model.fit(train, eval_set=test)