In [1]:
import numpy as np
import numpy.linalg as linalg
from sklearn.linear_model import Lasso
from copy import deepcopy, copy 
from typing import NamedTuple
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils._testing import ignore_warnings
import matplotlib.pyplot as plt
import wandb

## Utils

In [2]:
def get_psd(sigma):
    M, v = np.linalg.eig(sigma)
    v_inv = np.linalg.inv(v) 
    E = np.diag(np.clip(M,0,np.inf))+0.0001
    return (v@E@v_inv).astype(float)    

def get_obj_fct(lasso, hyper):
    @ignore_warnings(category=ConvergenceWarning)
    def obj_fct_n(Theta, client_codeword):
        #This needs fiving
        obj_fct = 0.0
        for nn in range(hyper.n):
            for t in range(client_codeword[nn].shape[0]):
                y = client_codeword[nn][t]
                lasso.fit(Theta.T, y) # involves solving a LASSO problem
                hj = lasso.coef_
                obj_fct += 0.5*linalg.norm(y - hj@Theta )**2 + hyper.lmbd*linalg.norm( hj, 1)
        return (obj_fct + 0.5*hyper.eta* linalg.norm( Theta )**2)
    return obj_fct_n


# def quant(x, s):
#     uni_noise = np.random.uniform(size=x.shape)
#     #np.max , np.norm depends on the need for sparsity 
#     quant = np.floor(s/np.norm(x)*np.abs(x)+uni_noise)
#     reconstruct  = quant-uni_noise
#     return (1/s)*np.max(x)*np.sign(x)*reconstruct

# def quant(x,s):
#     uni_noise = np.random.uniform(size=x.shape)
#     #np.max , np.norm depends on the need for sparsity 
#     quant = np.floor(s/linalg.norm(x)*np.abs(x)+uni_noise)
#     reconstruct  = quant-uni_noise
#     return (1/s)*linalg.norm(x)*np.sign(x)*reconstruct


def quant(x, s):
    return (1/s) * np.sign(x) * linalg.norm(x) * np.floor(s* np.abs(x)/(linalg.norm(x))+np.random.uniform(size=x.shape))


def generate_data(hyper):
    shared_dict = np.random.normal(size=(hyper.d,hyper.m))
    client_codeword , client_data = [], [] 
    if hyper.homo_switch==0:
        noise = 0.0001
    else:
        noise = 0
    for nn in range(hyper.n): 
        client_data.append(np.random.normal(size=(hyper.T,hyper.d))* np.random.binomial(
            1, 0.2,size=(hyper.T,hyper.d)))
        client_codeword.append(client_data[nn]@shared_dict + noise*np.random.normal(size=(hyper.T,hyper.m)))
    return shared_dict,client_codeword,client_data


## Experiment

In [3]:
def main(hyper,wandb_log=False):
    np.random.seed(hyper.seed)
    _ , client_codeword, _ = generate_data(hyper)
    init_dict , _ , _ = generate_data(hyper)
    lasso = Lasso(alpha=hyper.lmbd, fit_intercept=False)
    obj_fct_n = get_obj_fct(lasso, hyper)

    p = (hyper.Participate/hyper.n)
    obj_sto_all = []
    theta_norm_all = []
    @ignore_warnings(category=ConvergenceWarning)
    def fit_client(x,y):
        fitted = lasso.fit(x,y)
        return fitted.coef_

    for rep in range(hyper.rep_times):
        theta = deepcopy(init_dict)
        V_theta = [np.zeros(theta.shape) for _ in range(hyper.n)] 
        Msg_theta = [np.zeros(theta.shape) for _ in range(hyper.n)] 
        Va_theta, Ha_theta, D_theta = (np.zeros(theta.shape) for _ in range(3))
        obj_sto = []
        theta_norm = []
        print("rep = ", rep, " \n")
        for t in range(hyper.IterNum):
            if t%50 == 0:
                obj_sto.append(obj_fct_n(theta, client_codeword))
                theta_norm.append(linalg.norm(D_theta))
                print( "t ",t ," obj = ",  obj_sto[-1])
                
                if wandb_log==True:
                    log_dict={
                        "Theta Update Norm": theta_norm[-1],
                        "Objective Value": obj_sto[-1] 
                    }
                    wandb.log(log_dict)

            idx_PP = np.random.permutation(hyper.n)[:hyper.Participate]
            #for each client
            for nn in idx_PP:
                idx_batch = np.random.permutation(hyper.T)[:hyper.BatchSize]
                yy = client_codeword[nn][idx_batch]
                SS1 = np.zeros((hyper.d,hyper.d))
                SS2 = np.zeros((hyper.m,hyper.d))
                for b in range(hyper.BatchSize):
                    hj = fit_client(theta.T, yy[b])
                    SS1 += np.outer(hj,hj)
                    SS2 += np.outer(yy[b],hj)
                SS1 = SS1/hyper.BatchSize
                SS2 = SS2/hyper.BatchSize
                client_theta = SS2@( np.linalg.inv(SS1 + hyper.eta*np.eye(hyper.d)))  
                client_theta = client_theta.T
                Delta = client_theta - theta - V_theta[nn]
                
                ##Message conostruction
                if hyper.squant!=0:
                    V_theta[nn] = V_theta[nn]+(hyper.alpha/p)*quant(Delta, hyper.squant)
                    Msg_theta[nn] = quant(Delta, hyper.squant)
                else:
                    Msg_theta[nn] = Delta
                    V_theta[nn] = V_theta[nn]+(hyper.alpha/p)*Delta
                

            #Server Update
            step_size = 0.1 / np.sqrt(t+0.1)
            D_theta = (1/(hyper.n*p))*sum([Msg_theta[i] for i in idx_PP])
            Ha_theta =  Va_theta + D_theta
            Va_theta =Va_theta+(hyper.alpha/(hyper.n*p))*sum([Msg_theta[i] for i in idx_PP])
            theta = theta+step_size* Ha_theta  

        obj_sto_all.append(obj_sto)
        theta_norm_all.append(theta_norm)
    return obj_sto_all, theta_norm_all

## Run the experiment:

In [4]:
class Hyper(NamedTuple):
    seed = 42
    ## Problems dimensions
    d = 15 # original dimension
    m = 10 # codeword dimension
    n = 20 # number of workers
    T = 50 # Samples per worker
    homo_switch = 0 #0 for heter data, 1 for homo
    surrogate = 1 #1 aggregating over the prameter space, 0 over the 
    #Loss function parameters
    eta = 0.1 #dict l2 penalty weight
    lmbd = 0.05 #codeword l1 penalty weight
    rep_times = 1 # no. of repetition to run the experiment
    # parameters for the algorithms 
    BatchSize = 20 # batch size per client per run
    Participate = 5 # no of active workers each run
    alpha = 0.01 # step size
    squant = 0 # quantization level
    IterNum = 10_000 # no. of iterations per experiment
    

In [5]:
hyper = Hyper()
# wandb.login()
# wandb.init(project='FedSur', name='Fed Dict Par Learn', config=Hyper()._asdict())
main(Hyper(), wandb_log=False)

rep =  0  

t  0  obj =  549.9240389094546
t  50  obj =  642.1328465620952
t  100  obj =  717.2954556980202
t  150  obj =  774.5473009570931
t  200  obj =  812.3616495180538
t  250  obj =  841.1223709985038
t  300  obj =  866.0859934566923
t  350  obj =  888.3366994752896
t  400  obj =  902.424246714035
t  450  obj =  908.8740350624294
t  500  obj =  911.7025454910272
t  550  obj =  916.0173224454983
t  600  obj =  921.5427562063003
t  650  obj =  925.5650244851655
t  700  obj =  931.8777895788724
t  750  obj =  938.4938153635626
t  800  obj =  942.4765577604682
t  850  obj =  945.8479023391911
t  900  obj =  947.7786013832239
t  950  obj =  950.867654139467
t  1000  obj =  953.290601903602
t  1050  obj =  955.9411041272496
t  1100  obj =  958.1655110911607
t  1150  obj =  958.7510838742255
t  1200  obj =  958.4302018396165
t  1250  obj =  956.8889652604394
t  1300  obj =  955.2755106478737
t  1350  obj =  954.3874876141064
t  1400  obj =  953.0983331799604
t  1450  obj =  951.83683905