In [1]:
from copy import deepcopy
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import torchvision
import numpy as np
from openfl.experimental.interface import FLSpec, Aggregator, Collaborator
from openfl.experimental.runtime import LocalRuntime
from openfl.experimental.placement import aggregator, collaborator

n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

mnist_train = torchvision.datasets.MNIST('files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

mnist_test = torchvision.datasets.MNIST('files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

def FedAvg(models):
    new_model = models[0]
    state_dicts = [model.state_dict() for model in models]
    state_dict = new_model.state_dict()
    for key in models[1].state_dict():
        state_dict[key] = np.sum([state[key] for state in state_dicts],axis=0) / len(models)
    new_model.load_state_dict(state_dict)
    return new_model

def inference(network,test_loader):
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
      for data, target in test_loader:
        output = network(data)
        test_loss += F.nll_loss(output, target, size_average=False).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
      test_loss, correct, len(test_loader.dataset),
      100. * correct / len(test_loader.dataset)))
    accuracy = float(correct / len(test_loader.dataset))
    return accuracy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class FederatedFlow(FLSpec):

    def __init__(self, model = None, optimizer = None, **kwargs):
        super().__init__(**kwargs)
        if model is not None:
            self.model = model
            self.optimizer = optimizer
        else:
            self.model = Net()
            self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,
                                   momentum=momentum)

    @aggregator
    def start(self):
        """
        Start step.
        """
        print(f'Performing initialization for model')
        self.collaborators = self.runtime.collaborators
        self.private = 10
        self.next(self.aggregated_model_validation,foreach='collaborators',exclude=['private'])

    @collaborator
    def aggregated_model_validation(self):
        """
        Perform Aggregated model validation.
        """
        print(f'Performing aggregated model validation for collaborator {self.input}')
        self.agg_validation_score = inference(self.model,self.test_loader)
        print(f'{self.input} value of {self.agg_validation_score}')
        self.next(self.train)

    @collaborator
    def train(self):
        """
        Local Model Training.
        """
        self.model.train()
        self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,
                                   momentum=momentum)
        train_losses = []
        for batch_idx, (data, target) in enumerate(self.train_loader):
          self.optimizer.zero_grad()
          output = self.model(data)
          loss = F.nll_loss(output, target)
          loss.backward()
          self.optimizer.step()
          if batch_idx % log_interval == 0:
            print('Train Epoch: 1 [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
               batch_idx * len(data), len(self.train_loader.dataset),
              100. * batch_idx / len(self.train_loader), loss.item()))
            self.loss = loss.item()
            torch.save(self.model.state_dict(), 'model.pth')
            torch.save(self.optimizer.state_dict(), 'optimizer.pth')
        self.training_completed = True
        self.next(self.local_model_validation)

    @collaborator
    def local_model_validation(self):
        """
        Local Model Validation.
        """
        self.local_validation_score = inference(self.model,self.test_loader)
        print(f'Doing local model validation for collaborator {self.input}: {self.local_validation_score}')
        self.next(self.join, exclude=['training_completed'])

    @aggregator
    def join(self,inputs):
        """
        Model Aggregation.
        """
        self.average_loss = sum(input.loss for input in inputs)/len(inputs)
        self.aggregated_model_accuracy = sum(input.agg_validation_score for input in inputs)/len(inputs)
        self.local_model_accuracy = sum(input.local_validation_score for input in inputs)/len(inputs)
        print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')
        print(f'Average training loss = {self.average_loss}')
        print(f'Average local model validation values = {self.local_model_accuracy}')
        self.model = FedAvg([input.model for input in inputs])
        self.optimizer = [input.optimizer for input in inputs][0]
        self.next(self.end)
        
    @aggregator
    def end(self):
        """
        End step.
        """
        print(f'This is the end of the flow')  

Aggregator step "start" registered
Collaborator step "aggregated_model_validation" registered
Collaborator step "train" registered
Collaborator step "local_model_validation" registered
Aggregator step "join" registered
Aggregator step "end" registered


In [3]:
# Setup participants
aggregator = Aggregator()
aggregator.private_attributes = {}

# Setup collaborators with private attributes
collaborator_names = ['Portland', 'Seattle', 'Chandler','Bangalore']
collaborators = [Collaborator(name=name) for name in collaborator_names]
for idx, collaborator in enumerate(collaborators):
    local_train = deepcopy(mnist_train)
    local_test = deepcopy(mnist_test)
    local_train.data = mnist_train.data[idx::len(collaborators)]
    local_train.targets = mnist_train.targets[idx::len(collaborators)]
    local_test.data = mnist_test.data[idx::len(collaborators)]
    local_test.targets = mnist_test.targets[idx::len(collaborators)]
    collaborator.private_attributes = {
            'train_loader': torch.utils.data.DataLoader(local_train,batch_size=batch_size_train, shuffle=True),
            'test_loader': torch.utils.data.DataLoader(local_test,batch_size=batch_size_train, shuffle=True)
    }

local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators)
print(f'Local runtime collaborators = {local_runtime._collaborators}')

model = None
best_model = None
optimizer = None
top_model_accuracy = 0
for i in range(2):
    print(f'Starting round {i}...')
    flflow = FederatedFlow(model,optimizer,checkpoint=True)
    flflow.runtime = local_runtime
    flflow.run()
    model = flflow.model
    optimizer = flflow.optimizer
    aggregated_model_accuracy = flflow.aggregated_model_accuracy
    if aggregated_model_accuracy > top_model_accuracy:
        print(f'Accuracy improved to {aggregated_model_accuracy} for round {i}')
        top_model_accuracy = aggregated_model_accuracy

2022-12-05 21:17:35,105	INFO worker.py:1528 -- Started a local Ray instance.


Local runtime collaborators = {'Portland': <openfl.experimental.interface.participants.Collaborator object at 0x7f811c5ebdf0>, 'Seattle': <openfl.experimental.interface.participants.Collaborator object at 0x7f811c5ebd00>, 'Chandler': <openfl.experimental.interface.participants.Collaborator object at 0x7f811c5ebe50>, 'Bangalore': <openfl.experimental.interface.participants.Collaborator object at 0x7f811c5ebee0>}
Starting round 0...
Created flow FederatedFlow

Calling start
Performing initialization for model
Saving data artifacts for start
Saved data artifacts for start
Sending state from aggregator to collaborators
[2m[36m(wrapper pid=452531)[0m 
[2m[36m(wrapper pid=452531)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=452531)[0m 
[2m[36m(wrapper pid=452531)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=452531)[0m Performing aggregated model validation for collaborator Portland




[2m[36m(wrapper pid=452531)[0m 
[2m[36m(wrapper pid=452531)[0m Test set: Avg. loss: 2.3264, Accuracy: 309/2500 (12%)
[2m[36m(wrapper pid=452531)[0m 
[2m[36m(wrapper pid=452531)[0m Portland value of 0.12359999865293503
[2m[36m(wrapper pid=452531)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=452532)[0m 
[2m[36m(wrapper pid=452532)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=452532)[0m 
[2m[36m(wrapper pid=452532)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=452532)[0m Performing aggregated model validation for collaborator Seattle




[2m[36m(wrapper pid=452532)[0m 
[2m[36m(wrapper pid=452532)[0m Test set: Avg. loss: 2.3319, Accuracy: 272/2500 (11%)
[2m[36m(wrapper pid=452532)[0m 
[2m[36m(wrapper pid=452532)[0m Seattle value of 0.1088000014424324
[2m[36m(wrapper pid=452529)[0m 
[2m[36m(wrapper pid=452529)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=452529)[0m 
[2m[36m(wrapper pid=452529)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=452529)[0m Performing aggregated model validation for collaborator Chandler




[2m[36m(wrapper pid=452532)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=452531)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=452531)[0m 
[2m[36m(wrapper pid=452531)[0m Calling train




[2m[36m(wrapper pid=452529)[0m 
[2m[36m(wrapper pid=452529)[0m Test set: Avg. loss: 2.3338, Accuracy: 284/2500 (11%)
[2m[36m(wrapper pid=452529)[0m 
[2m[36m(wrapper pid=452529)[0m Chandler value of 0.1136000007390976
[2m[36m(wrapper pid=452529)[0m Saving data artifacts for aggregated_model_validation




[2m[36m(wrapper pid=452532)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=452532)[0m 
[2m[36m(wrapper pid=452532)[0m Calling train
[2m[36m(wrapper pid=452529)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=452529)[0m 
[2m[36m(wrapper pid=452529)[0m Calling train




[2m[36m(wrapper pid=452531)[0m Saving data artifacts for train
[2m[36m(wrapper pid=452531)[0m Saved data artifacts for train
[2m[36m(wrapper pid=452531)[0m 
[2m[36m(wrapper pid=452531)[0m Calling local_model_validation




[2m[36m(wrapper pid=452531)[0m 
[2m[36m(wrapper pid=452531)[0m Test set: Avg. loss: 0.6865, Accuracy: 2004/2500 (80%)
[2m[36m(wrapper pid=452531)[0m 
[2m[36m(wrapper pid=452531)[0m Doing local model validation for collaborator Portland: 0.8015999794006348
[2m[36m(wrapper pid=452531)[0m Saving data artifacts for local_model_validation
[2m[36m(wrapper pid=452532)[0m Saving data artifacts for train
[2m[36m(wrapper pid=452531)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=452531)[0m Should transfer from local_model_validation to join
[2m[36m(wrapper pid=452532)[0m Saved data artifacts for train
[2m[36m(wrapper pid=452532)[0m 
[2m[36m(wrapper pid=452532)[0m Calling local_model_validation




[2m[36m(wrapper pid=452529)[0m Saving data artifacts for train
[2m[36m(wrapper pid=452532)[0m 
[2m[36m(wrapper pid=452532)[0m Test set: Avg. loss: 0.6087, Accuracy: 2077/2500 (83%)
[2m[36m(wrapper pid=452532)[0m 
[2m[36m(wrapper pid=452532)[0m Doing local model validation for collaborator Seattle: 0.8307999968528748
[2m[36m(wrapper pid=452532)[0m Saving data artifacts for local_model_validation




[2m[36m(wrapper pid=452529)[0m Saved data artifacts for train
[2m[36m(wrapper pid=452529)[0m 
[2m[36m(wrapper pid=452529)[0m Calling local_model_validation
[2m[36m(wrapper pid=452532)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=452532)[0m Should transfer from local_model_validation to join
[2m[36m(wrapper pid=452529)[0m 
[2m[36m(wrapper pid=452529)[0m Test set: Avg. loss: 0.7447, Accuracy: 1975/2500 (79%)
[2m[36m(wrapper pid=452529)[0m 
[2m[36m(wrapper pid=452529)[0m Doing local model validation for collaborator Chandler: 0.7900000214576721
[2m[36m(wrapper pid=452529)[0m Saving data artifacts for local_model_validation
[2m[36m(wrapper pid=452529)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=452529)[0m Should transfer from local_model_validation to join


[2m[36m(raylet)[0m Spilled 2116 MiB, 8 objects, write throughput 213 MiB/s. Set RAY_verbose_spill_logs=0 to disable this message.


[2m[36m(wrapper pid=452873)[0m 
[2m[36m(wrapper pid=452873)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=452873)[0m 
[2m[36m(wrapper pid=452873)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=452873)[0m Performing aggregated model validation for collaborator Bangalore




[2m[36m(wrapper pid=452873)[0m 
[2m[36m(wrapper pid=452873)[0m Test set: Avg. loss: 2.3345, Accuracy: 272/2500 (11%)
[2m[36m(wrapper pid=452873)[0m 
[2m[36m(wrapper pid=452873)[0m Bangalore value of 0.1088000014424324
[2m[36m(wrapper pid=452873)[0m Saving data artifacts for aggregated_model_validation




[2m[36m(wrapper pid=452873)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=452873)[0m 
[2m[36m(wrapper pid=452873)[0m Calling train
[2m[36m(wrapper pid=452873)[0m Saving data artifacts for train




[2m[36m(wrapper pid=452873)[0m Saved data artifacts for train
[2m[36m(wrapper pid=452873)[0m 
[2m[36m(wrapper pid=452873)[0m Calling local_model_validation
[2m[36m(wrapper pid=452873)[0m 
[2m[36m(wrapper pid=452873)[0m Test set: Avg. loss: 0.6157, Accuracy: 2102/2500 (84%)
[2m[36m(wrapper pid=452873)[0m 
[2m[36m(wrapper pid=452873)[0m Doing local model validation for collaborator Bangalore: 0.8407999873161316
[2m[36m(wrapper pid=452873)[0m Saving data artifacts for local_model_validation
[2m[36m(wrapper pid=452873)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=452873)[0m Should transfer from local_model_validation to join

Calling join
Average aggregated model validation values = 0.11370000056922436
Average training loss = 0.987363874912262
Average local model validation values = 0.8157999962568283
Saving data artifacts for join
Saved data artifacts for join

Calling end
This is the end of the flow
Saving data artifacts for end
Sav

  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)


Created flow FederatedFlow

Calling start
Performing initialization for model
Saving data artifacts for start
Saved data artifacts for start
Sending state from aggregator to collaborators
[2m[36m(wrapper pid=453007)[0m 
[2m[36m(wrapper pid=453007)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=453007)[0m 
[2m[36m(wrapper pid=453007)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=453007)[0m Performing aggregated model validation for collaborator Portland
[2m[36m(wrapper pid=453006)[0m 
[2m[36m(wrapper pid=453006)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=453006)[0m 
[2m[36m(wrapper pid=453006)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=453006)[0m Performing aggregated model validation for collaborator Seattle




[2m[36m(wrapper pid=453007)[0m 
[2m[36m(wrapper pid=453007)[0m Test set: Avg. loss: 0.6670, Accuracy: 2098/2500 (84%)
[2m[36m(wrapper pid=453007)[0m 
[2m[36m(wrapper pid=453007)[0m Portland value of 0.8392000198364258
[2m[36m(wrapper pid=453006)[0m 
[2m[36m(wrapper pid=453006)[0m Test set: Avg. loss: 0.6717, Accuracy: 2114/2500 (85%)
[2m[36m(wrapper pid=453006)[0m 
[2m[36m(wrapper pid=453006)[0m Seattle value of 0.8456000089645386
[2m[36m(wrapper pid=453007)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=453006)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=453058)[0m 
[2m[36m(wrapper pid=453058)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=453058)[0m 
[2m[36m(wrapper pid=453058)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=453058)[0m Performing aggregated model validation for collaborator Chandler




[2m[36m(wrapper pid=453058)[0m 
[2m[36m(wrapper pid=453058)[0m Test set: Avg. loss: 0.6791, Accuracy: 2093/2500 (84%)
[2m[36m(wrapper pid=453058)[0m 
[2m[36m(wrapper pid=453058)[0m Chandler value of 0.8371999859809875
[2m[36m(wrapper pid=453007)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=453007)[0m 
[2m[36m(wrapper pid=453007)[0m Calling train




[2m[36m(wrapper pid=453006)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=453006)[0m 
[2m[36m(wrapper pid=453006)[0m Calling train
[2m[36m(wrapper pid=453058)[0m Saving data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=453058)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=453058)[0m 
[2m[36m(wrapper pid=453058)[0m Calling train




[2m[36m(wrapper pid=453007)[0m Saving data artifacts for train
[2m[36m(wrapper pid=453006)[0m Saving data artifacts for train
[2m[36m(wrapper pid=453007)[0m Saved data artifacts for train
[2m[36m(wrapper pid=453007)[0m 
[2m[36m(wrapper pid=453007)[0m Calling local_model_validation
[2m[36m(wrapper pid=453006)[0m Saved data artifacts for train
[2m[36m(wrapper pid=453006)[0m 
[2m[36m(wrapper pid=453006)[0m Calling local_model_validation




[2m[36m(wrapper pid=453007)[0m 
[2m[36m(wrapper pid=453007)[0m Test set: Avg. loss: 0.3601, Accuracy: 2245/2500 (90%)
[2m[36m(wrapper pid=453007)[0m 
[2m[36m(wrapper pid=453007)[0m Doing local model validation for collaborator Portland: 0.8980000019073486
[2m[36m(wrapper pid=453007)[0m Saving data artifacts for local_model_validation
[2m[36m(wrapper pid=453006)[0m 
[2m[36m(wrapper pid=453006)[0m Test set: Avg. loss: 0.3554, Accuracy: 2244/2500 (90%)
[2m[36m(wrapper pid=453006)[0m 
[2m[36m(wrapper pid=453006)[0m Doing local model validation for collaborator Seattle: 0.897599995136261
[2m[36m(wrapper pid=453006)[0m Saving data artifacts for local_model_validation
[2m[36m(wrapper pid=453058)[0m Saving data artifacts for train
[2m[36m(wrapper pid=453007)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=453007)[0m Should transfer from local_model_validation to join
[2m[36m(wrapper pid=453006)[0m Saved data artifacts for local_m

[2m[36m(raylet)[0m Spilled 4233 MiB, 16 objects, write throughput 242 MiB/s.


[2m[36m(wrapper pid=453058)[0m 
[2m[36m(wrapper pid=453058)[0m Test set: Avg. loss: 0.3535, Accuracy: 2253/2500 (90%)
[2m[36m(wrapper pid=453058)[0m 
[2m[36m(wrapper pid=453058)[0m Doing local model validation for collaborator Chandler: 0.901199996471405
[2m[36m(wrapper pid=453058)[0m Saving data artifacts for local_model_validation
[2m[36m(wrapper pid=453058)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=453058)[0m Should transfer from local_model_validation to join




[2m[36m(wrapper pid=453277)[0m 
[2m[36m(wrapper pid=453277)[0m Running aggregated_model_validation in a new process
[2m[36m(wrapper pid=453277)[0m 
[2m[36m(wrapper pid=453277)[0m Calling aggregated_model_validation
[2m[36m(wrapper pid=453277)[0m Performing aggregated model validation for collaborator Bangalore
[2m[36m(wrapper pid=453277)[0m 
[2m[36m(wrapper pid=453277)[0m Test set: Avg. loss: 0.6620, Accuracy: 2119/2500 (85%)
[2m[36m(wrapper pid=453277)[0m 
[2m[36m(wrapper pid=453277)[0m Bangalore value of 0.847599983215332
[2m[36m(wrapper pid=453277)[0m Saving data artifacts for aggregated_model_validation




[2m[36m(wrapper pid=453277)[0m Saved data artifacts for aggregated_model_validation
[2m[36m(wrapper pid=453277)[0m 
[2m[36m(wrapper pid=453277)[0m Calling train
[2m[36m(wrapper pid=453277)[0m Saving data artifacts for train
[2m[36m(wrapper pid=453277)[0m Saved data artifacts for train
[2m[36m(wrapper pid=453277)[0m 
[2m[36m(wrapper pid=453277)[0m Calling local_model_validation




[2m[36m(wrapper pid=453277)[0m 
[2m[36m(wrapper pid=453277)[0m Test set: Avg. loss: 0.3556, Accuracy: 2240/2500 (90%)
[2m[36m(wrapper pid=453277)[0m 
[2m[36m(wrapper pid=453277)[0m Doing local model validation for collaborator Bangalore: 0.8960000276565552
[2m[36m(wrapper pid=453277)[0m Saving data artifacts for local_model_validation
[2m[36m(wrapper pid=453277)[0m Saved data artifacts for local_model_validation
[2m[36m(wrapper pid=453277)[0m Should transfer from local_model_validation to join

Calling join
Average aggregated model validation values = 0.842399999499321
Average training loss = 0.6987558454275131
Average local model validation values = 0.8982000052928925
Saving data artifacts for join
Saved data artifacts for join

Calling end
This is the end of the flow
Saving data artifacts for end
Saved data artifacts for end
Accuracy improved to 0.842399999499321 for round 1


In [4]:
from openfl.experimental.utilities.ui import InspectFlow

In [5]:
if flflow._checkpoint:
    InspectFlow(flflow, flflow._run_id, show_html=True)

Flowgraph generated at :/home/scngupta/.metaflow/FederatedFlow/1670255320094610


In [6]:
run_id = flflow._run_id

In [7]:
import metaflow

In [8]:
from metaflow import Metaflow, Flow, Task, Step

In [9]:
m = Metaflow()
list(m)

[Flow('FederatedMnistFlowWithWatermarking'),
 Flow('FederatedFlow_MNIST_Watermarking'),
 Flow('FederatedFlow')]

In [10]:
f = Flow('FederatedFlow').latest_run

In [11]:
f

Run('FederatedFlow/1670255320094610')

In [12]:
list(f)

[Step('FederatedFlow/1670255320094610/end'),
 Step('FederatedFlow/1670255320094610/join'),
 Step('FederatedFlow/1670255320094610/local_model_validation'),
 Step('FederatedFlow/1670255320094610/train'),
 Step('FederatedFlow/1670255320094610/aggregated_model_validation'),
 Step('FederatedFlow/1670255320094610/start')]

In [13]:
s = Step(f'FederatedFlow/{run_id}/train')

In [14]:
s

Step('FederatedFlow/1670255320094610/train')

In [15]:
list(s)

[Task('FederatedFlow/1670255320094610/train/1')]

In [16]:
t = Task(f'FederatedFlow/{run_id}/train/7')

MetaflowNotFound: Task('FederatedFlow/1670255320094610/train/7') does not exist

In [None]:
t

In [None]:
t.data

In [None]:
t.data.input