In [4]:
from copy import deepcopy
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import torchvision
import numpy as np
from openfl.experimental.interface import FLSpec, Aggregator, Collaborator
from openfl.experimental.runtime import LocalRuntime
from openfl.experimental.placement import aggregator, collaborator

n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

mnist_train = torchvision.datasets.MNIST('files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

mnist_test = torchvision.datasets.MNIST('files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

def FedAvg(models):
    new_model = models[0]
    state_dicts = [model.state_dict() for model in models]
    state_dict = new_model.state_dict()
    for key in models[1].state_dict():
        state_dict[key] = np.sum([state[key] for state in state_dicts],axis=0) / len(models)
    new_model.load_state_dict(state_dict)
    return new_model

def inference(network,test_loader):
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
      for data, target in test_loader:
        output = network(data)
        test_loss += F.nll_loss(output, target, size_average=False).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
      test_loss, correct, len(test_loader.dataset),
      100. * correct / len(test_loader.dataset)))
    accuracy = float(correct / len(test_loader.dataset))
    return accuracy

In [5]:
class FederatedFlow(FLSpec):

    def __init__(self, model = None, optimizer = None, **kwargs):
        super().__init__(**kwargs)
        if model is not None:
            self.model = model
            self.optimizer = optimizer
        else:
            self.model = Net()
            self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,
                                   momentum=momentum)

    @aggregator
    def start(self):
        print(f'Performing initialization for model')
        self.collaborators = self.runtime.collaborators
        self.private = 10
        self.next(self.aggregated_model_validation,foreach='collaborators',exclude=['private'])

    @collaborator
    def aggregated_model_validation(self):
        print(f'Performing aggregated model validation for collaborator {self.input}')
        self.agg_validation_score = inference(self.model,self.test_loader)
        print(f'{self.input} value of {self.agg_validation_score}')
        self.next(self.train)

    @collaborator
    def train(self):
        self.model.train()
        self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,
                                   momentum=momentum)
        train_losses = []
        for batch_idx, (data, target) in enumerate(self.train_loader):
          self.optimizer.zero_grad()
          output = self.model(data)
          loss = F.nll_loss(output, target)
          loss.backward()
          self.optimizer.step()
          if batch_idx % log_interval == 0:
            print('Train Epoch: 1 [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
               batch_idx * len(data), len(self.train_loader.dataset),
              100. * batch_idx / len(self.train_loader), loss.item()))
            self.loss = loss.item()
            torch.save(self.model.state_dict(), 'model.pth')
            torch.save(self.optimizer.state_dict(), 'optimizer.pth')
        self.training_completed = True
        self.next(self.local_model_validation)

    @collaborator
    def local_model_validation(self):
        self.local_validation_score = inference(self.model,self.test_loader)
        print(f'Doing local model validation for collaborator {self.input}: {self.local_validation_score}')
        self.next(self.join, exclude=['training_completed'])

    @aggregator
    def join(self,inputs):
        self.average_loss = sum(input.loss for input in inputs)/len(inputs)
        self.aggregated_model_accuracy = sum(input.agg_validation_score for input in inputs)/len(inputs)
        self.local_model_accuracy = sum(input.local_validation_score for input in inputs)/len(inputs)
        print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')
        print(f'Average training loss = {self.average_loss}')
        print(f'Average local model validation values = {self.local_model_accuracy}')
        self.model = FedAvg([input.model for input in inputs])
        self.optimizer = [input.optimizer for input in inputs][0]
        self.next(self.end)
        
    @aggregator
    def end(self):
        print(f'This is the end of the flow')  

Aggregator step "start" registered
Collaborator step "aggregated_model_validation" registered
Collaborator step "train" registered
Collaborator step "local_model_validation" registered
Aggregator step "join" registered
Aggregator step "end" registered


In [6]:
# Setup participants
aggregator = Aggregator()
aggregator.private_attributes = {}

# Setup collaborators with private attributes
collaborator_names = ['Portland', 'Seattle', 'Chandler','Bangalore']
collaborators = [Collaborator(name=name) for name in collaborator_names]
for idx, collaborator in enumerate(collaborators):
    local_train = deepcopy(mnist_train)
    local_test = deepcopy(mnist_test)
    local_train.data = mnist_train.data[idx::len(collaborators)]
    local_train.targets = mnist_train.targets[idx::len(collaborators)]
    local_test.data = mnist_test.data[idx::len(collaborators)]
    local_test.targets = mnist_test.targets[idx::len(collaborators)]
    collaborator.private_attributes = {
            'train_loader': torch.utils.data.DataLoader(local_train,batch_size=batch_size_train, shuffle=True),
            'test_loader': torch.utils.data.DataLoader(local_test,batch_size=batch_size_train, shuffle=True)
    }

local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators)
print(f'Local runtime collaborators = {local_runtime._collaborators}')

model = None
best_model = None
optimizer = None
top_model_accuracy = 0
for i in range(2):
    print(f'Starting round {i}...')
    flflow = FederatedFlow(model,optimizer,checkpoint=True)
    flflow.runtime = local_runtime
    flflow.run()
    model = flflow.model
    optimizer = flflow.optimizer
    aggregated_model_accuracy = flflow.aggregated_model_accuracy
    if aggregated_model_accuracy > top_model_accuracy:
        print(f'Accuracy improved to {aggregated_model_accuracy} for round {i}')
        top_model_accuracy = aggregated_model_accuracy

Local runtime collaborators = {'Portland': <openfl.experimental.interface.participants.Collaborator object at 0x7f28d0e7ed00>, 'Seattle': <openfl.experimental.interface.participants.Collaborator object at 0x7f28d0e7eeb0>, 'Chandler': <openfl.experimental.interface.participants.Collaborator object at 0x7f28d0e7ef70>, 'Bangalore': <openfl.experimental.interface.participants.Collaborator object at 0x7f28d0e7e520>}
Starting round 0...
Created flow FederatedFlow

Calling start
Performing initialization for model
Saving data artifacts for start
Saved data artifacts for start
No transition point detected!
foreach_methods = []
Sending state from aggregator to collaborators




[2m[36m(wrapper pid=1687134)[0m 
[2m[36m(wrapper pid=1687134)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=1687134)[0m 
[2m[36m(wrapper pid=1687134)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=1687134)[0m Performing aggregated model validation for collaborator Portland
[2m[36m(wrapper pid=1687143)[0m 
[2m[36m(wrapper pid=1687143)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=1687143)[0m 
[2m[36m(wrapper pid=1687143)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=1687143)[0m Performing aggregated model validation for collaborator Seattle
[2m[36m(wrapper pid=1687134)[0m 
[2m[36m(wrapper pid=1687134)[0m Test set: Avg. loss: 2.3264, Accuracy: 309/2500 (12%)
[2m[36m(wrapper pid=1687134)[0m 
[2m[36m(wrapper pid=1687134)[0m Portland value of 0.12359999865293503
[2m[36m(wrapper pid=1687134)[0m Saving data artifacts for aggregated_model_validation




[2m[36m(wrapper pid=1687137)[0m 
[2m[36m(wrapper pid=1687137)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=1687137)[0m 
[2m[36m(wrapper pid=1687137)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=1687137)[0m Performing aggregated model validation for collaborator Chandler
[2m[36m(wrapper pid=1687143)[0m 
[2m[36m(wrapper pid=1687143)[0m Test set: Avg. loss: 2.3319, Accuracy: 272/2500 (11%)
[2m[36m(wrapper pid=1687143)[0m 
[2m[36m(wrapper pid=1687143)[0m Seattle value of 0.1088000014424324
[2m[36m(wrapper pid=1687143)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1687135)[0m 
[2m[36m(wrapper pid=1687135)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=1687135)[0m 
[2m[36m(wrapper pid=1687135)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=1687135)[0m Performing aggregated model validation for collaborator Bangalore




[2m[36m(wrapper pid=1687137)[0m 
[2m[36m(wrapper pid=1687137)[0m Test set: Avg. loss: 2.3338, Accuracy: 284/2500 (11%)
[2m[36m(wrapper pid=1687137)[0m 
[2m[36m(wrapper pid=1687137)[0m Chandler value of 0.1136000007390976
[2m[36m(wrapper pid=1687137)[0m Saving data artifacts for aggregated_model_validation




[2m[36m(wrapper pid=1687134)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1687134)[0m No transition point detected!
[2m[36m(wrapper pid=1687134)[0m foreach_methods = ['aggregated_model_validation']
[2m[36m(wrapper pid=1687134)[0m 
[2m[36m(wrapper pid=1687134)[0m Calling train
[2m[36m(wrapper pid=1687135)[0m 
[2m[36m(wrapper pid=1687135)[0m Test set: Avg. loss: 2.3345, Accuracy: 272/2500 (11%)
[2m[36m(wrapper pid=1687135)[0m 
[2m[36m(wrapper pid=1687135)[0m Bangalore value of 0.1088000014424324
[2m[36m(wrapper pid=1687135)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1687143)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1687143)[0m No transition point detected!
[2m[36m(wrapper pid=1687143)[0m foreach_methods = ['aggregated_model_validation']
[2m[36m(wrapper pid=1687143)[0m 
[2m[36m(wrapper pid=1687143)[0m Calling train




[2m[36m(wrapper pid=1687137)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1687137)[0m No transition point detected!
[2m[36m(wrapper pid=1687137)[0m foreach_methods = ['aggregated_model_validation']
[2m[36m(wrapper pid=1687137)[0m 
[2m[36m(wrapper pid=1687137)[0m Calling train








[2m[36m(wrapper pid=1687135)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1687135)[0m No transition point detected!
[2m[36m(wrapper pid=1687135)[0m foreach_methods = ['aggregated_model_validation']
[2m[36m(wrapper pid=1687135)[0m 
[2m[36m(wrapper pid=1687135)[0m Calling train
[2m[36m(wrapper pid=1687134)[0m Saving data artifacts for train
[2m[36m(wrapper pid=1687134)[0m Saved data artifacts for train
[2m[36m(wrapper pid=1687134)[0m No transition point detected!
[2m[36m(wrapper pid=1687134)[0m foreach_methods = ['aggregated_model_validation']
[2m[36m(wrapper pid=1687134)[0m 
[2m[36m(wrapper pid=1687134)[0m Calling local_model_validation




[2m[36m(wrapper pid=1687143)[0m Saving data artifacts for train
[2m[36m(wrapper pid=1687143)[0m Saved data artifacts for train
[2m[36m(wrapper pid=1687143)[0m No transition point detected!
[2m[36m(wrapper pid=1687143)[0m foreach_methods = ['aggregated_model_validation']
[2m[36m(wrapper pid=1687143)[0m 
[2m[36m(wrapper pid=1687143)[0m Calling local_model_validation




[2m[36m(wrapper pid=1687134)[0m 
[2m[36m(wrapper pid=1687134)[0m Test set: Avg. loss: 0.7068, Accuracy: 2050/2500 (82%)
[2m[36m(wrapper pid=1687134)[0m 
[2m[36m(wrapper pid=1687134)[0m Doing local model validation for collaborator Portland: 0.8199999928474426
[2m[36m(wrapper pid=1687134)[0m Saving data artifacts for local_model_validation
[2m[36m(wrapper pid=1687137)[0m Saving data artifacts for train
[2m[36m(wrapper pid=1687134)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=1687134)[0m No transition point detected!
[2m[36m(wrapper pid=1687134)[0m foreach_methods = ['aggregated_model_validation']
[2m[36m(wrapper pid=1687134)[0m Sending state from collaborator to aggregator
[2m[36m(wrapper pid=1687134)[0m 
[2m[36m(wrapper pid=1687134)[0m Calling join


RayTaskError(TypeError): [36mray::wrapper()[39m (pid=1687134, ip=10.4.0.4)
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/placement/placement.py", line 18, in wrapper
    f()
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/placement/placement.py", line 77, in wrapper
    f(*args, **kwargs)
  File "/tmp/ipykernel_1686978/1826581246.py", line 25, in aggregated_model_validation
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/interface/fl_spec.py", line 310, in next
    self.execute_task(cln, f, parent_func, **kwargs)
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/interface/fl_spec.py", line 277, in execute_task
    to_exec()
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/placement/placement.py", line 77, in wrapper
    f(*args, **kwargs)
  File "/tmp/ipykernel_1686978/1826581246.py", line 47, in train
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/interface/fl_spec.py", line 310, in next
    self.execute_task(cln, f, parent_func, **kwargs)
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/interface/fl_spec.py", line 277, in execute_task
    to_exec()
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/placement/placement.py", line 77, in wrapper
    f(*args, **kwargs)
  File "/tmp/ipykernel_1686978/1826581246.py", line 53, in local_model_validation
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/interface/fl_spec.py", line 310, in next
    self.execute_task(cln, f, parent_func, **kwargs)
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/interface/fl_spec.py", line 277, in execute_task
    to_exec()
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/placement/placement.py", line 48, in wrapper
    f(*args, **kwargs)
TypeError: join() missing 1 required positional argument: 'inputs'

[2m[36m(wrapper pid=1687137)[0m Saved data artifacts for train
[2m[36m(wrapper pid=1687137)[0m No transition point detected!
[2m[36m(wrapper pid=1687137)[0m foreach_methods = ['aggregated_model_validation']
[2m[36m(wrapper pid=1687137)[0m 
[2m[36m(wrapper pid=1687137)[0m Calling local_model_validation
[2m[36m(wrapper pid=1687143)[0m 
[2m[36m(wrapper pid=1687143)[0m Test set: Avg. loss: 0.6056, Accuracy: 2104/2500 (84%)
[2m[36m(wrapper pid=1687143)[0m 
[2m[36m(wrapper pid=1687143)[0m Doing local model validation for collaborator Seattle: 0.8416000008583069
[2m[36m(wrapper pid=1687143)[0m Saving data artifacts for local_model_validation




[2m[36m(wrapper pid=1687135)[0m Saving data artifacts for train
[2m[36m(wrapper pid=1687143)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=1687143)[0m No transition point detected!
[2m[36m(wrapper pid=1687143)[0m foreach_methods = ['aggregated_model_validation']
[2m[36m(wrapper pid=1687143)[0m Sending state from collaborator to aggregator
[2m[36m(wrapper pid=1687143)[0m 
[2m[36m(wrapper pid=1687143)[0m Calling join
[2m[36m(wrapper pid=1687135)[0m Saved data artifacts for train
[2m[36m(wrapper pid=1687135)[0m No transition point detected!
[2m[36m(wrapper pid=1687135)[0m foreach_methods = ['aggregated_model_validation']
[2m[36m(wrapper pid=1687135)[0m 
[2m[36m(wrapper pid=1687135)[0m Calling local_model_validation




[2m[36m(wrapper pid=1687137)[0m 
[2m[36m(wrapper pid=1687137)[0m Test set: Avg. loss: 0.7255, Accuracy: 1983/2500 (79%)
[2m[36m(wrapper pid=1687137)[0m 
[2m[36m(wrapper pid=1687137)[0m Doing local model validation for collaborator Chandler: 0.7932000160217285
[2m[36m(wrapper pid=1687137)[0m Saving data artifacts for local_model_validation
[2m[36m(wrapper pid=1687137)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=1687137)[0m No transition point detected!
[2m[36m(wrapper pid=1687137)[0m foreach_methods = ['aggregated_model_validation']
[2m[36m(wrapper pid=1687137)[0m Sending state from collaborator to aggregator
[2m[36m(wrapper pid=1687137)[0m 
[2m[36m(wrapper pid=1687137)[0m Calling join
[2m[36m(wrapper pid=1687135)[0m 
[2m[36m(wrapper pid=1687135)[0m Test set: Avg. loss: 0.6053, Accuracy: 2094/2500 (84%)
[2m[36m(wrapper pid=1687135)[0m 
[2m[36m(wrapper pid=1687135)[0m Doing local model validation for collaborator Ban

2022-11-29 23:28:20,503	ERROR worker.py:399 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): [36mray::wrapper()[39m (pid=1687137, ip=10.4.0.4)
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/placement/placement.py", line 18, in wrapper
    f()
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/placement/placement.py", line 77, in wrapper
    f(*args, **kwargs)
  File "/tmp/ipykernel_1686978/1826581246.py", line 25, in aggregated_model_validation
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/interface/fl_spec.py", line 310, in next
    self.execute_task(cln, f, parent_func, **kwargs)
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/interface/fl_spec.py", line 277, in execute_task
    to_exec()
  File "/home/pfoley1/anaconda3/envs/py3.8/lib/python3.8/site-packages/openfl/experimental/placement/placem

In [None]:
run_id = flflow._run_id

In [None]:
import metaflow

In [None]:
from metaflow import Metaflow, Flow, Task, Step

In [None]:
m = Metaflow()
list(m)

In [None]:
f = Flow('FederatedFlow').latest_run

In [None]:
f

In [None]:
list(f)

In [None]:
s = Step(f'FederatedFlow/{run_id}/train')

In [None]:
s

In [None]:
list(s)

In [None]:
t = Task(f'FederatedFlow/{run_id}/train/7')

In [None]:
t

In [None]:
t.data

In [None]:
t.data.input