In [12]:
# based on clonealign framework, re-write with pyro
import os
from collections import defaultdict
import torch
import numpy as np
import pandas as pd
import scipy.stats
from torch.distributions import constraints
from torch.nn import Softplus
from matplotlib import pyplot
%matplotlib inline

import pyro
import pyro.distributions as dist
from pyro.distributions.util import broadcast_shape
from pyro import poutine
from pyro.infer.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.ops.indexing import Vindex

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.5.1')
pyro.enable_validation(True)

In [2]:
# input data

expr_csv = pd.read_csv('data/SPECTRUM-OV-022_expr_clonealign_input.csv', header = 0, index_col=0)
cnv_csv = pd.read_csv('data/SPECTRUM-OV-022_cnv_clonealign_input.csv', header = 0, index_col=0)

# cast cnv greater than 6
cnv = torch.tensor(cnv_csv.values, dtype=torch.float)
cnv = torch.transpose(cnv, 0, 1)

cnv[cnv > 6] = 6

expr = torch.tensor(expr_csv.values, dtype = torch.float)
expr = torch.transpose(expr, 0, 1)

In [13]:
# input data: cnv, expr
# cnv: clone_count * gene_count
# expr: cell_count * gene_count

# use the 
def inverse_softplus(x):
    return x + torch.log(-torch.expm1(-x))

@config_enumerate
def clonealign_pyro(cnv, expr):
    num_of_clones = len(cnv)
    num_of_cells = len(expr)
    num_of_genes = len(expr[0])

    softplus = Softplus()

    # initialize per_copy_expr using the data (This typically speeds up convergence)
    expr = expr * 2000 / torch.reshape(torch.sum(expr, 1), (num_of_cells, 1))
    per_copy_expr_guess = torch.mean(expr, 0)

    # calculate copy number mean
    copy_number_mean = torch.mean(cnv, 0)

    # draw chi from gamma
    chi = pyro.sample('chi', dist.Gamma(torch.ones(6) * 2, torch.ones(6)).to_event(1))

    with pyro.plate('gene', num_of_genes):
        # draw per_copy_expr from softplus-transformed Normal distribution
        per_copy_expr = pyro.sample('per_copy_expr',
                                    dist.Normal(inverse_softplus(per_copy_expr_guess), torch.ones(num_of_genes)))
        

        per_copy_expr = softplus(per_copy_expr)
        
        # instead of softplus-transformed normal, use negative binomial instead for per_copy_expr
        # per_copy_expr = pyro.sample('per_copy_expr', dist.NegativeBinomial())

        # draw w from Normal
        w = pyro.sample('w', dist.Normal(torch.zeros(6), torch.sqrt(chi)).to_event(1))

        # sample the gene_type_score from uniform distribution.
        # the score reflects how much the copy number influence expression.
        gene_type_score = pyro.sample('gene_type_score', dist.Dirichlet(torch.ones(2) * 0.1))

    #gene_type = pyro.sample('gene_type', dist.Bernoulli(probs = gene_type_score))

    with pyro.plate('cell', num_of_cells):
        # draw clone_assign_prob from Dir
        clone_assign_prob = pyro.sample('clone_assign_prob', dist.Dirichlet(torch.ones(num_of_clones) * 0.1))
        # draw clone_assign from Cat
        clone_assign = pyro.sample('clone_assign', dist.Categorical(clone_assign_prob))

        # draw psi from Normal
        psi = pyro.sample('psi', dist.Normal(torch.zeros(6), torch.ones(6)).to_event(1))        

        expected_expr = per_copy_expr * (
                    Vindex(cnv)[clone_assign] * gene_type_score[:, 0] + copy_number_mean * gene_type_score[:, 1]) * torch.exp(
            torch.matmul(psi, torch.transpose(w, 0, 1)))

        print(expected_expr.shape)
        print(expr.shape)
        # draw expr from Multinomial
        pyro.sample('obs', dist.Multinomial(total_count = 2000, probs=expected_expr, validate_args=False), obs=expr)
        
        

In [14]:
# initialize Adam optimizer
optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})

# TraceEnum_ELBO will marginalize out the assignments of datapoints to clusters
elbo = TraceEnum_ELBO(max_plate_nesting=1)

pyro.clear_param_store()

# AutoGuide
global_guide = AutoDelta(poutine.block(clonealign_pyro, \
                                       expose = ['chi', 'per_copy_expr', 'w', 'k', 'gene_type_score', \
                                                 'clone_assign_prob', 'psi']))
# put together SVI object
svi = SVI(clonealign_pyro, global_guide, optim, loss=elbo)

In [15]:
guide_trace = poutine.trace(global_guide).get_trace(cnv, expr)
model_trace = poutine.trace(poutine.replay(poutine.enum(clonealign_pyro, -3), guide_trace)).get_trace(cnv, expr)
model_trace.compute_log_prob()

with open('test.txt', 'w') as f:
    f.write(model_trace.format_shapes())

torch.Size([5139, 4059])
torch.Size([5139, 4059])


In [16]:
gradient_norms = defaultdict(list)
svi.loss(clonealign_pyro, global_guide, cnv, expr)  # Initializes param store.

torch.Size([5139, 4059])


12474650.0

In [159]:
losses = []
max_iter = 200
rel_tol = 1e-5
print('Start Inference.')
for i in range(max_iter if not smoke_test else 2):
    loss = svi.step(cnv, expr)
    
    if i >= 1:
        loss_diff = abs((losses[-1] - loss)/losses[-1])
        if loss_diff < rel_tol:
            print('ELBO converged at iteration ' + str(i))
            break
    
    losses.append(loss)
    
    print('.' if i % 200 else '\n', end='')

Start Inference.

.......................................................................................................................................................................................................

In [160]:
map_estimates = global_guide(cnv, expr)

clone_assign_prob = map_estimates['clone_assign_prob']
gene_type_score = map_estimates['gene_type_score']

per_copy_expr = map_estimates['per_copy_expr']
psi = map_estimates['psi']
chi = map_estimates['chi']
w = map_estimates['w']

In [161]:
clone_assign_prob

tensor([[6.6291e-19, 6.6291e-19, 1.0000e+00],
        [1.0000e+00, 6.6288e-19, 6.6281e-19],
        [1.0000e+00, 6.6283e-19, 6.6266e-19],
        ...,
        [8.1535e-19, 5.6970e-19, 1.0000e+00],
        [6.6291e-19, 1.0000e+00, 6.6291e-19],
        [1.0000e+00, 6.6291e-19, 6.6291e-19]], grad_fn=<ExpandBackward>)

In [162]:
def clonealign_pyro_simulation(cnv, expr, per_copy_expr, psi, chi, w):
    num_of_clones = len(cnv)
    num_of_cells = len(expr)
    num_of_genes = len(expr[0])

    softplus = Softplus()
    
    # calculate copy number mean
    copy_number_mean = torch.mean(cnv, 0)
    
    per_copy_expr = softplus(per_copy_expr)


    with pyro.plate('gene', num_of_genes):
        # sample the gene_type_score from uniform distribution.
        # the score reflects how much the copy number influence expression.
        gene_type_score = pyro.sample('gene_type_score', dist.Dirichlet(torch.ones(2) * 0.1))

    #gene_type = pyro.sample('gene_type', dist.Bernoulli(probs = gene_type_score))

    with pyro.plate('cell', num_of_cells):
        # draw clone_assign_prob from Dir
        clone_assign_prob = pyro.sample('clone_assign_prob', dist.Dirichlet(torch.ones(num_of_clones)))
        # draw clone_assign from Cat
        clone_assign = pyro.sample('clone_assign', dist.Categorical(clone_assign_prob))
        

        expected_expr = per_copy_expr * (
                    Vindex(cnv)[clone_assign] * gene_type_score[:, 0] + copy_number_mean * gene_type_score[:, 1]) * torch.exp(
            torch.matmul(psi, torch.transpose(w, 0, 1)))

        # draw expr from Multinomial
        expr_simulated = pyro.sample('obs', dist.Multinomial(total_count = 2000, probs=expected_expr, validate_args=False))
    
    return expr_simulated, gene_type_score, clone_assign_prob, clone_assign

In [163]:
# simulate expr and gene_type_score data

num_of_datasets = 1

for i in range(num_of_datasets):
    expr_simulated, gene_type_score_simuated, clone_assign_prob_simulated, clone_assign_simulated = clonealign_pyro_simulation(cnv, expr, per_copy_expr, psi, chi, w)
    
    expr_simulated = torch.transpose(expr_simulated, 0, 1)
    
    expr_simulated_dataframe = pd.DataFrame(expr_simulated.data.numpy())
    gene_type_score_simuated_dataframe = pd.DataFrame(gene_type_score_simuated.data.numpy())
    clone_assign_prob_simulated_dataframe = pd.DataFrame(clone_assign_prob_simulated.data.numpy())
    clone_assign_simulated_dataframe = pd.DataFrame(clone_assign_simulated.data.numpy())
    
    # rename
    cell_name = {i:c for i, c in enumerate(expr_csv.columns)}
    gene_name = {i:c for i, c in enumerate(expr_csv.index)}
    clone_name = {i:c for i, c in enumerate(cnv_csv.columns)}
    

    expr_simulated_dataframe.rename(index = gene_name, inplace = True)
    expr_simulated_dataframe.rename(columns = cell_name, inplace = True)
    
    gene_type_score_simuated_dataframe.rename(index = gene_name, inplace = True)
    clone_assign_simulated_dataframe.rename(index = cell_name, inplace = True)
    
    clone_assign_prob_simulated_dataframe.rename(index = cell_name, inplace = True)
    clone_assign_prob_simulated_dataframe.rename(index = clone_name, inplace = True)
    
    expr_simulated_dataframe.to_csv("data/expr_simulated_" + str(i) + ".csv")
    gene_type_score_simuated_dataframe.to_csv("data/gene_type_score_simulated_" + str(i) + ".csv")
    clone_assign_prob_simulated_dataframe.to_csv("data/clone_assign_prob_simulated_" + str(i) + ".csv")
    
    clone_assign_simulated_dataframe.to_csv("data/clone_assign_simulated_" + str(i) + ".csv")



In [2]:
import torch


In [3]:
a = torch.ones(10)
a = torch.reshape(a, (10, 1))
b = torch.zeros(10)
b = torch.reshape(b, (10, 1))

In [5]:
b

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]])

In [6]:
c = torch.cat((a, b), 1)

In [11]:
import random
random_indices = random.sample(range(10), 5)
random_indices

[4, 7, 5, 8, 0]

In [13]:
a = torch.ones(10)
a
a[[1, 2, 3, 4]]

tensor([1., 1., 1., 1.])

In [12]:
c[:, [1, 1, 0]]

tensor([[0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.]])

In [3]:
expr_csv = pd.read_csv('data/SPECTRUM-OV-022_expr_clonealign_input.csv', header = 0, index_col=0)

In [6]:
expr_csv.index[[1, 4, 7]]

Index(['HCRTR1', 'SPOCD1', 'KHDRBS1'], dtype='object')