In [None]:
# !pip install dask-ml[xgboost]
!pip install xgboost
!pip install scikit-learn

# Train XGBoost on 256 PCA components

In [1]:
from os.path import join

import xgboost as xgb
import numpy as np
import dask.dataframe as dd

In [2]:
DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_simple_norm_parquet'

In [3]:
x_train = np.load(join(DATA_PATH, 'pca/x_pca_training_train_split_256.npy'))
y_train = dd.read_parquet(join(DATA_PATH, 'train'), columns='cell_type').compute().to_numpy()

x_val = np.load(join(DATA_PATH, 'pca/x_pca_training_val_split_256.npy'))
y_val = dd.read_parquet(join(DATA_PATH, 'val'), columns='cell_type').compute().to_numpy()

class_weights = np.load(join(DATA_PATH, 'class_weights.npy'))

In [4]:
class_weights = {i: weight for i, weight in enumerate(np.load(join(DATA_PATH, 'class_weights.npy')))}
weights = np.array([class_weights[label] for label in y_train])

In [5]:
clf = xgb.XGBClassifier(
    tree_method='gpu_hist',
    gpu_id=0,
    n_estimators=1000,
    eta=0.1,
    subsample=0.75,
    max_depth=20,
    n_jobs=20,
    early_stopping_rounds=10
)
clf = clf.fit(
    x_train, y_train, sample_weight=weights, 
    eval_set=[(x_val, y_val)]
)

[0]	validation_0-mlogloss:2.10118
[1]	validation_0-mlogloss:1.93694
[2]	validation_0-mlogloss:1.81939
[3]	validation_0-mlogloss:1.72482
[4]	validation_0-mlogloss:1.64597
[5]	validation_0-mlogloss:1.57715
[6]	validation_0-mlogloss:1.51643
[7]	validation_0-mlogloss:1.46269
[8]	validation_0-mlogloss:1.41416
[9]	validation_0-mlogloss:1.37082
[10]	validation_0-mlogloss:1.33098
[11]	validation_0-mlogloss:1.29495
[12]	validation_0-mlogloss:1.26117
[13]	validation_0-mlogloss:1.23021
[14]	validation_0-mlogloss:1.20158
[15]	validation_0-mlogloss:1.17507
[16]	validation_0-mlogloss:1.15024
[17]	validation_0-mlogloss:1.12724
[18]	validation_0-mlogloss:1.10570
[19]	validation_0-mlogloss:1.08525
[20]	validation_0-mlogloss:1.06623
[21]	validation_0-mlogloss:1.04858
[22]	validation_0-mlogloss:1.03150
[23]	validation_0-mlogloss:1.01581
[24]	validation_0-mlogloss:1.00115
[25]	validation_0-mlogloss:0.98713
[26]	validation_0-mlogloss:0.97399
[27]	validation_0-mlogloss:0.96165
[28]	validation_0-mlogloss:0.9

In [6]:
clf.save_model('model6.json')