In [12]:
import jax 
import jax.numpy as jnp
from jax import random
from typing import NamedTuple
from tqdm import tqdm
import wandb
guassian_pdf = jax.scipy.stats.multivariate_normal.pdf 
from em_utils import get_prob, generate_data, compute_likelyhood, outer_clients, nearestPD

## Utils

In [13]:

@jax.jit
def client_run(data, means,sigma, SS_1, SS_2, vss_1, vss_2):
    def _get_ss2(prob_vec,data):
        #probvec: (n), data: (n,d)
        prob_vec = jnp.expand_dims(prob_vec, axis=1) # (n,1)
        return jnp.mean(prob_vec*data ,axis=0) #(n,d) -> (d)
    get_ss2 = jax.vmap(_get_ss2, (0,None),0) #(L,d)
    prob = get_prob(data,means,sigma) #(n,L)
    sum_prob = jnp.sum(prob, axis=1)  #(n)
    sum_prob = jnp.expand_dims(sum_prob,1) + 0.000001#(n,1)
    raw_ss_1 = prob/sum_prob #(n,L)
    ss_1 = jnp.mean(raw_ss_1, axis=0) #(L)
    ss_2 = get_ss2(raw_ss_1.T, data) #(L,d)
    delta1 = ss_1-SS_1-vss_1
    delta2 = ss_2-SS_2-vss_2
    return delta1, delta2


parallel_client_run = jax.vmap(client_run, (0,None,None,None,None,0,0),0)

## Main

In [14]:
def main(hyper, wandb_log=False):
    
    key = random.PRNGKey(hyper.seed)
    # Generate the data
    key, t_means,t_sigma,t_mixture_weights, clients_data =  generate_data(key, hyper)
    yTy = outer_clients(clients_data,clients_data)
    yTy = jnp.mean(yTy,axis=(0,1))
    #Initialize the variables
    key, means, sigma, mixture_weights, _ =  generate_data(key, hyper)
    SS_1 , SS_2 = jnp.zeros(shape=(hyper.L,)), jnp.zeros(shape=(hyper.L, hyper.d))
    V1 , V2 =  jnp.zeros(shape=(hyper.L,)), jnp.zeros(shape=(hyper.L, hyper.d))
    H1 , H2 =  jnp.zeros(shape=(hyper.L,)), jnp.zeros(shape=(hyper.L, hyper.d))
    #Variables list
    Vss1 = jnp.zeros(shape=(hyper.n, hyper.L))
    Vss2 = jnp.zeros(shape=(hyper.n,hyper.L, hyper.d))
    optimal_value = compute_likelyhood(clients_data,t_mixture_weights,t_means,t_sigma)
    optimal_gap = []
    S1_norm = []
    S2_norm = []
    likelyhood = []
    for i in tqdm(range(hyper.epochs)):
        key, *subkey = random.split(key,num=hyper.n+2)
        subkey = (_ for _ in subkey)  
        indicator = random.bernoulli(next(subkey),hyper.p,shape=(hyper.n,))
        participate = jnp.nonzero(indicator)[0]
        batch_data = []
        for client in list(participate):
            batch_data.append(clients_data[client][random.permutation(next(subkey),20)[:hyper.b]])
        batch_data = jnp.array(batch_data)                
        delta1, delta2 = parallel_client_run(batch_data, means, sigma, SS_1, SS_2, Vss1[participate],Vss2[participate])
        
        #vectorize Vss updates. 
        Vss1 = Vss1.at[participate].set(Vss1[participate]+hyper.alpha/hyper.p*delta1)
        Vss2 = Vss2.at[participate].set(Vss2[participate]+hyper.alpha/hyper.p*delta2)
    
        #Central Updates:
        d1 = jnp.mean(delta1, axis=0)
        d2 = jnp.mean(delta2,axis=0)
        H1 = V1 + d1
        H2 = V2 + d2
        SS_1 = SS_1+hyper.gamma*H1
        SS_2 = SS_2+hyper.gamma*H2
        V1 = V1+hyper.alpha/hyper.p*d1
        V2 = V2+hyper.alpha/hyper.p*d2

        #calculate the new variables 
        mixture_weights = (1/jnp.sum(SS_1))*SS_1
        means = jnp.expand_dims((1/SS_1),axis=1)*SS_2
        sigma = yTy - jnp.sum(jnp.expand_dims(SS_1, axis=(1,2))*outer(means, means), axis=0)
        sigma = nearestPD(sigma)
        optimal_gap.append(float(jnp.abs(optimal_value -compute_likelyhood(clients_data,mixture_weights,means,sigma))))
        S1_norm.append(float(jnp.linalg.norm(d1)))
        S2_norm.append(float(jnp.linalg.norm(d2)))
        likelyhood.append(float(compute_likelyhood(clients_data,mixture_weights,means,sigma)))
        if(wandb_log):
            log_dict={
                "S1 Update": S1_norm[-1],
                "S2 Update": S2_norm[-1],
                "Optimality Gap": optimal_gap[-1] ,
                "Likelyhood": likelyhood[-1]
            }
            wandb.log(log_dict)
    print("optimality gap  ", optimal_gap)
    print("S1 updated norm ", S1_norm)
    print("S2 updated norm ", S2_norm)
    print(optimal_gap)

## Experiment Running

In [15]:

class Hyper(NamedTuple):
    #Hyperparameters with default settings..
    p:float= 0.65 # partipation rate
    L:int = 2 # length
    b:int = 5 #batch size
    N:int = 1_000 #total number of examples
    n:int = 10 # number of agents
    d:int = 3 # dimension of the experiment
    gamma:float = 10e-2 #hyper parameter
    alpha:float = 10e-2 # the total things
    epochs:int = 50
    repeats:int = 1
    seed:int = 42

In [16]:
wandb.login()
wandb.init(entity='waihegz',project='FedSur', name='FedEMList')
wandb.config.update(Hyper()._asdict())
main(Hyper(), wandb_log=True)



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Likelyhood,▁▁▂▂▃▃▄▄▅▆▆▆▇▇▇▇▇▇█████▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▆
Optimality Gap,█▇▇▆▅▄▃▃▁▁▁▂▂▃▃▃▃▃▄▄▄▄▄▄▃▃▃▂▃▃▃▃▃▃▃▃▃▂▂▂
S1 Update,█▆▄▄▃▂▂▁▂▂▃▂▂▂▂▂▂▂▃▂▂▂▂▃▁▁▁▃▂▂▂▃▃▂▂▃▂▂▂▂
S2 Update,██▆▄▂▂▂▂▂▂▄▁▃▂▃▂▃▃▂▃▂▂▂▄▂▂▃▃▁▃▁▃▄▂▃▂▃▃▂▂

0,1
Likelyhood,0.03493
Optimality Gap,0.00251
S1 Update,0.13799
S2 Update,0.21166


100%|██████████| 50/50 [00:02<00:00, 20.18it/s]

optimality gap   [0.134805366396904, 0.1246090680360794, 0.11190737783908844, 0.10513995587825775, 0.09289431571960449, 0.08098989725112915, 0.06738331913948059, 0.0491432249546051, 0.03885588049888611, 0.02432587742805481, 0.01496782898902893, 0.0063037872314453125, 0.010718464851379395, 0.02273768186569214, 0.02563542127609253, 0.02635371685028076, 0.038302987813949585, 0.043362170457839966, 0.04098597168922424, 0.04971066117286682, 0.05101165175437927, 0.050673067569732666, 0.06840512156486511, 0.06810864806175232, 0.06408259272575378, 0.06968116760253906, 0.059958070516586304, 0.06677475571632385, 0.054405391216278076, 0.04885604977607727, 0.041637420654296875, 0.04495254158973694, 0.04196181893348694, 0.025435209274291992, 0.04614818096160889, 0.04252183437347412, 0.05124926567077637, 0.04907238483428955, 0.04101797938346863, 0.050353437662124634, 0.03743758797645569, 0.03665462136268616, 0.038530588150024414, 0.03779307007789612, 0.034408241510391235, 0.03737759590148926, 0.02721


