# Restoring Trained Models

In [1]:
import pandas as pd
from preprocess import GSE
from sklearn.preprocessing import MinMaxScaler
from models import load_weights, plot_latent, plot_reconstructions

## Loading a trained model

In [2]:
dataset, model = load_weights('all_data_new', 'vae')

  if __name__ == '__main__':


### Generating cell labels for all data.

In [3]:
label_df = pd.read_csv('data/all_data_new.txt', sep='\t', usecols=[0, 1, 2])
label_df.head()

  interactivity=interactivity, compiler=compiler, result=result)


Unnamed: 0,tumor,dataset,cell
0,MUV1,GSE102130,MUV1-P04-B12
1,MUV1,GSE102130,MUV1-P04-C08
2,MUV1,GSE102130,MUV1-P04-D09
3,MUV1,GSE102130,MUV1-P04-D10
4,MUV1,GSE102130,MUV1-P04-E03


In [4]:
# generating cell labels for bokeh plot
cell_labels = label_df[['dataset', 'cell']].apply(lambda c: c[0] + '-' + c[1], axis=1)
cell_labels = list(cell_labels)

In [5]:
plot_latent(dataset, model, cell_names=cell_labels)

## Passing a single dataset through autoencoder trained on all data

In [6]:
genes = pd.read_csv('data/gene_counts.txt')
genes.head()

Unnamed: 0,Gene,Count
0,SNX3,8
1,INMT,8
2,DHODH,8
3,MIEN1,8
4,EXD2,8


In [7]:
# keep only genes in all datasets
filter_genes = genes[genes['Count'] == 8]['Gene']
filter_genes.head()

0     SNX3
1     INMT
2    DHODH
3    MIEN1
4     EXD2
Name: Gene, dtype: object

Load a new dataset (GSE57872):

In [8]:
gse57872 = GSE(class_name='patient_id')
gse57872.data.head()

Unnamed: 0,A2M,AAAS,AAK1,AAMP,AARS,AARSD1,AASDH,AASDHPPT,AASS,AATF,...,ZSCAN30,ZSWIM6,ZSWIM7,ZUFSP,ZW10,ZWILCH,ZXDC,ZYG11B,ZYX,ZZZ3
MGH264_A01,-3.80147,-3.8899,-3.985616,2.651558,2.170748,-2.550822,4.80733,3.96117,-0.192665,3.614482,...,6.262256,2.909466,-3.118284,-1.538324,-1.550699,-1.558581,-1.920271,3.007439,-2.509017,-2.149696
MGH264_A02,-3.80147,-3.8899,-3.158708,2.358992,-6.041791,-0.056092,3.606735,-2.63225,2.249388,6.857517,...,2.91234,-1.821098,-3.118284,-1.538324,-1.550699,-1.558581,-3.06862,2.53956,2.164481,-2.149696
MGH264_A03,-3.80147,-3.8899,1.733125,-5.820241,-6.041791,-0.576957,-2.473517,-4.354127,0.063178,-2.570976,...,-2.593571,-1.821098,5.521892,-1.538324,-1.550699,-1.558581,0.174665,-0.165409,0.734268,-2.149696
MGH264_A04,-3.80147,-3.8899,-1.665669,3.514271,-6.041791,-3.699171,4.509461,-4.354127,2.985972,-2.570976,...,4.453042,4.952177,-0.854351,-1.538324,-1.550699,-1.558581,-3.06862,-1.884744,-2.509017,-2.149696
MGH264_A05,-3.80147,3.742495,-2.166992,-5.820241,2.094729,4.021873,5.535007,4.019633,2.56037,-2.570976,...,-2.593571,-1.821098,4.328808,-1.538324,7.021985,-1.558581,4.590947,-0.128456,-2.509017,-2.149696


Note: all_data_new.txt contains all dataset filtered by gene and then scaled. You must implement the same pipeline here. GSE object provides the data scaled on all datasets.

*WARNING* : You should provide a suffix to plot_latent or some plots might be replaced

In [9]:
filtered_gse57872 = gse57872.data[filter_genes]
data_scaled = pd.DataFrame(MinMaxScaler().fit_transform(filtered_gse57872.values),
                           columns=filtered_gse57872.columns,
                           index=filtered_gse57872.index)
data_scaled.head()

Unnamed: 0,SNX3,INMT,DHODH,MIEN1,EXD2,LTA4H,REV3L,CCND2,MTMR9,NUP85,...,AP3M1,TRIM37,NUDT21,CCNI,DLG1,ACTL6A,ZRANB2,PPIL3,ACTR2,OTUD7B
MGH264_A01,0.0,0.0,0.0,0.0,0.0,0.376972,0.477716,0.769334,0.593638,0.747905,...,0.0,0.0,0.753674,0.648484,0.567578,0.0,0.793461,0.553562,0.0,0.0
MGH264_A02,0.690598,0.27079,0.0,0.0,0.825481,0.0,0.800515,0.736555,0.130644,0.0,...,0.0,0.0,0.0,0.569593,0.607381,0.860326,0.83667,0.0,0.377934,0.0
MGH264_A03,0.0,0.400174,0.359327,0.0,0.867033,0.328774,0.568145,0.805525,0.0,0.0,...,0.0,0.685409,0.0,0.844729,0.380794,0.0,0.0,0.0,0.309564,0.0
MGH264_A04,0.0,0.245337,0.0,0.0,0.611076,0.0,0.758784,0.604455,0.0,0.0,...,0.0,0.0,0.0,0.390571,0.0,0.0,0.777052,0.0,0.366423,0.0
MGH264_A05,0.0,0.249359,0.901454,0.0,0.867866,0.529915,0.776914,0.695886,0.549622,0.758096,...,0.760453,0.375964,0.0,0.792104,0.685979,0.832832,0.820849,0.0,0.507898,0.67508


In [10]:
gse57872.data_scaled = data_scaled.values
plot_latent(gse57872, model, suffix='gse57872')