In [1]:
import jax 
import jax.numpy as jnp
from jax import random
from typing import NamedTuple
from tqdm import tqdm
import matplotlib.pyplot as plt
guassian_pdf = jax.scipy.stats.multivariate_normal.pdf 
from utils import outer_clients, generate_data,compute_likelyhood, get_prob, outer

## Utils

In [16]:
def theta_client_run(data,sigma, means):
    def _weighted_means(data,weights):
        return jnp.average(data,axis=0,weights=weights)
    
    weighted_means = jax.vmap(_weighted_means, (None,0),0)
    prob =  get_prob(data, means, sigma) #(L,n)
    norm_prob = jnp.expand_dims((1/jnp.sum(prob, axis=1)),axis=1)*prob 
    print(data.shape, norm_prob.shape)
    mu = weighted_means(data,norm_prob.T)
    #norm_prob -> (l,N)
    mixture_weights  = jnp.sum(norm_prob,axis=1)
    mixture_weights  = (1/jnp.sum(mixture_weights))*mixture_weights
    sigma = jnp.mean(outer(data,data),axis=0)-jnp.average(
        outer(mu,mu),axis=0,weights=mixture_weights)
    return sigma, means, mixture_weights

## Main

In [17]:
class Hyper(NamedTuple):
    #Hyperparameters with default settings..
    p = 0.95 # partipation rate
    L = 2 # length
    b = 100 #batch size
    N = 10_000 #total number of examples
    n = 1000 # number of agents 
    d = 3 # dimension of the experiment 
    gamma = 10e-2 #hyper parameter
    alpha = 10e-2 # the total things
    epochs = 100
    repeats = 1
    seed = 42

In [18]:
def main():
    hyper = Hyper()
    key = random.PRNGKey(hyper.seed)
    # Generate the data
    key, t_means,t_sigma,t_mixture_weights, clients_data =  generate_data(key, hyper)
    yTy = outer_clients(clients_data,clients_data)
    yTy = jnp.mean(yTy,axis=(0,1))
    #Initialize the variables
    key, means, sigma, mixture_weights, _ =  generate_data(key, hyper)
    optimal_value = compute_likelyhood(clients_data,t_mixture_weights,t_means,t_sigma)
    optimal_gap = []
    for i in tqdm(range(hyper.epochs)):
        key, *subkey = random.split(key,num=hyper.n+2)
        subkey = (_ for _ in subkey)  
        indicator = random.bernoulli(next(subkey),hyper.p,shape=(hyper.n,))
        participate = jnp.nonzero(indicator)[0]
        clients_sigma, clients_means, clients_mixture = [],[],[]
        for client in list(participate):
            batch_data = clients_data[client][random.permutation(next(subkey),20)[:hyper.b]]        
            c_sigma,c_means, c_mixture = theta_client_run(batch_data, sigma, means)
            clients_sigma.append(c_sigma)
            clients_means.append(c_means)
            clients_mixture.append(c_mixture)
        #central updates:
        means = (1/len(clients_means))*sum(clients_means)
        mixture_weights = (1/len(clients_mixture))*sum(clients_mixture)
        sigma = (1/len(clients_sigma))*sum(clients_sigma)
        optimal_gap.append(jnp.abs(optimal_value -compute_likelyhood(clients_data,mixture_weights,means,sigma)))
    print("optimality gap")
    print(optimal_gap)
main()

  0%|          | 0/100 [00:00<?, ?it/s]

(20, 3) (20, 2)





ValueError: Length of weights not compatible with specified axis.