In [42]:
import vamb

import numpy as np
import pandas as pd

import torch as torch
import torch.nn as nn
from torch.optim import Adam as Adam
from torch.utils.data import DataLoader as DataLoader

import os

import glob

import sys

torch.manual_seed(0)

<torch._C.Generator at 0x7f1810085c50>

In [2]:
BASE_DIR = os.getcwd()

EXAMPLE_FASTA_FILE = '2021.01.26_15.46.45_sample_0'

In [3]:
vamb_inputs_base = os.path.join(BASE_DIR,'example_input_data/new_simulations/camisim_outputs/vamb_inputs')

contignames = vamb.vambtools.read_npz(os.path.join(vamb_inputs_base, 'contignames.npz'))

lengths = vamb.vambtools.read_npz(os.path.join(vamb_inputs_base, 'lengths.npz'))

tnfs = vamb.vambtools.read_npz(os.path.join(vamb_inputs_base, 'tnfs.npz'))
    
rpkms = vamb.vambtools.read_npz(os.path.join(vamb_inputs_base, 'rpkms.npz'))

In [4]:
dataloader, mask = vamb.encode.make_dataloader(rpkms, tnfs)

In [5]:
ncontigs, nsamples = dataloader.dataset.tensors[0].shape

In [7]:
print("Number of contigs to encode:", ncontigs)

Number of contigs to encode: 1342


In [24]:
class DISENTANGLED_VAE(torch.nn.Module):
    def __init__(self, nsamples, nhiddens=[512, 512], nlatent=32, alpha=0.15,
                 beta=1, dropout=0.2):
        

        super(DISENTANGLED_VAE, self).__init__()

        # Initialize simple attributes
        self.usecuda = True
        self.nsamples = nsamples
        self.ntnf = 103
        self.alpha = alpha
        self.beta = beta
        self.nhiddens = nhiddens
        self.nlatent = nlatent
        self.dropout = dropout

        # Initialize lists for holding hidden layers
        self.encoderlayers = torch.nn.ModuleList()
        self.encodernorms = torch.nn.ModuleList()
        self.decoderlayers = torch.nn.ModuleList()
        self.decodernorms = torch.nn.ModuleList()

        # Add all other hidden layers
        for nin, nout in zip([self.nsamples + self.ntnf] + self.nhiddens, self.nhiddens):
            self.encoderlayers.append(torch.nn.Linear(nin, nout))
            self.encodernorms.append(torch.nn.BatchNorm1d(nout))

        # Latent layers
        self.mu = torch.nn.Linear(self.nhiddens[-1], self.nlatent)
        self.logsigma = torch.nn.Linear(self.nhiddens[-1], self.nlatent)

        # Add first decoding layer
        for nin, nout in zip([self.nlatent] + self.nhiddens[::-1], self.nhiddens[::-1]):
            self.decoderlayers.append(torch.nn.Linear(nin, nout))
            self.decodernorms.append(torch.nn.BatchNorm1d(nout))

        # Reconstruction (output) layer
        self.outputlayer = torch.nn.Linear(self.nhiddens[0], self.nsamples + self.ntnf)

        # Activation functions
        self.relu = torch.nn.LeakyReLU()
        self.softplus = torch.nn.Softplus()
        self.dropoutlayer = torch.nn.Dropout(p=self.dropout)

        if self.usecuda:
            self.cuda()

    def _encode(self, tensor):
        tensors = list()

        # Hidden layers
        for encoderlayer, encodernorm in zip(self.encoderlayers, self.encodernorms):
            tensor = encodernorm(self.dropoutlayer(self.relu(encoderlayer(tensor))))
            tensors.append(tensor)

        # Latent layers
        mu = self.mu(tensor)

        logsigma = self.softplus(self.logsigma(tensor))

        return mu, logsigma

    # sample with gaussian noise
    def reparameterize(self, mu, logsigma):
        epsilon = torch.randn(mu.size(0), mu.size(1))

        if self.usecuda:
            epsilon = epsilon.cuda()

        epsilon.requires_grad = True

        # See comment above regarding softplus
        latent = mu + epsilon * torch.exp(logsigma/2)

        return latent

    def _decode(self, tensor):
        tensors = list()

        for decoderlayer, decodernorm in zip(self.decoderlayers, self.decodernorms):
            tensor = decodernorm(self.dropoutlayer(self.relu(decoderlayer(tensor))))
            tensors.append(tensor)

        reconstruction = self.outputlayer(tensor)

        # Decompose reconstruction to depths and tnf signal
        depths_out = reconstruction.narrow(1, 0, self.nsamples)
        tnf_out = reconstruction.narrow(1, self.nsamples, self.ntnf)

        # If multiple samples, apply softmax
        if self.nsamples > 1:
            depths_out = _softmax(depths_out, dim=1)

        return depths_out, tnf_out

    def forward(self, depths, tnf):
        tensor = torch.cat((depths, tnf), 1)
        mu, logsigma = self._encode(tensor)
        latent = self.reparameterize(mu, logsigma)
        depths_out, tnf_out = self._decode(latent)

        return depths_out, tnf_out, mu, logsigma

    def calc_loss(self, depths_in, depths_out, tnf_in, tnf_out, mu, logsigma):
        # If multiple samples, use cross entropy, else use SSE for abundance
        if self.nsamples > 1:
            # Add 1e-9 to depths_out to avoid numerical instability.
            ce = - ((depths_out + 1e-9).log() * depths_in).sum(dim=1).mean()
            ce_weight = (1 - self.alpha) / _log(self.nsamples)
        else:
            ce = (depths_out - depths_in).pow(2).sum(dim=1).mean()
            ce_weight = 1 - self.alpha

        sse = (tnf_out - tnf_in).pow(2).sum(dim=1).mean()
        kld = -0.5 * (1 + logsigma - mu.pow(2) - logsigma.exp()).sum(dim=1).mean()

        sse_weight = self.alpha / self.ntnf
        kld_weight = 1 / (self.nlatent * self.beta)
        loss = ce * ce_weight + sse * sse_weight + kld * kld_weight

        return loss, ce, sse, kld

    def trainepoch(self, data_loader, epoch, optimizer, batchsteps, logfile):
        self.train()

        epoch_loss = 0
        epoch_kldloss = 0
        epoch_sseloss = 0
        epoch_celoss = 0

        if epoch in batchsteps:
            data_loader = DataLoader(dataset=data_loader.dataset,
                                      batch_size=data_loader.batch_size * 2,
                                      shuffle=True,
                                      drop_last=True,
                                      num_workers=data_loader.num_workers,
                                      pin_memory=data_loader.pin_memory)

        for depths_in, tnf_in in data_loader:
            depths_in.requires_grad = True
            tnf_in.requires_grad = True

            if self.usecuda:
                depths_in = depths_in.cuda()
                tnf_in = tnf_in.cuda()

            optimizer.zero_grad()

            depths_out, tnf_out, mu, logsigma = self(depths_in, tnf_in)

            loss, ce, sse, kld = self.calc_loss(depths_in, depths_out, tnf_in,
                                                  tnf_out, mu, logsigma)

            loss.backward()
            optimizer.step()

            epoch_loss += loss.data.item()
            epoch_kldloss += kld.data.item()
            epoch_sseloss += sse.data.item()
            epoch_celoss += ce.data.item()

        print('\tEpoch: {}\tLoss: {:.6f}\tCE: {:.7f}\tSSE: {:.6f}\tKLD: {:.4f}\tBatchsize: {}'.format(
              epoch + 1,
              epoch_loss / len(data_loader),
              epoch_celoss / len(data_loader),
              epoch_sseloss / len(data_loader),
              epoch_kldloss / len(data_loader),
              data_loader.batch_size,
              ), file=logfile)

        return data_loader

    def encode(self, data_loader):
        """Encode a data loader to a latent representation with VAE
        Input: data_loader: As generated by train_vae
        Output: A (n_contigs x n_latent) Numpy array of latent repr.
        """

        self.eval()

        new_data_loader = DataLoader(dataset=data_loader.dataset,
                                      batch_size=data_loader.batch_size,
                                      shuffle=False,
                                      drop_last=False,
                                      num_workers=1,
                                      pin_memory=data_loader.pin_memory)

        depths_array, tnf_array = data_loader.dataset.tensors
        length = len(depths_array)

        # We make a Numpy array instead of a Torch array because, if we create
        # a Torch array, then convert it to Numpy, Numpy will believe it doesn't
        # own the memory block, and array resizes will not be permitted.
        latent = np.empty((length, self.nlatent), dtype=np.float32)

        row = 0
        with torch.no_grad():
            for depths, tnf in new_data_loader:
                # Move input to GPU if requested
                if self.usecuda:
                    depths = depths.cuda()
                    tnf = tnf.cuda()

                # Evaluate
                out_depths, out_tnf, mu, logsigma = self(depths, tnf)

                if self.usecuda:
                    mu = mu.cpu()

                latent[row: row + len(mu)] = mu
                row += len(mu)

        assert row == length
        return latent

    def save(self, filehandle):
        """Saves the VAE to a path or binary opened file. Load with VAE.load
        Input: Path or binary opened filehandle
        Output: None
        """
        state = {'nsamples': self.nsamples,
                 'alpha': self.alpha,
                 'beta': self.beta,
                 'dropout': self.dropout,
                 'nhiddens': self.nhiddens,
                 'nlatent': self.nlatent,
                 'state': self.state_dict(),
                }

        torch.save(state, filehandle)

    @classmethod
    def load(cls, path, cuda=False, evaluate=True):
        """Instantiates a VAE from a model file.
        Inputs:
            path: Path to model file as created by functions VAE.save or
                  VAE.trainmodel.
            cuda: If network should work on GPU [False]
            evaluate: Return network in evaluation mode [True]
        Output: VAE with weights and parameters matching the saved network.
        """

        # Forcably load to CPU even if model was saves as GPU model
        dictionary = _torch.load(path, map_location=lambda storage, loc: storage)

        nsamples = dictionary['nsamples']
        alpha = dictionary['alpha']
        beta = dictionary['beta']
        dropout = dictionary['dropout']
        nhiddens = dictionary['nhiddens']
        nlatent = dictionary['nlatent']
        state = dictionary['state']

        vae = cls(nsamples, nhiddens, nlatent, alpha, beta, dropout, cuda)
        vae.load_state_dict(state)

        if cuda:
            vae.cuda()

        if evaluate:
            vae.eval()

        return vae

    def trainmodel(self, dataloader, nepochs=500, lrate=1e-3,
                   batchsteps=[25, 75, 150, 300], logfile=None, modelfile=None):

        if batchsteps is None:
            batchsteps_set = set()
        else:
            # First collect to list in order to allow all element types, then check that
            # they are integers
            batchsteps = list(batchsteps)
            last_batchsize = dataloader.batch_size * 2**len(batchsteps)
            batchsteps_set = set(batchsteps)

        # Get number of features
        ncontigs, nsamples = dataloader.dataset.tensors[0].shape
        optimizer = Adam(self.parameters(), lr=lrate)

        if logfile is not None:
            print('\tNetwork properties:', file=logfile)
            print('\tCUDA:', self.usecuda, file=logfile)
            print('\tAlpha:', self.alpha, file=logfile)
            print('\tBeta:', self.beta, file=logfile)
            print('\tDropout:', self.dropout, file=logfile)
            print('\tN hidden:', ', '.join(map(str, self.nhiddens)), file=logfile)
            print('\tN latent:', self.nlatent, file=logfile)
            print('\n\tTraining properties:', file=logfile)
            print('\tN epochs:', nepochs, file=logfile)
            print('\tStarting batch size:', dataloader.batch_size, file=logfile)
            batchsteps_string = ', '.join(map(str, sorted(batchsteps))) if batchsteps_set else "None"
            print('\tBatchsteps:', batchsteps_string, file=logfile)
            print('\tLearning rate:', lrate, file=logfile)
            print('\tN sequences:', ncontigs, file=logfile)
            print('\tN samples:', nsamples, file=logfile, end='\n\n')

        # Train
        for epoch in range(nepochs):
            dataloader = self.trainepoch(dataloader, epoch, optimizer, batchsteps_set, logfile)

        # Save weights - Lord forgive me, for I have sinned when catching all exceptions
        if modelfile is not None:
            try:
                self.save(modelfile)
            except:
                pass

        return None

In [53]:
dvae = DISENTANGLED_VAE(nsamples=rpkms.shape[1], beta=500)

with open('vamb_models/model_disentangled_vae.pt', 'wb') as modelfile:
    dvae.trainmodel(dataloader, nepochs=300, modelfile=modelfile, batchsteps=None, logfile=sys.stdout)

	Network properties:
	CUDA: True
	Alpha: 0.15
	Beta: 500
	Dropout: 0.2
	N hidden: 512, 512
	N latent: 32

	Training properties:
	N epochs: 300
	Starting batch size: 256
	Batchsteps: None
	Learning rate: 0.001
	N sequences: 1342
	N samples: 1

	Epoch: 1	Loss: 1.051648	CE: 1.0207675	SSE: 125.618611	KLD: 16.8946	Batchsize: 256
	Epoch: 2	Loss: 0.764932	CE: 0.7391028	SSE: 92.642772	KLD: 28.4512	Batchsize: 256
	Epoch: 3	Loss: 0.549276	CE: 0.5172049	SSE: 73.568968	KLD: 40.2013	Batchsize: 256
	Epoch: 4	Loss: 0.430502	CE: 0.3867863	SSE: 67.739410	KLD: 49.3478	Batchsize: 256
	Epoch: 5	Loss: 0.383236	CE: 0.3369763	SSE: 63.406117	KLD: 71.4713	Batchsize: 256
	Epoch: 6	Loss: 0.331206	CE: 0.2884325	SSE: 56.460612	KLD: 61.0259	Batchsize: 256
	Epoch: 7	Loss: 0.291122	CE: 0.2460990	SSE: 53.417184	KLD: 66.3411	Batchsize: 256
	Epoch: 8	Loss: 0.265027	CE: 0.2204272	SSE: 50.279342	KLD: 71.0589	Batchsize: 256
	Epoch: 9	Loss: 0.214108	CE: 0.1627020	SSE: 48.901179	KLD: 73.5299	Batchsize: 256
	Epoch: 10	Loss: 0

In [60]:
latent = dvae.encode(dataloader)

print(latent.shape)

latent_output_path = os.path.join(BASE_DIR, 'example_input_data/new_simulations/camisim_outputs/vamb_outputs/latent_space.npy')
with open(latent_output_path, 'wb') as outfile:
    np.save(outfile, latent)

(1342, 32)


In [61]:
filtered_labels = [n for (n,m) in zip(contignames, mask) if m]
cluster_iterator = vamb.cluster.cluster(latent, labels=filtered_labels)
clusters = dict(cluster_iterator)

medoid, contigs = next(iter(clusters.items()))
print('First key:', medoid, '(of type:', type(medoid), ')')
print('Type of values:', type(contigs))
print('First element of value:', next(iter(contigs)), 'of type:', type(next(iter(contigs))))

First key: S0C50940 (of type: <class 'numpy.str_'> )
Type of values: <class 'set'>
First element of value: S0C50940 of type: <class 'numpy.str_'>


In [62]:
latent

array([[ 3.7632344 ,  0.9128553 , -0.99542826, ...,  3.4221175 ,
         2.475425  , -0.8712864 ],
       [-1.8906579 , -0.74544764, -0.12259644, ..., -0.789257  ,
        -3.511978  ,  0.4764389 ],
       [ 0.21262881, -1.123332  , -0.57075024, ..., -0.7388763 ,
        -3.0588045 ,  0.7617666 ],
       ...,
       [-1.2402482 ,  1.0010657 ,  0.03230138, ..., -1.0659124 ,
        -2.618636  ,  1.291593  ],
       [ 4.1214113 , -1.9292675 ,  0.835457  , ...,  4.6487226 ,
         0.13279133,  0.4252876 ],
       [ 1.7570639 , -0.84198314,  1.0179619 , ..., -0.31993234,
        -0.5479913 ,  0.8817526 ]], dtype=float32)

In [63]:
def filterclusters(clusters, lengthof):
    filtered_bins = dict()
    for medoid, contigs in clusters.items():
        binsize = sum(lengthof[contig] for contig in contigs)
    
        if binsize >= 50000:
            filtered_bins[medoid] = contigs
    
    return filtered_bins
        
lengthof = dict(zip(contignames, lengths))
filtered_bins = filterclusters(vamb.vambtools.binsplit(clusters, 'C'), lengthof)
print('Number of bins before splitting and filtering:', len(clusters))
print('Number of bins after splitting and filtering:', len(filtered_bins))

Number of bins before splitting and filtering: 534
Number of bins after splitting and filtering: 10


In [64]:
contig_mapping_table = pd.read_csv(os.path.join(BASE_DIR, f"example_input_data/new_simulations/camisim_outputs/{EXAMPLE_FASTA_FILE}/contigs/gsa_mapping.tsv"), sep='\t')

contig_mapping_output_path = os.path.join(BASE_DIR, 'example_input_data/new_simulations/camisim_outputs/vamb_outputs/encoding_mapping.tsv')    

contig_mapping_table[contig_mapping_table['#anonymous_contig_id'].isin(contignames)].reset_index().drop('index', axis=1).set_index(
    '#anonymous_contig_id').reindex(contignames).to_csv(contig_mapping_output_path, sep='\t')

In [65]:
vamb_outputs_base = os.path.join(BASE_DIR, 'example_input_data/new_simulations/camisim_outputs/vamb_outputs')

if not os.path.exists(vamb_outputs_base):
    os.mkdir(vamb_outputs_base)
    

# This writes a .tsv file with the clusters and corresponding sequences
with open(os.path.join(vamb_outputs_base, 'clusters_dvae.tsv'), 'w') as file:
    vamb.vambtools.write_clusters(file, filtered_bins)

# Only keep contigs in any filtered bin in memory
keptcontigs = set.union(*filtered_bins.values())



# decompress fasta.gz if present
fasta_path = os.path.join(BASE_DIR, f"example_input_data/new_simulations/camisim_outputs/{EXAMPLE_FASTA_FILE}/contigs/anonymous_gsa.fasta.gz")
if os.path.exists(fasta_path) and not os.path.exists(fasta_path.replace('.fasta.gz','.fasta')):
    !gzip -dk $fasta_path


with open(os.path.join(BASE_DIR, f"example_input_data/new_simulations/camisim_outputs/{EXAMPLE_FASTA_FILE}/contigs/anonymous_gsa.fasta"), 'rb') as file:
    fastadict = vamb.vambtools.loadfasta(file, keep=keptcontigs)


bindir = os.path.join(vamb_outputs_base, 'dvae_bins')
if not os.path.exists(bindir):
    os.mkdir(bindir)
    
    
files = glob.glob(os.path.join(bindir,'*'))
for f in files:
    os.remove(f)

vamb.vambtools.write_bins(bindir, filtered_bins, fastadict, maxbins=500)