In [1]:
from sklearn import model_selection
from sklearn import metrics
from sklearn.preprocessing import StandardScaler
import os
import numpy as np
import pandas as pd
from MicroBiome import MicroBiomeDataSet, Trainer, TrainTester, MultiTrainTester
from SNN import FeedForward, SiameseDataSet, SiameseModel
import seaborn as sns

# Ignore warning messages
if True:
    import warnings
    warnings.filterwarnings('ignore')

# Load Data

In [2]:
top_dir = '/project'
data_dir = os.path.join(top_dir, 'data')
preproc_dir = os.path.join(data_dir, 'preprocessed')
inp_dir = os.path.join(preproc_dir, 'Gupta_2020_Precompiled_Cleaned')

In [3]:
os.listdir(inp_dir)

['MatrixDataClean.csv', 'SampleMetaDataClean.csv', 'FeatMetaDataClean.csv']

In [4]:
MatrixData = np.loadtxt(os.path.join(inp_dir, 'MatrixDataClean.csv'))
MatrixData.shape

(4347, 903)

In [5]:
SampleMeta = pd.read_csv(os.path.join(inp_dir, 'SampleMetaDataClean.csv'))
SampleMeta

Unnamed: 0,6,study,Study No. (From VG sheet (V-*) from SB sheet (S-*)),Title of Paper,Author (year),Journal,Study Accession,Sample Accession or Sample ID,Sample title (ENA/SRA),Sample title (Paper),...,Use of milk or milk products (Yes/No/Sometimes),Use of Animal Product- Meat (Yes/No/Vegetarian/Vegan),Alcohol Consumption (Yes/No),Diet1,Intervention for study (Diet),Intervention for study (medication),Sequencing Platform,Library layout (SINGLE/PAIRED),Read count after Quality control,Healthy
0,SAMEA104142287,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142287,ZSL-004,ZSL-004,...,,,,,,,Illumina HiSeq 2000,PAIRED,43356775.0,Unhealthy
1,SAMEA104142288,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142288,ZSL-007,ZSL-007,...,,,,,,,Illumina HiSeq 2000,PAIRED,41073871.0,Unhealthy
2,SAMEA104142293,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142293,ZSL-010,ZSL-010,...,,,,,,,Illumina HiSeq 2000,PAIRED,40199673.0,Unhealthy
3,SAMEA104142291,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142291,ZSL-011,ZSL-011,...,,,,,,,Illumina HiSeq 2000,PAIRED,31054158.0,Unhealthy
4,SAMEA104142284,V-2_ACVD,V-2,The gut microbiome in atherosclerotic cardiova...,Jie (2017),Nature communications,PRJEB21528,SAMEA104142284,ZSL-019,ZSL-019,...,,,,,,,Illumina HiSeq 2000,PAIRED,36081150.0,Unhealthy
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4342,SAMEA4431948,S-7_Underweight,S-7,Two distinct metacommunities characterize the ...,He (2017),Gigascience,PRJEB15371,SAMEA4431948,SZAXPI029564-74,SZAXPI029564-74,...,,,,,No,No,Illumina HiSeq 2000,PAIRED,52212493.0,Unhealthy
4343,SAMEA4431949,S-7_Underweight,S-7,Two distinct metacommunities characterize the ...,He (2017),Gigascience,PRJEB15371,SAMEA4431949,SZAXPI029565-77,SZAXPI029565-77,...,,,,,No,No,Illumina HiSeq 2000,PAIRED,50635879.0,Unhealthy
4344,SAMEA4431951,S-7_Underweight,S-7,Two distinct metacommunities characterize the ...,He (2017),Gigascience,PRJEB15371,SAMEA4431951,SZAXPI029567-80,SZAXPI029567-80,...,,,,,No,No,Illumina HiSeq 2000,PAIRED,40712962.0,Unhealthy
4345,SAMEA4431964,S-7_Underweight,S-7,Two distinct metacommunities characterize the ...,He (2017),Gigascience,PRJEB15371,SAMEA4431964,SZAXPI029580-98,SZAXPI029580-98,...,,,,,No,No,Illumina HiSeq 2000,PAIRED,38177360.0,Unhealthy


In [6]:
FeatMeta = pd.read_csv(os.path.join(inp_dir, 'FeatMetaDataClean.csv'))
FeatMeta

Unnamed: 0.1,Unnamed: 0,feature
0,s__Abiotrophia_defectiva,s__Abiotrophia_defectiva
1,s__Acetobacter_unclassified,s__Acetobacter_unclassified
2,s__Achromobacter_piechaudii,s__Achromobacter_piechaudii
3,s__Achromobacter_unclassified,s__Achromobacter_unclassified
4,s__Achromobacter_xylosoxidans,s__Achromobacter_xylosoxidans
...,...,...
898,s__Weissella_koreensis,s__Weissella_koreensis
899,s__Weissella_paramesenteroides,s__Weissella_paramesenteroides
900,s__Weissella_unclassified,s__Weissella_unclassified
901,s__Wohlfahrtiimonas_chitiniclastica,s__Wohlfahrtiimonas_chitiniclastica


# Define and Train Model

In [7]:
MyFeedForward = FeedForward(100, [50, 25, 10])

In [8]:
MyFeedForward

FeedForward(
  (layers): ModuleList(
    (0): Linear(in_features=100, out_features=50, bias=True)
    (1): Linear(in_features=50, out_features=25, bias=True)
    (2): Linear(in_features=25, out_features=10, bias=True)
  )
)

In [9]:
MySNN = SiameseModel(MyFeedForward, predict_unknown = False)

In [10]:
# list(MySNN.logistic.parameters()) + list(MySNN.model.parameters())

In [11]:
MyTrainer = Trainer(model = MySNN, scale_X = True, use_pca = True, n_components = 100)

In [12]:
# make binarized class matrix
y = SampleMeta['Healthy'].to_numpy().astype('str') == 'Healthy'
y = y.astype('int32')
y = np.concatenate((np.logical_not(np.equal(y, 1)).astype('int32').reshape(y.shape[0], 1), y.reshape(y.shape[0], 1)), axis = 1)

In [13]:
y.shape

(4347, 2)

In [14]:
# double check on class labels
np.all(np.logical_not(np.equal(y[:,0], y[:, 1])))

True

In [15]:
MatrixData.dtype

dtype('float64')

In [16]:
y.dtype

dtype('int32')

In [17]:
MyTrainer.fit(MatrixData, y)

#########################################
Epoch 1 of 5
__Training__
2021-03-05 21:39:09
Batch Mean Loss: 0.7015756368637085
Batch Mean Loss: 0.6975278854370117
Batch Mean Loss: 0.7057262659072876
Batch Mean Loss: 0.6953636407852173
Batch Mean Loss: 0.7015439867973328
Batch Mean Loss: 0.7038464546203613
Batch Mean Loss: 0.6987776160240173
Batch Mean Loss: 0.6923776865005493
Batch Mean Loss: 0.6986624598503113
Batch Mean Loss: 0.698875904083252
Batch Mean Loss: 0.6931143999099731
Batch Mean Loss: 0.6965702176094055
Batch Mean Loss: 0.7009853720664978
Batch Mean Loss: 0.7026235461235046
Batch Mean Loss: 0.7006507515907288
Batch Mean Loss: 0.7026547789573669
Batch Mean Loss: 0.710139811038971
Batch Mean Loss: 0.6961067914962769
Batch Mean Loss: 0.6927386522293091
Batch Mean Loss: 0.7048537731170654
__Validation__
2021-03-05 21:39:10
Batch Mean Loss: 0.6970341205596924
Batch Mean Loss: 0.6933556199073792
Batch Mean Loss: 0.6929722428321838
#########################################
Epoch 2 o

In [18]:
for x in MySNN.model.parameters():
    print(x.dtype)

torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32


In [19]:
for x in MySNN.logistic.parameters():
    print(x.dtype)

torch.float32
torch.float32


In [20]:
import torch

In [21]:
X = torch.tensor(np.random.rand(2000).reshape(20, 100)).to(torch.float32)

In [22]:
X

tensor([[0.6175, 0.6117, 0.0071,  ..., 0.6244, 0.2956, 0.1055],
        [0.4565, 0.2184, 0.4165,  ..., 0.9945, 0.1759, 0.0181],
        [0.4939, 0.1788, 0.3665,  ..., 0.1843, 0.0809, 0.4283],
        ...,
        [0.7543, 0.6195, 0.2974,  ..., 0.4243, 0.0826, 0.5147],
        [0.8989, 0.7825, 0.2278,  ..., 0.1247, 0.3854, 0.6434],
        [0.7804, 0.2172, 0.1378,  ..., 0.5620, 0.3459, 0.3152]])

In [23]:
MySNN.model.forward(X).shape

torch.Size([20, 10])

In [24]:
MySNN.TrainDL

<torch.utils.data.dataloader.DataLoader at 0x7fb03acc1510>

In [25]:
y_pred = MyTrainer.predict(MatrixData)

In [26]:
MyTrainer.model.max_dist

tensor(0.0847)

In [27]:
y_pred

array([[0, 1],
       [0, 1],
       [1, 0],
       ...,
       [1, 0],
       [1, 0],
       [1, 0]], dtype=int32)

In [28]:
y_pred.shape

(4347, 2)

In [30]:
np.sum(np.equal(y_pred[:,1], 1))

2818