# Train Model

We will train CNN on 3959 cells ('Blast, no lineage spec','Myelocyte','Promyelocyte','Metamyelocyte','Promonocyte') from 85 samples (53 AML, 32 APL) to predict whether a cell came from an AML or APL patient. 

## Load Data

In [5]:
import sys
sys.path.append('../')
from DeepAPL.DeepAPL import DeepAPL_SC
import os
import numpy as np
import pickle
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve,roc_auc_score
import warnings
warnings.filterwarnings('ignore')

gpu = 1
os.environ["CUDA DEVICE ORDER"] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)

#Train Classifier on Discovery Cohort
classes = ['AML','APL']
#Select for only Immature cells
cell_types = ['Blast, no lineage spec','Myelocyte','Promyelocyte','Metamyelocyte','Promonocyte']
device = '/device:GPU:'+ str(gpu)
DAPL = DeepAPL_SC('Blast_S_'+str(gpu),device=device)
DAPL.Import_Data(directory='../Data/All', Load_Prev_Data=False, classes=classes,
                 include_cell_types=cell_types,sample=None)

## Train Model

In [6]:
%%capture
folds = 100
seeds = np.array(range(folds))
epochs_min = 25
graph_seed = 0
DAPL.Monte_Carlo_CrossVal(folds=folds,seeds=seeds,epochs_min=epochs_min,
                          stop_criterion=0.25,test_size=0.25,graph_seed=graph_seed)

INFO:tensorflow:Blast_S_1/models/model_0/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_1/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_2/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_3/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_4/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_5/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_6/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_7/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_8/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tenso

INFO:tensorflow:Blast_S_1/models/model_74/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_75/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_76/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_77/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_78/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_79/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_80/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_81/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
INFO:tensorflow:Blast_S_1/models/model_82/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.
I

Get per Cell predictions and save relevant data for downstream analysis.

In [7]:
DAPL.Get_Cell_Predicted()
with open('Cell_Preds.pkl','wb') as f:
    pickle.dump(DAPL.Cell_Pred,f,protocol=4)
with open('Cell_Masks.pkl','wb') as f:
    pickle.dump(DAPL.w,f,protocol=4)