## Plotting

In [37]:
%matplotlib inline

import os
import sys
import torch
import hydra
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import seaborn as sns
sys.path.insert(0, "..")
from floral.dataset import get_data
from floral.floral import Floral
from floral.client import FloralClient
from floral.model import Router

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['figure.figsize'] = (12, 9)

TASK = "synthetic_linear"
# MODEL_DIR = f"/Users/zelig/Desktop/code/zeligism/FLoRAL-flower/outputs/test_{TASK}/id=test_clustering,lr=0.1/seed=0/"
MODEL_DIR = f"/Users/zelig/Desktop/code/zeligism/FLoRAL-flower/outputs/test_{TASK}/id=test_lr_lessdata,lr=0.1/seed=0/"

In [38]:
with hydra.initialize(version_base=None, config_path="../floral/conf"):
    cfg = hydra.compose(config_name="base", overrides=[f"task@_global_={TASK}"])
    cfg.task = TASK
    cfg.dataset.simple = "simple" in TASK
cfg

{'task': 'synthetic_linear', 'logdir': '???', 'show_cfg': False, 'wandb': False, 'is_rnn': False, 'experiment': 'experiment', 'identifier': 'identifier', 'num_rounds': 1000, 'local_epochs': 1, 'model': {'_target_': 'torch.nn.Linear', 'in_features': '${dataset.dim}', 'out_features': '${dataset.dim_out}'}, 'dataset': {'_target_': 'floral.dataset.SyntheticDataset', 'linear': True, 'simple': False, 'data_path': 'data', 'num_clients': 10, 'num_clusters': 2, 'samples_per_client': '???', 'dim': 10, 'dim_out': 3, 'uv_constant': 2.0, 'rank': 1, 'label_noise_std': 0.0}, 'deterministic': True, 'seed': 0, 'batch_size': 4, 'test_batch_size': 128, 'train_proportion': 0.8, 'dataloader': {'num_workers': 0}, 'task_dir_prefix': '', 'lr': 0.1, 'lora_lr': '${lr}', 'router_lr': '${lr}', 'router_entropy': 0.0, 'lora_penalty': 0.0, 'weight_decay': 0.0, 'task_dir': '${task_dir_prefix}id=${identifier},lr=${lr}', 'floral': {'num_clients': '${dataset.num_clients}', 'num_clusters': '${dataset.num_clusters}', 'ran

In [39]:
model = hydra.utils.instantiate(cfg.model)
model = Floral(model, **cfg.floral)
state_dict = torch.load(os.path.join(MODEL_DIR, 'model.pt'))
model.load_state_dict(state_dict)
model

Floral(
  (base_model): Linear(in_features=10, out_features=3, bias=True)
  (router): FloralRouter()
  (lora_modules): ModuleDict(
    (/): LoraLinearExperts()
  )
)

In [40]:
list(model.lora_modules.keys())

['/']

In [41]:
model.base_model.weight.round().int()

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0, -1, -1],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  1,  0,  0,  0,  0,  0,  0,  0,  0]], dtype=torch.int32)

In [42]:
torch.round(model.base_model.weight, decimals=2)

tensor([[ 0.0200, -0.0100, -0.1300, -0.2400, -0.1300,  0.0100, -0.2800,  0.1100,
         -0.6300, -0.6500],
        [-0.1600,  0.1300,  0.1900, -0.4200, -0.1400,  0.2800,  0.0200,  0.0200,
          0.0600, -0.0800],
        [ 0.0600,  0.7300, -0.1500,  0.2900, -0.1700, -0.1500,  0.2000,  0.0300,
         -0.1200, -0.2900]], grad_fn=<RoundBackward1>)

In [43]:
model.lora_modules["/"].fuse().round().int()

tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]], dtype=torch.int32)

In [44]:
display(model.lora_modules["/"].weight_in.round().int())
display(model.lora_modules["/"].weight_out.round().int())

tensor([[[ 0, -1, -1,  0,  1,  0,  0,  0,  0,  0]],

        [[-1, -1,  0,  0,  0,  0,  0,  0,  1,  0]]], dtype=torch.int32)

tensor([[[0],
         [0],
         [0]],

        [[0],
         [0],
         [0]]], dtype=torch.int32)

In [45]:
torch.round(model.lora_modules["/"].fuse(), decimals=2)

tensor([[[-0., -0., -0., -0., 0., 0., 0., 0., -0., 0.],
         [-0., -0., -0., -0., 0., 0., 0., 0., -0., 0.],
         [-0., -0., -0., -0., 0., 0., 0., 0., -0., 0.]],

        [[-0., -0., -0., 0., 0., -0., -0., 0., 0., 0.],
         [-0., -0., -0., 0., 0., -0., -0., 0., 0., 0.],
         [-0., -0., -0., 0., 0., -0., -0., 0., 0., 0.]]],
       grad_fn=<RoundBackward1>)

In [46]:
router_weights = Router.load_router_weights(os.path.join(MODEL_DIR, 'pvt'))
torch.stack(router_weights).softmax(-1).mul(100).round()

tensor([[100.,   0.],
        [  0., 100.],
        [100.,   0.],
        [  0., 100.],
        [100.,   0.],
        [  0., 100.],
        [100.,   0.],
        [  0., 100.],
        [100.,   0.],
        [  0., 100.]])

# Evaluate

In [47]:
cfg.dataset.uv_constant = 0.0
cfg.floral.constant = 0.0
data_loaders, _ = get_data(cfg)
train_loaders, test_loaders = zip(*data_loaders)

In [48]:
from floral.utils import init_device, evaluate
loss_fn = hydra.utils.instantiate(cfg.loss_fn)

with torch.no_grad():
    for i in range(cfg.floral.num_clients):
        # get model
        model = hydra.utils.instantiate(cfg.model)
        model = Floral(model, **cfg.floral)
        model.client_id = i
        state_dict = torch.load(os.path.join(MODEL_DIR, 'model.pt'))
        state_dict["router.weight"] = router_weights[i].clone().detach()
        model.load_state_dict(state_dict)
        # evaluate
        metrics = evaluate(model, loss_fn, init_device(), test_loaders[i], i)
        print("Loss =", metrics["loss"].get_avg())


INFO flwr 2024-03-16 23:14:04,255 | training.py:60 | Test | Client 0: [1/1] loss=3.0597	acc=0.0000
INFO flwr 2024-03-16 23:14:04,259 | training.py:60 | Test | Client 1: [1/1] loss=1.9917	acc=0.0000
INFO flwr 2024-03-16 23:14:04,263 | training.py:60 | Test | Client 2: [1/1] loss=1.0573	acc=0.0000
INFO flwr 2024-03-16 23:14:04,268 | training.py:60 | Test | Client 3: [1/1] loss=2.5298	acc=0.0000
INFO flwr 2024-03-16 23:14:04,274 | training.py:60 | Test | Client 4: [1/1] loss=0.4687	acc=0.0000
INFO flwr 2024-03-16 23:14:04,279 | training.py:60 | Test | Client 5: [1/1] loss=1.9839	acc=0.0000
INFO flwr 2024-03-16 23:14:04,283 | training.py:60 | Test | Client 6: [1/1] loss=2.4523	acc=0.0000
INFO flwr 2024-03-16 23:14:04,288 | training.py:60 | Test | Client 7: [1/1] loss=2.2146	acc=0.0000
INFO flwr 2024-03-16 23:14:04,292 | training.py:60 | Test | Client 8: [1/1] loss=2.5278	acc=0.0000
INFO flwr 2024-03-16 23:14:04,297 | training.py:60 | Test | Client 9: [1/1] loss=0.5341	acc=0.0000


Loss = 3.0597381591796875
Loss = 1.9916549921035767
Loss = 1.0573395490646362
Loss = 2.5298025608062744
Loss = 0.46870318055152893
Loss = 1.9838873147964478
Loss = 2.4523427486419678
Loss = 2.2145512104034424
Loss = 2.5277822017669678
Loss = 0.5340678095817566


In [56]:
dataset = train_loaders[0].dataset.dataset.fl_dataset

### Base

In [77]:
W = dataset.W[0, :, :dataset.dim_out]
biases = dataset.biases[:, :dataset.samples_per_client, :]

In [72]:
print(W)
print(torch.round(model.base_model.weight.T, decimals=2))

tensor([[-5.9892e-01, -8.4366e-02,  9.8465e-02],
        [ 2.9673e-01,  7.3350e-01,  8.4598e-03],
        [ 3.2082e-01,  9.7339e-02,  5.2858e-01],
        [-3.0456e-01,  1.9034e-01,  1.0540e-01],
        [-1.9617e-02,  1.0636e-01,  1.1560e-01],
        [-5.0183e-01, -5.1696e-01, -4.8517e-02],
        [ 4.4820e-04,  1.3353e-01,  2.0249e-01],
        [ 5.4738e-02, -3.0238e-01,  5.7204e-01],
        [-1.9363e-01,  4.9456e-01, -3.4126e-01],
        [ 1.2901e-01,  1.8344e-01, -1.8919e-01]])
tensor([[ 0.0200, -0.1600,  0.0600],
        [-0.0100,  0.1300,  0.7300],
        [-0.1300,  0.1900, -0.1500],
        [-0.2400, -0.4200,  0.2900],
        [-0.1300, -0.1400, -0.1700],
        [ 0.0100,  0.2800, -0.1500],
        [-0.2800,  0.0200,  0.2000],
        [ 0.1100,  0.0200,  0.0300],
        [-0.6300,  0.0600, -0.1200],
        [-0.6500, -0.0800, -0.2900]], grad_fn=<RoundBackward1>)


In [None]:
print(W)
print(torch.round(model.base_model.weight.T, decimals=2))

In [75]:
x = torch.randn(3, dataset.dim)
h_true = x.matmul(W)
h_hat = x.matmul(model.base_model.weight.T)
print(torch.round(h_true, decimals=2))
print(torch.round(h_hat, decimals=2))

tensor([[ 0.3700,  0.5300, -0.2200],
        [ 1.2700,  3.5800, -0.3800],
        [-1.9200, -2.8400,  0.7200]])
tensor([[ 0.3600,  0.1400,  0.5100],
        [-0.5400, -0.9300,  2.7200],
        [ 2.2200, -1.2000, -1.1300]], grad_fn=<RoundBackward1>)


### Lora

In [60]:
for cluster_id in range(cfg.floral.num_clusters):
    print("Cluster =", cluster_id)
    uv = dataset.Ws[cluster_id] - dataset.W[0]
    uv = uv[:, :cfg.dataset.dim_out]
    uv_hats = model.lora_modules['/'].fuse().transpose(1,2)
    print(uv)
    print(torch.round(uv_hats[cluster_id], decimals=2))
    print("Error =", (uv.unsqueeze(0) - uv_hats).flatten(1).pow(2).sum(1).sqrt().min().item())
    print()

Cluster = 0
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
tensor([[-0., -0., -0.],
        [-0., -0., -0.],
        [-0., -0., -0.],
        [-0., -0., -0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [-0., -0., -0.],
        [0., 0., 0.]], grad_fn=<RoundBackward1>)
Error = 0.0

Cluster = 1
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
tensor([[-0., -0., -0.],
        [-0., -0., -0.],
        [-0., -0., -0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [-0., -0., -0.],
        [-0., -0., -0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], grad_fn=<RoundBackward1

In [61]:
for cluster_id in range(cfg.floral.num_clusters):
    print("Cluster =", cluster_id)
    h_true = x.matmul(dataset.Ws[cluster_id] - dataset.W)
    h_hat = x.matmul(model.lora_modules['/'].fuse()[cluster_id].T)
    print(torch.round(h_true, decimals=2))
    print(torch.round(h_hat, decimals=2))
    print()

Cluster = 0
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], grad_fn=<RoundBackward1>)

Cluster = 1
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], grad_fn=<RoundBackward1>)



In [62]:
probs = torch.stack(list(router_weights.values())).softmax(-1)
# loras = torch.einsum("kc,c...->k...", (probs, model.lora_modules['/'].fuse()))
# signs = loras.sign().prod(0)
# gmean = loras.abs().add(1e-10).log().mean(dim=0).exp()
# gmean[signs <= 0] = 0.
loras = model.lora_modules['/']
lora = loras.fuse()
signs = lora.sign().prod(0)
lora_gmean = lora.abs().add(1e-10).log().mean(dim=0).exp()
lora_gmean[signs <= 0] = 0.
torch.round(lora_gmean, decimals=2)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<RoundBackward1>)