In [1]:
import numpy as np
import scanpy as sp
import pandas as pd
import pickle
import matplotlib
import matplotlib.pyplot as plt
plt.style.use('ggplot')
%matplotlib inline

import sklearn as sk
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.model_selection import train_test_split

seed = 2023 # DO NOT CHANGE!

In [2]:
print(f"sklearn version: {sk.__version__}")
print(f"numpy version: {np.__version__}")
print(f"pandas version: {pd.__version__}")
print(f"matplotlib version: {matplotlib.__version__}")
print(f"scanpy version: {sp.__version__}")

sklearn version: 1.0.1
numpy version: 1.21.3
pandas version: 1.3.4
matplotlib version: 3.4.3
scanpy version: 1.9.1


# Split Train/test

In [3]:
PBMC = sp.read_h5ad("../pbmc_multiome.h5ad")

In [4]:
data = PBMC.X

In [5]:
print(f"The data has {PBMC.n_obs} observations and {PBMC.n_vars} features.")

The data has 9641 observations and 19607 features.


In [14]:
cutoff = 0.001

cell_types, type_numbers = np.unique(PBMC.obs['predicted.id'], return_counts=True)
bad_types = cell_types[type_numbers / len(PBMC.obs['predicted.id'])<cutoff]
print(bad_types)

bad_types_mask = np.invert(np.isin(PBMC.obs['predicted.id'], bad_types))
X = PBMC.X[bad_types_mask]
Y = PBMC.obs['predicted.id'][bad_types_mask]

print(Y.shape)

['ASDC' 'CD4 Proliferating' 'CD8 Proliferating' 'HSPC' 'ILC' 'cDC1' 'dnT']
(9619,)


In [15]:
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.25, random_state=seed)

n_train = X_train.shape[0]
n_test = X_test.shape[0]
print(f"{n_train} train samples\n{n_test} test samples\n{n_train/(n_train+n_test)*100:.2f}% of samples used for training")

7214 train samples
2405 test samples
75.00% of samples used for training


# Model Analysis

In [16]:
# load model
with open('PBMC_RF.pkl', 'rb') as f:
    best_model = pickle.load(f)

In [17]:
test_pred = best_model.predict(X_test)
train_pred = best_model.predict(X_train)
all_pred = best_model.predict(data)

In [21]:
# Export the Test_pred, Train_pred and all_pred
with open("PBMC_all_pred_RF.pkl", "wb") as f:
    pickle.dump(all_pred,f)

with open("PBMC_Y_test_RF.pkl", "wb") as f:
    pickle.dump(Y_test,f)

with open("PBMC_Y_train_RF.pkl", "wb") as f:
    pickle.dump(all_pred,f)
    
with open("PBMC_Test_pred_RF.pkl", "wb") as f:
    pickle.dump(test_pred,f)

## Accuracy

In [18]:
from sklearn.metrics import accuracy_score, balanced_accuracy_score


print(f"Train accuracy: {accuracy_score(Y_train, train_pred):.5f}")
print(f"Test accuracy: {accuracy_score(Y_test, test_pred):.5f}")
print("")
print(f"Balanced Train Accuracy: {balanced_accuracy_score(Y_train, train_pred):.5f}")
print(f"Balanced Test Accuracy: {balanced_accuracy_score(Y_test, test_pred):.5f}")
print("")

Train accuracy: 1.00000
Test accuracy: 0.66154

Balanced Train Accuracy: 1.00000
Balanced Test Accuracy: 0.23119



## F1 Score

In [20]:
from sklearn.metrics import classification_report

print(classification_report(Y_test, test_pred))

                  precision    recall  f1-score   support

  B intermediate       0.65      0.92      0.76        78
        B memory       0.00      0.00      0.00        17
         B naive       0.00      0.00      0.00        29
       CD14 Mono       0.79      0.99      0.88       507
       CD16 Mono       0.00      0.00      0.00        74
       CD4 Naive       0.61      0.40      0.49       335
         CD4 TCM       0.57      0.94      0.71       657
         CD4 TEM       0.00      0.00      0.00        44
       CD8 Naive       1.00      0.18      0.30       137
         CD8 TCM       0.00      0.00      0.00         8
         CD8 TEM       0.69      0.67      0.68       306
            MAIT       0.00      0.00      0.00        43
              NK       0.81      0.53      0.64        74
NK Proliferating       0.00      0.00      0.00         1
   NK_CD56bright       0.00      0.00      0.00         4
     Plasmablast       0.00      0.00      0.00         3
            T

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
