In [1]:
import vamb

import numpy as np
import torch as torch
import torch.nn as nn
from torch.optim import Adam as Adam
from torch.utils.data import DataLoader as DataLoader
from torch.utils.data import TensorDataset

import wandb

from collections import namedtuple

import os

import json

import sys

import pandas as pd

torch.manual_seed(0)

<torch._C.Generator at 0x7fadb8f52c90>

# Load In Data

In [2]:
BASE_DIR = os.getcwd()

EXAMPLE_FASTA_FILE = '2021.01.26_15.46.45_sample_0'

In [3]:
vamb_inputs_base = os.path.join(BASE_DIR,'example_input_data/new_simulations/camisim_outputs/vamb_inputs')

contignames = vamb.vambtools.read_npz(os.path.join(vamb_inputs_base, 'contignames.npz'))

lengths = vamb.vambtools.read_npz(os.path.join(vamb_inputs_base, 'lengths.npz'))

tnfs = vamb.vambtools.read_npz(os.path.join(vamb_inputs_base, 'tnfs.npz'))
    
rpkms = vamb.vambtools.read_npz(os.path.join(vamb_inputs_base, 'rpkms.npz'))

# Adapt Through DataLoader

In [4]:
depthssum = rpkms.sum(axis=1)
mask = tnfs.sum(axis=1) != 0
mask &= depthssum != 0
depthssum = depthssum[mask]

rpkm = rpkms[mask].astype(np.float32, copy=False)
tnf = tnfs[mask].astype(np.float32, copy=False)

## lkj
def calculate_z_score(array):
    array_mean = array.mean(axis=0)
    array_std = array.std(axis=0)

    shape = np.copy(array.shape)
    shape[0] = 1
    shape = tuple(shape)

    array_mean.shape = shape
    array_mean.shape = shape

    array = (array - array_mean) / array_std
    
    return(array)
    
rpkm = calculate_z_score(rpkm)
tnf = calculate_z_score(tnf)
depthstensor = torch.from_numpy(rpkm)
tnftensor = torch.from_numpy(tnf)

n_workers = 4

dataset = TensorDataset(depthstensor, tnftensor)
dataloader = DataLoader(dataset=dataset, batch_size=256, drop_last=True,
                             shuffle=True, num_workers=n_workers, pin_memory=True)



ncontigs, nsamples = dataset.tensors[0].shape

# VAE Model

In [5]:
class DISENTANGLED_BETA_VAE(torch.nn.Module):
    def __init__(self, nsamples, config):
        super(DISENTANGLED_BETA_VAE, self).__init__()
  
        # SET UP AND CONFIGURE THE MODEL
        self.ntnf = tnfs.shape[1]
        
        self.nlatent = config.nlatent
        self.dropout = config.dropout
        self.learning_rate = config.learning_rate
        self.alpha = config.alpha
        self.beta = config.beta
        self.nepochs = config.nepochs
        
        nhiddens = [512, 512]
        
        self.nsamples = nsamples
        self.cuda_on = False
        
        if self.cuda_on:
            self.cuda()

        self.encoderlayers = torch.nn.ModuleList()
        self.encodernorms = torch.nn.ModuleList()
        self.decoderlayers = torch.nn.ModuleList()
        self.decodernorms = torch.nn.ModuleList()


        # ENCODER LAYERS
        self.encoderlayers.append( torch.nn.Linear((self.nsamples + self.ntnf), 512) )
        self.encodernorms.append( torch.nn.BatchNorm1d(512) )

        self.encoderlayers.append( torch.nn.Linear(512, 512) )
        self.encodernorms.append( torch.nn.BatchNorm1d(512) )


        # LATENT LAYERS
        self.mu = torch.nn.Linear(512, self.nlatent)
        self.logsigma = torch.nn.Linear(512, self.nlatent)


        # DECODER LAYRS
        self.decoderlayers.append(torch.nn.Linear(self.nlatent, 512))
        self.decodernorms.append(torch.nn.BatchNorm1d(512))

        self.decoderlayers.append(torch.nn.Linear(512, 512))
        self.decodernorms.append(torch.nn.BatchNorm1d(512))


        # RECONSTRUCTION LAYER
        self.outputlayer = torch.nn.Linear(512, (self.nsamples + self.ntnf) )


        # ACTIVATIONS
        self.relu = torch.nn.LeakyReLU()
        self.softplus = torch.nn.Softplus()
        self.dropoutlayer = torch.nn.Dropout(p=self.dropout)

        
    ###
    # ENCODE NEW CONTIGS TO LATENT SPACE
    ###
    def encode(self, data_loader):
        self.eval()

        new_data_loader = DataLoader(dataset=data_loader.dataset,
                                      batch_size=data_loader.batch_size,
                                      shuffle=False,
                                      drop_last=False,
                                      num_workers=4,
                                      pin_memory=data_loader.pin_memory)

        depths_array, tnf_array = data_loader.dataset.tensors
        length = len(depths_array)

        latent = np.empty((length, self.nlatent), dtype=np.float32)

        row = 0
        with torch.no_grad():
            for depths, tnf in new_data_loader:
                if self.cuda_on:
                    depths = depths.cuda()
                    tnf = tnf.cuda()

                # Evaluate
                out_depths, out_tnf, mu, logsigma = self(depths, tnf)

                if self.cuda_on:
                    mu = mu.cpu()

                latent[row: row + len(mu)] = mu
                row += len(mu)

        assert row == length
        return latent
    
    ###
    # SPECIFIC ENCODING AND DECODING FUNCTIONS
    ###
    # REPARAMATERIZE
    def reparameterize(self, mu, logsigma):
        epsilon = torch.randn(mu.size(0), mu.size(1))

        if self.cuda_on:
            epsilon = epsilon.cuda()

        epsilon.requires_grad = True

        # See comment above regarding softplus
        latent = mu + epsilon * torch.exp(logsigma/2)

        return latent
    
    
    # ENCODE CONTIGS
    def encode_contigs(self, tensor):
        tensors = list()

        # Hidden layers
        for encoderlayer, encodernorm in zip(self.encoderlayers, self.encodernorms):
            tensor = encodernorm(self.dropoutlayer(self.relu(encoderlayer(tensor))))
            tensors.append(tensor)

        # Latent layers
        mu = self.mu(tensor)
        logsigma = self.softplus(self.logsigma(tensor))

        return mu, logsigma
    
    
    # DECODE CONTIGS
    def decode_contigs(self, tensor):
        tensors = list()

        for decoderlayer, decodernorm in zip(self.decoderlayers, self.decodernorms):
            tensor = decodernorm(self.dropoutlayer(self.relu(decoderlayer(tensor))))
            tensors.append(tensor)

        reconstruction = self.outputlayer(tensor)

        # Decompose reconstruction to depths and tnf signal
        depths_out = reconstruction.narrow(1, 0, self.nsamples)
        tnf_out = reconstruction.narrow(1, self.nsamples, tnfs.shape[1])

        return depths_out, tnf_out
    
    
    ###
    # LOSS CALCULATION
    ###
    # CALCULATE LOSS
    def calc_loss(self, depths_in, depths_out, tnf_in, tnf_out, mu, logsigma):
        ce = (depths_out - depths_in).pow(2).sum(dim=1).mean()
        ce_weight = 1 - 0.15 # alpha

        sse = (tnf_out - tnf_in).pow(2).sum(dim=1).mean()
        kld = -0.5 * (1 + logsigma - mu.pow(2) - logsigma.exp()).sum(dim=1).mean()

        sse_weight = 0.15 / self.ntnf # alpha / ntnf
        # BETA PARAMETER HERE
        kld_weight = 1 / (self.nlatent * self.beta)
        loss = ce * ce_weight + sse * sse_weight + kld * kld_weight

        return loss, ce, sse, kld
    

    ###
    # TRAINING FUNCTIONS
    ###
    # FORWARD
    def forward(self, depths, tnf):
        tensor = torch.cat((depths, tnf), 1)
        mu, logsigma = self.encode_contigs(tensor)
        latent = self.reparameterize(mu, logsigma)
        depths_out, tnf_out = self.decode_contigs(latent)

        return depths_out, tnf_out, mu, logsigma   
        
     
    
    # TRAIN SPECIFIC EPOCH
    def trainepoch(self, data_loader, epoch, optimizer, batchsteps):
        self.train()

        epoch_loss, epoch_kldloss, epoch_sseloss, epoch_celoss = 0, 0, 0, 0

        if epoch in batchsteps:
            data_loader = DataLoader(dataset=data_loader.dataset,
                                      batch_size=data_loader.batch_size * 2,
                                      shuffle=True,
                                      drop_last=True,
                                      num_workers=data_loader.num_workers,
                                      pin_memory=data_loader.pin_memory)

        for depths_in, tnf_in in data_loader:
            depths_in.requires_grad = True
            tnf_in.requires_grad = True

            # CUDA ENABLING
            if self.cuda_on:
                depths_in = depths_in.cuda()
                tnf_in = tnf_in.cuda()

            optimizer.zero_grad()

            depths_out, tnf_out, mu, logsigma = self(depths_in, tnf_in)

            loss, ce, sse, kld = self.calc_loss(depths_in, depths_out, tnf_in,
                                                  tnf_out, mu, logsigma)

            loss.backward()
            optimizer.step()

            epoch_loss = epoch_loss + loss.data.item()
            epoch_kldloss = epoch_kldloss + kld.data.item()
            epoch_sseloss = epoch_sseloss + sse.data.item()
            epoch_celoss = epoch_celoss + ce.data.item()

        print('\tEpoch: {}\tLoss: {:.6f}\tCE: {:.7f}\tSSE: {:.6f}\tKLD: {:.4f}\tBatchsize: {}'.format(
              epoch + 1,
              epoch_loss / len(data_loader),
              epoch_celoss / len(data_loader),
              epoch_sseloss / len(data_loader),
              epoch_kldloss / len(data_loader),
              data_loader.batch_size,
              ))
        wandb.log({
            "epoch": (epoch+1), 
            "loss": epoch_loss / len(data_loader),
            "CELoss": epoch_celoss / len(data_loader),
            "SSELoss": epoch_sseloss / len(data_loader),
            "KLDLoss": epoch_kldloss / len(data_loader),
            "Batchsize": data_loader.batch_size
        })

        return data_loader
    
    
    
    # TRAIN MODEL    
    def trainmodel(self, dataloader, batchsteps=[25, 75, 150, 300], modelfile=None):
        
        batchsteps_set = set()
        
        ncontigs, nsamples = dataloader.dataset.tensors[0].shape
        optimizer = Adam(self.parameters(), lr=self.learning_rate)
        
        
        # TRAIN EPOCH
        for epoch in range(self.nepochs):
            dataloader = self.trainepoch(dataloader, epoch, optimizer, batchsteps_set)


In [9]:
wandb_on = True
sweep_on = True


if wandb_on:
    
    #beta_vae = DISENTANGLED_BETA_VAE(nsamples=rpkms.shape[1], config=config)
    #wandb.watch(beta_vae)
    #beta_vae.trainmodel(dataloader, batchsteps=None)

    if sweep_on:
        sweep_config = {
        'method': 'bayes',
        'metric': {
          'name': 'loss',
          'goal': 'minimize'   
            },
            'parameters': {
                'nepochs': {
                    'values': [5]
                },
                'dropout': {
                    'values': [0.2, 0.4, 0.6]
                },
                'learning_rate': {
                    'values': [1e-2, 1e-3, 1e-4, 1e-5]
                },
                'alpha': {
                    'values': [0.15]
                },
                'beta': {
                    'values': [1, 200, 400, 800]
                },
                'nlatent': {
                    'values': [32, 64]
                }
            }
        }
        
        
        def train():
            config_defaults = {
                'nepochs': 5,
                'dropout': 0.2,
                'learning_rate': 1e-3,
                'alpha': 0.15,
                'beta': 200,
                'nlatent': 32
            }
            
            wandb.init(project='cs_230_vae', entity='pmccaffrey6', config=config_defaults)
            
            config = wandb.config
            beta_vae = DISENTANGLED_BETA_VAE(nsamples=rpkms.shape[1], config=config)
            wandb.watch(beta_vae)
            beta_vae.trainmodel(dataloader, batchsteps=None)
            print('rundir:', wandb.run.dir)
            torch.save(beta_vae.state_dict(), os.path.join(wandb.run.dir, "model.h5"))
            

        sweep_id = wandb.sweep(sweep_config, entity="pmccaffrey6", project="cs_230_vae")
        wandb.agent(sweep_id, function=train)




Create sweep with ID: 4t8iicmb
Sweep URL: https://wandb.ai/pmccaffrey6/cs_230_vae/sweeps/4t8iicmb


[34m[1mwandb[0m: Agent Starting Run: 4v085q82 with config:
[34m[1mwandb[0m: 	alpha: 0.15
[34m[1mwandb[0m: 	beta: 400
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	learning_rate: 0.01
[34m[1mwandb[0m: 	nepochs: 5
[34m[1mwandb[0m: 	nlatent: 32


	Epoch: 1	Loss: 7.789081	CE: 6.8864641	SSE: 136.909744	KLD: 22223.3975	Batchsize: 256
	Epoch: 2	Loss: 1.811571	CE: 1.4561278	SSE: 128.357265	KLD: 4952.7538	Batchsize: 256
	Epoch: 3	Loss: 1.086201	CE: 1.0144000	SSE: 93.712982	KLD: 1119.8180	Batchsize: 256
	Epoch: 4	Loss: 0.942016	CE: 0.9212229	SSE: 83.734911	KLD: 474.0213	Batchsize: 256
	Epoch: 5	Loss: 0.759025	CE: 0.6718147	SSE: 71.402379	KLD: 1075.1774	Batchsize: 256
rundir: /home/pathinformatics/jupyter_projects/vamb/stanford_cs230_project/wandb/run-20210312_124043-4v085q82/files


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,5.0
loss,0.75902
CELoss,0.67181
SSELoss,71.40238
KLDLoss,1075.17744
Batchsize,256.0
_runtime,5.0
_timestamp,1615574448.0
_step,4.0


0,1
epoch,▁▃▅▆█
loss,█▂▁▁▁
CELoss,█▂▁▁▁
SSELoss,█▇▃▂▁
KLDLoss,█▂▁▁▁
Batchsize,▁▁▁▁▁
_runtime,▁▅▅██
_timestamp,▁▅▅██
_step,▁▃▅▆█


[34m[1mwandb[0m: Agent Starting Run: pzq1gak3 with config:
[34m[1mwandb[0m: 	alpha: 0.15
[34m[1mwandb[0m: 	beta: 200
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	learning_rate: 1e-05
[34m[1mwandb[0m: 	nepochs: 5
[34m[1mwandb[0m: 	nlatent: 64
[34m[1mwandb[0m: Currently logged in as: [33mpmccaffrey6[0m (use `wandb login --relogin` to force relogin)


	Epoch: 1	Loss: 1.382310	CE: 1.3895778	SSE: 136.772055	KLD: 25.4240	Batchsize: 256
	Epoch: 2	Loss: 1.286904	CE: 1.2762968	SSE: 137.376575	KLD: 25.4554	Batchsize: 256
	Epoch: 3	Loss: 1.320336	CE: 1.3168200	SSE: 136.684402	KLD: 25.3900	Batchsize: 256
	Epoch: 4	Loss: 1.252808	CE: 1.2375339	SSE: 136.582260	KLD: 25.5730	Batchsize: 256
	Epoch: 5	Loss: 1.307951	CE: 1.3029193	SSE: 136.285968	KLD: 25.5372	Batchsize: 256
rundir: /home/pathinformatics/jupyter_projects/vamb/stanford_cs230_project/wandb/run-20210312_124054-pzq1gak3/files


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,5.0
loss,1.30795
CELoss,1.30292
SSELoss,136.28597
KLDLoss,25.53721
Batchsize,256.0
_runtime,4.0
_timestamp,1615574458.0
_step,4.0


0,1
epoch,▁▃▅▆█
loss,█▃▅▁▄
CELoss,█▃▅▁▄
SSELoss,▄█▄▃▁
KLDLoss,▂▄▁█▇
Batchsize,▁▁▁▁▁
_runtime,▁▅▅▅█
_timestamp,▁▅▅▅█
_step,▁▃▅▆█


[34m[1mwandb[0m: Agent Starting Run: mx66bv25 with config:
[34m[1mwandb[0m: 	alpha: 0.15
[34m[1mwandb[0m: 	beta: 200
[34m[1mwandb[0m: 	dropout: 0.4
[34m[1mwandb[0m: 	learning_rate: 1e-05
[34m[1mwandb[0m: 	nepochs: 5
[34m[1mwandb[0m: 	nlatent: 64


	Epoch: 1	Loss: 1.359319	CE: 1.3629724	SSE: 136.523907	KLD: 25.2294	Batchsize: 256
	Epoch: 2	Loss: 1.284897	CE: 1.2755204	SSE: 136.459439	KLD: 25.3030	Batchsize: 256
	Epoch: 3	Loss: 1.356670	CE: 1.3596638	SSE: 136.636572	KLD: 25.2163	Batchsize: 256
	Epoch: 4	Loss: 1.368580	CE: 1.3750954	SSE: 135.811417	KLD: 25.1574	Batchsize: 256
	Epoch: 5	Loss: 1.308759	CE: 1.3027259	SSE: 136.972925	KLD: 25.1814	Batchsize: 256
rundir: /home/pathinformatics/jupyter_projects/vamb/stanford_cs230_project/wandb/run-20210312_124103-mx66bv25/files


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,5.0
loss,1.30876
CELoss,1.30273
SSELoss,136.97292
KLDLoss,25.18136
Batchsize,256.0
_runtime,3.0
_timestamp,1615574466.0
_step,4.0


0,1
epoch,▁▃▅▆█
loss,▇▁▇█▃
CELoss,▇▁▇█▃
SSELoss,▅▅▆▁█
KLDLoss,▄█▄▁▂
Batchsize,▁▁▁▁▁
_runtime,▁▅▅██
_timestamp,▁▅▅██
_step,▁▃▅▆█


[34m[1mwandb[0m: Agent Starting Run: 4w18k5ry with config:
[34m[1mwandb[0m: 	alpha: 0.15
[34m[1mwandb[0m: 	beta: 400
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	learning_rate: 0.01
[34m[1mwandb[0m: 	nepochs: 5
[34m[1mwandb[0m: 	nlatent: 32


	Epoch: 1	Loss: 6.617038	CE: 6.9429825	SSE: 138.622148	KLD: 6574.4154	Batchsize: 256
	Epoch: 2	Loss: 2.507680	CE: 2.0788482	SSE: 124.260435	KLD: 7164.1227	Batchsize: 256


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


In [13]:
# GET THE BEST MODEL'S PARAMS FROM THE WANDB API
api = wandb.Api()
print()
sweep = api.sweep(f"pmccaffrey6/cs_230_vae/{sweep_id}")

runs = sorted(sweep.runs, key=lambda run: run.summary.get("loss", 0), reverse=True)
loss = runs[0].summary.get("loss", 0)

best_run = runs[0].name

print(f"Best run {best_run} with {loss}% validation accuracy")
best_params_dict = json.loads(runs[0].json_config)
print('best_params_dict:', best_params_dict)

def convert(dictionary):
    return namedtuple('GenericDict', dictionary.keys())(**dictionary)

best_params = convert( {key:best_params_dict[key]['value'] for key in best_params_dict.keys()} )


Best run volcanic-sweep-7 with 1.3094314575195312% validation accuracy
best_params_dict: {'beta': {'value': 200, 'desc': None}, 'alpha': {'value': 0.15, 'desc': None}, 'dropout': {'value': 0.6, 'desc': None}, 'nepochs': {'value': 5, 'desc': None}, 'nlatent': {'value': 32, 'desc': None}, 'learning_rate': {'value': 1e-05, 'desc': None}}


In [21]:
runs[0].file("model.h5").download(replace=True)

<_io.TextIOWrapper name='./model.h5' mode='r' encoding='UTF-8'>

In [29]:
beta_vae = DISENTANGLED_BETA_VAE(nsamples=rpkms.shape[1], config=best_params)
beta_vae.load_state_dict(torch.load('model.h5'))

latent = beta_vae.encode(dataloader)
print("Latent shape:", latent.shape)

latent_output_path = os.path.join(BASE_DIR, 'example_input_data/new_simulations/camisim_outputs/vamb_outputs/latent_space.npy')
with open(latent_output_path, 'wb') as outfile:
    np.save(outfile, latent)

Latent shape: (1342, 32)


# Leverage Other VAMB Tools for Clustering and Post Steps

In [36]:
contig_mapping_table = pd.read_csv(os.path.join(BASE_DIR, f"example_input_data/new_simulations/camisim_outputs/{EXAMPLE_FASTA_FILE}/contigs/gsa_mapping.tsv"), sep='\t')

contig_mapping_output_path = os.path.join(BASE_DIR, 'example_input_data/new_simulations/camisim_outputs/vamb_outputs/encoding_mapping.tsv')    

contig_mapping_table[contig_mapping_table['#anonymous_contig_id'].isin(contignames)].reset_index().drop('index', axis=1).set_index(
    '#anonymous_contig_id').reindex(contignames).to_csv(contig_mapping_output_path, sep='\t')

In [30]:
filtered_labels = [n for (n,m) in zip(contignames, mask) if m]
cluster_iterator = vamb.cluster.cluster(latent, labels=filtered_labels)
clusters = dict(cluster_iterator)

medoid, contigs = next(iter(clusters.items()))
print('First key:', medoid, '(of type:', type(medoid), ')')
print('Type of values:', type(contigs))
print('First element of value:', next(iter(contigs)), 'of type:', type(next(iter(contigs))))

First key: S0C34345 (of type: <class 'numpy.str_'> )
Type of values: <class 'set'>
First element of value: S0C34345 of type: <class 'numpy.str_'>


In [31]:
def filterclusters(clusters, lengthof):
    filtered_bins = dict()
    for medoid, contigs in clusters.items():
        binsize = sum(lengthof[contig] for contig in contigs)
    
        if binsize >= 200000:
            filtered_bins[medoid] = contigs
    
    return filtered_bins
        
lengthof = dict(zip(contignames, lengths))
filtered_bins = filterclusters(vamb.vambtools.binsplit(clusters, 'C'), lengthof)
print('Number of bins before splitting and filtering:', len(clusters))
print('Number of bins after splitting and filtering:', len(filtered_bins))

Number of bins before splitting and filtering: 1163
Number of bins after splitting and filtering: 3


In [32]:
vamb_outputs_base = os.path.join(BASE_DIR, 'example_input_data/new_simulations/camisim_outputs/vamb_outputs')

if not os.path.exists(vamb_outputs_base):
    os.mkdir(vamb_outputs_base)
    

# This writes a .tsv file with the clusters and corresponding sequences
with open(os.path.join(vamb_outputs_base, 'clusters_dvae.tsv'), 'w') as file:
    vamb.vambtools.write_clusters(file, filtered_bins)

# Only keep contigs in any filtered bin in memory
keptcontigs = set.union(*filtered_bins.values())



# decompress fasta.gz if present
fasta_path = os.path.join(BASE_DIR, f"example_input_data/new_simulations/camisim_outputs/{EXAMPLE_FASTA_FILE}/contigs/anonymous_gsa.fasta.gz")
if os.path.exists(fasta_path) and not os.path.exists(fasta_path.replace('.fasta.gz','.fasta')):
    !gzip -dk $fasta_path


with open(os.path.join(BASE_DIR, f"example_input_data/new_simulations/camisim_outputs/{EXAMPLE_FASTA_FILE}/contigs/anonymous_gsa.fasta"), 'rb') as file:
    fastadict = vamb.vambtools.loadfasta(file, keep=keptcontigs)


bindir = os.path.join(vamb_outputs_base, 'dvae_bins')
if not os.path.exists(bindir):
    os.mkdir(bindir)
vamb.vambtools.write_bins(bindir, filtered_bins, fastadict, maxbins=500)