There is an endemic problem today in Bioinformatics.  With Deep Learning$^{\text{TM}}$ becoming the next Hot New Thing, the Bioinformatics community has worked quickly to catch up.  However, I think some important lessons of the Machine Learning community have not filtered through to the Bioinformatics community.  Specifically the big issue in how to choose a test set.

# What purpose does a test set serve?

The fundamental purpose of a test set is a fair and honest evaluation of the performance of the trained model.  In other words, the test set is used to answer the question: if we train the model and use int on future data, how well can we expect the model to perform.  The honesty part is absolutely critical, because otherwise we're going to overpromise the performance.  This is why the machine learning community hammers on the concept of a proper test set.  And it also guides the choice of the test set. I'll illustrate this with some examples.

## Examples

Suppose we want to build a model to predict stock prices.  If we train the model today, then we would be able to predict prices for tomorrow.  In particular, no information from tomorrow will be seen by the model during training.   In particular, if stocks A & B are correlated, say they're both in the same industry and will be similarly affected by underlying economic conditions, then we cannot use the fact that if stock A is up then so will stock B (except in a Granger sense, meaning we can use the current rise in A to predict a likely rise of B in the future).  Or if there is some economic event that causes a general shift in stock prices.  The latter would be apparent if you saw 80% of future prices, but not if you saw none of them.  Therefore the proper split to simulate this behavior is to take a time-based split.  If we don't do this, then we would fool ourselves in the performance of the algorithm and it's possible that we might deploy an algorithm that under-performs and loses us money.

Or consider the following example from [r/MachineLearning](https://www.reddit.com/r/MachineLearning/comments/c4ylga/d_misuse_of_deep_learning_in_nature_journals/): a Nature paper proposed a deep neural network to predict the location of aftershocks.  If we train a model today and an earthquake happens tomorrow, then what information is available to predict the location of the subsequent aftershocks?  We can use past earthquakes and aftershocks, as well as information we get from the initial earthquake.  What we don't get is aftershocks from the current earthquake.  Therefore, a proper split would be either by time (as above) or by earthquakes grouped with their aftershocks.  As the post above shows, doing the latter type of split results in a simple regularized regression having better test set performance, which indicates that the deep neural network is over-fitting.  

Now consider examples from bioinformatics.  Suppose we want to build a model to predict which guide RNAs are going to be effective.  If we want to apply the model to help design guides in a new experiment, then we would very typically not have access to a previous experiment in the same cell type and target phenotype (that we are trying to select for).  If we did, then we can just use that experiment to select which guides to use.  In particular, if we use the same experiment to predict and evaluate the model then there will be several confounders such as batch effects which will make us over-confident in  our predictions.  [One paper](https://www.nature.com/articles/nbt.4061) clearly showed this with an out-of-sample test set (hidden in the supplementary) where simple regularized regression showed better performance than their proposed deep learning model.  

Now consider the problem of predicting gene expression from other modalities, such as the promoter genetic sequence plus open chromatin of the particular cell type.  If we want to deploy the model, then we would take the model, the genetic sequence, and open chromatin data to predict the gene expression of a sample for which we have no gene expression data on, like a new patient sample.  The key here is that there a lot of biological (cell type to cell type or person to person) variation and batch effects present.  These effects hugely impact the variation, but they won't be available to the model in production.  Therefore, if the model is able "see" those batch effects (say using a simple train-test split), we will overestimate the accuracy of the model.  


# Example: guide RNA design

To clearly illustrate how this issue arises we'll use the third example above.  Let's say we want to build a model to improve on-target effects for CRISPRko (CRISPR knockout) guides.  To train the model we'll use the [Toronto Knockout Library](http://tko.ccbr.utoronto.ca/) dataset, a collection of CRISPRko experiments on 5 different cell lines for gene essentiallity.  To remove the bias of biological effect and the bias of using the training data to select positive hit genes, we'll subset the training data to previously known essential genes (from http://www.ncbi.nlm.nih.gov/pubmed/24987113).  

First what we'll have to do is process the counts to convert it to log fold change.  We'll do this using all guides.

## Preprocessing


```{toggle}

# this was done in R outside and is not run in this notebook
tko_loc = '/Users/tim.daley/blog/timydaley.github.io/crispr_tko/'
libs = c("DLD1", "GBM", "HCT116_1", "HeLa", "RPE1")
df_list = list()
for(l in libs){
  loc = paste0(tko_loc, "readcount-", l, "-lib1")
  x = read.table(loc, header = T)
  df_list[[l]] = x
}
for(l in libs){
  df_list[[l]]["SEQ"] = sapply(df_list[[l]]$GENE_CLONE, function(s) unlist(strsplit(s, "_"))[2])
}
design_matrices = list()
counts_list = list()
# we need custom design matrices for each experiment because the designs are not identical
# DLD1
counts_list[["DLD1"]] = df_list[["DLD1"]][c("DLD_T0", "DLD_ETOH_R1", "DLD_ETOH_R2", "DLD_ETOH_R3")]
design_matrices[["DLD1"]] = data.frame(condition = c(0, 1, 1, 1), row.names = colnames(counts_list[["DLD1"]]))
# GBM
counts_list[["GBM"]] = df_list[["GBM"]][c("T0", "T21A", "T21B")]
design_matrices[["GBM"]] = data.frame(condition = c(0, 1, 1), row.names = colnames(counts_list[["GBM"]]))
# HCT116_1
counts_list[["HCT116_1"]] = df_list[["HCT116_1"]][c("LIB1_T0", "LIB1_T18_A", "LIB1_T18_B")]
design_matrices[["HCT116_1"]] = data.frame(condition = c(0, 1, 1), row.names = colnames(counts_list[["HCT116_1"]]))
# HeLa
counts_list[["HeLa"]] = df_list[["HeLa"]][c("T0", "T18A", "T18B", "T18C")]
design_matrices[["HeLa"]] = data.frame(condition = c(0, 1, 1, 1), row.names = colnames(counts_list[["HeLa"]]))
# RPE1
counts_list[["RPE1"]] = df_list[["RPE1"]][c("T0", "T18A", "T18B")]
design_matrices[["RPE1"]] = data.frame(condition = c(0, 1, 1), row.names = colnames(counts_list[["RPE1"]]))
# now compute log2 fold changes
log2fc_list = list()
for(l in libs){
  d = DESeq2::DESeqDataSetFromMatrix(countData = counts_list[[l]],
                                     colData = design_matrices[[l]],
                                     design = ~condition)
  d = DESeq2::DESeq(d)
  d = DESeq2::results(d)
  log2fc_list[[l]] = data.frame(d, seq = df_list[[l]]$SEQ, gene = df_list[[l]]$GENE)
}
# now subset to known positive genes
#essential_genes = factor(scan(paste0(tko_loc, "ConstitutiveCoreEssentialGenes.txt"), what = character()))
essential_genes = read.table(file = paste0(tko_loc, "reference_essentials_and_nonessentials_sym_hgnc_entrez/constitutive_core_essentials_hg-Table1.tsv"), header = T)
#sum(essential_genes$Gene %in% factor(df_list[["DLD1"]]$GENE))

# what we really want is a table with log2fc, guide sequence, gene, and cell type
log2fc = data.frame()
for(l in libs){
  log2fc = rbind(log2fc, data.frame(log2fc_list[[l]][c("seq", "gene", "log2FoldChange")], lib = l))
}
log2fc['essential'] = 1*(log2fc$gene %in% essential_genes$Gene)
write.table(log2fc, file = paste0(tko_loc, "CombinedLog2FoldChanges.txt"), quote = F, sep = '\t', row.names = F)
```

In [1]:
import pandas as pd
import numpy as np

log2fc_df = pd.read_csv("../CombinedLog2FoldChanges.txt", sep = '\t')
print(log2fc_df.shape)
log2fc_df.head()

(456600, 5)


Unnamed: 0,seq,gene,log2FoldChange,lib,essential
0,CACCTTCGAGCTGCTGCGCG,A1BG,-0.198332,DLD1,0
1,AAGAGCGCCTCGGTCCCAGC,A1BG,-0.631673,DLD1,0
2,TGGACTTCCAGCTACGGCGC,A1BG,-1.315708,DLD1,0
3,CACTGGCGCCATCGAGAGCC,A1BG,0.989644,DLD1,0
4,GCTCGGGCTTGTCCACAGGA,A1BG,0.021679,DLD1,0


In [2]:
log2fc_df.loc[log2fc_df['log2FoldChange'].isna(), 'log2FoldChange'] = 0

In [3]:
# breakdown between libraries
log2fc_df['lib'].value_counts()

HCT116_1    91320
GBM         91320
RPE1        91320
HeLa        91320
DLD1        91320
Name: lib, dtype: int64

In [4]:
log2fc_df['seq'].map(lambda s: len(s)).value_counts()

20    456600
Name: seq, dtype: int64

In [5]:
essential_data = log2fc_df[(log2fc_df['essential'] == 1) | (log2fc_df['gene'] == 'chr10')].copy()
#essential_data.loc[essential_data['log2FoldChange'].isna(), 'log2FoldChange'] = 0
essential_data = essential_data.reset_index(drop = True)
essential_data['lib'].value_counts()

GBM         2888
HCT116_1    2888
RPE1        2888
HeLa        2888
DLD1        2888
Name: lib, dtype: int64

In [6]:
seq_array = pd.DataFrame(np.array([list(x) for x in essential_data['seq']]))
seq_array_1hot = pd.get_dummies(seq_array)
print(seq_array_1hot.shape)
seq_array_1hot.head()

(14440, 76)


Unnamed: 0,0_A,0_C,0_G,0_T,1_A,1_C,1_G,1_T,2_A,2_C,...,16_G,17_A,17_C,17_G,18_A,18_C,18_G,19_A,19_C,19_G
0,0,0,1,0,0,0,1,0,1,0,...,1,1,0,0,0,0,1,0,1,0
1,0,0,1,0,0,0,1,0,0,1,...,0,1,0,0,0,0,1,0,0,1
2,0,0,0,1,1,0,0,0,1,0,...,0,0,1,0,0,0,1,1,0,0
3,1,0,0,0,0,0,1,0,0,0,...,1,0,0,1,1,0,0,0,1,0
4,0,0,1,0,1,0,0,0,0,0,...,0,1,0,0,0,1,0,1,0,0


In [7]:
seq_array_1hot.columns.values

array(['0_A', '0_C', '0_G', '0_T', '1_A', '1_C', '1_G', '1_T', '2_A',
       '2_C', '2_G', '2_T', '3_A', '3_C', '3_G', '3_T', '4_A', '4_C',
       '4_G', '4_T', '5_A', '5_C', '5_G', '5_T', '6_A', '6_C', '6_G',
       '6_T', '7_A', '7_C', '7_G', '7_T', '8_A', '8_C', '8_G', '8_T',
       '9_A', '9_C', '9_G', '9_T', '10_A', '10_C', '10_G', '10_T', '11_A',
       '11_C', '11_G', '11_T', '12_A', '12_C', '12_G', '12_T', '13_A',
       '13_C', '13_G', '13_T', '14_A', '14_C', '14_G', '14_T', '15_A',
       '15_C', '15_G', '15_T', '16_A', '16_C', '16_G', '17_A', '17_C',
       '17_G', '18_A', '18_C', '18_G', '19_A', '19_C', '19_G'],
      dtype=object)

In [8]:
seq_array[19].value_counts()

C    6780
A    4180
G    3480
Name: 19, dtype: int64

OK, so the missing T's in the end of the guide appear to be missing from the data.  Let's re-index it to add those in.

In [9]:
new_cols = seq_array_1hot.columns.values.tolist() + ['16_T', '17_T', '18_T', '19_T']
print(new_cols)
seq_array_1hot = seq_array_1hot.reindex(columns =  new_cols, fill_value=0)
seq_array_1hot.head()

['0_A', '0_C', '0_G', '0_T', '1_A', '1_C', '1_G', '1_T', '2_A', '2_C', '2_G', '2_T', '3_A', '3_C', '3_G', '3_T', '4_A', '4_C', '4_G', '4_T', '5_A', '5_C', '5_G', '5_T', '6_A', '6_C', '6_G', '6_T', '7_A', '7_C', '7_G', '7_T', '8_A', '8_C', '8_G', '8_T', '9_A', '9_C', '9_G', '9_T', '10_A', '10_C', '10_G', '10_T', '11_A', '11_C', '11_G', '11_T', '12_A', '12_C', '12_G', '12_T', '13_A', '13_C', '13_G', '13_T', '14_A', '14_C', '14_G', '14_T', '15_A', '15_C', '15_G', '15_T', '16_A', '16_C', '16_G', '17_A', '17_C', '17_G', '18_A', '18_C', '18_G', '19_A', '19_C', '19_G', '16_T', '17_T', '18_T', '19_T']


Unnamed: 0,0_A,0_C,0_G,0_T,1_A,1_C,1_G,1_T,2_A,2_C,...,18_A,18_C,18_G,19_A,19_C,19_G,16_T,17_T,18_T,19_T
0,0,0,1,0,0,0,1,0,1,0,...,0,0,1,0,1,0,0,0,0,0
1,0,0,1,0,0,0,1,0,0,1,...,0,0,1,0,0,1,0,0,0,0
2,0,0,0,1,1,0,0,0,1,0,...,0,0,1,1,0,0,0,0,0,0
3,1,0,0,0,0,0,1,0,0,0,...,1,0,0,0,1,0,0,0,0,0
4,0,0,1,0,1,0,0,0,0,0,...,0,1,0,1,0,0,0,0,0,0


To control for variable gene effect sizes I'll include a gene indicator.  

In [10]:
gene_one_hot = pd.get_dummies(essential_data['gene'], prefix = 'gene')
print(gene_one_hot.shape)
gene_one_hot.head()

(14440, 321)


Unnamed: 0,gene_ACTL6A,gene_ACTR6,gene_ALYREF,gene_ANAPC4,gene_ANAPC5,gene_AP2S1,gene_AQR,gene_ARCN1,gene_ARL5B,gene_ATP6V0D1,...,gene_XIAP,gene_XPO1,gene_YY1,gene_ZBTB48,gene_ZC3H13,gene_ZC3H18,gene_ZFR,gene_ZNF160,gene_ZNF207,gene_chr10
0,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [11]:
X = seq_array_1hot.merge(gene_one_hot, left_index = True, right_index = True)
y = essential_data['log2FoldChange']

In [12]:
from IPython.display import display, HTML
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.metrics import mean_squared_error,explained_variance_score
from tabulate import tabulate

## Simple train-test split

First, let's look at a simple train-test split.  Since there's 5 libraries/data sources, I'll do a 20% test set size. Note that the training data is evenly split by library, so taking a standard Cv split results in a split by library.  I'll shuffle the data frame before computing the CV scores.

In [13]:
shuffled_df = X.copy()
shuffled_df['y'] = y.copy()
shuffled_df = shuffled_df.sample(frac=1, replace=False).reset_index(drop=True)
shuffled_y = shuffled_df['y']
shuffled_X = shuffled_df.drop(['y'], axis = 1)
rf_model = RandomForestRegressor()
cv_scores = cross_val_score(rf_model, shuffled_X, shuffled_y, cv=5)
cv_scores = pd.DataFrame({'cv': [1, 2, 3, 4, 5],
                          'score': cv_scores})
display(HTML(cv_scores.to_html()))

Unnamed: 0,cv,score
0,1,0.309174
1,2,0.27467
2,3,0.31301
3,4,0.336566
4,5,0.346031


In [14]:
cv_scores_summary = pd.DataFrame({'mean': [cv_scores['score'].mean()],
                                  'sd': [cv_scores['score'].std()]})
display(HTML(cv_scores_summary.to_html()))

Unnamed: 0,mean,sd
0,0.31589,0.027787


## Split by library

Now let's take a look at what happens when you split by library. Note that since the libraries are in order, and there are an equal number of guides per library, we can do a standard 5-way cross validation.

In [15]:
rf_model = RandomForestRegressor()
lib_cv_scores = cross_val_score(rf_model, X, y, cv=5)
lib_cv_scores = pd.DataFrame({'lib': essential_data['lib'].unique(),
                              'score': lib_cv_scores})
display(HTML(lib_cv_scores.to_html()))

Unnamed: 0,lib,score
0,DLD1,0.466004
1,GBM,0.49741
2,HCT116_1,0.322926
3,HeLa,-0.278183
4,RPE1,0.352822


In [16]:
lib_cv_scores_summary = pd.DataFrame({'mean': [lib_cv_scores['score'].mean()],
                                      'sd': [lib_cv_scores['score'].std()]})
display(HTML(lib_cv_scores_summary.to_html()))

Unnamed: 0,mean,sd
0,0.272196,0.316336


# Interpretation

Note that the average $R^{2}$ (the default score for `RandomForestRegressor`) is lower when split by library.  In addition the variance is higher, which results in the test set performance being high for some libraries. Some libraries are predicted very well (e.g. DLD1 and GBM) and some are predicted very poorly (HeLa).  The latter was noted in a [previous project](https://genomebiology.biomedcentral.com/articles/10.1186/s13059-020-01972-x) I was involved in with [Sunil Bodapati](https://www.linkedin.com/in/sunil-bodapati/).  

The order of the cell types are as follows:
- DLD1, [male colorectal cancer cell line](https://www.atcc.org/products/ccl-221);
- GBM, glioblastoma (don't know the exact cell line);
- HCT116, [male colorectal carcinoma cell line](https://imanislife.com/collections/cell-lines/hct116-cells/);
- HeLa, [female cervical cancer](https://en.wikipedia.org/wiki/HeLa);
- RPE1, [female immortalized retinal pigment epithelium](https://web.expasy.org/cellosaurus/CVCL_4388).

Note that the first two have the highest test set scores.  It seems reasonable that DLD1 and HCT116 would be highly predictive of each other, since they are similar cell types.  And it is reasonable that HeLa is very difficult to predict, since the karyotype of HeLa is completely haywire.  I really have no hypotheses about the good test score of GBM.    Critically what we're missing is the metadata, such as specific experimental design and who prepared the libraries.  In my experience, such details are crucial to evaluating the quality of a sequencing-based experiment.  When reserchers outside the organization use a publicly available ML tool, then batch effects will be new (to the ML tool).  