In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import scipy.stats as stats
from scipy.optimize import minimize

import autograd.numpy as auto_np
from autograd import hessian

from utils import is_positive_definite, flat_sigma_to_sigma_matrix, sigma_matrix_to_flat_sigma
from init import init


def likelyhood_ys(w, x, s, z, gamma, beta_0, beta_1, sigma_matrix): # TODO : correct bugs
    
    prod = 1
    for i in range(len(z)):
        treatment = s[i]
        if treatment == 0:
            fn = stats.norm.pdf(z[0][i],  loc=np.dot(x[i], beta_0), 
                                scale=sigma_matrix[1][1]**2)
            normal_cdf = stats.norm.cdf((np.dot(-w[i], gamma) - (sigma_matrix[0][1]/sigma_matrix[1][1])
                                         * (z[0][i] - np.dot(x[i], beta_0)))
                                         / (1-sigma_matrix[0][1]**2/sigma_matrix[1][1])**0.5)
        else:
            fn = stats.norm.pdf(z[1][i],  loc=np.dot(x[i], beta_1), 
                                scale=sigma_matrix[2][2]**2)
            normal_cdf = stats.norm.cdf((np.dot(-w[i], gamma) - (sigma_matrix[0][2]/sigma_matrix[2][2])
                                         * (z[1][i] - np.dot(x[i], beta_1)))
                                         / (1-sigma_matrix[0][2]**2/sigma_matrix[2][2])**0.5)
        prod *= fn * normal_cdf

    return prod

def g_sigma(w, s, x, z, beta_0, beta_1, gamma, sigma_matrix, g_0, G_0):

     # likelihood of ys given sigma
    first_term = likelyhood_ys(w, x, s, z, gamma, beta_0, beta_1, sigma_matrix)

    # density of sigma given g_0, G_0
    flat_sigma = sigma_matrix_to_flat_sigma(sigma_matrix)
    second_term = stats.multivariate_normal.pdf(flat_sigma, g_0, G_0)

    third_term = is_positive_definite(sigma_matrix)

    result = first_term * second_term * third_term
    return result

def q_sigma(w, s, x, z, beta_0, beta_1, gamma, sigma_matrix, g_0, G_0, sample=True):

    def log_g_sigma(sigma_matrix):
        return np.log(g_sigma(w, s, x, z, beta_0, beta_1, gamma, sigma_matrix, g_0, G_0))
    H = hessian(log_g_sigma)(sigma_matrix)
    V = np.linalg.inv(-H)

    # Compute mu the mode of sub_log_g_sigma
    result = minimize(-log_g_sigma, x0=sigma_matrix)
    mode = result.x[0]

    if sample:
        # sample from multivariate t density
        return stats.multivariate_t.rvs(loc=mode, shape=V)
    else:
        # return likelyhood of sigma given mode and V
        return stats.multivariate_t.pdf(sigma_matrix, mode, V)
        

def sample_sigma(old_sigma_matrix, s, w, x, z, beta_0, beta_1, gamma, g_0, G_0):
    old_likelihood = (g_sigma(w, s, x, z, beta_0, beta_1, gamma, old_sigma_matrix, g_0, G_0)*
                      q_sigma(w, s, x, z, beta_0, beta_1, gamma, old_sigma_matrix, g_0, G_0, sample=False))

    new_sigma_matrix = q_sigma(w, x, z, beta_0, beta_1, gamma, sigma_matrix, g_0, G_0, sample=True)

    new_likelihood = (g_sigma(w, x, z, beta_0, beta_1, gamma, new_sigma_matrix, g_0, G_0)
                      *q_sigma(w, x, z, beta_0, beta_1, gamma, new_sigma_matrix, g_0, G_0, sample=False))
    acceptance_ratio = new_likelihood / old_likelihood
    if acceptance_ratio > np.random.uniform():
        return new_sigma_matrix
    else:
        return old_sigma_matrix

    
def sample_zi_star(w, x, s, z, gamma, beta_0, beta_1, sigma):
    '''zi_star = si_star, z_i[1 - si]
    '''
    si_star = np.zeros(len(s))
    for i in range(len(s)):
        treatment = s[i]
        if treatment == 0:
            si_star[i] = stats.truncnorm.rvs(a=-np.inf, b=0,
                                             oc=np.dot(w[i], gamma), scale=1)
            z[1][i] = stats.norm.rvs(loc=np.dot(x[i], beta_1),
                                     scale=sigma[2][2]**2) #TODO : où est si*
        else:
            si_star[i] = stats.truncnorm.rvs(a=0, b=np.inf,
                                             loc=np.dot(w[i], gamma), scale=1)
            z[0][i] = stats.norm.rvs(loc=np.dot(x[i], beta_0),
                                     scale=sigma[1][1]**2) #TODO : où est si*
    return si_star, z

def sample_beta(w, x, s, z, gamma, beta_0,
                beta_1, sigma, B_0, beta_0_init, s_star):
    new_z = np.zeros(len(s), 3)
    new_z[:, 0], new_z[:,1], new_z[:, 2] = s_star, z[0], z[1]
    big_X = np.zeros(len(s), 3)
    big_X[:, 0], big_X[:, 1], big_X[:, 2] = s, x, x

    sigma_inv = np.linalg.inv(sigma)
    B_0_inv = np.linalg.inv(B_0)
    B = np.linalg.inv(B_0_inv + np.sum([np.dot(np.dot(big_X[i], sigma_inv), big_X[i]) for i in range(len(big_X))]))
    beta_hat = np.dot(B, np.dot(B_0_inv, beta_0_init)
                         + np.sum([np.dot(np.dot(big_X[i], sigma_inv), new_z[i]) for i in range(len(big_X))]))

    new_beta = stats.multivariate_normal.rvs(loc=beta_hat, shape=B)
    return new_beta

def super_update(w, x, s, z, gamma,  sigma_matrix, beta_0, beta_1, g_0, G_0, beta_0_init, B_0):
    new_sigma = sample_sigma(sigma_matrix, s, w, x, z, beta_0, beta_1, gamma, g_0, G_0)
    s_star, z = sample_zi_star(w, x, s, z, gamma, beta_0, beta_1, new_sigma)
    new_beta = sample_beta(w, x, s, z, gamma, beta_0, beta_1, new_sigma, B_0, beta_0_init, s_star)
    return new_sigma, new_beta

def main_gaussian(w, x, s, z, gamma, beta_0, beta_1, sigma): # TODO : correct arguments
    for i in range(10_000):
        sigma, beta = super_update(w, x, s, z, gamma, beta_0, beta_1, sigma)
    return beta, sigma

In [2]:
s, w, x, y, z, sigma_matrix, beta, g_0, G_0, beta_0_init, B_0 = init()
gamma, beta_0, beta_1 = beta[:5], beta[5:9], beta[9:]

In [3]:
z[1][0]
stats.norm.pdf(z[1][0],  loc=np.dot(x[0], beta_1))

0.0

In [8]:
def log_g_sigma(sigma_matrix):
        return np.log(g_sigma(w, s, x, z, beta_0, beta_1, gamma, sigma_matrix, g_0, G_0))
H = hessian(log_g_sigma)(sigma_matrix)

In [None]:
H

In [16]:
super_update(w, x, s, z, gamma,  sigma_matrix, beta_0, beta_1, g_0, G_0, beta_0_init, B_0)

ValueError: setting an array element with a sequence.

## Mixture

In [5]:
m=2
p=[0.5, 0.5]
v = np.random.choice([0,1], size=len(s), p=p)
# TODO : Créer les x y w z pour chaque groupe
# TODO : appliquer la fonction super_update sur chaque groupe
# TODO : update les v page 7 du papier 10_000 fois