In [2]:
import numpy as np
import numpy.linalg as linalg
from sklearn.linear_model import Lasso
from copy import deepcopy, copy 
from typing import NamedTuple
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils._testing import ignore_warnings
import matplotlib.pyplot as plt
import wandb

## Utils

In [3]:
def get_psd(sigma):
    M, v = np.linalg.eig(sigma)
    v_inv = np.linalg.inv(v) 
    E = np.diag(np.clip(M,0,np.inf))+0.0001
    return (v@E@v_inv).astype(float)    

def get_obj_fct(lasso, hyper):
    @ignore_warnings(category=ConvergenceWarning)
    def obj_fct_n(Theta, client_codeword):
        #This needs fiving
        obj_fct = 0.0
        for nn in range(hyper.n):
            for t in range(client_codeword[nn].shape[0]):
                y = client_codeword[nn][t]
                lasso.fit(Theta.T, y) # involves solving a LASSO problem
                hj = lasso.coef_
                obj_fct += 0.5*linalg.norm(y - hj@Theta )**2 + hyper.lmbd*linalg.norm( hj, 1)
        return (obj_fct + 0.5*hyper.eta* linalg.norm( Theta )**2)
    return obj_fct_n

# def quant(x, s):
#     uni_noise = np.random.uniform(size=x.shape)
#     #np.max , np.norm depends on the need for sparsity 
#     max = np.max(np.abs(x))
#     quant = np.floor(s/max*np.abs(x)+uni_noise)
#     reconstruct  = quant-uni_noise
#     return (1/s)*max*np.sign(x)*reconstruct

# def quant(x,s):
#     uni_noise = np.random.uniform(size=x.shape)
#     #np.max , np.norm depends on the need for sparsity 
#     quant = np.floor(s/linalg.norm(x)*np.abs(x)+uni_noise)
#     reconstruct  = quant-uni_noise
#     return (1/s)*linalg.norm(x)*np.sign(x)*reconstruct


def quant(x, s):
    return (1/s) * np.sign(x) * linalg.norm(x) * np.floor(s* np.abs(x)/(linalg.norm(x))+np.random.uniform(size=x.shape))


def generate_data(hyper):
    shared_dict = np.random.normal(size=(hyper.d,hyper.m))
    client_codeword , client_data = [], [] 
    if hyper.homo_switch==0:
        noise = 0.0001
    else:
        noise = 0
    for nn in range(hyper.n): 
        client_data.append(np.random.normal(size=(hyper.T,hyper.d))* np.random.binomial(
            1, 0.2,size=(hyper.T,hyper.d)))
        client_codeword.append(client_data[nn]@shared_dict + noise*np.random.normal(size=(hyper.T,hyper.m)))
    return shared_dict,client_codeword,client_data


## Experiment

In [4]:
def main(hyper,wandb_log=False):
    np.random.seed(hyper.seed)
    _ , client_codeword, _ = generate_data(hyper)
    init_dict , _ , _ = generate_data(hyper)
    lasso = Lasso(alpha=hyper.lmbd, fit_intercept=False)
    obj_fct_n = get_obj_fct(lasso, hyper)

    p = (hyper.Participate/hyper.n)
    obj_sto_all = []
    SS1_norm_all = []
    SS2_norm_all = []
    @ignore_warnings(category=ConvergenceWarning)
    def fit_client(x,y):
        fitted = lasso.fit(x,y)
        return fitted.coef_

    for rep in range(hyper.rep_times):
        Theta = deepcopy(init_dict)
        V_SS1 = [np.zeros((hyper.d,hyper.d)) for _ in range(hyper.n)] 
        V_SS2 = [np.zeros((hyper.m,hyper.d)) for _ in range(hyper.n)] 
        Msg_SS1 = [np.zeros((hyper.d,hyper.d)) for _ in range(hyper.n)] 
        Msg_SS2 = [np.zeros((hyper.m,hyper.d)) for _ in range(hyper.n)] 
        Va_SS1, Ha_SS1, Sa_SS1 = (np.zeros((hyper.d,hyper.d)) for _ in range(3))
        Va_SS2, Ha_SS2, Sa_SS2 = (np.zeros((hyper.m,hyper.d)) for _ in range(3))
        obj_sto = []
        SS1_norm = []
        SS2_norm = []
        D_SS1 = np.zeros((hyper.d,hyper.d))
        D_SS2 = np.zeros((hyper.m,hyper.d))
        print("rep = ", rep, " \n")
        for t in range(hyper.IterNum):
            if t%50 == 0:
                obj_sto.append(obj_fct_n(Theta, client_codeword))
                SS1_norm.append(linalg.norm(D_SS1))
                SS2_norm.append(linalg.norm(D_SS2))
                print( "t ",t ," obj = ",  obj_sto[-1])
                
                if wandb_log==True:
                    log_dict={
                        "SS1 norm": SS1_norm[-1],
                        "SS2 norm": SS2_norm[-1],
                        "Objective Value": obj_sto[-1] 
                    }
                    wandb.log(log_dict)

            idx_PP = np.random.permutation(hyper.n)[:hyper.Participate]
            #for each client
            for nn in idx_PP:
                idx_batch = np.random.permutation(hyper.T)[:hyper.BatchSize]
                yy = client_codeword[nn][idx_batch]
                #client message
                SS1 = np.zeros((hyper.d,hyper.d))
                SS2 = np.zeros((hyper.m,hyper.d))
                for b in range(hyper.BatchSize):
                    hj = fit_client(Theta.T, yy[b])
                    SS1 += np.outer(hj,hj)
                    SS2 += np.outer(yy[b],hj)

                ##Message conostruction
                Delta1 = SS1/hyper.BatchSize - Sa_SS1 - V_SS1[nn]
                Delta2 = SS2/hyper.BatchSize - Sa_SS2 - V_SS2[nn]
                V_SS1[nn] = V_SS1[nn] + (hyper.alpha/p)*quant(Delta1, hyper.squant)
                V_SS2[nn] = V_SS2[nn] + (hyper.alpha/p)*quant(Delta2, hyper.squant)
                Msg_SS1[nn] = quant(Delta1, hyper.squant)
                Msg_SS2[nn] = quant(Delta2, hyper.squant); 
            

            #Server Update
            step_size = 0.1 / np.sqrt(t+0.1)
            D_SS1 = (1/(hyper.n*p))*sum([Msg_SS1[i] for i in idx_PP])
            D_SS2 = (1/(hyper.n*p))*sum([Msg_SS2[i] for i in idx_PP])    
            Ha_SS1 = Va_SS1 + D_SS1
            Ha_SS2 = Va_SS2 + D_SS2
            Sa_SS1 = Sa_SS1 + step_size*Ha_SS1
            Sa_SS1 = get_psd(0.5*(Sa_SS1+Sa_SS1.T))
            Sa_SS2 = Sa_SS2 + step_size*Ha_SS2
            Va_SS1 = Va_SS1 + (hyper.alpha/(hyper.n*p))*sum([Msg_SS1[i] for i in idx_PP])
            Va_SS2 = Va_SS2 + (hyper.alpha/(hyper.n*p))*sum([Msg_SS2[i] for i in idx_PP])
            Theta = Sa_SS2 @ ( np.linalg.inv(Sa_SS1 + hyper.eta*np.eye(hyper.d)))
            Theta = Theta.T  

        obj_sto_all.append(obj_sto)
        SS1_norm_all.append(SS1_norm)
        SS2_norm_all.append(SS2_norm)
    return obj_sto_all, SS1_norm_all, SS2_norm_all
    

## Run the experiment:

In [5]:
class Hyper(NamedTuple):
    seed = 42
    ## Problems dimensions
    d = 15 # original dimension
    m = 10 # codeword dimension
    n = 20 # number of workers
    T = 50 # Samples per worker
    homo_switch = 0 #0 for heter data, 1 for homo
    surrogate = 1 #1 aggregating over the prameter space, 0 over the 
    #Loss function parameters
    eta = 0.1 #dict l2 penalty weight
    lmbd = 0.05 #codeword l1 penalty weight
    rep_times = 1 # no. of repetition to run the experiment
    # parameters for the algorithms 
    BatchSize = 3 # batch size per client per run
    Participate = 5 # no of active workers each run
    alpha = 0.01 # step size
    squant = 15 # quantization level
    IterNum = 10_000 # no. of iterations per experiment
    

In [8]:
hyper = Hyper()
wandb.login()
wandb.init(project='FedSur', name='Fed Dict Learn', config=Hyper()._asdict())
main(Hyper(), wandb_log=True)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwaihegz[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.21 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


rep =  0  

t  0  obj =  549.9240389094546
t  50  obj =  642.8881256255041
t  100  obj =  600.7213755657879
t  150  obj =  580.2741355922307
t  200  obj =  560.2479859943293
t  250  obj =  553.2154144908319
t  300  obj =  541.3065628857597
t  350  obj =  530.0526142533867
t  400  obj =  519.2120061624136
t  450  obj =  510.2768460986564
t  500  obj =  504.54583369281886
t  550  obj =  502.0047713224124
t  600  obj =  503.26040668345667
t  650  obj =  504.79703785173496
t  700  obj =  506.5009098947856
t  750  obj =  508.1739583261809
t  800  obj =  507.8123066506428
t  850  obj =  509.35651146591863
t  900  obj =  510.75225966712054
t  950  obj =  513.0693675098083
t  1000  obj =  511.1723010785509
t  1050  obj =  513.1363432371218
t  1100  obj =  512.5616474817665
t  1150  obj =  511.8374260853703
t  1200  obj =  512.4850229981087
t  1250  obj =  512.3838516918522
t  1300  obj =  512.2792724668045
t  1350  obj =  512.7711645850487
t  1400  obj =  513.5981811218772
t  1450  obj =  514.

([[549.9240389094546,
   642.8881256255041,
   600.7213755657879,
   580.2741355922307,
   560.2479859943293,
   553.2154144908319,
   541.3065628857597,
   530.0526142533867,
   519.2120061624136,
   510.2768460986564,
   504.54583369281886,
   502.0047713224124,
   503.26040668345667,
   504.79703785173496,
   506.5009098947856,
   508.1739583261809,
   507.8123066506428,
   509.35651146591863,
   510.75225966712054,
   513.0693675098083,
   511.1723010785509,
   513.1363432371218,
   512.5616474817665,
   511.8374260853703,
   512.4850229981087,
   512.3838516918522,
   512.2792724668045,
   512.7711645850487,
   513.5981811218772,
   514.6108112249935,
   514.7546213598735,
   514.0954487969841,
   512.9686289241806,
   512.8346049991729,
   512.8563829933843,
   512.4089040297324,
   513.4299658458231,
   513.6793474759075,
   515.3153796895468,
   516.7106104680184,
   517.9947847613722,
   518.0147213078087,
   518.9476989935297,
   521.3042361677662,
   524.5603875969738,
   52