In [1]:
from sklearn import model_selection
from sklearn import metrics
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder
import os
import numpy as np
import pandas as pd
from MicroBiome import MicroBiomeDataSet, Trainer, TrainTester, MultiTrainTester
from SNN import FeedForward, SiameseDataSet, SiameseModel
import seaborn as sns

# Ignore warning messages
if True:
    import warnings
    warnings.filterwarnings('ignore')

# Load Data

In [2]:
top_dir = '/project'
data_dir = os.path.join(top_dir, 'data')
preproc_dir = os.path.join(data_dir, 'preprocessed')
inp_dir = os.path.join(preproc_dir, 'Gupta_2020_Precompiled_Cleaned')

In [3]:
os.listdir(inp_dir)

['MatrixDataClean.csv', 'SampleMetaDataClean.csv', 'FeatMetaDataClean.csv']

In [4]:
MatrixData = np.loadtxt(os.path.join(inp_dir, 'MatrixDataClean.csv'))
MatrixData.shape

(4347, 903)

In [5]:
SampleMeta = pd.read_csv(os.path.join(inp_dir, 'SampleMetaDataClean.csv'))
SampleMeta

Unnamed: 0,6,study,Study No. (From VG sheet (V-*) from SB sheet (S-*)),Title of Paper,Author (year),Journal,Study Accession,Sample Accession or Sample ID,Sample title (ENA/SRA),Sample title (Paper),...,Use of milk or milk products (Yes/No/Sometimes),Use of Animal Product- Meat (Yes/No/Vegetarian/Vegan),Alcohol Consumption (Yes/No),Diet1,Intervention for study (Diet),Intervention for study (medication),Sequencing Platform,Library layout (SINGLE/PAIRED),Read count after Quality control,Healthy
0,SAMEA104142287,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142287,ZSL-004,ZSL-004,...,,,,,,,Illumina HiSeq 2000,PAIRED,43356775.0,Unhealthy
1,SAMEA104142288,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142288,ZSL-007,ZSL-007,...,,,,,,,Illumina HiSeq 2000,PAIRED,41073871.0,Unhealthy
2,SAMEA104142293,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142293,ZSL-010,ZSL-010,...,,,,,,,Illumina HiSeq 2000,PAIRED,40199673.0,Unhealthy
3,SAMEA104142291,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142291,ZSL-011,ZSL-011,...,,,,,,,Illumina HiSeq 2000,PAIRED,31054158.0,Unhealthy
4,SAMEA104142284,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142284,ZSL-019,ZSL-019,...,,,,,,,Illumina HiSeq 2000,PAIRED,36081150.0,Unhealthy
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4342,SAMEA4431948,S-7_Underweight,S-7,Two distinct metacommunities characterize the ...,He (2017),Gigascience,PRJEB15371,SAMEA4431948,SZAXPI029564-74,SZAXPI029564-74,...,,,,,No,No,Illumina HiSeq 2000,PAIRED,52212493.0,Unhealthy
4343,SAMEA4431949,S-7_Underweight,S-7,Two distinct metacommunities characterize the ...,He (2017),Gigascience,PRJEB15371,SAMEA4431949,SZAXPI029565-77,SZAXPI029565-77,...,,,,,No,No,Illumina HiSeq 2000,PAIRED,50635879.0,Unhealthy
4344,SAMEA4431951,S-7_Underweight,S-7,Two distinct metacommunities characterize the ...,He (2017),Gigascience,PRJEB15371,SAMEA4431951,SZAXPI029567-80,SZAXPI029567-80,...,,,,,No,No,Illumina HiSeq 2000,PAIRED,40712962.0,Unhealthy
4345,SAMEA4431964,S-7_Underweight,S-7,Two distinct metacommunities characterize the ...,He (2017),Gigascience,PRJEB15371,SAMEA4431964,SZAXPI029580-98,SZAXPI029580-98,...,,,,,No,No,Illumina HiSeq 2000,PAIRED,38177360.0,Unhealthy


In [6]:
FeatMeta = pd.read_csv(os.path.join(inp_dir, 'FeatMetaDataClean.csv'))
FeatMeta

Unnamed: 0.1,Unnamed: 0,feature
0,s__Abiotrophia_defectiva,s__Abiotrophia_defectiva
1,s__Acetobacter_unclassified,s__Acetobacter_unclassified
2,s__Achromobacter_piechaudii,s__Achromobacter_piechaudii
3,s__Achromobacter_unclassified,s__Achromobacter_unclassified
4,s__Achromobacter_xylosoxidans,s__Achromobacter_xylosoxidans
...,...,...
898,s__Weissella_koreensis,s__Weissella_koreensis
899,s__Weissella_paramesenteroides,s__Weissella_paramesenteroides
900,s__Weissella_unclassified,s__Weissella_unclassified
901,s__Wohlfahrtiimonas_chitiniclastica,s__Wohlfahrtiimonas_chitiniclastica


# Define and Train Model

In [7]:
MyFeedForward = FeedForward(100, [50, 25, 10])

In [8]:
MyFeedForward

FeedForward(
  (layers): ModuleList(
    (0): Linear(in_features=100, out_features=50, bias=True)
    (1): Linear(in_features=50, out_features=25, bias=True)
    (2): Linear(in_features=25, out_features=10, bias=True)
  )
)

In [9]:
MySNN = SiameseModel(MyFeedForward, predict_unknown = False)

In [10]:
# list(MySNN.logistic.parameters()) + list(MySNN.model.parameters())

In [11]:
MyTrainer = Trainer(model = MySNN, scale_X = True, use_pca = True, n_components = 100)
MyTrainTester = TrainTester(MyTrainer, metrics.accuracy_score)

In [12]:
# make binarized class matrix
y = SampleMeta['Healthy'].to_numpy().astype('str') == 'Healthy'
y = y.astype('int32')
OneHot = OneHotEncoder(sparse=False)
y = OneHot.fit_transform(y.reshape((y.shape[0], 1)))

In [13]:
y.shape

(4347, 2)

In [14]:
MatrixData.dtype

dtype('float64')

In [15]:
y.dtype

dtype('float64')

In [16]:
MyTrainTester.train(MatrixData, y)

#########################################
Epoch 1 of 5
__Training__
2021-03-06 00:41:08
Batch Mean Loss: 0.6978363990783691
Batch Mean Loss: 0.7005513310432434
Batch Mean Loss: 0.6980744004249573
Batch Mean Loss: 0.6950007081031799
Batch Mean Loss: 0.6930163502693176
Batch Mean Loss: 0.6927925944328308
Batch Mean Loss: 0.6914594769477844
Batch Mean Loss: 0.6903901100158691
Batch Mean Loss: 0.6945204138755798
Batch Mean Loss: 0.6905078291893005
Batch Mean Loss: 0.6920049786567688
Batch Mean Loss: 0.6954811215400696
Batch Mean Loss: 0.691545844078064
Batch Mean Loss: 0.6875942349433899
Batch Mean Loss: 0.6940100193023682
Batch Mean Loss: 0.6918597221374512
Batch Mean Loss: 0.6911461353302002
Batch Mean Loss: 0.6907747983932495
Batch Mean Loss: 0.6913939714431763
Batch Mean Loss: 0.6902841925621033
MEAN LOSS: 0.6930122344970703
__Validation__
2021-03-06 00:41:08
Batch Mean Loss: 0.6907351613044739
Batch Mean Loss: 0.689842700958252
Batch Mean Loss: 0.6895021200180054
MEAN LOSS: 0.69021033

In [17]:
for x in MySNN.model.parameters():
    print(x.dtype)

torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32


In [18]:
for x in MySNN.logistic.parameters():
    print(x.dtype)

torch.float32
torch.float32


In [19]:
y_pred_train = MyTrainTester.Trainer.predict(MyTrainTester.X_train)

In [20]:
y_pred_train_1D = OneHot.inverse_transform(y_pred_train)

In [21]:
y_train_1D = OneHot.inverse_transform(MyTrainTester.y_train)

In [22]:
metrics.balanced_accuracy_score(y_pred_train_1D, y_train_1D)

0.6699919733975461