In [1]:
from copy import deepcopy
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import torchvision
import numpy as np
from openfl.experimental.interface import FLSpec, Aggregator, Collaborator
from openfl.experimental.runtime import LocalRuntime
from openfl.experimental.placement import aggregator, collaborator

n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

mnist_train = torchvision.datasets.MNIST('files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

mnist_test = torchvision.datasets.MNIST('files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

def FedAvg(models):
    new_model = models[0]
    state_dicts = [model.state_dict() for model in models]
    state_dict = new_model.state_dict()
    for key in models[1].state_dict():
        state_dict[key] = np.sum([state[key] for state in state_dicts],axis=0) / len(models)
    new_model.load_state_dict(state_dict)
    return new_model

def inference(network,test_loader):
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
      for data, target in test_loader:
        output = network(data)
        test_loss += F.nll_loss(output, target, size_average=False).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
      test_loss, correct, len(test_loader.dataset),
      100. * correct / len(test_loader.dataset)))
    accuracy = float(correct / len(test_loader.dataset))
    return accuracy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class FederatedFlow(FLSpec):

    def __init__(self, model = None, optimizer = None, **kwargs):
        super().__init__(**kwargs)
        if model is not None:
            self.model = model
            self.optimizer = optimizer
        else:
            self.model = Net()
            self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,
                                   momentum=momentum)

    @aggregator
    def start(self):
        print(f'Performing initialization for model')
        self.collaborators = self.runtime.collaborators
        self.private = 10
        self.next(self.aggregated_model_validation,foreach='collaborators',exclude=['private'])

    @collaborator
    def aggregated_model_validation(self):
        print(f'Performing aggregated model validation for collaborator {self.input}')
        self.agg_validation_score = inference(self.model,self.test_loader)
        print(f'{self.input} value of {self.agg_validation_score}')
        self.next(self.train)

    @collaborator
    def train(self):
        self.model.train()
        train_losses = []
        for batch_idx, (data, target) in enumerate(self.train_loader):
          self.optimizer.zero_grad()
          output = self.model(data)
          loss = F.nll_loss(output, target)
          loss.backward()
          self.optimizer.step()
          if batch_idx % log_interval == 0:
            print('Train Epoch: 1 [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
               batch_idx * len(data), len(self.train_loader.dataset),
              100. * batch_idx / len(self.train_loader), loss.item()))
            self.loss = loss.item()
            torch.save(self.model.state_dict(), 'model.pth')
            torch.save(self.optimizer.state_dict(), 'optimizer.pth')
        self.training_completed = True
        self.next(self.local_model_validation)

    @collaborator
    def local_model_validation(self):
        self.local_validation_score = inference(self.model,self.test_loader)
        print(f'Doing local model validation for collaborator {self.input}: {self.local_validation_score}')
        self.next(self.join, exclude=['training_completed'])

    @aggregator
    def join(self,inputs):
        self.average_loss = sum(input.loss for input in inputs)/len(inputs)
        self.aggregated_model_accuracy = sum(input.agg_validation_score for input in inputs)/len(inputs)
        self.local_model_accuracy = sum(input.local_validation_score for input in inputs)/len(inputs)
        print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')
        print(f'Average training loss = {self.average_loss}')
        print(f'Average local model validation values = {self.local_model_accuracy}')
        self.model = FedAvg([input.model for input in inputs])
        self.optimizer = [input.optimizer for input in inputs][0]
        self.next(self.end)
        
    @aggregator
    def end(self):
        print(f'This is the end of the flow')  

Aggregator step "start" registered
Collaborator step "aggregated_model_validation" registered
Collaborator step "train" registered
Collaborator step "local_model_validation" registered
Aggregator step "join" registered
Aggregator step "end" registered


In [3]:
# Setup participants
aggregator = Aggregator()
aggregator.private_attributes = {}

# Setup collaborators with private attributes
collaborator_names = ['Portland', 'Seattle', 'Chandler','Bangalore']
collaborators = [Collaborator(name=name) for name in collaborator_names]
for idx, collaborator in enumerate(collaborators):
    local_train = deepcopy(mnist_train)
    local_test = deepcopy(mnist_test)
    local_train.data = mnist_train.data[idx::len(collaborators)]
    local_train.targets = mnist_train.targets[idx::len(collaborators)]
    local_test.data = mnist_test.data[idx::len(collaborators)]
    local_test.targets = mnist_test.targets[idx::len(collaborators)]
    collaborator.private_attributes = {
            'train_loader': torch.utils.data.DataLoader(local_train,batch_size=batch_size_train, shuffle=True),
            'test_loader': torch.utils.data.DataLoader(local_test,batch_size=batch_size_train, shuffle=True)
    }

local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators)
print(f'Local runtime collaborators = {local_runtime._collaborators}')

model = None
best_model = None
optimizer = None
top_model_accuracy = 0
for i in range(2):
    print(f'Starting round {i}...')
    flflow = FederatedFlow(model,optimizer,checkpoint=True)
    flflow.runtime = local_runtime
    flflow.run()
    model = flflow.model
    optimizer = flflow.optimizer
    aggregated_model_accuracy = flflow.aggregated_model_accuracy
    if aggregated_model_accuracy > top_model_accuracy:
        print(f'Accuracy improved to {aggregated_model_accuracy} for round {i}')
        top_model_accuracy = aggregated_model_accuracy

Local runtime collaborators = {'Portland': <openfl.experimental.interface.participants.Collaborator object at 0x7f4f9c25dc40>, 'Seattle': <openfl.experimental.interface.participants.Collaborator object at 0x7f4f9c25db50>, 'Chandler': <openfl.experimental.interface.participants.Collaborator object at 0x7f4f9c25dc70>, 'Bangalore': <openfl.experimental.interface.participants.Collaborator object at 0x7f4f9c25dd00>}
Starting round 0...
Created flow FederatedFlow

Calling start
Performing initialization for model
Saving data artifacts for start
TaskDataStore init function invoked!
Saved data artifacts for start
Sending state from aggregator to collaborators
[2m[36m(wrapper pid=1827)[0m 
[2m[36m(wrapper pid=1827)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=1827)[0m 
[2m[36m(wrapper pid=1827)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=1827)[0m Performing aggregated model validation for collaborator Portland




[2m[36m(wrapper pid=1827)[0m 
[2m[36m(wrapper pid=1827)[0m Test set: Avg. loss: 2.3264, Accuracy: 309/2500 (12%)
[2m[36m(wrapper pid=1827)[0m 
[2m[36m(wrapper pid=1827)[0m Portland value of 0.12359999865293503
[2m[36m(wrapper pid=1827)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1827)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=1816)[0m Performing aggregated model validation for collaborator Seattle




[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Test set: Avg. loss: 2.3319, Accuracy: 272/2500 (11%)
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Seattle value of 0.1088000014424324
[2m[36m(wrapper pid=1816)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1816)[0m TaskDataStore init function invoked!




[2m[36m(wrapper pid=1850)[0m 
[2m[36m(wrapper pid=1850)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=1850)[0m 
[2m[36m(wrapper pid=1850)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=1850)[0m Performing aggregated model validation for collaborator Chandler
[2m[36m(wrapper pid=1827)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1827)[0m 
[2m[36m(wrapper pid=1827)[0m Calling train
[2m[36m(wrapper pid=1850)[0m 
[2m[36m(wrapper pid=1850)[0m Test set: Avg. loss: 2.3338, Accuracy: 284/2500 (11%)
[2m[36m(wrapper pid=1850)[0m 
[2m[36m(wrapper pid=1850)[0m Chandler value of 0.1136000007390976
[2m[36m(wrapper pid=1850)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1850)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1846)[0m 
[2m[36m(wrapper pid=1846)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=1846



[2m[36m(wrapper pid=1816)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Calling train
[2m[36m(wrapper pid=1846)[0m 
[2m[36m(wrapper pid=1846)[0m Test set: Avg. loss: 2.3345, Accuracy: 272/2500 (11%)
[2m[36m(wrapper pid=1846)[0m 
[2m[36m(wrapper pid=1846)[0m Bangalore value of 0.1088000014424324
[2m[36m(wrapper pid=1846)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1846)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1850)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1850)[0m 
[2m[36m(wrapper pid=1850)[0m Calling train




[2m[36m(wrapper pid=1846)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1846)[0m 
[2m[36m(wrapper pid=1846)[0m Calling train




[2m[36m(wrapper pid=1827)[0m Saving data artifacts for train
[2m[36m(wrapper pid=1827)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1827)[0m Saved data artifacts for train
[2m[36m(wrapper pid=1827)[0m 
[2m[36m(wrapper pid=1827)[0m Calling local_model_validation




[2m[36m(wrapper pid=1816)[0m Saving data artifacts for train
[2m[36m(wrapper pid=1816)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1816)[0m Saved data artifacts for train
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Calling local_model_validation
[2m[36m(wrapper pid=1827)[0m 
[2m[36m(wrapper pid=1827)[0m Test set: Avg. loss: 0.6976, Accuracy: 2014/2500 (81%)
[2m[36m(wrapper pid=1827)[0m 
[2m[36m(wrapper pid=1827)[0m Doing local model validation for collaborator Portland: 0.8055999875068665
[2m[36m(wrapper pid=1827)[0m Saving data artifacts for local_model_validation
[2m[36m(wrapper pid=1827)[0m TaskDataStore init function invoked!




[2m[36m(wrapper pid=1827)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=1827)[0m Should transfer from local_model_validation to join
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Test set: Avg. loss: 0.6159, Accuracy: 2092/2500 (84%)
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Doing local model validation for collaborator Seattle: 0.8367999792098999
[2m[36m(wrapper pid=1816)[0m Saving data artifacts for local_model_validation
[2m[36m(wrapper pid=1816)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1850)[0m Saving data artifacts for train
[2m[36m(wrapper pid=1850)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1816)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=1816)[0m Should transfer from local_model_validation to join




[2m[36m(wrapper pid=1850)[0m Saved data artifacts for train
[2m[36m(wrapper pid=1850)[0m 
[2m[36m(wrapper pid=1850)[0m Calling local_model_validation
[2m[36m(wrapper pid=1850)[0m 
[2m[36m(wrapper pid=1850)[0m Test set: Avg. loss: 0.7163, Accuracy: 2027/2500 (81%)
[2m[36m(wrapper pid=1850)[0m 
[2m[36m(wrapper pid=1850)[0m Doing local model validation for collaborator Chandler: 0.8108000159263611
[2m[36m(wrapper pid=1850)[0m Saving data artifacts for local_model_validation
[2m[36m(wrapper pid=1850)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1846)[0m Saving data artifacts for train
[2m[36m(wrapper pid=1846)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1850)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=1850)[0m Should transfer from local_model_validation to join




[2m[36m(wrapper pid=1846)[0m Saved data artifacts for train
[2m[36m(wrapper pid=1846)[0m 
[2m[36m(wrapper pid=1846)[0m Calling local_model_validation
[2m[36m(wrapper pid=1846)[0m 
[2m[36m(wrapper pid=1846)[0m Test set: Avg. loss: 0.5898, Accuracy: 2106/2500 (84%)
[2m[36m(wrapper pid=1846)[0m 
[2m[36m(wrapper pid=1846)[0m Doing local model validation for collaborator Bangalore: 0.8424000144004822
[2m[36m(wrapper pid=1846)[0m Saving data artifacts for local_model_validation
[2m[36m(wrapper pid=1846)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1846)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=1846)[0m Should transfer from local_model_validation to join
Next function = join

Calling join
Average aggregated model validation values = 0.11370000056922436
Average training loss = 0.8680389523506165
Average local model validation values = 0.8238999992609024
Saving data artifacts for join
TaskDataStore init function invok

  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)


Created flow FederatedFlow

Calling start
Performing initialization for model
Saving data artifacts for start
TaskDataStore init function invoked!
Saved data artifacts for start
Sending state from aggregator to collaborators




[2m[36m(wrapper pid=1846)[0m 
[2m[36m(wrapper pid=1846)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=1846)[0m 
[2m[36m(wrapper pid=1846)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=1846)[0m Performing aggregated model validation for collaborator Portland
[2m[36m(wrapper pid=1846)[0m 
[2m[36m(wrapper pid=1846)[0m Test set: Avg. loss: 0.6733, Accuracy: 2117/2500 (85%)
[2m[36m(wrapper pid=1846)[0m 
[2m[36m(wrapper pid=1846)[0m Portland value of 0.8468000292778015
[2m[36m(wrapper pid=1846)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1846)[0m TaskDataStore init function invoked!




[2m[36m(wrapper pid=1850)[0m 
[2m[36m(wrapper pid=1850)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=1850)[0m 
[2m[36m(wrapper pid=1850)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=1850)[0m Performing aggregated model validation for collaborator Seattle
[2m[36m(wrapper pid=1850)[0m 
[2m[36m(wrapper pid=1850)[0m Test set: Avg. loss: 0.6750, Accuracy: 2127/2500 (85%)
[2m[36m(wrapper pid=1850)[0m 
[2m[36m(wrapper pid=1850)[0m Seattle value of 0.8507999777793884
[2m[36m(wrapper pid=1850)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1850)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=1816)[0m Performing aggregated model validation for collaborator Ch



[2m[36m(wrapper pid=1846)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1846)[0m 
[2m[36m(wrapper pid=1846)[0m Calling train
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Test set: Avg. loss: 0.6856, Accuracy: 2101/2500 (84%)
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Chandler value of 0.840399980545044
[2m[36m(wrapper pid=1816)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1816)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1850)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1850)[0m 
[2m[36m(wrapper pid=1850)[0m Calling train
[2m[36m(wrapper pid=1827)[0m 
[2m[36m(wrapper pid=1827)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=1827)[0m 
[2m[36m(wrapper pid=1827)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=1827)[0m Performing aggregated model validation for co



[2m[36m(wrapper pid=1827)[0m 
[2m[36m(wrapper pid=1827)[0m Test set: Avg. loss: 0.6671, Accuracy: 2121/2500 (85%)
[2m[36m(wrapper pid=1827)[0m 
[2m[36m(wrapper pid=1827)[0m Bangalore value of 0.8483999967575073
[2m[36m(wrapper pid=1827)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1827)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1816)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Calling train
[2m[36m(wrapper pid=1827)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=1827)[0m 
[2m[36m(wrapper pid=1827)[0m Calling train
[2m[36m(wrapper pid=1846)[0m Saving data artifacts for train
[2m[36m(wrapper pid=1846)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1850)[0m Saving data artifacts for train
[2m[36m(wrapper pid=1850)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1846)

[2m[36m(wrapper pid=1816)[0m Saving data artifacts for train
[2m[36m(wrapper pid=1816)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1850)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=1850)[0m Should transfer from local_model_validation to join
[2m[36m(wrapper pid=1816)[0m Saved data artifacts for train
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Calling local_model_validation
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Test set: Avg. loss: 0.3675, Accuracy: 2229/2500 (89%)
[2m[36m(wrapper pid=1816)[0m 
[2m[36m(wrapper pid=1816)[0m Doing local model validation for collaborator Chandler: 0.8916000127792358
[2m[36m(wrapper pid=1816)[0m Saving data artifacts for local_model_validation
[2m[36m(wrapper pid=1816)[0m TaskDataStore init function invoked!
[2m[36m(wrapper pid=1816)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=1816)[0m Should transfer from l

In [4]:
run_id = flflow._run_id

In [5]:
import metaflow

In [6]:
from metaflow import Metaflow, Flow, Task, Step

In [7]:
m = Metaflow()
list(m)

[Flow('LinearFlow'),
 Flow('FederatedFlow'),
 Flow('VerticalFlow'),
 Flow('NewFlow'),
 Flow('BranchFlow'),
 Flow('TestFlow'),
 Flow('ForeachFlow')]

In [8]:
f = Flow('FederatedFlow').latest_run

In [9]:
f

Run('FederatedFlow/1664554700764836')

In [10]:
list(f)

[Step('FederatedFlow/1664554700764836/join'),
 Step('FederatedFlow/1664554700764836/local_model_validation'),
 Step('FederatedFlow/1664554700764836/train'),
 Step('FederatedFlow/1664554700764836/aggregated_model_validation'),
 Step('FederatedFlow/1664554700764836/start')]

In [11]:
s = Step(f'FederatedFlow/{run_id}/train')

In [12]:
s

Step('FederatedFlow/1664554700764836/train')

In [13]:
list(s)

[Task('FederatedFlow/1664554700764836/train/12'),
 Task('FederatedFlow/1664554700764836/train/10'),
 Task('FederatedFlow/1664554700764836/train/7'),
 Task('FederatedFlow/1664554700764836/train/6')]

In [14]:
t = Task(f'FederatedFlow/{run_id}/train/12')

In [15]:
t

Task('FederatedFlow/1664554700764836/train/12')

In [16]:
t.data

<MetaflowData: collaborators, model, agg_validation_score, loss, training_completed, train_loader, input, test_loader, optimizer>

In [17]:
t.data.input

TaskDataStore init function invoked!


'Bangalore'