In [6]:
import numpy as np

# Data generation

In [9]:
np.random.seed(0)
data = np.random.uniform(-10, 10, 1000)
data.sort()
data[:10], data[-10:]

(array([-9.9890807 , -9.972333  , -9.94593572, -9.92279297, -9.90609048,
        -9.89694071, -9.84231793, -9.77145083, -9.76571832, -9.75927554]),
 array([9.78819555, 9.80677895, 9.80690003, 9.83780658, 9.84022487,
        9.88801579, 9.95924503, 9.97694013, 9.98555988, 9.99617156]))

# Loss

In [10]:
def value(x, data_batch):
    return np.mean([(x - x_0) ** 4 for x_0 in data_batch])
    
def gradient(x, data_batch):
    return 4 * np.mean([(x - x_0) ** 3 for x_0 in data_batch])

def hessian(x, data_batch):
    return 12 * np.mean([(x - x_0) ** 2 for x_0 in data_batch])

## Optimal value

In [11]:
step_size = 1

x0 = 1000
x = x0
N = 20
for i in range(N):
    grad = gradient(x, data)
    hess = hessian(x, data)
    x -= step_size * 1 / hess * grad
    f_value = value(x, data)
    print(f'Step#{i+1}/{N}: x: {x}, loss: {f_value}')

x_opt = x
f_opt = f_value

Step#1/20: x: 666.616964281754, loss: 197658689927.21805
Step#2/20: x: 444.3503542438573, loss: 39054057585.9382
Step#3/20: x: 296.1557394712065, loss: 7718984085.153355
Step#4/20: x: 197.3340286041539, loss: 1526779340.4356422
Step#5/20: x: 131.41496411368178, loss: 302490358.3123419
Step#6/20: x: 87.4121064719466, loss: 60150389.78977136
Step#7/20: x: 57.99183832064399, loss: 12056561.465732094
Step#8/20: x: 38.251331746437266, loss: 2457112.1547158738
Step#9/20: x: 24.90220650295319, loss: 516918.01650052355
Step#10/20: x: 15.725537166015714, loss: 114331.86339315155
Step#11/20: x: 9.216235719233138, loss: 26578.7988397845
Step#12/20: x: 4.410145948503721, loss: 6321.060472057016
Step#13/20: x: 1.1155071887606867, loss: 2282.4037393535136
Step#14/20: x: 0.06940950048197525, loss: 2046.9200936817354
Step#15/20: x: 0.04194366641634431, loss: 2046.767086627283
Step#16/20: x: 0.0419405005684303, loss: 2046.767086625251
Step#17/20: x: 0.04194050056839399, loss: 2046.767086625251
Step#18/

# Running experiments

## Parameters

In [12]:
x0 = 100
n_local_steps = 10
n_clients = 10
batch_size = 8
n_comms = 100
client_lr = 0.00001
n_seeds = 3
c_0 = 0
c_1 = 1

## Local SGD

In [13]:
def calculate_clipped_lr(c_0, c_1, g_p):
    return 1 / (c_0 + c_1 * abs(g_p))

In [14]:
def local_sgd(
    x0, 
    f_opt,
    data, 
    client_lr, 
    c_0,
    c_1,
    batch_size, 
    n_clients, 
    n_local_steps,
    n_comms,
    n_seeds
):
    seed_loss_vals = {}
    data_per_par = len(data) // n_clients
    client_data_list = [data[i * data_per_par : (i + 1) * data_per_par] for i in range(n_clients)]
    print(f'Number of data points per client: {data_per_par}')
    print(f'Client data lens: {[len(d) for d in client_data_list]}')
    x_p = x0
    print(f'Comm [{0}/{n_comms}] f-f_opt={value(x_p, data) - f_opt:.4f}')

    for seed in range(n_seeds):
        np.random.seed(seed)
        loss_vals_list = [0.0 for _ in range(n_comms)]
        seed_loss_vals[seed] = loss_vals_list

        for p in range(n_comms):
            g_p = 0
            x_p_m = x_p

            for m in range(n_clients):
                # local run
                client_data = client_data_list[m]

                for i in range(n_local_steps):
                    input_batch = np.random.choice(client_data, batch_size)
                    stoch_grad = gradient(x_p_m, input_batch)
                    x_p_m -= client_lr * stoch_grad

                g_p += x_p - x_p_m
            g_p *= 1 / (client_lr * n_clients * n_local_steps)
            server_lr = calculate_clipped_lr(c_0, c_1, g_p)
            x_p -= server_lr * g_p
            loss_vals_list[p] = value(x_p, data) - f_opt
            print(f'Seed [{seed+1}/{n_seeds}] | Comm [{p+1}/{n_comms}] | f-f_opt={value(x_p, data) - f_opt:.4f}')

        print()

    return seed_loss_vals 

In [15]:
seed_loss_vals = local_sgd(
    x0, 
    f_opt, 
    data, 
    client_lr, 
    c_0, 
    c_1, 
    batch_size, 
    n_clients,
    n_local_steps,
    n_comms,
    n_seeds
)

Number of data points per client: 100
Client data lens: [100, 100, 100, 100, 100, 100, 100, 100, 100, 100]
Comm [0/100] f-f_opt=102351720.7111
Seed [1/3] | Comm [1/100] f-f_opt=98361307.7969
Seed [1/3] | Comm [2/100] f-f_opt=94489108.1201
Seed [1/3] | Comm [3/100] f-f_opt=90732755.7231
Seed [1/3] | Comm [4/100] f-f_opt=87089908.6481
Seed [1/3] | Comm [5/100] f-f_opt=83558248.9376
Seed [1/3] | Comm [6/100] f-f_opt=80135482.6339
Seed [1/3] | Comm [7/100] f-f_opt=76819339.7793
Seed [1/3] | Comm [8/100] f-f_opt=73607574.4161
Seed [1/3] | Comm [9/100] f-f_opt=70497964.5867
Seed [1/3] | Comm [10/100] f-f_opt=67488312.3334
Seed [1/3] | Comm [11/100] f-f_opt=64576443.6986
Seed [1/3] | Comm [12/100] f-f_opt=61760208.7245
Seed [1/3] | Comm [13/100] f-f_opt=59037481.4536
Seed [1/3] | Comm [14/100] f-f_opt=56406159.9282
Seed [1/3] | Comm [15/100] f-f_opt=53864166.1905
Seed [1/3] | Comm [16/100] f-f_opt=51409446.2830
Seed [1/3] | Comm [17/100] f-f_opt=49039970.2480
Seed [1/3] | Comm [18/100] f-f_op

## Baseline 1

In [16]:
def local_sgd_baseline_1(
    x0, 
    f_opt,
    data, 
    c_0,
    c_1,
    batch_size, 
    n_clients, 
    n_local_steps,
    n_comms,
    n_seeds
):
    seed_loss_vals = {}
    data_per_par = len(data) // n_clients
    client_data_list = [data[i * data_per_par : (i + 1) * data_per_par] for i in range(n_clients)]
    print(f'Number of data points per client: {data_per_par}')
    print(f'Client data lens: {[len(d) for d in client_data_list]}')
    x_p = x0
    print(f'Comm [{0}/{n_comms}] f-f_opt={value(x_p, data) - f_opt:.4f}')

    for seed in range(n_seeds):
        np.random.seed(seed)
        loss_vals_list = [0.0 for _ in range(n_comms)]
        seed_loss_vals[seed] = loss_vals_list
        for p in range(n_comms):
            x_p_m = x_p
            x_p = 0

            for m in range(n_clients):
                # local run
                client_data = client_data_list[m]

                for i in range(n_local_steps):
                    input_batch = np.random.choice(client_data, batch_size)
                    stoch_grad = gradient(x_p_m, input_batch)
                    client_lr = calculate_clipped_lr(c_0, c_1, stoch_grad)
                    x_p_m -= client_lr * stoch_grad

                x_p += x_p_m
            x_p /= n_clients
            loss_vals_list[p] = value(x_p, data) - f_opt
            print(f'Seed [{seed+1}/{n_seeds}] | Comm [{p+1}/{n_comms}] | f-f_opt={value(x_p, data) - f_opt:.4f}')
        print()

    return seed_loss_vals 

In [17]:
seed_loss_vals = local_sgd_baseline_1(
    x0,
    f_opt,
    data, 
    c_0,
    c_1,
    batch_size,
    n_clients,
    n_local_steps,
    n_comms,
    n_seeds
)

Number of data points per client: 100
Client data lens: [100, 100, 100, 100, 100, 100, 100, 100, 100, 100]
Comm [0/100] f-f_opt=102351720.7111
Seed [1/3] | Comm [1/100] f-f_opt=4855882.9051
Seed [1/3] | Comm [2/100] f-f_opt=35610.4517
Seed [1/3] | Comm [3/100] f-f_opt=63.3303
Seed [1/3] | Comm [4/100] f-f_opt=39.5980
Seed [1/3] | Comm [5/100] f-f_opt=26.0351
Seed [1/3] | Comm [6/100] f-f_opt=83.5965
Seed [1/3] | Comm [7/100] f-f_opt=26.0351
Seed [1/3] | Comm [8/100] f-f_opt=39.5980
Seed [1/3] | Comm [9/100] f-f_opt=63.3303
Seed [1/3] | Comm [10/100] f-f_opt=39.5980
Seed [1/3] | Comm [11/100] f-f_opt=63.3303
Seed [1/3] | Comm [12/100] f-f_opt=39.5980
Seed [1/3] | Comm [13/100] f-f_opt=5.0682
Seed [1/3] | Comm [14/100] f-f_opt=0.3566
Seed [1/3] | Comm [15/100] f-f_opt=39.5980
Seed [1/3] | Comm [16/100] f-f_opt=26.0351
Seed [1/3] | Comm [17/100] f-f_opt=83.5965
Seed [1/3] | Comm [18/100] f-f_opt=26.0351
Seed [1/3] | Comm [19/100] f-f_opt=83.5965
Seed [1/3] | Comm [20/100] f-f_opt=26.0351


# Local-SGD Baseline 2

In [23]:
def clip(vector, clip_level):
    return vector * min(1, clip_level / np.linalg.norm(vector))

In [52]:
def local_sgd_baseline_2(
    x0, 
    f_opt,
    data, 
    server_lr,
    client_lr,
    client_cl,
    batch_size, 
    n_clients, 
    n_local_steps,
    n_comms,
    n_seeds
):
    seed_loss_vals = {}
    data_per_par = len(data) // n_clients
    client_data_list = [data[i * data_per_par : (i + 1) * data_per_par] for i in range(n_clients)]
    print(f'Number of data points per client: {data_per_par}')
    print(f'Client data lens: {[len(d) for d in client_data_list]}')
    x_p = x0
    print(f'Comm [{0}/{n_comms}] f-f_opt={value(x_p, data) - f_opt:.4f}')

    for seed in range(n_seeds):
        np.random.seed(seed)
        loss_vals_list = [0.0 for _ in range(n_comms)]
        seed_loss_vals[seed] = loss_vals_list
        for p in range(n_comms):
            x_p_m = x_p
            g_p = 0

            for m in range(n_clients):
                # local run
                client_data = client_data_list[m]

                for i in range(n_local_steps):
                    input_batch = np.random.choice(client_data, batch_size)
                    stoch_grad = gradient(x_p_m, input_batch)
                    x_p_m -= client_lr * stoch_grad

                g_p_m = 1 / (client_lr * n_local_steps) * (x_p - x_p_m)
                g_p += clip(g_p_m, client_cl)
                print(f'Seed [{seed+1}/{n_seeds}] | Comm [{p+1}/{n_comms}] | Client [{m + 1}/{n_clients}] | f-f_opt={value(x_p, client_data) - f_opt:.4f}')

            g_p /= n_clients
            print(g_p)
            x_p -= server_lr * g_p
            loss_vals_list[p] = value(x_p, data) - f_opt
            print(f'Seed [{seed+1}/{n_seeds}] | Comm [{p+1}/{n_comms}] | f-f_opt={value(x_p, data) - f_opt:.4f}')

        print()

    return seed_loss_vals 

In [73]:
server_lr = 0.01
client_cl = 1000
client_lr = 0.00001

In [74]:
seed_loss_vals = local_sgd_baseline_2(
    x0,
    f_opt,
    data, 
    server_lr,
    client_lr,
    client_cl,
    batch_size,
    n_clients,
    n_local_steps,
    n_comms,
    n_seeds
)

Number of data points per client: 100
Client data lens: [100, 100, 100, 100, 100, 100, 100, 100, 100, 100]
Comm [0/100] f-f_opt=102351720.7111
Seed [1/3] | Comm [1/100] | Client [1/10] | f-f_opt=140913614.3975
Seed [1/3] | Comm [1/100] | Client [2/10] | f-f_opt=131363478.1364
Seed [1/3] | Comm [1/100] | Client [3/10] | f-f_opt=121648833.8893
Seed [1/3] | Comm [1/100] | Client [4/10] | f-f_opt=113348352.4173
Seed [1/3] | Comm [1/100] | Client [5/10] | f-f_opt=105632754.3105
Seed [1/3] | Comm [1/100] | Client [6/10] | f-f_opt=97313687.8045
Seed [1/3] | Comm [1/100] | Client [7/10] | f-f_opt=89086520.4439
Seed [1/3] | Comm [1/100] | Client [8/10] | f-f_opt=81809858.6472
Seed [1/3] | Comm [1/100] | Client [9/10] | f-f_opt=74097375.0094
Seed [1/3] | Comm [1/100] | Client [10/10] | f-f_opt=68302732.0553
1000.0
Seed [1/3] | Comm [1/100] | f-f_opt=67488312.3334
Seed [1/3] | Comm [2/100] | Client [1/10] | f-f_opt=95877762.4286
Seed [1/3] | Comm [2/100] | Client [2/10] | f-f_opt=88741976.7326
Se