In [None]:
!pip install git+https://github.com/securefederatedai/openfl.git
!pip install -r ../workflow_interface_requirements.txt
!pip install torch
!pip install torchvision

In [None]:
from copy import deepcopy
import numpy as np
import torch
import torchvision
from time import time
from torchvision import datasets, transforms
from torch import nn, optim

from openfl.experimental.interface import FLSpec, Aggregator, Collaborator
from openfl.experimental.runtime import LocalRuntime
from openfl.experimental.placement import aggregator, collaborator

# Data preprocessing
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)),
                                ])
trainset = datasets.MNIST('mnist', download=True,
                          train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=2048, shuffle=False)

testset = datasets.MNIST('mnist', download=True,
                         train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

torch.manual_seed(0)  # Define our model segments
input_size = 784
hidden_sizes = [128, 640]
output_size = 10

label_model = nn.Sequential(
    nn.Linear(hidden_sizes[1], output_size),
    nn.LogSoftmax(dim=1)
)

label_model_optimizer = optim.SGD(label_model.parameters(), lr=0.03)

data_model = nn.Sequential(
    nn.Linear(input_size, hidden_sizes[0]),
    nn.ReLU(),
    nn.Linear(hidden_sizes[0], hidden_sizes[1]),
    nn.ReLU(),
)

data_model_optimizer = optim.SGD(data_model.parameters(), lr=0.03)

In [None]:
class VerticalTwoPartyFlow(FLSpec):

    def __init__(self, batch_num):
        super().__init__()
        self.batch_num = batch_num

    @aggregator
    def start(self):
        self.collaborators = self.runtime.collaborators
        print(f'Batch_num = {self.batch_num}')
        # 1) Zero the gradients
        self.label_model_optimizer.zero_grad()
        self.next(self.data_model_forward_pass, foreach='collaborators')

    @collaborator
    def data_model_forward_pass(self):
        self.data_model_output_local = ''
        for idx, (images, _) in enumerate(self.trainloader):
            if idx < self.batch_num:
                continue
            self.data_model_optimizer.zero_grad()
            images = images.view(images.shape[0], -1)
            model_output = self.data_model(images)
            self.data_model_output_local = model_output
            self.data_model_output = model_output.detach().requires_grad_()
            break
        self.next(self.label_model_forward_pass)
                  #exclude=['data_model_output_local'])

    @aggregator
    def label_model_forward_pass(self, inputs):
        criterion = nn.NLLLoss()
        self.grad_to_local = []
        total_loss = 0
        self.data_remaining = False
        for idx, (_, labels) in enumerate(self.trainloader):
            if idx < self.batch_num:
                continue
            self.data_remaining = True
            pred = self.label_model(inputs[0].data_model_output)
            loss = criterion(pred, labels)
            loss.backward()
            self.grad_to_local = inputs[0].data_model_output.grad.clone()
            self.label_model_optimizer.step()
            total_loss += loss
            break
        print(f'Total loss = {total_loss}')  # / len(self.trainloader)}')
        self.next(self.data_model_backprop, foreach='collaborators')

    @collaborator
    def data_model_backprop(self):
        if self.data_remaining:
            self.data_model_optimizer = optim.SGD(self.data_model.parameters(), lr=0.03)
            self.data_model_optimizer.zero_grad()
            self.data_model_output_local.backward(self.grad_to_local)
            self.data_model_optimizer.step()
        self.next(self.join)

    @aggregator
    def join(self, inputs):
        print(f'Join batch_num = {self.batch_num}')
        self.batch_num += 1
        self.next(self.end)

    @aggregator
    def end(self):
        print(f'This is the end of the flow')

In [None]:
# Setup participants
aggregator = Aggregator()

def callable_to_initialize_aggregator_private_attributes(train_loader,label_model,label_model_optimizer):
        return {"trainloader": train_loader,
                "label_model" : label_model,
                "label_model_optimizer":label_model_optimizer
                }  

# Setup aggregator private attributes via callable function
aggregator = Aggregator(
    name="agg",
    private_attributes_callable=callable_to_initialize_aggregator_private_attributes,
    train_loader = trainloader,
    label_model=label_model,
    label_model_optimizer=label_model_optimizer
)

# Setup collaborators private attributes via callable function
collaborator_names = ['Portland']

def callable_to_initialize_collaborator_private_attributes(index,data_model,data_model_optimizer,train_loader):
    return {
        "data_model": data_model,
        "data_model_optimizer": data_model_optimizer,
        "trainloader" : deepcopy(train_loader)
    }

collaborators = []
for idx, collaborator_name in enumerate(collaborator_names):
        collaborators.append(
            Collaborator(
                name=collaborator_name,
                private_attributes_callable=callable_to_initialize_collaborator_private_attributes,
                index=idx,
                data_model = data_model,
                data_model_optimizer = data_model_optimizer,
                train_loader = trainloader
            )
        )

local_runtime = LocalRuntime(
    aggregator=aggregator, collaborators=collaborators, backend='ray')
print(f'Local runtime collaborators = {local_runtime.collaborators}')

epochs = 100
batch_num = 0
for i in range(epochs):
    print(f'Starting round {i}')
    data_remaining = True
    vflow = VerticalTwoPartyFlow(batch_num=0)
    vflow.runtime = local_runtime
    while data_remaining:
        vflow.run()
        batch_num = vflow.batch_num
        data_remaining = vflow.data_remaining
        print(f'Continuing training loop: batch_num = {batch_num}')
    vflow.batch_num = 0

In [None]:
run_id = vflow._run_id

In [None]:
from metaflow import Metaflow, Flow, Task, Step

In [None]:
m = Metaflow()
list(m)

In [None]:
f = Flow('VerticalTwoPartyFlow').latest_run