In [1]:
# Python version 3.11.7

# Load general use packages
import pandas as pd
import numpy as np
import os

# Load torch for making tensors
import torch

# Local scripts
from multi_mlp import simple_FC, JointMLP

## Define Inputs

Here, we will define one model per omic. Hidden layers can be set and should account for data size. DeepIMV cannot handle missing values, so a dataset with missing values removed and/or imputed is needed. 

In [None]:
# Load datasets
rna_16s = pd.read_csv("../../Dataset/Scaled/16S_Edata.csv").drop("Feature", axis = 1)
print("16S has " + str(len(rna_16s)) + " features")

metaproteomics = pd.read_csv("../../Dataset/Scaled/Metaproteomics.csv").drop("Feature", axis = 1)
print("Metaproteomics has " + str(len(metaproteomics)) + " features")

metabpos = pd.read_csv("../../Dataset/Scaled/Metabolomics_Positive.csv").drop("Feature", axis = 1)
print("Metabolomics Positive has " + str(len(metabpos)) + " features")

metabneg = pd.read_csv("../../Dataset/Scaled/Metabolomics_Negative.csv").drop("Feature", axis = 1)
print("Metabolomics Negative has " + str(len(metabneg)) + " features")

16S has 8 features
Metaproteomics has 2726 features
Metabolomics Positive has 2800 features
Metabolomics Negative has 1752 features


## Hyperparameter Tuning

Here we will test different grid sizes: 

| Approximate Proportional Size of Datasets | 16S Size | Metaproteomics & Metabolomics Sizes |
|---|---|---|
| 1/4 | 2 | 512 |
| 1/2 | 4 | 1024 |
| 1 | 8 | 2048 |
| 2 | 16 | 4096 | 
| 4 | 32 | 8192 | 

In [3]:
# Define sizes
small_size = [2, 4, 8, 16, 32]
large_size = [512, 1024, 2048, 4096, 8192]

# Load splits 
splits = []
with open("splits.txt", "r") as file:
    for line in file:
        values = line.replace("\n", "").split(" ")
        values = [int(x)-1 for x in values] # Correct for indexing changes between R and python
        splits.append(values)


# Hold all final loss values
final_loss = []

# Iterate through sizes 
for el in range(len(small_size)):

    # Iterate through splits
    for el2 in range(len(splits)):

        print("Element Size Test:", el, "and Crossfold Number:", el2)

        # Define groups (thankfully always 2 00-week and 6 post-00 week)
        groups = torch.tensor([0,0,1,1,1,1,1,1])

        # Define models 
        rna_16s_train = simple_FC(input_size = rna_16s.shape[0], hidden_sizes = [small_size[el], 64], prediction_dim = 2, dropout = 0.5)
        metap_train = simple_FC(input_size = metaproteomics.shape[0], hidden_sizes = [large_size[el], 64], prediction_dim = 2, dropout = 0.5)
        metab_pos_train = simple_FC(input_size = metabpos.shape[0], hidden_sizes = [large_size[el], 64], prediction_dim = 2, dropout = 0.5)
        metab_neg_train = simple_FC(input_size = metabneg.shape[0], hidden_sizes = [large_size[el], 64], prediction_dim = 2, dropout = 0.5)

        # Define joint model
        joint_mlp = JointMLP(marginal_models = [rna_16s_train, metap_train, metab_pos_train, metab_neg_train], hidden_dim = 64, hooks=False)

        # Optimize parameters 
        optimizer_mlp = torch.optim.AdamW(joint_mlp.parameters(), lr=1e-4)

        # Define data views with subsetting
        views = [torch.tensor(rna_16s.iloc[:,splits[el2]].T.values, dtype = torch.float32),
                torch.tensor(metaproteomics.iloc[:,splits[el2]].T.values, dtype = torch.float32),
                torch.tensor(metabpos.iloc[:,splits[el2]].T.values, dtype = torch.float32),
                torch.tensor(metabneg.iloc[:,splits[el2]].T.values, dtype = torch.float32)]

        acc_mlp = []

        # Time: 20 seconds             
        for i in range(1500):

            # Update the mlp
            yhat, h, yhats, hiddens = joint_mlp(*views)

            # pass the predictions and distributions to the loss function and update parameters
            _, _, loss = joint_mlp.loss(groups, yhat, yhats)

            optimizer_mlp.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(joint_mlp.parameters(), 2.0)
            optimizer_mlp.step()

        final_loss.append(loss.item())

Element Size Test: 0 and Crossfold Number: 0
Element Size Test: 0 and Crossfold Number: 1
Element Size Test: 0 and Crossfold Number: 2
Element Size Test: 0 and Crossfold Number: 3
Element Size Test: 0 and Crossfold Number: 4
Element Size Test: 0 and Crossfold Number: 5
Element Size Test: 0 and Crossfold Number: 6
Element Size Test: 0 and Crossfold Number: 7
Element Size Test: 0 and Crossfold Number: 8
Element Size Test: 0 and Crossfold Number: 9
Element Size Test: 0 and Crossfold Number: 10
Element Size Test: 0 and Crossfold Number: 11
Element Size Test: 0 and Crossfold Number: 12
Element Size Test: 0 and Crossfold Number: 13
Element Size Test: 0 and Crossfold Number: 14
Element Size Test: 1 and Crossfold Number: 0
Element Size Test: 1 and Crossfold Number: 1
Element Size Test: 1 and Crossfold Number: 2
Element Size Test: 1 and Crossfold Number: 3
Element Size Test: 1 and Crossfold Number: 4
Element Size Test: 1 and Crossfold Number: 5
Element Size Test: 1 and Crossfold Number: 6
Eleme

In [4]:
sizes = [0.25 for x in range(15)]
sizes.extend([0.5 for x in range(15)])
sizes.extend([1 for x in range(15)])
sizes.extend([2 for x in range(15)])
sizes.extend([4 for x in range(15)])

pd.DataFrame({
    "Size": sizes,
    "Loss": final_loss
}).to_csv("DeepIMV_tuning.csv", index = False)

## Define and Run Models

In [3]:
# Define all four models. Predicition dim is the number of categories.
rna_16s_model = simple_FC(input_size = rna_16s.shape[0], hidden_sizes = [16, 64], prediction_dim = 2, dropout = 0.5)
metap_model = simple_FC(input_size = metaproteomics.shape[0], hidden_sizes = [4096, 64], prediction_dim = 2, dropout = 0.5)
metab_pos_model = simple_FC(input_size = metabpos.shape[0], hidden_sizes = [4096, 64], prediction_dim = 2, dropout = 0.5)
metab_neg_model = simple_FC(input_size = metabneg.shape[0], hidden_sizes = [4096, 64], prediction_dim = 2, dropout = 0.5)

In [4]:
# Define joint model
joint_mlp = JointMLP(marginal_models = [rna_16s_model, metap_model, metab_pos_model, metab_neg_model], hidden_dim = 64, hooks=False)

In [5]:
# Optimize parameters 
optimizer_mlp = torch.optim.AdamW(joint_mlp.parameters(), lr=1e-4)

In [6]:
groups = torch.tensor([0,0,0,1,1,1,1,1,1,1,1,1])

# Define data views
views = [torch.tensor(rna_16s.T.values, dtype = torch.float32),
        torch.tensor(metaproteomics.T.values, dtype = torch.float32),
        torch.tensor(metabpos.T.values, dtype = torch.float32),
        torch.tensor(metabneg.T.values, dtype = torch.float32)]

acc_mlp = []

# Time: 20 seconds             
for i in range(1500):

    # Update the mlp
    yhat, h, yhats, hiddens = joint_mlp(*views)

    # pass the predictions and distributions to the loss function and update parameters
    _, _, loss = joint_mlp.loss(groups, yhat, yhats)

    optimizer_mlp.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(joint_mlp.parameters(), 2.0)
    optimizer_mlp.step()

    print(f'Epoch {i+1} loss: {loss.item():.3f}')

Epoch 1 loss: 0.390
Epoch 2 loss: 0.342
Epoch 3 loss: 0.310
Epoch 4 loss: 0.293
Epoch 5 loss: 0.258
Epoch 6 loss: 0.213
Epoch 7 loss: 0.236
Epoch 8 loss: 0.206
Epoch 9 loss: 0.207
Epoch 10 loss: 0.215
Epoch 11 loss: 0.169
Epoch 12 loss: 0.197
Epoch 13 loss: 0.170
Epoch 14 loss: 0.148
Epoch 15 loss: 0.154
Epoch 16 loss: 0.159
Epoch 17 loss: 0.122
Epoch 18 loss: 0.121
Epoch 19 loss: 0.148
Epoch 20 loss: 0.120
Epoch 21 loss: 0.143
Epoch 22 loss: 0.121
Epoch 23 loss: 0.135
Epoch 24 loss: 0.133
Epoch 25 loss: 0.116
Epoch 26 loss: 0.099
Epoch 27 loss: 0.118
Epoch 28 loss: 0.122
Epoch 29 loss: 0.109
Epoch 30 loss: 0.106
Epoch 31 loss: 0.142
Epoch 32 loss: 0.100
Epoch 33 loss: 0.087
Epoch 34 loss: 0.113
Epoch 35 loss: 0.096
Epoch 36 loss: 0.093
Epoch 37 loss: 0.096
Epoch 38 loss: 0.104
Epoch 39 loss: 0.103
Epoch 40 loss: 0.106
Epoch 41 loss: 0.095
Epoch 42 loss: 0.093
Epoch 43 loss: 0.106
Epoch 44 loss: 0.093
Epoch 45 loss: 0.112
Epoch 46 loss: 0.089
Epoch 47 loss: 0.104
Epoch 48 loss: 0.110
E

In [7]:
joint_mlp.eval()

JointMLP(
  (margin_models): ModuleList(
    (0): simple_FC(
      (fc1): Linear(in_features=8, out_features=16, bias=True)
      (fc2): Linear(in_features=16, out_features=64, bias=True)
      (fc_out): Linear(in_features=64, out_features=2, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (1): simple_FC(
      (fc1): Linear(in_features=2726, out_features=4096, bias=True)
      (fc2): Linear(in_features=4096, out_features=64, bias=True)
      (fc_out): Linear(in_features=64, out_features=2, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (2): simple_FC(
      (fc1): Linear(in_features=2800, out_features=4096, bias=True)
      (fc2): Linear(in_features=4096, out_features=64, bias=True)
      (fc_out): Linear(in_features=64, out_features=2, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (3): simple_FC(
      (fc1): Linear(in_features=1752, out_features=4096, bias=True)
      (fc2): Linear(in_features=4096, out_features=64, bias

In [8]:
for name, param in joint_mlp.named_parameters():
    if 'weight' in name:
        print(f"Layer: {name}, Size: {param.data.shape}")

Layer: margin_models.0.fc1.weight, Size: torch.Size([16, 8])
Layer: margin_models.0.fc2.weight, Size: torch.Size([64, 16])
Layer: margin_models.0.fc_out.weight, Size: torch.Size([2, 64])
Layer: margin_models.1.fc1.weight, Size: torch.Size([4096, 2726])
Layer: margin_models.1.fc2.weight, Size: torch.Size([64, 4096])
Layer: margin_models.1.fc_out.weight, Size: torch.Size([2, 64])
Layer: margin_models.2.fc1.weight, Size: torch.Size([4096, 2800])
Layer: margin_models.2.fc2.weight, Size: torch.Size([64, 4096])
Layer: margin_models.2.fc_out.weight, Size: torch.Size([2, 64])
Layer: margin_models.3.fc1.weight, Size: torch.Size([4096, 1752])
Layer: margin_models.3.fc2.weight, Size: torch.Size([64, 4096])
Layer: margin_models.3.fc_out.weight, Size: torch.Size([2, 64])
Layer: fc1.weight, Size: torch.Size([64, 64])
Layer: fc2.weight, Size: torch.Size([2, 64])


In [9]:
# V Matrix (V1)
joint_mlp.margin_models[0].fc1.weight

Parameter containing:
tensor([[ 0.2790,  0.2309, -0.0063, -0.2579,  0.0850, -0.3016,  0.0200, -0.0869],
        [-0.2498, -0.1668, -0.0803, -0.3449,  0.2276,  0.3328,  0.1798,  0.3152],
        [-0.0541,  0.0271, -0.1936,  0.2319,  0.2634, -0.3089,  0.1669,  0.2696],
        [-0.3063, -0.0737, -0.0033, -0.0562, -0.1863,  0.1694, -0.2140,  0.3235],
        [ 0.1182, -0.3317, -0.1386,  0.1005, -0.0682,  0.2948,  0.2535,  0.2093],
        [-0.1675,  0.0162, -0.1980,  0.2635, -0.1848,  0.3976,  0.1379, -0.0448],
        [ 0.1726, -0.0071, -0.3636, -0.1580, -0.2271, -0.0712, -0.0891,  0.0193],
        [ 0.2241,  0.2741,  0.1876,  0.2943, -0.2548,  0.2008, -0.1256, -0.2199],
        [-0.0502, -0.3924,  0.2249, -0.1255, -0.3172,  0.1670,  0.0772,  0.1702],
        [-0.1431, -0.0323, -0.3135,  0.3107, -0.1539,  0.2071, -0.1078,  0.1275],
        [-0.0265, -0.0756,  0.1487,  0.2065, -0.4041,  0.0938,  0.2345,  0.0952],
        [ 0.3062, -0.2431,  0.1785, -0.3100,  0.3237,  0.1039,  0.1919, -0.2

In [10]:
# Weights for Features in V1 (V2)
joint_mlp.margin_models[0].fc2.weight

Parameter containing:
tensor([[ 0.2131, -0.2432, -0.0552,  ...,  0.2356, -0.1717, -0.0754],
        [ 0.2080, -0.2044,  0.1308,  ..., -0.0826,  0.0228, -0.1007],
        [-0.1071,  0.0178, -0.0762,  ..., -0.1366, -0.2381, -0.0738],
        ...,
        [ 0.1545,  0.2002, -0.1474,  ...,  0.1817, -0.1117,  0.2896],
        [ 0.1910,  0.2340, -0.1562,  ...,  0.0482,  0.0609,  0.0120],
        [-0.0745,  0.1786,  0.2316,  ..., -0.1642, -0.1805, -0.2076]],
       requires_grad=True)

In [11]:
# Each V2 weight for its class [0 or 1] (V3)
joint_mlp.margin_models[0].fc_out.weight

Parameter containing:
tensor([[-0.1288, -0.1215, -0.0963,  0.0599,  0.0039,  0.0739, -0.0406,  0.0122,
          0.1709,  0.0703, -0.0570, -0.1277,  0.0548,  0.0993, -0.0809,  0.0239,
         -0.0780,  0.0338, -0.1634, -0.0497,  0.1177,  0.0447, -0.0412, -0.1204,
         -0.1221, -0.0976,  0.0885,  0.0990,  0.0912,  0.0005,  0.0464, -0.1301,
          0.0812, -0.0997, -0.0258,  0.0805, -0.1269,  0.0561, -0.1737, -0.0667,
          0.0629, -0.1432,  0.1169, -0.0914, -0.0677, -0.0933, -0.0600, -0.0720,
          0.0337,  0.0050, -0.0960, -0.1327, -0.0326, -0.0997,  0.1280, -0.0814,
         -0.0429, -0.0659, -0.0628, -0.1549,  0.0058, -0.0214, -0.1393, -0.0646],
        [ 0.0765,  0.0541, -0.0277, -0.1206,  0.1063, -0.0142, -0.0745,  0.1606,
         -0.1578,  0.0716,  0.1374,  0.1087, -0.0935, -0.1224, -0.0281, -0.0026,
          0.1334, -0.0084,  0.0819,  0.0893, -0.0786,  0.0533, -0.0921,  0.0290,
          0.1141, -0.0405, -0.0929, -0.1405, -0.0682,  0.0143,  0.0282, -0.0307,
     

In [12]:
joint_mlp.margin_models[2].fc_out.weight

Parameter containing:
tensor([[ 0.1054,  0.0732, -0.0925, -0.1017,  0.1042, -0.0540, -0.0952, -0.0391,
          0.0752, -0.0404,  0.0474, -0.0726, -0.1169,  0.0741, -0.0197,  0.0486,
          0.0855, -0.0506,  0.0540, -0.0649,  0.0024, -0.0510,  0.0341, -0.0650,
          0.0863,  0.0556, -0.0969,  0.1085,  0.0249, -0.0529, -0.1295,  0.0832,
          0.0008, -0.0110,  0.1011,  0.0792, -0.0973, -0.1034,  0.0873,  0.0810,
         -0.1242, -0.0622, -0.0250, -0.0249,  0.0772,  0.0910,  0.0708, -0.0841,
          0.0697,  0.0870,  0.1200,  0.0933,  0.0778, -0.0024,  0.0644,  0.0394,
         -0.1050, -0.0313,  0.1104,  0.0147,  0.0391, -0.1236, -0.0438, -0.1183],
        [-0.0874,  0.0761,  0.0955,  0.0800, -0.0843, -0.0491, -0.0430, -0.0830,
         -0.0780,  0.0260, -0.0871, -0.0731, -0.0145,  0.0468,  0.1188,  0.0956,
         -0.0449,  0.0463,  0.1070, -0.0012, -0.1100,  0.0723,  0.0003, -0.0532,
          0.0298,  0.0102, -0.0806, -0.0366,  0.0241,  0.0113, -0.0301,  0.0929,
     

In [13]:
import shap

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


In [14]:
import torch.nn as nn

# need to wrap the model in this class to get around some issues with the SHAP package
class JointMLPWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
 
    def forward(self, *datas):
        yhat, _, _, _ = self.model(*datas)
        return yhat
    
joint_model_wrp = JointMLPWrapper(joint_mlp)

# make the explainer object
explainer = shap.DeepExplainer(joint_model_wrp, views)

In [15]:
shap_values = explainer.shap_values(views, check_additivity=False)

In [16]:
# First index is which class (00wk or post-00wk). The second index is which view. The third index is the shapley values by samples per feature.
# Shapley values were summed across samples. 

shap_vals = pd.concat([
    pd.DataFrame({
        "Shapley Value": pd.DataFrame(shap_values[0][0]).sum().tolist(),
        "Feature": pd.read_csv("../../Dataset/Model_Ready/16S_Edata.csv")["Feature"].tolist(),
        "View": ["16S" for x in range(len(rna_16s))]
    }),
    pd.DataFrame({
        "Shapley Value": pd.DataFrame(shap_values[0][1]).sum().tolist(),
        "Feature": pd.read_csv("../../Dataset/Model_Ready/Metaproteomics.csv")["Feature"].tolist(),
        "View": ["metaproteomics" for x in range(len(metaproteomics))]
    }),
    pd.DataFrame({
        "Shapley Value": pd.DataFrame(shap_values[0][2]).sum().tolist(),
        "Feature": pd.read_csv("../../Dataset/Model_Ready/Metabolomics_Positive.csv")["Feature"].tolist(),
        "View": ["metabolomics positive" for x in range(len(metabpos))]
    }),
    pd.DataFrame({
        "Shapley Value": pd.DataFrame(shap_values[0][3]).sum().tolist(),
        "Feature": pd.read_csv("../../Dataset/Model_Ready/Metabolomics_Negative.csv")["Feature"].tolist(),
        "View": ["metabolomics negative" for x in range(len(metabneg))]
    }),
]).reset_index(drop = True)

shap_vals.to_csv("DeepIMV_ShapValues.csv", index = False)

shap_vals


Unnamed: 0,Shapley Value,Feature,View
0,-1.502011e-05,Variovorax,16S
1,-3.093563e-05,Sphingopyxis,16S
2,-9.967419e-06,Ensifer,16S
3,-2.631962e-06,Rhodococcus,16S
4,-5.466928e-05,Dyadobacter,16S
...,...,...,...
7281,-7.611933e-06,4.95_213.12411,metabolomics negative
7282,-4.591290e-07,1.17_330.841266,metabolomics negative
7283,-1.815025e-06,7.15_303.123453,metabolomics negative
7284,-1.290065e-05,5.78_226.035199,metabolomics negative
