In [1]:
import os
from pathlib import Path
import time
import yaml
from typing import Dict, List, Union

In [2]:
def generate_envoy_configs(config: Dict,
                           save_path: Union[str, Path] = '../envoy/',
                           n_cols_total: int = 10,
                           n_cols_corrupt: int = 1,
                           assign_gpus: bool = False,
                           image_hw: str = '300,400') -> List[Path]:
    
    config['shard_descriptor']['params']['enforce_image_hw'] = image_hw
    
    config_paths = [(Path(save_path) / f'{i}_envoy_config.yaml').absolute()
                for i in range(1, n_cols_total + 1)]

    for i, path in enumerate(config_paths):
        if assign_gpus:
            config['params']['cuda_devices'] = [i,]
        else:
            config['params']['cuda_devices'] = []
        if i < n_cols_corrupt:
            config['shard_descriptor']['params']['corrupt'] = True
        else:
            config['shard_descriptor']['params']['corrupt'] = False
        config['shard_descriptor']['params']['rank_worldsize'] = f'{i+1},{n_cols_total}'
        with open(path, "w") as stream:
            yaml.safe_dump(config, stream)
            
    return config_paths
            
def remove_configs(config_paths):
    for path in config_paths:
        path.unlink()

## Start the Director service

In [3]:
# cwd = Path.cwd()
# director_workspace_path = Path('../director/').absolute()
# director_config_file = director_workspace_path / 'director_config.yaml'
# director_logfile = director_workspace_path / 'director.log'
# director_logfile.unlink(missing_ok=True)
# # 

# os.environ['main_folder'] = str(cwd)
# os.environ['director_workspace_path'] = str(director_workspace_path)
# os.environ['director_logfile'] = str(director_logfile)
# os.environ['director_config_file'] = str(director_config_file)

In [4]:
# %%script /bin/bash --bg
# cd $director_workspace_path
# fx director start --disable-tls -c $director_config_file > $director_logfile &
# cd $main_folder

In [5]:
def start_director():
    cwd = Path.cwd()
    director_workspace_path = Path('../director/').absolute()
    os.chdir(director_workspace_path)
    director_config_file = director_workspace_path / 'director_config.yaml'
    director_logfile = director_workspace_path / 'director.log'
    director_logfile.unlink(missing_ok=True)
    os.system('fx director start --disable-tls '
              f'-c {director_config_file} > {director_logfile} &')
    os.chdir(cwd)

## Start Envoys

In [6]:
# envoy_workspace_path = Path('../envoy/').absolute()
"""Make sure the Director port matches one from the Director config file"""
def start_envoys(config_paths: List[Path]) -> None:
    envoy_workspace_path = config_paths[0].parent
    cwd = Path.cwd()
    os.chdir(envoy_workspace_path)
    for i, path in enumerate(config_paths):
        os.system(f'fx envoy start -n env_{i + 1} --disable-tls '
                  f'--envoy-config-path {path} -dh localhost -dp 50050 '
                  f'>env_{i + 1}.log &')
    os.chdir(cwd)
    
# start_envoys(config_paths)

In [7]:
def start_federation(envoy_config_path='../envoy/envoy_config.yaml',
                    n_cols_total=3, n_cols_corrupt=0, deacticate_cols_indeces=(),
                    assign_gpus=True):
    # Read the original envoy config file content
    with open(Path(envoy_config_path), "r") as stream:
        orig_config = yaml.safe_load(stream)
    # Write new configs
    config_paths = generate_envoy_configs(orig_config, n_cols_total=n_cols_total,
                                          n_cols_corrupt=n_cols_corrupt, assign_gpus=assign_gpus)
    for idx in deacticate_cols_indeces:
        del config_paths[idx]
    # Start Director and Envoys 
    start_director()
    time.sleep(2)
    start_envoys(config_paths)
    time.sleep(2)
#     remove_configs(config_paths)

## Run experiments

In [8]:
# from openfl.component.aggregation_functions import AggregationFunction
# import numpy as np
# class One_Good_Envoy(AggregationFunction):
#     def __init__(self, col_name='env_3', weight_scale: float = 0.5):
#         self.good_col = col_name
#         self.weight_scale = weight_scale

#     def call(self, local_tensors, *_) -> np.ndarray:
#         weights = [x.weight if self.good_col == x.col_name else x.weight * self.weight_scale
#                    for x in local_tensors]
#         tensors = np.array([x.tensor for x in local_tensors])
#         return np.average(tensors, weights=weights, axis=0)

In [9]:
from testbook import testbook


def start_experiment_from_ipynb(experiment_name, notebook_path='./PyTorch_Kvasir_UNet.ipynb',
                                n_rounds=40, train_batch_size=6, lr=1e-4, weight_decay=0,
                               opt_treatment='CONTINUE_GLOBAL'):
    command = (
         'torch.manual_seed(0) \n'
         'model_unet = UNet() \n'
        f'optimizer_adam = optim.Adam(model_unet.parameters(), lr={lr}, weight_decay={weight_decay}) \n'
         'MI = ModelInterface(model=model_unet, optimizer=optimizer_adam, framework_plugin=framework_adapter) \n'
        f'fed_dataset = KvasirSD(train_bs={train_batch_size}, valid_bs=8) \n'
        f'fl_experiment = FLExperiment(federation=federation, experiment_name="{experiment_name}") \n'
         'fl_experiment.start(model_provider=MI, task_keeper=TI, data_loader=fed_dataset, \n'
                       f'rounds_to_train={n_rounds}, \n'
                       f"opt_treatment='{opt_treatment}', \n"
                        "device_assignment_policy='CUDA_PREFERRED') \n"
         'fl_experiment.stream_metrics()'
    )
    # We may execute only some cells, i.e. range(0,23)
    # We should not execute `start_experiment`
    with testbook(notebook_path, execute=True, timeout=None) as tb:
        tb.inject(command, pop=True)

## Stop the Federation and Clean up

In [18]:
# To stop all services run
# !pkill fx
os.system('pkill fx')

256

# Run all together 

In [19]:
n_cols_total=3
n_cols_corrupt=0
n_rounds=50
train_batch_size=6
lr=1.5e-4
weight_decay=1e-6
for n_cols_corrupt in [2]:
    try:
        start_federation(n_cols_total=n_cols_total, n_cols_corrupt=n_cols_corrupt,
                        deacticate_cols_indeces=(0,1))
        experiment_name = (
            f'KVSR_rounds{n_rounds}_bs{train_batch_size}_'
            f'off{n_cols_corrupt}_of{n_cols_total}'
        )
        start_experiment_from_ipynb(experiment_name=experiment_name,
                                n_rounds=n_rounds, train_batch_size=train_batch_size,
                               lr=lr, weight_decay=weight_decay)
    finally:
        os.system('pkill fx')
    

File ‘kvasir_data/kvasir.zip’ already there; not retrieving.
You should consider upgrading via the '/home/idavidyu/.virtualenvs/corrupt-envoy/bin/python -m pip install --upgrade pip' command.
  new_state[k] = pt.from_numpy(tensor_dict.pop(k)).to(device)
train: 100%|██████████| 49/49 [00:10<00:00,  4.56it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.11it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.04it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.53it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.02it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.22it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.52it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.01it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.11it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.50it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.13it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.95it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.51it/s]
validate: 100%|████████

validate: 100%|██████████| 6/6 [00:01<00:00,  4.82it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.50it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.99it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.10it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.50it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.85it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.07it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.47it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.02it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.80it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.11it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.08it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.99it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.94it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]


validate: 100%|██████████| 6/6 [00:01<00:00,  5.12it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.21it/s]
train: 100%|██████████| 49/49 [00:11<00:00,  4.44it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.75it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.24it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.08it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.09it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.10it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.93it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.91it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.01it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.92it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.00it/s]


train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:10<00:00,  1.68s/it]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.95it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.81it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.10it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.47it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.97it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.19it/s]
train: 100%|██████████| 49/49 [00:11<00:00,  4.45it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.08it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.93it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.93it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.07it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.88it/s]


validate: 100%|██████████| 6/6 [00:01<00:00,  5.13it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.92it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.77it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.03it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.31it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.95it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.08it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.50it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.98it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.34it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.51it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.93it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.91it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]


validate: 100%|██████████| 6/6 [00:01<00:00,  4.89it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.87it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.08it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.03it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.02it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.11it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.98it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.25it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.94it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.14it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.91it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.06it/s]


train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.77it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.09it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.90it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.17it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.76it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.18it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.88it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.04it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.87it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.99it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.06it/s]


validate: 100%|██████████| 6/6 [00:01<00:00,  5.11it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.80it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.84it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.70it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.25it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.10it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.02it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.90it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.16it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.00it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.05it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.50it/s]


validate: 100%|██████████| 6/6 [00:01<00:00,  5.29it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.02it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.50it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.88it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.39it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.48it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.91it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.11it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.03it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.96it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.47it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.95it/s]
validate: 100%|██████████| 6/6 [00:01<00:00,  5.05it/s]
train: 100%|██████████| 49/49 [00:10<00:00,  4.49it/s]
validate: 100%|██████████| 6/6 [00:10<00:00,  1.76s/it]
validate: 100%|██████████| 6/6 [00:01<00:00,  4.97it/s]

In [12]:
[1,2,3,4][0]

1