**TODO**

In [1]:
import jax 
import jax.numpy as jnp
from jax import random, scipy
from typing import NamedTuple
import matplotlib.pyplot as plt
guassian_pdf = jax.scipy.stats.multivariate_normal.pdf 

## Utils

In [None]:
def get_psd(sigma):
    M, v = jnp.linalg.eig(sigma)
    v_inv = jnp.linalg.inv(v) 
    E = jnp.diag(jnp.clip(M,0,jnp.inf))+0.0001
    return (v@E@v_inv).astype(float)
    
def _outer(a,b):
    return jnp.outer(a,b)
outer = jax.vmap(_outer, (0,0),0)
outer_clients = jax.vmap(outer,(0,0),0)


def _get_prob(a, mean, sigma): #should return the shape (L,n)
    return guassian_pdf(a,mean,sigma)
_get_prob_1 = jax.vmap(_get_prob, (0,None,None),0)
get_prob = jax.vmap(_get_prob_1,(None,0,None),1)    


def client_run(data, means,sigma, SS_1, SS_2, vss_1, vss_2):
    def _get_ss2(prob_vec,data):
        #probvec: (n), data: (n,d)
        prob_vec = jnp.expand_dims(prob_vec, axis=1) # (n,1)
        return jnp.mean(prob_vec*data ,axis=0) #(n,d) -> (d)
    get_ss2 = jax.vmap(_get_ss2, (0,None),0) #(L,d)
    
    prob = get_prob(data,means,sigma) #(n,L)
    sum_prob = jnp.sum(prob, axis=1)  #(n)
    sum_prob = jnp.expand_dims(sum_prob,1) + 0.000001#(n,1)
    raw_ss_1 = prob/sum_prob #(n,L)
    ss_1 = jnp.mean(raw_ss_1, axis=0) #(L)
    ss_2 = get_ss2(raw_ss_1.T, data) #(L,d)
    delta1 = ss_1-SS_1-vss_1
    delta2 = ss_2-SS_2-vss_2
    return delta1, delta2

def generate_data(key, hyper:Hyper):
    #beaware of for loop here
    key, *subkey = random.split(key, num=12)
    subkey = (i for i in subkey)
    means = random.normal(next(subkey), shape=(hyper.L,hyper.d))
    sigma = jnp.cov(random.normal(next(subkey), shape=(hyper.d,hyper.d))@random.normal(next(subkey),shape=(hyper.d,100_000)))
    source_data = [random.multivariate_normal(next(subkey), means[i], sigma,shape=(hyper.N,)) for i in range(hyper.L)]
    logits  = random.dirichlet(next(subkey),1.5* jnp.array([1.0 for _ in range(hyper.L)]))
    mixture_weights = jax.nn.softmax(logits)
    cuts = random.categorical(next(subkey), logits, shape=(hyper.N,))
    raw_data = jnp.vstack([dt[cuts==idx] for idx,dt in enumerate(source_data)])
    clients_data = jnp.array(jnp.split(random.permutation(next(subkey),raw_data), hyper.n))
    return key,means,sigma,mixture_weights,clients_data 


def compute_likelyhood(data, mixture, means,sigma ):
    #Need to use the prob_function
    likelyhood = get_prob(data, means, sigma)
    likelyhood = jnp.expand_dims(mixture,axis=1) *likelyhood 
    return jnp.mean(likelyhood)

## Main

In [None]:
def main():
    hyper = Hyper()
    key = random.PRNGKey(hyper.seed)
    # Generate the data
    key, t_means,t_sigma,t_mixture_weights, clients_data =  generate_data(key, hyper)
    yTy = outer_clients(clients_data,clients_data)
    yTy = jnp.mean(yTy,axis=(0,1))
    #Initialize the variables
    key, means, sigma, mixture_weights, _ =  generate_data(key, hyper)
    SS_1 , SS_2 = (1/hyper.L)*jnp.ones(shape=(hyper.L,)), (1/hyper.d)*jnp.zeros(shape=(hyper.L, hyper.d))
    V1 , V2 =  jnp.ones(shape=(hyper.L,)), jnp.ones(shape=(hyper.L, hyper.d))
    H1 , H2 =  jnp.ones(shape=(hyper.L,)), jnp.ones(shape=(hyper.L, hyper.d))
    #Variables list
    Vss1 = [1/hyper.n * jnp.zeros(shape=(hyper.L,)) for i in range(hyper.n)] #(L)
    Vss2 = [1/hyper.n * jnp.zeros(shape=(hyper.L,hyper.d)) for i in range(hyper.n)] #(L,d)

    for i in range(hyper.epochs):
        key, *subkey = random.split(key,num=hyper.n+2)
        subkey = (_ for _ in subkey)  
        indicator = random.bernoulli(next(subkey),hyper.p,shape=(hyper.n,))
        participate = jnp.nonzero(indicator)[0]
        delta1, delta2 = [],[]
        for client in list(participate):
            batch_data = clients_data[client][random.permutation(next(subkey),20)[:hyper.b]]        
            d1, d2 = client_run(batch_data, means, sigma, SS_1, SS_2, Vss1[client],Vss2[client])
            delta1.append(d1)
            delta2.append(d2)
            #client updates
            Vss1[client] = Vss1[client] + hyper.alpha*d1
            Vss2[client] = Vss2[client] + hyper.alpha*d2      
        #central updates:
        d1 = (1/len(delta1))*sum(delta1)
        d2 = (1/len(delta2))*sum(delta2)        
        H1 = V1 + d1
        H2 = V2 + d2
        SS_1 = SS_1+hyper.gamma*H1
        SS_2 = SS_2+hyper.gamma*H2
        V1 = V1+hyper.alpha*d1
        V2 = V2+hyper.alpha*d2

        #calculate the new variables 
        mixture_weights = (1/jnp.sum(SS_1))*SS_1
        means = jnp.expand_dims((1/SS_1),axis=1)*SS_2
        sigma = yTy - jnp.sum(jnp.expand_dims(SS_1, axis=(1,2))*outer(means, means), axis=0)
        sigma = get_psd(sigma)
        sigma = 0.5*(sigma.T+sigma)
        print(compute_likelyhood(clients_data,mixture_weights,means,sigma))

main()

0.012980999
0.011927749
0.009679693
0.009350866
0.0086539285
0.0074894493
0.007350084
0.0065777563


KeyboardInterrupt: 

In [187]:
mixture_weights = jnp.array([0.52507305,0.47492698])
means  = jnp.array([[0.26824555, 0.59139985, 0.38074902] ,[0.3420444  ,0.5129838 ,0.323579 ]])
sigma = [[ 0.07922603, -0.33615774, -0.06213453] ,[-0.33615774 , 2.123206   ,0.763678 ],[-0.06213453 ,0.763678   ,2.0263853 ]]
sigma = jnp.array(sigma)
key = random.PRNGKey(10)
data = random.normal(key,shape=(20,3))

def _get_ss2(prob_vec,data):
    #probvec: (n), data: (n,d)
    prob_vec = jnp.expand_dims(prob_vec, axis=1) # (n,1)
    return jnp.mean(prob_vec*data ,axis=0) #(n,d) -> (d)
get_ss2 = jax.vmap(_get_ss2, (0,None),0) #(L,d)
prob = get_prob(data,means,sigma) #(n,L)


In [188]:
# mixture_weights = jnp.array([0.52507305,0.47492698])
# means  = jnp.array([[0.26824555, 0.59139985, 0.38074902] ,[0.3420444  ,0.5129838 ,0.323579 ]])
# sigma = [[ 0.07922603, -0.33615774, -0.06213453] ,[-0.33615774 , 2.123206   ,0.763678 ],[-0.06213453 ,0.763678   ,2.0263853 ]]
# sigma = jnp.array(sigma)
# key = random.PRNGKey(10)
# data = random.normal(key,shape=(20,3))

# def _get_ss2(prob_vec,data):
#     #probvec: (n), data: (n,d)
#     prob_vec = jnp.expand_dims(prob_vec, axis=1) # (n,1)
#     return jnp.mean(prob_vec*data ,axis=0) #(n,d) -> (d)
# get_ss2 = jax.vmap(_get_ss2, (0,None),0) #(L,d)
# prob = get_prob(data,means,sigma) #(n,L)
# print(prob)
# sum_prob = jnp.sum(prob, axis=1)  #(n)
# sum_prob = jnp.expand_dims(sum_prob,1) #(n,1)

[[2.1549885e-20 4.5955199e-22]
 [7.9120917e-04 2.1809743e-04]
 [0.0000000e+00 0.0000000e+00]
 [5.5985961e-06 8.0959131e-07]
 [8.0906075e-06 4.5888020e-05]
 [1.3341231e-02 3.1152757e-02]
 [1.7404336e-01 1.7396888e-01]
 [4.8305420e-03 1.2926011e-02]
 [1.5241438e-01 1.5923110e-01]
 [6.0815236e-04 2.1850737e-03]
 [6.9629946e-10 5.9858958e-11]
 [8.7312621e-08 1.0458122e-08]
 [2.5043380e-33 1.8214096e-35]
 [1.1300329e-28 9.8329362e-27]
 [0.0000000e+00 0.0000000e+00]
 [7.6777539e-21 1.6249882e-22]
 [1.3452614e-31 1.4565859e-29]
 [2.9830981e-11 4.0415721e-10]
 [3.2318218e-03 9.5317699e-04]
 [6.0610386e-20 1.4087968e-21]]
