<a href="https://colab.research.google.com/github/yashlal/Deepfake-Microbiomes/blob/main/architecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from newsolver import predict_community_fullnp
import numpy as np
import pandas as pd
import random as rd
from numba import njit
from numba.typed import List
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pickle
import torch.optim as optim
import time
from math import sqrt
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from modules import regenerate_PWMatrix
from scipy.stats import wasserstein_distance as WD

train_size, test_size = 1000, 25

data = pd.read_excel('RealData.xlsx', index_col=0)
specs = data.columns.tolist()
trimmed_specs = []

for i in range(len(specs)):
    if data.iloc[:,i].astype(bool).sum() >= 85:
        trimmed_specs.append(specs[i])
dim1 = len(trimmed_specs)

typed_trimmed_specs = List()
[typed_trimmed_specs.append(x) for x in trimmed_specs]

@njit()
def get_LT(full_ar):
    ar = []
    for i in range(len(full_ar)):
        for j in range(i):
            ar.append(full_ar[i][j])
    return ar

@njit()
def generate_matrix(comm, tolerance):
    dim = len(comm)
    ar = np.zeros((dim,dim))

    for i in range(dim):
        for j in range(i+1):
            if i == j:
                ar[i][j] = 0
            else:
                r = rd.random()
                if r<0.5:
                    ar[i][j] = 0.99
                    ar[j][i] = 0.01
                else:
                    ar[i][j] = 0.01
                    ar[j][i] = 0.99

    return ar

def datagen():
    lm = generate_matrix(typed_trimmed_specs, 0)
    cm = predict_community_fullnp(lm, trimmed_specs, verb=False)
    return (cm, get_LT(lm))

# select CUDA if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if str(device) == 'cuda:0':
    print('CUDA device selected!')
elif str(device) == 'cpu':
	print('CUDA device not available. CPU selected')


class MyNet(nn.Module):
    def __init__(self, hyperparam):
        super(MyNet, self).__init__()
        self.fc1 = nn.Linear(462, hyperparam)
        self.fc2 = nn.Linear(hyperparam, 231*461)
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

mytest_x = []
mytest_y = []

pbar1=tqdm(range(test_size))
pbar1.set_description('Generating Test Set')
for i in pbar1:
    x, y = datagen()
    mytest_x.append(torch.from_numpy(x).float().to(device))
    mytest_y.append(torch.FloatTensor(y).to(device))

def test_net_comm(model, test_x, test_size):
    pbar2=tqdm(range(test_size))
    pbar2.set_description('Testing Neural Net')
    for i in range(test_size):
        cm_real = test_x[i]
        output = (model(cm_real).to(device)).tolist()
        mat_y = np.array(regenerate_PWMatrix(output, 462))
        cm_pred = predict_community_fullnp(mat_y, trimmed_specs)
        print(f'Test {i}: WD Distance {WD(cm_pred, cm_real.tolist())}')
        print(cm_real)
        print(cm_pred)

def train_net(model, train_size):
    pbar3=tqdm(range(train_size))
    pbar3.set_description('Training Neural Net')
    for i in pbar3:
        optimizer.zero_grad()
        x, y = datagen()
        input = torch.from_numpy(x).float().to(device)
        true_y = torch.FloatTensor(y).to(device)
        output = model(input).to(device)
        loss = criterion(output, true_y).to(device)
        s = sqrt(loss.item()/(231*461))
        print(f'Epoch {i}: Loss {s}')
        loss.backward()
        optimizer.step()

if __name__=='__main__':
    net = MyNet(500).to(device)

    #Multi GPU Support
    if torch.cuda.device_count() > 1:
          print(f'Using {torch.cuda.device_count()} GPUs')
          net = nn.DataParallel(net)
    elif torch.cuda.device_count() == 1:
        print(f'Using {torch.cuda.device_count()} GPU')
    criterion = nn.MSELoss(reduction='sum')
    optimizer = optim.Adam(net.parameters(), lr=1e-4)
    train_net(net, train_size=train_size)
    test_net_comm(net, mytest_x, test_size=test_size)

CUDA device selected!


HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))


Using 1 GPU


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Epoch 0: Loss 0.5783359555030881
Epoch 1: Loss 0.5771655138605322
Epoch 2: Loss 0.576429368526639
Epoch 3: Loss 0.5760899947292261
Epoch 4: Loss 0.5766614013329548
Epoch 5: Loss 0.5764677395946479
Epoch 6: Loss 0.5755941808779508
Epoch 7: Loss 0.5744500171149578
Epoch 8: Loss 0.5740881645945797
Epoch 9: Loss 0.5732401275596033
Epoch 10: Loss 0.5728895163751896
Epoch 11: Loss 0.5726164319244491
Epoch 12: Loss 0.5718633019684557
Epoch 13: Loss 0.5718589722434011
Epoch 14: Loss 0.5712317779565426
Epoch 15: Loss 0.5692217877280734
Epoch 16: Loss 0.5706428177593823
Epoch 17: Loss 0.5694703487506009
Epoch 18: Loss 0.5675778957481298
Epoch 19: Loss 0.5689501664574785
Epoch 20: Loss 0.5665519014923842
Epoch 21: Loss 0.5662203091445586
Epoch 22: Loss 0.5640077898613242
Epoch 23: Loss 0.5647916591577152
Epoch 24: Loss 0.5638721709343691
Epoch 25: Loss 0.5632257929207558
Epoch 26: Loss 0.5638223382331718
Epoch 27: Loss 0.5604060393479505
Epoch 28: Loss 0.5617543332424746
Epoch 29: Loss 0.56027835

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

Test 0: WD Distance 0.0016623473355560896
tensor([7.0293e-03, 1.3420e-10, 5.6086e-03, 2.2052e-03, 4.8137e-03, 6.8048e-03,
        2.3457e-03, 5.4238e-03, 2.4142e-03, 1.4660e-09, 4.7089e-05, 3.4366e-05,
        1.6404e-03, 1.1305e-03, 8.0148e-12, 4.4286e-03, 1.0161e-04, 7.7640e-03,
        5.4769e-09, 4.4954e-03, 1.9802e-04, 3.5651e-03, 6.3024e-07, 2.6603e-03,
        8.7708e-05, 1.1659e-03, 3.1354e-03, 7.2004e-04, 5.4799e-07, 4.9041e-09,
        2.6650e-03, 5.5177e-03, 4.0611e-04, 6.8677e-08, 6.3740e-04, 2.0643e-08,
        4.9608e-09, 6.2817e-03, 2.7845e-07, 2.2442e-06, 2.7270e-03, 2.9437e-03,
        8.3038e-04, 9.7701e-05, 5.3444e-03, 2.4969e-03, 3.6861e-08, 2.2191e-09,
        1.3216e-03, 1.8338e-04, 2.4443e-11, 1.3353e-04, 1.3397e-09, 5.0086e-06,
        2.3142e-03, 4.2917e-04, 2.3170e-03, 4.8978e-04, 4.5240e-08, 4.9238e-03,
        5.4143e-09, 1.6037e-03, 5.5925e-03, 2.4443e-03, 6.6724e-03, 4.7396e-03,
        6.8501e-06, 1.5804e-03, 3.7542e-03, 8.1224e-03, 9.1577e-12, 3.7598e-03