In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from tqdm import tqdm as tqdm

import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO, Predictive, NUTS, MCMC
from pyro.infer.autoguide import AutoDelta, AutoNormal
from pyro.ops.indexing import Vindex
from pyro.optim import Adam

## Various copy number sequence

In [None]:
rng = np.random.RandomState(42)

In [None]:
args = {}
args['hidden_dim'] = 2
args['n'] = 20
args['x_states'] = np.array([1,2])

In [None]:
probs_x = [0.99, 0.01]

len_seq = 5000 
x_seq = []
y_seq = []
x_states = args['x_states'] 
x_state = 1
for i in range(len_seq):
    draw = rng.choice(np.arange(args['hidden_dim']), p=probs_x)
    if draw == 0:
        x_state = x_states[x_states == x_state]
    elif draw == 1:
        x_state = x_states[x_states != x_state]
        
    x_seq.append(x_state)
    y_seq.append(rng.binomial(args['n'], x_state/(x_state+1)))
    


In [None]:
plt.figure(figsize=(12,2))
plt.plot(np.array(x_seq)/(np.array(x_seq) + 1) * args['n'], color='red')
plt.scatter(np.arange(len_seq), np.array(y_seq), s=2)
#plt.ylim([0,1])

## Pytorch implementation

In [None]:
def log_domain_matmul(log_A, log_B):
    """
    log_A : m x n
    log_B : n x p
    output : m x p matrix

    Normally, a matrix multiplication
    computes out_{i,j} = sum_k A_{i,k} x B_{k,j}

    A log domain matrix multiplication
    computes out_{i,j} = logsumexp_k log_A_{i,k} + log_B_{k,j}
    """
    m = log_A.shape[0]
    n = log_A.shape[1]
    p = log_B.shape[1]

    log_A_expanded = torch.reshape(log_A, (m,n,1))
    log_B_expanded = torch.reshape(log_B, (1,n,p))

    elementwise_sum = log_A_expanded + log_B_expanded
    out = torch.logsumexp(elementwise_sum, dim=1)
    return out

def maxmul(log_A, log_B):
    """
    log_A : m x n
    log_B : n x p
    output : m x p matrix; maxval, argmaxval

    Similar to the log domain matrix multiplication,
    this computes out_{i,j} = max_k log_A_{i,k} + log_B_{k,j}
    """
    m = log_A.shape[0]
    n = log_A.shape[1]
    p = log_B.shape[1]

    log_A_expanded = torch.stack([log_A] * p, dim=2)
    log_B_expanded = torch.stack([log_B] * m, dim=0)

    elementwise_sum = log_A_expanded + log_B_expanded
    out1,out2 = torch.max(elementwise_sum, dim=1)

    return out1,out2

In [None]:
# batch, state, 

class TransitionModel(torch.nn.Module):
    def __init__(self, N_states):
        super(TransitionModel, self).__init__()
        self.N_states = N_states
        self.unnormalized_transition_matrix = torch.nn.Parameter(torch.randn(N_states,N_states))
    def forward(self, log_alpha):
        log_transition_matrix = torch.nn.functional.log_softmax(self.unnormalized_transition_matrix, dim=0)

        # Matrix multiplication in the log domain
        out = log_domain_matmul(log_transition_matrix, log_alpha.transpose(0,1)).transpose(0,1)
        return out
    def maxmul(self, log_alpha):
        log_transition_matrix = torch.nn.functional.log_softmax(self.unnormalized_transition_matrix, dim=0)

        out1, out2 = maxmul(log_transition_matrix, log_alpha.transpose(0,1))
        return out1.transpose(0,1), out2.transpose(0,1)

class EmissionModel(torch.nn.Module):
    def __init__(self, N_states, x_states, n):
        super(EmissionModel, self).__init__()
        self.N_states = N_states
        self.x_states = x_states
        self.n = n
    def forward(self, x_t):
        '''
        out: batch x state
        '''
        batch_size = x_t.shape[0]
        xs = torch.tensor(self.x_states.reshape((1, self.N_states)))
        logprob = torch.distributions.Binomial(self.n, xs / (1 + xs)).log_prob(x_t.reshape((batch_size, 1)))
        return logprob
        #in log domain

class HMM(torch.nn.Module):
    def __init__(self, args):
        super(HMM, self).__init__()
        self.N_states = args['hidden_dim']
        self.x_states = args['x_states']
        self.n = args['n']

        # A
        self.transition_model = TransitionModel(self.N_states)

        # b(x_t)
        self.emission_model = EmissionModel(self.N_states, self.x_states, self.n)

        # pi
        self.unnormalized_state_priors = torch.nn.Parameter(torch.randn(self.N_states))

        # use the GPU
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda: self.cuda()
            
    def sample(self, T=10):
        state_priors = torch.nn.functional.softmax(self.unnormalized_state_priors, dim=0)
        transition_matrix = torch.nn.functional.softmax(self.transition_model.unnormalized_transition_matrix, dim=0)
        #emission_matrix = torch.nn.functional.softmax(self.emission_model.unnormalized_emission_matrix, dim=1)

        # sample initial state
        z_t = torch.distributions.categorical.Categorical(state_priors).sample().item()
        z = []; x = []
        z.append(z_t)
        for t in range(0,T):
            # sample emission
            x_t = torch.distributions.Binomial(self.n, self.x_states[z_t] / (1 + self.x_states[z_t])).sample().item()
            x.append(x_t)

            # sample transition
            z_t = torch.distributions.categorical.Categorical(transition_matrix[:,z_t]).sample().item()
            if t < T-1: z.append(z_t)

        return x, z
    
    def forward(self, x, T):
        """
        x : IntTensor of shape (batch size, T_max)
        T : IntTensor of shape (batch size)

        Compute log p(x) for each example in the batch.
        T = length of each example
        """
        if self.is_cuda:
            x = x.cuda()
            T = T.cuda()

        batch_size = x.shape[0]; T_max = x.shape[1]
        log_state_priors = torch.nn.functional.log_softmax(self.unnormalized_state_priors, dim=0)
        log_alpha = torch.zeros(batch_size, T_max, self.N_states)
        if self.is_cuda: log_alpha = log_alpha.cuda()

        log_alpha[:, 0, :] = self.emission_model(x[:,0]) + log_state_priors
        for t in range(1, T_max):
            log_alpha[:, t, :] = self.emission_model(x[:,t]) + self.transition_model(log_alpha[:, t-1, :])

        # Select the sum for the final timestep (each x may have different length).
        log_sums = log_alpha.logsumexp(dim=2)
        log_probs = torch.gather(log_sums, 1, T.view(-1,1) - 1)
        return log_probs
    
    def viterbi(self, x, T):
        """
        x : IntTensor of shape (batch size, T_max)
        T : IntTensor of shape (batch size)
        Find argmax_z log p(x|z) for each (x) in the batch.
        """
        if self.is_cuda:
            x = x.cuda()
            T = T.cuda()

        batch_size = x.shape[0]; T_max = x.shape[1]
        log_state_priors = torch.nn.functional.log_softmax(self.unnormalized_state_priors, dim=0)
        log_delta = torch.zeros(batch_size, T_max, self.N_states).float()
        psi = torch.zeros(batch_size, T_max, self.N_states).long()
        if self.is_cuda:
            log_delta = log_delta.cuda()
            psi = psi.cuda()

        log_delta[:, 0, :] = self.emission_model(x[:,0]) + log_state_priors
        for t in range(1, T_max):
            max_val, argmax_val = self.transition_model.maxmul(log_delta[:, t-1, :])
            log_delta[:, t, :] = self.emission_model(x[:,t]) + max_val
            psi[:, t, :] = argmax_val

        # Get the log probability of the best path
        log_max = log_delta.max(dim=2)[0]
        best_path_scores = torch.gather(log_max, 1, T.view(-1,1) - 1)

        # This next part is a bit tricky to parallelize across the batch,
        # so we will do it separately for each example.
        z_star = []
        for i in range(0, batch_size):
            z_star_i = [ log_delta[i, T[i] - 1, :].max(dim=0)[1].item() ]
        for t in range(T[i] - 1, 0, -1):
            z_t = psi[i, t, z_star_i[0]].item()
            z_star_i.insert(0, z_t)

        z_star.append(z_star_i)
        return z_star, best_path_scores # return both the best path and its log probability
    
    def EM(self):
        raise NotImplementedError

        
    def amaeba_optimisation_interface(self, transition_matrix, x, T):
        mat = transition_matrix.reshape((2,2)).astype('float32')
        self.transition_model.unnormalized_transition_matrix = torch.nn.Parameter(torch.tensor(mat))
        loss = self.forward(x, T).cpu().data.numpy()
        print(loss[0])
        return -loss[0]

In [None]:
from scipy.optimize import minimize

model = HMM(args)
for p in model.parameters():
    p.requires_grad = False

dataset = (torch.tensor(y_seq).T, torch.tensor([len(y_seq)]))

output = minimize(model.amaeba_optimisation_interface,
         np.array([0.12,-0.51,0.2,-0.2]),
         args=tuple(dataset),
         method='Nelder-Mead',
         options={'maxiter': 50})

In [None]:
torch.nn.functional.softmax(torch.tensor(output.x.reshape((2,2))), dim=0)

In [None]:
from tqdm import tqdm 

class Trainer:
    def __init__(self, model, lr):
        self.model = model
        self.lr = lr
        self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=0.00001)

    def train(self, dataset, n_steps = 1000):
        train_loss = []
        num_samples = 0
        self.model.train()
        #print_interval = 50
        x,T = dataset
        for i in tqdm(range(n_steps)):
            #batch_size = len(x)
            #num_samples += batch_size
            log_probs = self.model(x,T)
            loss = -log_probs.mean()
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            train_loss.append(loss.cpu().data.numpy().item())
            #if idx % print_interval == 0:
            #    print("loss:", loss.item())
            #    for _ in range(5):
            #        sampled_x, sampled_z = self.model.sample()
            #        print(decode(sampled_x))
            #        print(sampled_z)
        return train_loss

In [None]:
model = HMM(args)
dataset = (torch.tensor(y_seq).T, torch.tensor([len(y_seq)]))

trainer = Trainer(model, lr=0.005)

In [None]:
loss_list = trainer.train(dataset, n_steps=1500)

In [None]:
plt.plot(loss_list)

In [None]:
random_states, logprob = model.viterbi(*dataset)

In [None]:
pred = np.array(random_states)[0]

In [None]:
plt.figure(figsize=(12,2))
plt.scatter(np.arange(len_seq), np.array(pred + 1)/(np.array(pred + 1) + 1) * args['n'], color='orange')
plt.plot(np.array(x_seq)/(np.array(x_seq) + 1) * args['n'], color='red')
plt.scatter(np.arange(len_seq), np.array(y_seq), s=2)
#plt.ylim([0,1])

In [None]:
torch.nn.functional.softmax(model.transition_model.unnormalized_transition_matrix, dim=0).cpu().data.numpy()

### Pyro implementation
(reversed)

In [None]:
plt.figure(figsize=(12,2))
#plt.plot(np.array(x_seq)/(np.array(x_seq) + 1) * args['n'], color='red')
plt.scatter(np.arange(len_seq), np.array(pred + 1)/(np.array(pred + 1) + 1) * args['n'], color='orange')
plt.scatter(np.arange(len_seq), np.array(y_seq), s=2)
#plt.ylim([0,1])

In [None]:
pred = np.array([inf_trace[f'x_{i}']['value'].numpy() for i in range(1000)])

In [None]:
guide_trace = poutine.trace(guide).get_trace(data, length, args)  # record the globals
trained_model = poutine.replay(model, trace=guide_trace)  # replay the globals

def classifier(data, temperature=0):
    inferred_model = infer_discrete(trained_model, temperature=temperature,
                                    first_available_dim=-2)  # avoid conflict with data plate
    trace = poutine.trace(inferred_model).get_trace(data, length, args)
    return trace.nodes

inf_trace = classifier(data)

In [None]:
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete


In [None]:
trace = Predictive(model, guide=guide, num_samples=20)(data, length, args)


In [None]:
trace['probs_x'].mean(0)

In [None]:
plt.plot(loss_list)

In [None]:
data = torch.tensor(np.array(y_seq).astype('float32')).T
length = data.shape[-1]

n_steps = 1000
loss_list = []
for step in tqdm(range(n_steps)):
    loss = svi.step(data, length, args=args)
    loss_list.append(loss)

In [None]:
pyro.set_rng_seed(42)
pyro.clear_param_store()
guide = AutoNormal(
    poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))
)
optim = Adam({"lr": 0.005})
elbo = TraceEnum_ELBO(
            max_plate_nesting=1,
            )

svi = SVI(model, guide, optim, elbo)

In [None]:
def model(sequence, length, args, batch_size=None, include_prior=True):
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args['hidden_dim']) + 0.1).to_event(1),
        )
          
    # In this first model we'll sequentially iterate over sequences in a
    # minibatch; this will make it easy to reason about tensor shapes.
    majAll_plate = pyro.plate("majAll", length, dim=-1)
    
    x = 0
    for t in pyro.markov(range(length)):
        # On the next line, we'll overwrite the value of x with an updated
        # value. If we wanted to record all x values, we could instead
        # write x[t] = pyro.sample(...x[t-1]...).
        x = pyro.sample(
            "x_{}".format(t),
            dist.Categorical(probs_x[x]),
            infer={"enumerate": "parallel"},
        )
        with majAll_plate:
            p = Vindex(torch.tensor(args['x_states']))[x]
            y = pyro.sample(
                    "y_{}".format(t),
                    dist.Binomial(torch.tensor(np.array(args['n']).astype('float32')),
                                  p/(p + 1)),
                    obs=sequence[:,t],
                )
