In [None]:
# !pip install wandb --upgrade

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# import libraries 
import pandas as pd 
import numpy as np 
import scipy.stats
from functools import reduce

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from tqdm import trange

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import random
import math

from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
import matplotlib.pyplot as plt
import wandb

In [None]:
# CODICE DA METTERE PRIMA DELLA NET


# subtype mapping
# 0 -> healthy
# 1 -> tumor
dict_label = {'TCGA-LUAD0' : 0,
              'TCGA-LUAD1' : 1,
              'TCGA-LUSC0' : 2,
              'TCGA-LUSC1' : 3}

# dict_label = {'TCGA-LUAD0' : 0,
#               'TCGA-LUAD1' : 1,
#               'TCGA-LUSC0' : 2,
#               'TCGA-LUSC1' : 3,
#               'TCGA-CPTAC0' : 4,
#               'TCGA-CPTAC1' : 5,
#               'TCGA-BRCA0' : 6,
#               'TCGA-BRCA1' : 7,
#               'TCGA-KIRC0' : 8,
#               'TCGA-KIRC1' : 9
#               }

# input files
file_meth = "/content/drive/MyDrive/University/Bioinformatics/dataset/meth_feat_sel_dataset.csv"
file_mRNA = "/content/drive/MyDrive/University/Bioinformatics/dataset/mRNA_feat_sel_dataset.csv"
file_miRNA = "/content/drive/MyDrive/University/Bioinformatics/dataset/miRNA_feat_sel_dataset.csv"

# leggo in un DF i 3 csv feat_sel
df_meth = pd.read_csv(file_meth)
df_mRNA = pd.read_csv(file_mRNA)
df_miRNA = pd.read_csv(file_miRNA)

# number of features (da passare alla rete)
# -2 perchè non consideriamo il case_id e il subtype
meth_dim = df_meth.shape[1] - 2
mRNA_dim = df_mRNA.shape[1] - 2
miRNA_dim = df_miRNA.shape[1] - 2 

# join su case_id 
dataset = reduce(lambda left,right: pd.merge(left,right,on=['case_id','subtype']), [df_miRNA,df_mRNA,df_meth] )
#dataset = reduce(lambda left,right: pd.merge(left,right,on=['case_id']), [df_miRNA,df_mRNA,df_meth] )

# just a check
print(dataset.shape)
print(meth_dim + miRNA_dim + mRNA_dim + 2)

# train_test_split stratified su colonna subtype
# we are done!

(1121, 6741)
6741


In [None]:
# 1. Start a new run
wandb.init(project='bioinfo', entity='latte')


# # 2. Save model inputs and hyperparameters
# config = wandb.config

In [None]:
def seed_everything(seed): 
    '''
    seed_everything(seed)
        Ensure deterministic behaviour. Use a seed that will be used to seed pytroch, numpy and python right at the start of main process.

    Parameters
    ----------
    seed : int
        the seed
    
    See also: 
    ----------
        ref: https://discuss.pytorch.org/t/reproducibility-with-all-the-bells-and-whistles/81097
        ref: https://pytorch.org/docs/stable/notes/randomness.html
    '''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    # if using 'cuda'
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled=False
    torch.backends.cudnn.deterministic=True

# set random seed
SEED = 769
seed_everything(SEED)

In [None]:
# other params
TRAIN_CLASSIFIER=False
NORMALIZATION=False

In [None]:
# Net (only classifier)
class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)

class Encoder (nn.Module):
    def __init__(self,meth_dim,mRNA_dim,miRNA_dim):
        super().__init__()
        self.meth_dim,self.mRNA_dim,self.miRNA_dim = meth_dim,mRNA_dim,miRNA_dim
        self.fc1_meth = nn.Linear(meth_dim, 25)
        self.m_meth = nn.BatchNorm1d(25)

        self.fc1_mRNA = nn.Linear(mRNA_dim, 25)
        self.m_mRNA = nn.BatchNorm1d(25)

        self.fc1_miRNA = nn.Linear(miRNA_dim, 25)
        self.m_miRNA = nn.BatchNorm1d(25)

        self.fc2 = nn.Linear(75, 75)
        self.m = nn.BatchNorm1d(75)
        self.fcm = nn.Linear(75, 75)
        self.fcstd = nn.Linear(75, 75)
        self.gl = nn.GELU()

    def forward(self, x): 
        def sampling(args):
            z_mean, z_log_var = args
            return z_mean + torch.exp(0.5 * z_log_var) * torch.normal(z_mean.size(), seed=0)
        #recupero dati input
        x_m = torch.tensor(x[:,:self.meth_dim])
        x_R = torch.tensor(x[:,self.meth_dim:self.meth_dim+self.mRNA_dim])
        x_p = torch.tensor(x[:,self.meth_dim+self.mRNA_dim:self.meth_dim+self.mRNA_dim+self.miRNA_dim])

        #independent layers for each omic
        x_meth = self.gl(self.m_meth(self.fc1_meth(x_m)))
        x_mRNA = self.gl(self.m_mRNA(self.fc1_mRNA(x_R)))
        x_prot = self.gl(self.m_miRNA(self.fc1_miRNA(x_p)))
        
        #concatenazione dei vari output da processare fino allo shared layer
        x = torch.cat([x_meth,x_mRNA,x_prot],axis = 1 )
        x = self.gl(self.m(self.fc2(x)))
        mean = self.fcm(x)
        std = self.fcstd(x)
        #z = LambdaLayer(sampling, [mean, var])
        return x

### GAN
# Discriminator - D
class Decoder(nn.Module):
    def __init__(self,meth_dim,mRNA_dim,miRNA_dim):
        super().__init__()
        self.dec = nn.Sequential(
                      nn.Linear(75, 75),
                      nn.GELU(),
                      nn.Linear(75,meth_dim+mRNA_dim+miRNA_dim)          
                      )
        
    def forward(self, x):
        return self.dec(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.disc = nn.Sequential(
                nn.Linear(75, 75),
                nn.GELU(),
                nn.Linear(75,1),
                nn.Sigmoid()) #because we will normalize inputs. 
    def forward(self, x):
        return self.disc(x)

        # Generator - G
class Discriminator_CL(nn.Module):
    def __init__(self):
        super().__init__()
        self.disc = nn.Sequential(
                nn.Linear(75, 75),
                nn.GELU(),
                nn.Linear(75,5),
                nn.Sigmoid() )
    def forward(self, x):
        return self.disc(x)

In [None]:
X = np.array(dataset.drop(['subtype','case_id'],axis=1).values)
Y = dataset.subtype.values

In [None]:
def toone(x):
  arr = np.zeros(5)
  arr[int(x)]=1
  return arr

In [None]:
# split dataset (non serve per la GAN)
X_train, X_test, Y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=SEED, stratify = Y)

In [None]:
sweep_config = {
    'method': 'random'
    }

metric = {
    'name': 'accuracy',
    'goal': 'maximize'   
    }

sweep_config['metric'] = metric

In [None]:
parameters_dict = {
    'optimizer': {
        'values': ['adam'] #, 'sgd']
        },
    'learning_rate': {
        # a flat distribution between 0 and 0.1
        'distribution': 'uniform',
        'min': 1e-5,
        'max': 1e-3
      },
    'batch_size': {
        'values': [32,64]
      }
    }
sweep_config['parameters'] = parameters_dict

In [None]:
import pprint

pprint.pprint(sweep_config)

{'method': 'random',
 'metric': {'goal': 'maximize', 'name': 'accuracy'},
 'parameters': {'batch_size': {'values': [32, 64]},
                'learning_rate': {'distribution': 'uniform',
                                  'max': 0.001,
                                  'min': 1e-05},
                'optimizer': {'values': ['adam']}}}


In [None]:
sweep_id = wandb.sweep(sweep_config, project="bioinfo")



Create sweep with ID: fdq4fe7f
Sweep URL: https://wandb.ai/latte/bioinfo/sweeps/fdq4fe7f


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train(config=None):
    # Initialize a new wandb run
    with wandb.init(config=config):
      # If called by wandb.agent, as below,
      # this config will be set by Sweep Controller
      config = wandb.config

      # TRAINING 8=====D  
      encoder = Encoder(meth_dim,mRNA_dim,miRNA_dim).to(device)
      decoder = Decoder(meth_dim,mRNA_dim,miRNA_dim).to(device)
      discriminator = Discriminator().to(device)
      discriminator_cl = Discriminator_CL().to(device)
      #noise = torch.randn((config.batch_size,in_dimension)).to(device)

      #transforms = transforms.Compose(
      #    [transforms.ToTensor() ,transforms.Normalize((mean,),(stdev,))])

      # optimizer
      if config.optimizer == "sgd":
        optim_disc = optim.SGD(discriminator.parameters(),lr = config.learning_rate,momentum=0.9)
        optim_disc_cl = optim.SGD(discriminator_cl.parameters(),lr = config.learning_rate,momentum=0.9)
        optim_gen = optim.SGD(encoder.parameters(),lr = config.learning_rate,momentum=0.9)
        optim_dec = optim.SGD(decoder.parameters(),lr = config.learning_rate,momentum=0.9)
      elif config.optimizer == "adam":
        optim_disc = optim.Adam(discriminator.parameters(),lr = config.learning_rate)
        optim_disc_cl = optim.Adam(discriminator_cl.parameters(),lr = config.learning_rate)
        optim_gen = optim.Adam(encoder.parameters(),lr = config.learning_rate)
        optim_dec = optim.Adam(decoder.parameters(),lr = config.learning_rate)

      bce = nn.BCELoss()
      mse = nn.MSELoss()
      cle = nn.CrossEntropyLoss()

      # SummaryWriter -> log tensorboard
      # writer_fake = SummaryWriter(f"runs/GAN_bio/fake")
      # writer_real = SummaryWriter(f"runs/GAN_bio/real")

      step = 0
      num_epochs = 50
      epochs = trange(num_epochs)
      lossesD, lossesCL, lossesG, lossesDec = [],[],[],[]
      accuracies_test = []

      # TRAINING GAN
      for epoch in epochs:
          # batch
          permutation = torch.randperm(X_train.shape[0])
          lossD, lossG = 0,0

          # compute metrics curr_epoch -> mean over all batches
          lossD_curr_epoch = list()
          lossE_curr_epoch = list()
          lossDec_curr_epoch = list()
          loss_discl_curr_epoch = list()
          accuracy_curr_epoch = list()

          # train a Batch (dim BS)
          for i in range(0,X_train.shape[0], config.batch_size):
              
              # extract X,Y (dimension config.batch_size)
              indices = permutation[i:i+config.batch_size]
              batch_x, batch_y = torch.FloatTensor(X_train)[indices], torch.LongTensor(Y_train)[indices]
              # generate fake img

              latent_real = torch.Tensor(np.random.normal(size=(batch_x.shape[0], 75)))

              fake = encoder(batch_x)
              
              # train discriminator max - LOG(D(real)) + log(1-D(G(z))) opt when D(real)==1 and D(G(z))==0
              # real img
              disc_real = discriminator(latent_real).view(-1)
              lossD_real = bce(disc_real,torch.ones_like(disc_real)) #max LOG(D(real)) part of the loss
              # fake img
              disc_fake = discriminator(fake).view(-1)  
              lossD_fake = bce(disc_fake,torch.zeros_like(disc_real))
              # final disc loss
              lossD = (lossD_real + lossD_fake) / 2
              discriminator.zero_grad()
              #lossD.backward(retain_graph=True) 
              
              # train generator min log (1-D(G(z))) opt when D(G(z))==1 ----> max log(D(G(z)))
              output = discriminator(fake).view(-1)
              # we use ones because of the structure of BCELoss ylog(D(G(z))) + (1-y)log(1-D(G(z)))
              lossE = bce(output, torch.ones_like(output)) 
              encoder.zero_grad()
            # lossE.backward()

              # reconstruction loss
              decoded = decoder(fake)
              lossDec = mse(decoded,batch_x)
              decoder.zero_grad()

              #classification loss 
              result = discriminator_cl(fake)
              loss_discl = cle(result,batch_y)
              discriminator_cl.zero_grad()


              lossf = lossD+lossE+lossDec+100000*loss_discl
              lossf.backward()
              optim_dec.step()
              optim_gen.step()
              optim_disc.step()
              optim_disc_cl.step()
              winners = result.argmax(dim=1)
              corrects = (winners == batch_y)
              accuracy = corrects.sum().float() / float( batch_y.size(0) )

              lat_test = encoder(torch.FloatTensor(X_test))
              win_test = discriminator_cl(lat_test)
              
              winners_test = win_test.argmax(dim=1)
              corrects_test = (winners_test == torch.Tensor(y_test))
              accuracy_test = corrects_test.sum().float() / float( torch.Tensor(y_test).size(0) )


              # update metrics curr_epoch
              lossD_curr_epoch.append(lossD.item())
              lossE_curr_epoch.append(lossE.item())
              lossDec_curr_epoch.append(lossDec.item())
              loss_discl_curr_epoch.append(loss_discl.item())
              accuracy_curr_epoch.append(accuracy_test.item())

              # lossD --> GAN fake/norm
              # lossE --> GAN
              # lossDec --> reconstruction loss (autoenc)
              # lossCL --> classificatore finale
              epochs.set_description("lossD %.2f lossE %.2f lossDec %.2f lossCL %.2f accuracy %.2f" %(lossD,lossE,lossDec,loss_discl,accuracy_test))
          
          # logga tutto -> metrics curr_epoch
          wandb.log({"lossD": np.mean(np.array(lossD_curr_epoch)),
                    "lossE": np.mean(np.array(lossE_curr_epoch)),
                    "lossDec": np.mean(np.array(lossDec_curr_epoch)),
                    "loss_discl": np.mean(np.array(loss_discl_curr_epoch)),
                    "accuracy": np.mean(np.array(accuracy_curr_epoch)) })
          
          # append loss computed on the last batch_size
          lossesD.append(lossD)
          lossesG.append(lossG)  
          lossesDec.append(lossDec)
          accuracies_test.append(accuracy_test)


In [None]:
# Save the model in the exchangeable ONNX format
#for model,name in zip([encoder, decoder, discriminator, discriminator_cl],['encoder','decoder','disciminator','discriminator_clf']):

# filename = f"encoder.onnx"
# torch.onnx.export(encoder, batch_x, filename)
# wandb.save(filename)

In [None]:
wandb.agent(sweep_id, train, count=10)

[34m[1mwandb[0m: Agent Starting Run: eua6as7o with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	learning_rate: 6.133991228763558e-05
[34m[1mwandb[0m: 	optimizer: adam


lossD 0.70 lossE 0.42 lossDec 0.45 lossCL 0.92 accuracy 0.98: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it]


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
lossD,0.70121
lossE,0.40933
lossDec,0.53516
loss_discl,0.92093
accuracy,0.97995
_runtime,54.0
_timestamp,1630766813.0
_step,49.0


0,1
lossD,▁▂▃▃▅▆▆▇▇▇███████▇▇▇▇▆▆▆▆▆▅▅▅▄▄▄▄▃▃▃▂▂▂▂
lossE,█▇▆▆▄▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossDec,███▇▇▆▅▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_discl,█▇▇▇▆▆▆▅▅▄▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy,▁▅▆▇████████████████████████████████████
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


[34m[1mwandb[0m: Agent Starting Run: drt81fsv with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	learning_rate: 0.00043102185465157625
[34m[1mwandb[0m: 	optimizer: adam
[34m[1mwandb[0m: Currently logged in as: [33mlatte[0m (use `wandb login --relogin` to force relogin)


lossD 0.59 lossE 0.37 lossDec 0.58 lossCL 0.97 accuracy 0.99: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it]


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
lossD,0.55597
lossE,0.40478
lossDec,0.4241
loss_discl,0.9081
accuracy,0.99437
_runtime,55.0
_timestamp,1630766873.0
_step,49.0


0,1
lossD,▆██▇▅▅▄▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossE,█▂▂▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossDec,█▆▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_discl,█▆▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy,▁▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████▇█▇█████████▇████
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


[34m[1mwandb[0m: Agent Starting Run: og700q11 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	learning_rate: 0.00039739539668902797
[34m[1mwandb[0m: 	optimizer: adam


lossD 0.54 lossE 0.42 lossDec 0.41 lossCL 0.91 accuracy 0.99: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it]


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
lossD,0.55687
lossE,0.40578
lossDec,0.43669
loss_discl,0.9063
accuracy,0.9839
_runtime,53.0
_timestamp,1630766933.0
_step,49.0


0,1
lossD,▆███▆▆▅▅▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossE,█▄▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossDec,█▆▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_discl,█▆▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy,▁▇▇█████████████████████████████████████
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


[34m[1mwandb[0m: Agent Starting Run: b265yo0m with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	learning_rate: 0.0008340250309291658
[34m[1mwandb[0m: 	optimizer: adam


lossD 0.55 lossE 0.41 lossDec 0.48 lossCL 0.91 accuracy 0.98: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it]


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
lossD,0.55527
lossE,0.4039
lossDec,0.42561
loss_discl,0.90595
accuracy,0.9857
_runtime,56.0
_timestamp,1630766995.0
_step,49.0


0,1
lossD,▇█▆▆▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossE,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossDec,█▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_discl,█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy,▁▆▇▇▇▇▇▇▇▇▇▇████████████▇▇██████████████
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


[34m[1mwandb[0m: Agent Starting Run: fofpnq3n with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	learning_rate: 0.0008735326720281971
[34m[1mwandb[0m: 	optimizer: adam


lossD 0.55 lossE 0.41 lossDec 0.35 lossCL 0.90 accuracy 0.99: 100%|██████████| 50/50 [00:51<00:00,  1.04s/it]


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
lossD,0.55239
lossE,0.4046
lossDec,0.39262
loss_discl,0.90562
accuracy,0.98378
_runtime,55.0
_timestamp,1630767056.0
_step,49.0


0,1
lossD,██▆▅▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossE,█▂▂▂▁▁▁▁▁▁▁▁▂▁▂▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁
lossDec,█▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▂▂▂▁▁▁▁▁▁
loss_discl,█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy,▁▄▄▄▅▅▅▅▆▇▆▆▇▇▇▇▇▇▇▇▇▇▆▆▇▇▆▇▇█▇▆▇▇▇▇▇▇▇▆
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


[34m[1mwandb[0m: Agent Starting Run: x8opztr1 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	learning_rate: 0.0005008865024966185
[34m[1mwandb[0m: 	optimizer: adam


lossD 0.57 lossE 0.40 lossDec 0.44 lossCL 0.91 accuracy 0.99: 100%|██████████| 50/50 [00:51<00:00,  1.04s/it]


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
lossD,0.55442
lossE,0.40641
lossDec,0.43249
loss_discl,0.90627
accuracy,0.9893
_runtime,57.0
_timestamp,1630767122.0
_step,49.0


0,1
lossD,▆██▇▅▅▄▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossE,█▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossDec,█▅▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_discl,█▆▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy,▁▇▇▇▇▇▇▇████████████████████████████████
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


[34m[1mwandb[0m: Agent Starting Run: ltj6c7t3 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	learning_rate: 0.0009953839604422514
[34m[1mwandb[0m: 	optimizer: adam


lossD 0.56 lossE 0.40 lossDec 0.41 lossCL 0.90 accuracy 0.99: 100%|██████████| 50/50 [00:51<00:00,  1.03s/it]


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
lossD,0.5522
lossE,0.40489
lossDec,0.39447
loss_discl,0.90552
accuracy,0.98919
_runtime,58.0
_timestamp,1630767187.0
_step,49.0


0,1
lossD,▇█▇▅▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossE,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossDec,█▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▂▂▂▁▁▂▁▁▁▁▁▁
loss_discl,█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy,▁▆▆▇▆▆▆▆▇▇▇▇██▇▇▇▇██▇▇▇▇▇██▇▇▇█▇▇▇█▇▇▇██
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


[34m[1mwandb[0m: Agent Starting Run: d6f3h1xn with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	learning_rate: 0.00033643109048153144
[34m[1mwandb[0m: 	optimizer: adam


lossD 0.57 lossE 0.41 lossDec 0.47 lossCL 0.91 accuracy 0.99: 100%|██████████| 50/50 [00:28<00:00,  1.75it/s]


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
lossD,0.57042
lossE,0.40667
lossDec,0.48516
loss_discl,0.90705
accuracy,0.98671
_runtime,34.0
_timestamp,1630767228.0
_step,49.0


0,1
lossD,▅▅▇▇██▇▇▆▆▆▅▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
lossE,█▆▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossDec,██▇▅▃▃▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_discl,█▇▆▆▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy,▁▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████████████████████
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


[34m[1mwandb[0m: Agent Starting Run: 78mf83sh with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	learning_rate: 0.0009439739986132525
[34m[1mwandb[0m: 	optimizer: adam


lossD 0.55 lossE 0.41 lossDec 0.45 lossCL 0.91 accuracy 0.99: 100%|██████████| 50/50 [00:28<00:00,  1.73it/s]


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
lossD,0.55321
lossE,0.4072
lossDec,0.40842
loss_discl,0.90583
accuracy,0.98468
_runtime,34.0
_timestamp,1630767271.0
_step,49.0


0,1
lossD,▅██▇▅▅▄▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossE,█▂▁▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁
lossDec,█▆▄▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_discl,█▆▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy,▁▅▆▆▆▆▆▆▆▆▇▇▇██▇▇▇▇██████████▇▇▇▇▇▇█▇██▇
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇██
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇██
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


[34m[1mwandb[0m: Agent Starting Run: mgj6kwrk with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	learning_rate: 0.0004313933415651649
[34m[1mwandb[0m: 	optimizer: adam


lossD 0.56 lossE 0.41 lossDec 0.50 lossCL 0.91 accuracy 0.99: 100%|██████████| 50/50 [00:28<00:00,  1.74it/s]


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
lossD,0.5631
lossE,0.40591
lossDec,0.47016
loss_discl,0.90755
accuracy,0.98806
_runtime,34.0
_timestamp,1630767313.0
_step,49.0


0,1
lossD,▆▇██▇▇▇▆▅▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁
lossE,█▄▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lossDec,█▇▆▅▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_discl,█▇▆▅▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy,▁▅▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████████████
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


In [None]:
### GAN
config = wandb.config

# BEST RUN HYPERPARAMS
hyperparameters = {
              'epochs' : 100,
              'batch_size': 32,
              'learning_rate': 0.000431,
              'optimizer': 'adam'}
              
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# tell wandb to get started
with wandb.init(project="bioinfo", config=hyperparameters):
  # access all HPs through wandb.config, so logging matches execution!
  config = wandb.config

  encoder = Encoder(meth_dim,mRNA_dim,miRNA_dim).to(device)
  decoder = Decoder(meth_dim,mRNA_dim,miRNA_dim).to(device)
  discriminator = Discriminator().to(device)
  discriminator_cl = Discriminator_CL().to(device)
  #noise = torch.randn((config.batch_size,in_dimension)).to(device)

  #transforms = transforms.Compose(
  #    [transforms.ToTensor() ,transforms.Normalize((mean,),(stdev,))])

  # optimizer
  if config.optimizer == "sgd":
    optim_disc = optim.SGD(discriminator.parameters(),lr = config.learning_rate,momentum=0.9)
    optim_disc_cl = optim.SGD(discriminator_cl.parameters(),lr = config.learning_rate,momentum=0.9)
    optim_gen = optim.SGD(encoder.parameters(),lr = config.learning_rate,momentum=0.9)
    optim_dec = optim.SGD(decoder.parameters(),lr = config.learning_rate,momentum=0.9)
  elif config.optimizer == "adam":
    optim_disc = optim.Adam(discriminator.parameters(),lr = config.learning_rate)
    optim_disc_cl = optim.Adam(discriminator_cl.parameters(),lr = config.learning_rate)
    optim_gen = optim.Adam(encoder.parameters(),lr = config.learning_rate)
    optim_dec = optim.Adam(decoder.parameters(),lr = config.learning_rate)

  bce = nn.BCELoss()
  mse = nn.MSELoss()
  cle = nn.CrossEntropyLoss()

  # SummaryWriter -> log tensorboard
  # writer_fake = SummaryWriter(f"runs/GAN_bio/fake")
  # writer_real = SummaryWriter(f"runs/GAN_bio/real")

  step = 0
  epochs = trange(config.epochs)
  lossesD, lossesCL, lossesG, lossesDec = [],[],[],[]
  accuracies_test = []

  # TRAINING GAN
  for epoch in epochs:
      # batch
      permutation = torch.randperm(X_train.shape[0])
      lossD, lossG = 0,0

      # compute metrics curr_epoch -> mean over all batches
      lossD_curr_epoch = list()
      lossE_curr_epoch = list()
      lossDec_curr_epoch = list()
      loss_discl_curr_epoch = list()
      accuracy_curr_epoch = list()

      # train a Batch (dim BS)
      for i in range(0,X_train.shape[0], config.batch_size):
          
          # extract X,Y (dimension config.batch_size)
          indices = permutation[i:i+config.batch_size]
          batch_x, batch_y = torch.FloatTensor(X_train)[indices], torch.LongTensor(Y_train)[indices]
          # generate fake img

          latent_real = torch.Tensor(np.random.normal(size=(batch_x.shape[0], 75)))

          fake = encoder(batch_x)
          
          # train discriminator max - LOG(D(real)) + log(1-D(G(z))) opt when D(real)==1 and D(G(z))==0
          # real img
          disc_real = discriminator(latent_real).view(-1)
          lossD_real = bce(disc_real,torch.ones_like(disc_real)) #max LOG(D(real)) part of the loss
          # fake img
          disc_fake = discriminator(fake).view(-1)  
          lossD_fake = bce(disc_fake,torch.zeros_like(disc_real))
          # final disc loss
          lossD = (lossD_real + lossD_fake) / 2
          discriminator.zero_grad()
          #lossD.backward(retain_graph=True) 
          
          # train generator min log (1-D(G(z))) opt when D(G(z))==1 ----> max log(D(G(z)))
          output = discriminator(fake).view(-1)
          # we use ones because of the structure of BCELoss ylog(D(G(z))) + (1-y)log(1-D(G(z)))
          lossE = bce(output, torch.ones_like(output)) 
          encoder.zero_grad()
        # lossE.backward()

          # reconstruction loss
          decoded = decoder(fake)
          lossDec = mse(decoded,batch_x)
          decoder.zero_grad()

          #classification loss 
          result = discriminator_cl(fake)
          loss_discl = cle(result,batch_y)
          discriminator_cl.zero_grad()


          lossf = loss_discl+lossDec #lossD+lossE+lossDec+100000*loss_discl
          lossf.backward()
          optim_dec.step()
          # optim_gen.step()
          # optim_disc.step()
          optim_disc_cl.step()
          winners = result.argmax(dim=1)
          corrects = (winners == batch_y)
          accuracy = corrects.sum().float() / float( batch_y.size(0) )

          lat_test = encoder(torch.FloatTensor(X_test))
          win_test = discriminator_cl(lat_test)
          
          winners_test = win_test.argmax(dim=1)
          corrects_test = (winners_test == torch.Tensor(y_test))
          accuracy_test = corrects_test.sum().float() / float( torch.Tensor(y_test).size(0) )


          # update metrics curr_epoch
          lossD_curr_epoch.append(lossD.item())
          lossE_curr_epoch.append(lossE.item())
          lossDec_curr_epoch.append(lossDec.item())
          loss_discl_curr_epoch.append(loss_discl.item())
          accuracy_curr_epoch.append(accuracy_test.item())

          # lossD --> GAN fake/norm
          # lossE --> GAN
          # lossDec --> reconstruction loss (autoenc)
          # lossCL --> classificatore finale
          epochs.set_description("lossD %.2f lossE %.2f lossDec %.2f lossCL %.2f accuracy %.2f" %(lossD,lossE,lossDec,loss_discl,accuracy_test))
      
      # logga tutto -> metrics curr_epoch
      wandb.log({"lossD": np.mean(np.array(lossD_curr_epoch)),
                 "lossE": np.mean(np.array(lossE_curr_epoch)),
                 "lossDec": np.mean(np.array(lossDec_curr_epoch)),
                 "loss_discl": np.mean(np.array(loss_discl_curr_epoch)),
                 "accuracy": np.mean(np.array(accuracy_curr_epoch)) })
      
      # append loss computed on the last batch_size
      lossesD.append(lossD)
      lossesG.append(lossG)  
      lossesDec.append(lossDec)
      accuracies_test.append(accuracy_test)