In [1]:
from copy import deepcopy
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import torchvision
import numpy as np
from openfl.experimental.interface import FLSpec, Aggregator, Collaborator
from openfl.experimental.runtime import LocalRuntime
from openfl.experimental.placement import aggregator, collaborator

n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

mnist_train = torchvision.datasets.MNIST('files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

mnist_test = torchvision.datasets.MNIST('files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

def FedAvg(models):
    new_model = models[0]
    state_dicts = [model.state_dict() for model in models]
    state_dict = new_model.state_dict()
    for key in models[1].state_dict():
        state_dict[key] = np.sum([state[key] for state in state_dicts],axis=0) / len(models)
    new_model.load_state_dict(state_dict)
    return new_model

def inference(network,test_loader):
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
      for data, target in test_loader:
        output = network(data)
        test_loss += F.nll_loss(output, target, size_average=False).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
      test_loss, correct, len(test_loader.dataset),
      100. * correct / len(test_loader.dataset)))
    accuracy = float(correct / len(test_loader.dataset))
    return accuracy

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class FederatedFlow(FLSpec):

    def __init__(self, model = None, optimizer = None, **kwargs):
        super().__init__(**kwargs)
        if model is not None:
            self.model = model
            self.optimizer = optimizer
        else:
            self.model = Net()
            self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,
                                   momentum=momentum)

    @aggregator
    def start(self):
        """
        Start step.
        """
        print(f'Performing initialization for model')
        self.collaborators = self.runtime.collaborators
        self.private = 10
        self.next(self.aggregated_model_validation,foreach='collaborators',exclude=['private'])

    @collaborator
    def aggregated_model_validation(self):
        """
        Perform Aggregated model validation.
        """
        print(f'Performing aggregated model validation for collaborator {self.input}')
        self.agg_validation_score = inference(self.model,self.test_loader)
        print(f'{self.input} value of {self.agg_validation_score}')
        self.next(self.train)

    @collaborator
    def train(self):
        """
        Local Model Training.
        """
        self.model.train()
        self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,
                                   momentum=momentum)
        train_losses = []
        for batch_idx, (data, target) in enumerate(self.train_loader):
          self.optimizer.zero_grad()
          output = self.model(data)
          loss = F.nll_loss(output, target)
          loss.backward()
          self.optimizer.step()
          if batch_idx % log_interval == 0:
            print('Train Epoch: 1 [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
               batch_idx * len(data), len(self.train_loader.dataset),
              100. * batch_idx / len(self.train_loader), loss.item()))
            self.loss = loss.item()
            torch.save(self.model.state_dict(), 'model.pth')
            torch.save(self.optimizer.state_dict(), 'optimizer.pth')
        self.training_completed = True
        self.next(self.local_model_validation)

    @collaborator
    def local_model_validation(self):
        """
        Local Model Validation.
        """
        self.local_validation_score = inference(self.model,self.test_loader)
        print(f'Doing local model validation for collaborator {self.input}: {self.local_validation_score}')
        self.next(self.join, exclude=['training_completed'])

    @aggregator
    def join(self,inputs):
        """
        Model Aggregation.
        """
        self.average_loss = sum(input.loss for input in inputs)/len(inputs)
        self.aggregated_model_accuracy = sum(input.agg_validation_score for input in inputs)/len(inputs)
        self.local_model_accuracy = sum(input.local_validation_score for input in inputs)/len(inputs)
        print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')
        print(f'Average training loss = {self.average_loss}')
        print(f'Average local model validation values = {self.local_model_accuracy}')
        self.model = FedAvg([input.model for input in inputs])
        self.optimizer = [input.optimizer for input in inputs][0]
        self.next(self.end)
        
    @aggregator
    def end(self):
        """
        End step.
        """
        print(f'This is the end of the flow')  

Aggregator step "start" registered
Collaborator step "aggregated_model_validation" registered
Collaborator step "train" registered
Collaborator step "local_model_validation" registered
Aggregator step "join" registered
Aggregator step "end" registered


In [4]:
# Setup participants
aggregator = Aggregator()
aggregator.private_attributes = {}

# Setup collaborators with private attributes
collaborator_names = ['Portland']#, 'Seattle', 'Chandler','Bangalore']
collaborators = [Collaborator(name=name) for name in collaborator_names]
for idx, collaborator in enumerate(collaborators):
    local_train = deepcopy(mnist_train)
    local_test = deepcopy(mnist_test)
    local_train.data = mnist_train.data[idx::len(collaborators)]
    local_train.targets = mnist_train.targets[idx::len(collaborators)]
    local_test.data = mnist_test.data[idx::len(collaborators)]
    local_test.targets = mnist_test.targets[idx::len(collaborators)]
    collaborator.private_attributes = {
            'train_loader': torch.utils.data.DataLoader(local_train,batch_size=batch_size_train, shuffle=True),
            'test_loader': torch.utils.data.DataLoader(local_test,batch_size=batch_size_train, shuffle=True)
    }

local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators,
                            backend='single_process' )
print(f'Local runtime collaborators = {local_runtime._collaborators}')

model = None
best_model = None
optimizer = None
top_model_accuracy = 0
for i in range(1):
    print(f'Starting round {i}...')
    flflow = FederatedFlow(model,optimizer,checkpoint=True)
    flflow.runtime = local_runtime
    flflow.run()
    model = flflow.model
    optimizer = flflow.optimizer
    aggregated_model_accuracy = flflow.aggregated_model_accuracy
    if aggregated_model_accuracy > top_model_accuracy:
        print(f'Accuracy improved to {aggregated_model_accuracy} for round {i}')
        top_model_accuracy = aggregated_model_accuracy

Local runtime collaborators = {'Portland': <openfl.experimental.interface.participants.Collaborator object at 0x7fdd61f38310>}
Starting round 0...
Created flow FederatedFlow

Calling start
Performing initialization for model
Saving data artifacts for start
Saved data artifacts for start
Sending state from aggregator to collaborators

Calling aggregated_model_validation
Performing aggregated model validation for collaborator Portland


  return F.log_softmax(x)



Test set: Avg. loss: 2.3316, Accuracy: 1137/10000 (11%)

Portland value of 0.1137000024318695
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation

Calling train
Saving data artifacts for train
Saved data artifacts for train

Calling local_model_validation

Test set: Avg. loss: 0.1995, Accuracy: 9384/10000 (94%)

Doing local model validation for collaborator Portland: 0.9383999705314636
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Next function = join

Calling join
Average aggregated model validation values = 0.1137000024318695
Average training loss = 0.3847687244415283
Average local model validation values = 0.9383999705314636


IndexError: list index out of range

In [5]:
from openfl.experimental.utilities.ui import InspectFlow

In [6]:
if flflow._checkpoint:
    InspectFlow(flflow, flflow._run_id, show_html=True)

Flowgraph generated at :/home/keerti/.metaflow/FederatedFlow/1669704190572687


<3>init: (31948) ERROR: UtilConnectToInteropServer:307: connect failed 2
<3>init: (31958) ERROR: UtilConnectToInteropServer:307: connect failed 2
<3>init: (31960) ERROR: UtilConnectToInteropServer:307: connect failed 2
<3>init: (31962) ERROR: UtilConnectToInteropServer:307: connect failed 2


In [None]:
run_id = flflow._run_id

In [None]:
import metaflow

In [None]:
from metaflow import Metaflow, Flow, Task, Step

In [None]:
m = Metaflow()
list(m)

In [None]:
f = Flow('FederatedFlow').latest_run

In [None]:
f

In [None]:
list(f)

In [None]:
s = Step(f'FederatedFlow/{run_id}/train')

In [None]:
s

In [None]:
list(s)

In [None]:
t = Task(f'FederatedFlow/{run_id}/train/7')

In [None]:
t

In [None]:
t.data

In [None]:
t.data.input