[![Open In Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/wenh06/fl_seminar/blob/master/code/fedprox-and-fedsplit.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wenh06/fl_seminar/blob/master/code/fedprox-and-fedsplit.ipynb)

# 问题

拟合$y = kx$, $k$为需要拟合的变量，用mean squared error (MSE)做loss

device1的点为$(0, 2), (1, 2)$,

device2的点为$(2, 0), (2, 1)$,

那么device1单独拟合结果为$k = 2$, device2单独拟合结果为$k = \frac{1}{4}$, 数据都拿一起拟合结果为$k = \frac{4}{9}$. 注意，这里一起拟合的结果并不等于单独拟合的均值，原因就在于loss function里的k是2次的，不是线性的。

In [None]:
def prox1(k, s):
    if np.isinf(s):
        return 2
    return (4 * s + k) / (2 * s + 1)

def prox2(k, s):
    if np.isinf(s):
        return 0.25
    return (4 * s + k) / (16 * s + 1)

In [None]:
true_k = 4 / 9

## FedProx

In [None]:
# 初始值
k_bar = 0
s = 1e-2

# k 的理论收敛值，注意，与 s 是相关的
theo_k = 4 * (9 * s + 1) / (32 * s + 9)

delta = np.inf
n_iter = 0

k1, k2 = k_bar, k_bar
loss1 = (k1-2)**2 + (k1-k_bar)**2 / (2*s)  # loss from 1
loss2 = (2*k2 - 1)**2 + (2*k2)**2 + (k2-k_bar)**2 / (2*s)  # loss from 2
loss = loss1 + loss2

print(f"init k_bar = {k_bar}, loss = {loss:.8f}")

while n_iter < 1e5 and delta > 1e-9:
    k1 = prox1(k_bar, s)
    k2 = prox2(k_bar, s)
    loss1 = (k1-2)**2 + (k1-k_bar)**2 / (2*s)  # loss from 1
    loss2 = (2*k2 - 1)**2 + (2*k2)**2 + (k2-k_bar)**2 / (2*s)  # loss from 2
    loss_decrease = loss - (loss1 + loss2)
    loss = loss1 + loss2
    new_k_bar = 0.5 * (k1 + k2)
    delta = abs(k_bar - new_k_bar)
    k_bar = new_k_bar
    n_iter += 1
    print(f"n_iter = {n_iter}, k_bar = {k_bar:.8f}, loss = {loss:.12f}, loss decrease = {loss_decrease:.12f}")
    print(f"n_iter = {n_iter}, k1 = {k1:.8f}, loss1 = {loss1:.8f}, k2 = {k2}, loss2 = {loss2:.8f}")

可以看到，总的loss是一直下降的，而且也收敛到了理论值theo_k，只是这个理论值不是整个问题的真实解

In [None]:
k_bar, theo_k, true_k

## FedSplit

In [None]:
# 初始值
k_bar = 0
s = 1e-2

z1, z2 = k_bar, k_bar
delta = np.inf
n_iter = 0

while n_iter < 1e5 and delta > 1e-9:
    tmp1 = prox1(2 * k_bar - z1, s)
    tmp2 = prox2(2 * k_bar - z2, s)
    z1 = z1 + 2 * (tmp1 - k_bar)
    z2 = z2 + 2 * (tmp2 - k_bar)
    new_k_bar = 0.5 * (z1 + z2)
    delta = abs(k_bar - new_k_bar)
    k_bar = new_k_bar
    n_iter += 1
    print(f"n_iter = {n_iter}, k_bar = {k_bar:.8f}")

In [None]:
k_bar, true_k