In [71]:
import vamb

import numpy as np
import torch as torch
import torch.nn as nn
from torch.optim import Adam as Adam
from torch.utils.data import DataLoader as DataLoader
from torch.utils.data import TensorDataset

import wandb

from collections import namedtuple

import os

import glob

import json

import sys

import pandas as pd

import ast

torch.manual_seed(0)

<torch._C.Generator at 0x7f9c32c61c90>

# Load Model

In [2]:
class DISENTANGLED_BETA_VAE(torch.nn.Module):
    def __init__(self, nsamples, config):
        super(DISENTANGLED_BETA_VAE, self).__init__()
  
        # SET UP AND CONFIGURE THE MODEL
        self.ntnf = tnfs.shape[1]
        
        self.nlatent = config.nlatent
        self.dropout = config.dropout
        self.learning_rate = config.learning_rate
        self.alpha = config.alpha
        self.beta = config.beta
        self.nepochs = config.nepochs
        
        nhiddens = [512, 512]
        
        self.nsamples = nsamples
        self.cuda_on = False

        self.encoderlayers = torch.nn.ModuleList()
        self.encodernorms = torch.nn.ModuleList()
        self.decoderlayers = torch.nn.ModuleList()
        self.decodernorms = torch.nn.ModuleList()


        # ENCODER LAYERS
        self.encoderlayers.append( torch.nn.Linear((self.nsamples + self.ntnf), 512) )
        self.encodernorms.append( torch.nn.BatchNorm1d(512) )

        self.encoderlayers.append( torch.nn.Linear(512, 512) )
        self.encodernorms.append( torch.nn.BatchNorm1d(512) )


        # LATENT LAYERS
        self.mu = torch.nn.Linear(512, self.nlatent)
        self.logsigma = torch.nn.Linear(512, self.nlatent)


        # DECODER LAYRS
        self.decoderlayers.append(torch.nn.Linear(self.nlatent, 512))
        self.decodernorms.append(torch.nn.BatchNorm1d(512))

        self.decoderlayers.append(torch.nn.Linear(512, 512))
        self.decodernorms.append(torch.nn.BatchNorm1d(512))


        # RECONSTRUCTION LAYER
        self.outputlayer = torch.nn.Linear(512, (self.nsamples + self.ntnf) )


        # ACTIVATIONS
        self.relu = torch.nn.LeakyReLU()
        self.softplus = torch.nn.Softplus()
        self.dropoutlayer = torch.nn.Dropout(p=self.dropout)

        
    ###
    # ENCODE NEW CONTIGS TO LATENT SPACE
    ###
    def encode(self, data_loader):
        self.eval()

        new_data_loader = DataLoader(dataset=data_loader.dataset,
                                      batch_size=data_loader.batch_size,
                                      shuffle=False,
                                      drop_last=False,
                                      num_workers=1,
                                      pin_memory=data_loader.pin_memory)

        depths_array, tnf_array = data_loader.dataset.tensors
        length = len(depths_array)

        latent = np.empty((length, self.nlatent), dtype=np.float32)

        row = 0
        with torch.no_grad():
            for depths, tnf in new_data_loader:
                if self.cuda_on:
                    depths = depths.cuda()
                    tnf = tnf.cuda()

                # Evaluate
                out_depths, out_tnf, mu, logsigma = self(depths, tnf)

                if self.cuda_on:
                    mu = mu.cpu()

                latent[row: row + len(mu)] = mu
                row += len(mu)

        assert row == length
        return latent
    
    ###
    # SPECIFIC ENCODING AND DECODING FUNCTIONS
    ###
    # REPARAMATERIZE
    def reparameterize(self, mu, logsigma):
        epsilon = torch.randn(mu.size(0), mu.size(1))

        if self.cuda_on:
            epsilon = epsilon.cuda()

        epsilon.requires_grad = True

        # See comment above regarding softplus
        latent = mu + epsilon * torch.exp(logsigma/2)

        return latent
    
    
    # ENCODE CONTIGS
    def encode_contigs(self, tensor):
        tensors = list()

        # Hidden layers
        for encoderlayer, encodernorm in zip(self.encoderlayers, self.encodernorms):
            tensor = encodernorm(self.dropoutlayer(self.relu(encoderlayer(tensor))))
            tensors.append(tensor)

        # Latent layers
        mu = self.mu(tensor)
        logsigma = self.softplus(self.logsigma(tensor))

        return mu, logsigma
    
    
    # DECODE CONTIGS
    def decode_contigs(self, tensor):
        tensors = list()

        for decoderlayer, decodernorm in zip(self.decoderlayers, self.decodernorms):
            tensor = decodernorm(self.dropoutlayer(self.relu(decoderlayer(tensor))))
            tensors.append(tensor)

        reconstruction = self.outputlayer(tensor)

        # Decompose reconstruction to depths and tnf signal
        depths_out = reconstruction.narrow(1, 0, self.nsamples)
        tnf_out = reconstruction.narrow(1, self.nsamples, tnfs.shape[1])

        return depths_out, tnf_out
    
    
    ###
    # LOSS CALCULATION
    ###
    # CALCULATE LOSS
    def calc_loss(self, depths_in, depths_out, tnf_in, tnf_out, mu, logsigma):
        ce = (depths_out - depths_in).pow(2).sum(dim=1).mean()
        ce_weight = 1 - 0.15 # alpha

        sse = (tnf_out - tnf_in).pow(2).sum(dim=1).mean()
        kld = -0.5 * (1 + logsigma - mu.pow(2) - logsigma.exp()).sum(dim=1).mean()

        sse_weight = 0.15 / self.ntnf # alpha / ntnf
        # BETA PARAMETER HERE
        kld_weight = 1 / (self.nlatent * self.beta)
        loss = ce * ce_weight + sse * sse_weight + kld * kld_weight

        return loss, ce, sse, kld
    

    ###
    # TRAINING FUNCTIONS
    ###
    # FORWARD
    def forward(self, depths, tnf):
        tensor = torch.cat((depths, tnf), 1)
        mu, logsigma = self.encode_contigs(tensor)
        latent = self.reparameterize(mu, logsigma)
        depths_out, tnf_out = self.decode_contigs(latent)

        return depths_out, tnf_out, mu, logsigma   
        
     
    
    # TRAIN SPECIFIC EPOCH
    def trainepoch(self, data_loader, epoch, optimizer, batchsteps):
        self.train()

        epoch_loss, epoch_kldloss, epoch_sseloss, epoch_celoss = 0, 0, 0, 0

        if epoch in batchsteps:
            data_loader = DataLoader(dataset=data_loader.dataset,
                                      batch_size=data_loader.batch_size * 2,
                                      shuffle=True,
                                      drop_last=True,
                                      num_workers=data_loader.num_workers,
                                      pin_memory=data_loader.pin_memory)

        for depths_in, tnf_in in data_loader:
            depths_in.requires_grad = True
            tnf_in.requires_grad = True

            # CUDE ENABLING
            #depths_in = depths_in.cuda()
            #tnf_in = tnf_in.cuda()

            optimizer.zero_grad()

            depths_out, tnf_out, mu, logsigma = self(depths_in, tnf_in)

            loss, ce, sse, kld = self.calc_loss(depths_in, depths_out, tnf_in,
                                                  tnf_out, mu, logsigma)

            loss.backward()
            optimizer.step()

            epoch_loss = epoch_loss + loss.data.item()
            epoch_kldloss = epoch_kldloss + kld.data.item()
            epoch_sseloss = epoch_sseloss + sse.data.item()
            epoch_celoss = epoch_celoss + ce.data.item()

        print('\tEpoch: {}\tLoss: {:.6f}\tCE: {:.7f}\tSSE: {:.6f}\tKLD: {:.4f}\tBatchsize: {}'.format(
              epoch + 1,
              epoch_loss / len(data_loader),
              epoch_celoss / len(data_loader),
              epoch_sseloss / len(data_loader),
              epoch_kldloss / len(data_loader),
              data_loader.batch_size,
              ))
        wandb.log({
            "epoch": (epoch+1), 
            "loss": epoch_loss / len(data_loader),
            "CELoss": epoch_celoss / len(data_loader),
            "SSELoss": epoch_sseloss / len(data_loader),
            "KLDLoss": epoch_kldloss / len(data_loader),
            "Batchsize": data_loader.batch_size
        })

        return data_loader
    
    
    
    # TRAIN MODEL    
    def trainmodel(self, dataloader, batchsteps=[25, 75, 150, 300], modelfile=None):
        
        batchsteps_set = set()
        
        ncontigs, nsamples = dataloader.dataset.tensors[0].shape
        optimizer = Adam(self.parameters(), lr=self.learning_rate)
        
        
        # TRAIN EPOCH
        for epoch in range(self.nepochs):
            dataloader = self.trainepoch(dataloader, epoch, optimizer, batchsteps_set)

# Load In Data

In [9]:
BASE_DIR = os.getcwd()

In [15]:
def filterclusters(clusters, lengthof):
    filtered_bins = dict()
    print('CLUSTERS:', len(clusters))
    for medoid, contigs in clusters.items():
        binsize = sum(lengthof[contig] for contig in contigs)
    
        if binsize >= 10000:
            filtered_bins[medoid] = contigs
    
    return filtered_bins

In [16]:
SIM_FASTA_FILES =  glob.glob('example_input_data/new_simulations/complexity_sim*/*sample_0*')


for SIM_FASTA_FILE in SIM_FASTA_FILES:
    vamb_inputs_base = os.path.join(SIM_FASTA_FILE,'vamb_inputs')
    
    contignames = vamb.vambtools.read_npz(os.path.join(vamb_inputs_base, 'contignames.npz'))
    lengths = vamb.vambtools.read_npz(os.path.join(vamb_inputs_base, 'lengths.npz'))
    tnfs = vamb.vambtools.read_npz(os.path.join(vamb_inputs_base, 'tnfs.npz'))   
    rpkms = vamb.vambtools.read_npz(os.path.join(vamb_inputs_base, 'rpkms.npz'))
    
    print(tnfs.shape)
    print(rpkms.shape)
    
    
    # ADAPT THROUGH DATALOADER
    depthssum = rpkms.sum(axis=1)
    mask = tnfs.sum(axis=1) != 0
    mask &= depthssum != 0
    depthssum = depthssum[mask]

    rpkm = rpkms[mask].astype(np.float32, copy=False)
    tnf = tnfs[mask].astype(np.float32, copy=False)

    ## lkj
    def calculate_z_score(array):
        array_mean = array.mean(axis=0)
        array_std = array.std(axis=0)

        shape = np.copy(array.shape)
        shape[0] = 1
        shape = tuple(shape)

        array_mean.shape = shape
        array_mean.shape = shape

        array = (array - array_mean) / array_std

        return(array)

    rpkm = calculate_z_score(rpkm)
    tnf = calculate_z_score(tnf)
    depthstensor = torch.from_numpy(rpkm)
    tnftensor = torch.from_numpy(tnf)

    n_workers = 1

    dataset = TensorDataset(depthstensor, tnftensor)
    dataloader = DataLoader(dataset=dataset, batch_size=256, drop_last=True,
                                 shuffle=True, num_workers=n_workers, pin_memory=False)



    ncontigs, nsamples = dataset.tensors[0].shape
    
    
    # RUN BETA VAE
    best_params_dict = {
        'nepochs': 5,
        'dropout': 0.2,
        'learning_rate': 1e-3,
        'alpha': 0.15,
        'beta': 800,
        'nlatent': 32
    }

    best_params = namedtuple('GenericDict', best_params_dict.keys())(**best_params_dict)

    beta_vae = DISENTANGLED_BETA_VAE(nsamples=rpkms.shape[1], config=best_params)
    beta_vae.load_state_dict(torch.load('model.h5'))

    latent = beta_vae.encode(dataloader)
    print("Latent shape:", latent.shape)

    latent_output_path = os.path.join(SIM_FASTA_FILE, 'latent_space.npy')
    with open(latent_output_path, 'wb') as outfile:
        np.save(outfile, latent)
        
        
    # CONTIG MAPPING
    contig_mapping_table_path = os.path.join(BASE_DIR, f"{SIM_FASTA_FILE}/contigs/gsa_mapping.tsv")
    contig_mapping_table_comp = contig_mapping_table_path + '.gz'
    if (not os.path.exists(contig_mapping_table_path)) and (os.path.exists(contig_mapping_table_comp)):
        !gzip -d $contig_mapping_table_comp
        
    contig_mapping_table = pd.read_csv(contig_mapping_table_path, sep='\t')
    

    contig_mapping_output_path = os.path.join(BASE_DIR, f"{SIM_FASTA_FILE}/encoding_mapping.tsv")    

    contig_mapping_table[contig_mapping_table['#anonymous_contig_id'].isin(contignames)].reset_index().drop('index', axis=1).set_index(
        '#anonymous_contig_id').reindex(contignames).to_csv(contig_mapping_output_path, sep='\t')
    
    
    filtered_labels = [n for (n,m) in zip(contignames, mask) if m]
    cluster_iterator = vamb.cluster.cluster(latent, labels=filtered_labels)
    clusters = dict(cluster_iterator)

    medoid, contigs = next(iter(clusters.items()))
    print('First key:', medoid, '(of type:', type(medoid), ')')
    print('Type of values:', type(contigs))
    print('First element of value:', next(iter(contigs)), 'of type:', type(next(iter(contigs))))
    
    
    # FILTER CLUSTERS  
    lengthof = dict(zip(contignames, lengths))
    filtered_bins = filterclusters(vamb.vambtools.binsplit(clusters, 'C'), lengthof)
    print('Number of bins before splitting and filtering:', len(clusters))
    print('Number of bins after splitting and filtering:', len(filtered_bins))
    
    
    # SAVE OUTPUTS
    vamb_outputs_base = os.path.join(BASE_DIR, SIM_FASTA_FILE)


    # This writes a .tsv file with the clusters and corresponding sequences
    with open(os.path.join(vamb_outputs_base, 'clusters_dvae.tsv'), 'w') as file:
        vamb.vambtools.write_clusters(file, filtered_bins)

    # Only keep contigs in any filtered bin in memory
    keptcontigs = set.union(*filtered_bins.values())

    # decompress fasta.gz if present
    fasta_path = os.path.join(BASE_DIR, f"{SIM_FASTA_FILE}/contigs/anonymous_gsa.fasta.gz")
    if os.path.exists(fasta_path) and not os.path.exists(fasta_path.replace('.fasta.gz','.fasta')):
        !gzip -dk $fasta_path


    with open(os.path.join(BASE_DIR, f"{SIM_FASTA_FILE}/contigs/anonymous_gsa.fasta"), 'rb') as file:
        fastadict = vamb.vambtools.loadfasta(file, keep=keptcontigs)

    bindir = os.path.join(vamb_outputs_base, 'dvae_bins')
    if not os.path.exists(bindir):
        os.mkdir(bindir)
    vamb.vambtools.write_bins(bindir, filtered_bins, fastadict, maxbins=500)
    
    
    # RUN CHECKM
    CHECKM_OUTDIR = os.path.join(BASE_DIR, SIM_FASTA_FILE, 'checkm_results')

    if not os.path.exists(CHECKM_OUTDIR):
        os.mkdir(CHECKM_OUTDIR)

        
    bins_inpath = os.path.join(BASE_DIR, SIM_FASTA_FILE, 'dvae_bins')
    bins_inpath_clean = os.path.join(BASE_DIR, SIM_FASTA_FILE, 'dvae_bins_clean')
    
    if not os.path.exists(bins_inpath_clean):
        os.mkdir(bins_inpath_clean)

    for bin_file in glob.glob(os.path.join(bins_inpath,'*')):
        bin_outfile = bin_file.replace('dvae_bins','dvae_bins_clean')
        !sed -e 's/\r$//' $bin_file > $bin_outfile

    !~/miniconda3/envs/vamb_env/bin/checkm lineage_wf -t 32 -x fna $bins_inpath_clean $CHECKM_OUTDIR

(774, 103)
(774, 1)
Latent shape: (774, 32)
First key: S0C34658 (of type: <class 'numpy.str_'> )
Type of values: <class 'set'>
First element of value: S0C34658 of type: <class 'numpy.str_'>
CLUSTERS: 567
Number of bins before splitting and filtering: 567
Number of bins after splitting and filtering: 6
[2021-03-14 21:21:05] INFO: CheckM v1.1.3
[2021-03-14 21:21:05] INFO: checkm lineage_wf -t 32 -x fna /home/pathinformatics/jupyter_projects/vamb/stanford_cs230_project/example_input_data/new_simulations/complexity_sim_10_genera_250_genomes/2021.03.14_07.34.03_sample_0/dvae_bins_clean /home/pathinformatics/jupyter_projects/vamb/stanford_cs230_project/example_input_data/new_simulations/complexity_sim_10_genera_250_genomes/2021.03.14_07.34.03_sample_0/checkm_results
[2021-03-14 21:21:05] INFO: [CheckM - tree] Placing bins in reference genome tree.
[2021-03-14 21:21:05] INFO: Identifying marker genes in 23 bins with 32 threads:
    Finished processing 23 of 23 (100.00%) bins.
[2021-03-14 21:2

# Compile Outputs

In [86]:
SIM_FASTA_FILES =  glob.glob('example_input_data/new_simulations/complexity_sim*/*sample_0*')

num_generas = []
num_genomes = []
num_bins = []
num_mappable_bins = []
mean_completenesses = []
mean_contaminations = []


for SIM_FASTA_FILE in SIM_FASTA_FILES:
    num_genera, num_genome = SIM_FASTA_FILE.split('/')[-2].split('_')[2], SIM_FASTA_FILE.split('/')[-2].split('_')[4]
    
    checkm_data = os.path.join(SIM_FASTA_FILE, 'checkm_results')

    num_bin = len(glob.glob(os.path.join(checkm_data,'bins','*')))
    
    t1 = pd.read_csv( os.path.join(checkm_data, 'storage', 'bin_stats_ext.tsv'), sep ='\t', names=['contig', 'data'])
    t1 = t1[~t1['data'].str.contains("'Completeness': 0.0")]
    
    num_mappable_bin = len(t1)
    
    completeness = np.array([ast.literal_eval(i)['Completeness'] for i in t1['data'].values])
    contamination = np.array([ast.literal_eval(i)['Contamination'] for i in t1['data'].values])
    
    mean_completeness = completeness.mean()
    mean_contamination = contamination.mean()
    
    num_generas.append(num_genera)
    num_genomes.append(num_genome)
    num_bins.append(num_bin)
    num_mappable_bins.append(num_mappable_bin)
    mean_completenesses.append(mean_completeness)
    mean_contaminations.append(mean_contamination)
   

experiment_beta = best_params_dict['beta']    
experiment_stats = pd.DataFrame(data={
    'num_genera':num_generas,
    'num_genomes':num_genomes,
    'num_bins':num_bins,
    'num_mappable_bins':num_mappable_bins,
    'mean_completeness':mean_completenesses,
    'mean_contamination':mean_contaminations,
    'beta':experiment_beta,
    'dropout':best_params_dict['dropout']
})

experiment_stats.to_csv(f"experiment_stats_beta{experiment_beta}.tsv", sep='\t', index=False)