[![Open In Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/wenh06/fl_seminar/blob/master/code/fedprox-and-fedsplit.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wenh06/fl_seminar/blob/master/code/fedprox-and-fedsplit.ipynb)

# 问题

拟合$y = kx$, $k$为需要拟合的变量，用mean squared error (MSE)做loss

device1的点为$(0, 2), (1, 2)$,

device2的点为$(2, 0), (2, 1)$,

那么device1单独拟合结果为$k = 2$, device2单独拟合结果为$k = \frac{1}{4}$, 数据都拿一起拟合结果为$k = \frac{4}{9}$. 注意，这里一起拟合的结果并不等于单独拟合的均值，原因就在于loss function里的k是2次的，不是线性的。

## 先用符号计算一些东西

In [None]:
import sympy as sp
import pytest

In [None]:
k,z,s = sp.symbols("k,z,s")

In [None]:
func1 = 4 + (k-2)**2
func2 = 4 * k**2 + (2*k-1)**2
func = func1 + func2

grad_f1_ns = sp.diff(func1, k)**2
grad_f2_ns = sp.diff(func2, k)**2
grad_f_ns = sp.diff(func, k)**2

In [None]:
display(grad_f1_ns), display(grad_f2_ns), display(grad_f_ns);

In [None]:
B_square = sp.simplify((grad_f1_ns + grad_f2_ns) / 2 / grad_f_ns)
B_square

bounded dissimilarity:
for some $\epsilon > 0$, $\exists B_{\epsilon}$ s.t. $\forall w \in \{ w ~|~ \lVert \nabla f(w) > \epsilon \rVert \}$, $B(w) \leqslant B$.

In [None]:
sp.plot(B_square, ylim=(0, 2))

In [None]:
sp.diff(sp.diff(func1, k), k), sp.diff(sp.diff(func2, k), k)

In [None]:
prox_sf1 = func1 + (k-z)**2 / (2*s)
sp.simplify(prox_sf1)

In [None]:
prox_sf2 = func2 + (k-z)**2 / (2*s)
prox_sf2

In [None]:
m_sf1 = sp.simplify(prox_sf1.subs(k, (4*s+z) / (2*s+1)))
m_sf1

In [None]:
m_sf2 = sp.simplify(prox_sf2.subs(k, (4*s+z) / (16*s+1)))
m_sf2

In [None]:
theo_k = 4 * (9 * s + 1) / (32 * s + 9)
theo_k

In [None]:
sp.simplify(sp.diff(m_sf1, z).subs(z, theo_k) + sp.diff(m_sf2, z).subs(z, theo_k))

一些关于FedProx文章中的Theorem 4的观察：

1. 在这个例子下，简化起见令$\gamma = 0$, $B$不能太小，$L_-$近似为0, $L\geqslant 16$, 只有让$\mu$比较大，在这里就是$s$比较小（$s = \dfrac{1}{\mu}$）的时候，定理里的$\rho$才会大于0。

2. 在$\lVert \nabla f \rVert$的零点附近，如果这个零点没有被$\mathbb{E}_k[\lVert \nabla F_k \rVert]$cancle掉的话，$B$会急速趋向于无穷，导致在$\lVert \nabla f \rVert$的零点附近，$\rho > 0$的假设不再成立，那么定理中的不等式就变得无意义了。当device之间分布完全一致的时候（理想情况下），$B$恒为1，就不会有这个问题。这也是FedSplit文章里提到的。

## 数值计算

In [None]:
def f1(k):
    """loss 1"""
    return (k - 2)**2 + 4

def f2(k):
    """loss 2"""
    return 8 * k**2 - 4 * k + 1

def prox1(z, s):
    if np.isinf(s):
        return 2
    return (4 * s + z) / (2 * s + 1)

def prox2(z, s):
    if np.isinf(s):
        return 0.25
    return (4 * s + z) / (16 * s + 1)

In [None]:
true_k = 4 / 9

## FedProx

In [None]:
# 初始值
k_bar = 0
s = 1e-4

# k 的理论收敛值，注意，与 s 是相关的
theo_k = 4 * (9 * s + 1) / (32 * s + 9)

delta = np.inf
n_iter = 0

k1, k2 = k_bar, k_bar
loss1 = (k1-2)**2 + (k1-k_bar)**2 / (2*s)  # loss from 1
loss2 = (2*k2 - 1)**2 + (2*k2)**2 + (k2-k_bar)**2 / (2*s)  # loss from 2
loss = loss1 + loss2

print(f"init k_bar = {k_bar}, loss = {loss:.8f}")

while n_iter < 1e5 and delta > 1e-9:
    k1 = prox1(k_bar, s)
    k2 = prox2(k_bar, s)
    loss1 = (k1-2)**2 + (k1-k_bar)**2 / (2*s)  # loss from 1
    loss2 = (2*k2 - 1)**2 + (2*k2)**2 + (k2-k_bar)**2 / (2*s)  # loss from 2
    loss_decrease = loss - (loss1 + loss2)
    loss = loss1 + loss2
    new_k_bar = 0.5 * (k1 + k2)
    delta = abs(k_bar - new_k_bar)
    k_bar = new_k_bar
    n_iter += 1
    print(f"n_iter = {n_iter}, k_bar = {k_bar:.8f}, loss = {loss:.12f}, loss decrease = {loss_decrease:.12f}")
    print(f"n_iter = {n_iter}, k1 = {k1:.8f}, loss1 = {loss1:.8f}, k2 = {k2}, loss2 = {loss2:.8f}")
    print(f"n_iter = {n_iter}, grad_f_ns = {grad_f_ns.subs(k, k_bar)}")

可以看到，总的loss是一直下降的，而且也收敛到了理论值theo_k，只是这个理论值不是整个问题的真实解

In [None]:
k_bar, theo_k, true_k

In [None]:
k_bar == pytest.approx(theo_k, abs=1e-5)

## FedSplit

In [None]:
# 初始值
k_bar = 0
s = 1e-2

z1, z2 = k_bar, k_bar
delta = np.inf
n_iter = 0

while n_iter < 1e5 and delta > 1e-10:
    tmp1 = prox1(2 * k_bar - z1, s)
    tmp2 = prox2(2 * k_bar - z2, s)
    z1 = z1 + 2 * (tmp1 - k_bar)
    z2 = z2 + 2 * (tmp2 - k_bar)
    new_k_bar = 0.5 * (z1 + z2)
    delta = abs(k_bar - new_k_bar)
    k_bar = new_k_bar
    n_iter += 1
    print(f"n_iter = {n_iter}, k_bar = {k_bar:.16f}")

In [None]:
k_bar, true_k

In [None]:
k_bar == pytest.approx(true_k, abs=1e-5)