In [None]:
import jax 
import jax.numpy as jnp
from jax import random, scipy
from typing import NamedTuple
from tqdm import tqdm
import matplotlib.pyplot as plt
import wandb
guassian_pdf = jax.scipy.stats.multivariate_normal.pdf 

## Utils

In [None]:

def get_psd(sigma):
    M, v = jnp.linalg.eig(sigma)
    v_inv = jnp.linalg.inv(v) 
    E = jnp.diag(jnp.clip(M,0,jnp.inf))+0.0001
    return (v@E@v_inv).astype(float)
    
def _outer(a,b):
    return jnp.outer(a,b)
outer = jax.vmap(_outer, (0,0),0)
outer_clients = jax.jit(jax.vmap(outer,(0,0),0))


def _get_prob(a, mean, sigma): #should return the shape (L,n)
    return guassian_pdf(a,mean,sigma)
_get_prob_1 = jax.vmap(_get_prob, (0,None,None),0)
get_prob = jax.vmap(_get_prob_1,(None,0,None),1)    

@jax.jit
def client_run(data, means,sigma, SS_1, SS_2, vss_1, vss_2):
    def _get_ss2(prob_vec,data):
        #probvec: (n), data: (n,d)
        prob_vec = jnp.expand_dims(prob_vec, axis=1) # (n,1)
        return jnp.mean(prob_vec*data ,axis=0) #(n,d) -> (d)
    get_ss2 = jax.vmap(_get_ss2, (0,None),0) #(L,d)
    prob = get_prob(data,means,sigma) #(n,L)
    sum_prob = jnp.sum(prob, axis=1)  #(n)
    sum_prob = jnp.expand_dims(sum_prob,1) + 0.000001#(n,1)
    raw_ss_1 = prob/sum_prob #(n,L)
    ss_1 = jnp.mean(raw_ss_1, axis=0) #(L)
    ss_2 = get_ss2(raw_ss_1.T, data) #(L,d)
    delta1 = ss_1-SS_1-vss_1
    delta2 = ss_2-SS_2-vss_2
    return delta1, delta2

def theta_client_run(data, means, sigma):
    mixture = None
    return means, sigma, mixture

parallel_client_run = jax.vmap(client_run, (0,None,None,None,None,0,0),0)

def generate_data(key, hyper:NamedTuple):
    #Do not jax.jit
    key, *subkey = random.split(key, num=10+hyper.L)
    subkey = (i for i in subkey)
    means = random.normal(next(subkey), shape=(hyper.L,hyper.d))
    sigma = jnp.cov(random.normal(next(subkey), shape=(hyper.d,hyper.d))@random.normal(next(subkey),shape=(hyper.d,100_000)))
    source_data = [random.multivariate_normal(next(subkey), means[i], sigma,shape=(hyper.N,)) for i in range(hyper.L)]
    logits  = random.dirichlet(next(subkey),1.5* jnp.array([1.0 for _ in range(hyper.L)]))
    mixture_weights = jax.nn.softmax(logits)
    cuts = random.categorical(next(subkey), logits, shape=(hyper.N,))
    raw_data = jnp.vstack([dt[cuts==idx] for idx,dt in enumerate(source_data)])
    clients_data = jnp.array(jnp.split(random.permutation(next(subkey),raw_data), hyper.n))
    return key,means,sigma,mixture_weights,clients_data 

@jax.jit
def compute_likelyhood(data, mixture, means,sigma ):
    #Need to use the prob_function
    likelyhood = get_prob(data, means, sigma)
    likelyhood = jnp.expand_dims(mixture,axis=1) *likelyhood 
    return jnp.mean(likelyhood)



## Main

In [None]:
def main(hyper, wandb_log=False):
    
    key = random.PRNGKey(hyper.seed)
    # Generate the data
    key, t_means,t_sigma,t_mixture_weights, clients_data =  generate_data(key, hyper)
    yTy = outer_clients(clients_data,clients_data)
    yTy = jnp.mean(yTy,axis=(0,1))
    #Initialize the variables
    key, means, sigma, mixture_weights, _ =  generate_data(key, hyper)
    SS_1 , SS_2 = jnp.zeros(shape=(hyper.L,)), jnp.zeros(shape=(hyper.L, hyper.d))
    V1 , V2 =  jnp.zeros(shape=(hyper.L,)), jnp.zeros(shape=(hyper.L, hyper.d))
    H1 , H2 =  jnp.zeros(shape=(hyper.L,)), jnp.zeros(shape=(hyper.L, hyper.d))
    #Variables list
    Vss1 = jnp.zeros(shape=(hyper.n, hyper.L))
    Vss2 = jnp.zeros(shape=(hyper.n,hyper.L, hyper.d))
    optimal_value = compute_likelyhood(clients_data,t_mixture_weights,t_means,t_sigma)
    optimal_gap = []
    S1_norm = []
    S2_norm = []
    for i in tqdm(range(hyper.epochs)):
        key, *subkey = random.split(key,num=hyper.n+2)
        subkey = (_ for _ in subkey)  
        indicator = random.bernoulli(next(subkey),hyper.p,shape=(hyper.n,))
        participate = jnp.nonzero(indicator)[0]
        batch_data = []
        for client in list(participate):
            batch_data.append(clients_data[client][random.permutation(next(subkey),20)[:hyper.b]])
        batch_data = jnp.array(batch_data)                
        delta1, delta2 = parallel_client_run(batch_data, means, sigma, SS_1, SS_2, Vss1[participate],Vss2[participate])
        
        #vectorize Vss updates. 
        Vss1 = Vss1.at[participate].set(hyper.alpha*delta1)
        Vss2 = Vss2.at[participate].set(hyper.alpha*delta2)
    
        #Central Updates:
        d1 = jnp.mean(delta1, axis=0)
        d2 = jnp.mean(delta2,axis=0)
        H1 = V1 + d1
        H2 = V2 + d2
        SS_1 = SS_1+hyper.gamma*H1
        SS_2 = SS_2+hyper.gamma*H2
        V1 = V1+hyper.alpha*d1
        V2 = V2+hyper.alpha*d2

        #calculate the new variables 
        mixture_weights = (1/jnp.sum(SS_1))*SS_1
        means = jnp.expand_dims((1/SS_1),axis=1)*SS_2
        sigma = yTy - jnp.sum(jnp.expand_dims(SS_1, axis=(1,2))*outer(means, means), axis=0)
        sigma = get_psd(sigma)
        sigma = 0.5*(sigma.T+sigma)
        optimal_gap.append(float(jnp.abs(optimal_value -compute_likelyhood(clients_data,mixture_weights,means,sigma))))
        S1_norm.append(float(jnp.linalg.norm(d1)))
        S2_norm.append(float(jnp.linalg.norm(d2)))
        if(wandb_log):
            log_dict={
                "S1 Update": S1_norm[-1],
                "S2 Update": S2_norm[-1],
                "Optimality Gap": optimal_gap[-1] 
            }
            wandb.log(log_dict)
    print("optimality gap  ", optimal_gap)
    print("S1 updated norm ", S1_norm)
    print("S2 updated norm ", S2_norm)
    print(optimal_gap)

## Experiment Running

In [13]:
class Hyper(NamedTuple):
    #Hyperparameters with default settings..
    p = 0.95 # partipation rate
    L = 2 # length
    b = 100 #batch size
    N = 10_000 #total number of examples
    n = 1000 # number of agents
    d = 3 # dimension of the experiment
    gamma = 10e-2 #hyper parameter
    alpha = 10e-2 # the total things
    epochs = 300
    repeats = 1
    seed = 42

In [14]:
wandb.login()
wandb.init(project='FedSur', name='FedEMList', config=Hyper()._asdict())
main(Hyper(), wandb_log=True)



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Optimality Gap,▁▄▅▇▇█▇▇▇▇▇▇▇▇▆▆▇████▇▇▆▇▇▇▆▆▇▇▇▇▇▆▇▇██▇
S1 Update,█▆▁▃▃▃▁▂▃▂▅▁▁▁▅▂▁▁▁▁▆▂▂▂▆▁▆▂▂▁▂▁▁▂▂▂▁▆▁▁
S2 Update,█▆▁▃▃▃▁▂▃▂▆▁▁▁▅▂▁▁▁▁▆▂▂▂▆▂▆▂▂▁▂▁▁▂▂▂▁▆▁▁

0,1
Optimality Gap,0.0
S1 Update,0.0
S2 Update,0.0


[34m[1mwandb[0m: wandb version 0.12.21 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


100%|██████████| 300/300 [10:58<00:00,  2.19s/it]

optimality gap   [0.01543755829334259, 0.014359993860125542, 0.013379056006669998, 0.01237981766462326, 0.011330239474773407, 0.010198567062616348, 0.00896645151078701, 0.0076258014887571335, 0.0061538685113191605, 0.004557985812425613, 0.0028155073523521423, 0.000891316682100296, 0.001168716698884964, 0.0033973418176174164, 0.005834717303514481, 0.008427154272794724, 0.011101581156253815, 0.013914022594690323, 0.012519292533397675, 0.013621263206005096, 0.01203065738081932, 0.01245751604437828, 0.011935733258724213, 0.01150529459118843, 0.011042993515729904, 0.010511212050914764, 0.009964130818843842, 0.009328536689281464, 0.00866185873746872, 0.007962614297866821, 0.007239244878292084, 0.006511811167001724, 0.0057641491293907166, 0.004986613988876343, 0.004230432212352753, 0.003482423722743988, 0.002747759222984314, 0.002038329839706421, 0.0013547390699386597, 0.0006654597818851471, 2.559274435043335e-06, 0.0006343573331832886, 0.0012248754501342773, 0.0017833523452281952, 0.00230529


