In [12]:

import os
from typing import List, Dict


from omegaconf import DictConfig, OmegaConf

from flp2p.client import FLClient
from flp2p.data import build_client_loaders, get_dataset
from flp2p.graph_runner import run_rounds
from flp2p.networks.lenet5 import LeNet5
from flp2p.networks.resnet18 import make_resnet18
import logging
import pickle
import random
import numpy as np
import torch
from flp2p.utils import plot_topology, build_topology

In [None]:
from hydra import compose, initialize

with initialize(version_base=None, config_path="conf", job_name="test_app"):
    cfg = compose(config_name="config")

{'data': {'name': 'cifar10', 'root': './data', 'batch_size': 10, 'num_workers': 0}, 'model': {'name': 'resnet18', 'num_classes': 10, 'pretrained': False}, 'partition': {'name': 'dirichlet', 'num_clients': 80, 'strategy': 'dirichlet', 'dirichlet_alpha': 0.1, 'min_partition_size': 1}, 'client': {'learning_rate': 0.0001, 'weight_decay': 0.0005, 'momentum': 0}, 'train': {'rounds': 100, 'local_epochs': 1, 'progress': True, 'participation_rate': 1, 'lr_decay': 0}, 'graph': {'name': 'two_clusters', 'topology': 'two_clusters'}, 'seed': 42, 'use_cuda': True, 'mixing_matrix': 'metropolis_hasting', 'run_name': 'main_link_activation', 'consensus_lr': 0.001, 'old_gradients': False, 'same_distrib_test_set': True, 'decrease_consensus': True, 'main_link_activation': 0.4, 'gossip_epochs': 1}

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() and cfg.use_cuda else "cpu")
base_model = make_resnet18(cfg.model).to(device)
init_state = base_model.state_dict()

# Model + Clients
clients: List[FLClient] = []
for i in range(cfg.partition.num_clients):
    if cfg.model.name == "lenet5":
        model = LeNet5(cfg.model).to(device)
        model.load_state_dict(init_state)
    elif cfg.model.name == "resnet18":
        model = make_resnet18(cfg.model).to(device)
        model.load_state_dict(init_state)
    else:
        raise ValueError(f"Unknown model: {cfg.model.name}")
    train_loader, test_loader = None, None
    client = FLClient(
        model=model,
        device=device,
        train_loader=train_loader,
        test_loader=test_loader,
        config=cfg.client
    )
    clients.append(client)

