In [1]:
from openfl.interface.interactive_api.federation import Federation

In [2]:
federation = Federation(client_id='client_id', director_node_fqdn='localhost', director_port='50051',tls=False, cert_chain=None, api_cert=None, api_private_key=None)

In [3]:
federation.get_shard_registry()

{'envoy_one': {'shard_info': node_info {
    name: "envoy_one"
  }
  sample_shape: "2"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-11-28 15:18:22',
  'current_time': '2022-11-28 15:19:02',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'},
 'envoy_two': {'shard_info': node_info {
    name: "envoy_two"
  }
  sample_shape: "2"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-11-28 15:18:25',
  'current_time': '2022-11-28 15:19:02',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'},
 'envoy_three': {'shard_info': node_info {
    name: "envoy_three"
  }
  sample_shape: "2"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-11-28 15:18:59',
  'current_time': '2022-11-28 15:19:02',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'}}

In [4]:
from openfl.interface.interactive_api.experiment import FLExperiment
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
fl_experiment = FLExperiment(federation=federation, experiment_name="pranav's test")

In [6]:
from torch import nn
import torch as t
import torch.nn.functional as F
import torch.optim as optim

class XOR(nn.Module):
    def __init__(self, input_dim = 2, output_dim=1):
        super(XOR, self).__init__()
        self.lin1 = nn.Linear(input_dim, 4)
        self.lin2 = nn.Linear(4, output_dim)

    def forward(self, x):
        x = self.lin1(x)
        x = F.sigmoid(x)
        x = self.lin2(x)
        return x


model = XOR()
loss_func = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.2, momentum=0.9)

In [7]:
from openfl.interface.interactive_api.experiment import ModelInterface
MI = ModelInterface(model, optimizer, framework_plugin='openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin')

In [8]:
from torch.utils.data import DataLoader


class CustomDataInterface(DataInterface):
    def __init__(self, **kwargs):
        # Initialize superclass with kwargs: this array will be passed
        # to get_data_loader methods
        super().__init__(**kwargs)

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor

        self.train_set = self._shard_descriptor.get_dataset('train')
        self.valid_set = self._shard_descriptor.get_dataset('val')


    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        return DataLoader(
            self.train_set, batch_size=4, shuffle=True
        )
    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        return DataLoader(
            self.valid_set, batch_size=4, shuffle=True
        )

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)

DI = CustomDataInterface()

In [9]:
import numpy as np
import tqdm

TI = TaskInterface()

task_settings = {
    'batch_size': 32,
}
@TI.add_kwargs(**task_settings)
@TI.register_fl_task(model='model_', data_loader='train_loader',
                     device='device', optimizer='optim')
def train(model_, train_loader, optim, device, batch_size):
    train_loader = tqdm.tqdm(train_loader, desc="train")
    model_.train()
    model_.to(device)

    losses = []

    for data, target in train_loader:
        optim.zero_grad()
        y_hat = model(data.to(dtype=t.float))
        loss = F.mse_loss(y_hat[:,0], target.to(dtype=t.float))
        loss.backward()
        optim.step()
        losses.append(loss.mean().item())
        # train_loader.set_postfix({'Loss', loss.mean().item()})
    return {'train_loss': np.mean(losses),}

@TI.register_fl_task(model='model_', data_loader='val_loader', device='device')
def validate(model_, val_loader, device):
    device = t.device('cpu')
    model_.eval()
    model_.to(device)

    val_loader = tqdm.tqdm(val_loader, desc="validate")
    val_score = 0
    total_samples = 0

    with t.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            # data, target = t.tensor(data).to(device), t.tensor(target).to(device, dtype=t.int64)
            output = model_(data.to(dtype=t.float))
            pred = output.argmax(dim=1)
            val_score += pred.eq(target.to()).sum().cpu().numpy()

    return {'acc': val_score / total_samples,}

In [13]:
fl_experiment.start(
    model_provider=MI,
    data_loader=DI,
    task_keeper=TI,
    rounds_to_train=3,
    opt_treatment='CONTINUE_GLOBAL'
)

In [11]:
fl_experiment.stream_metrics()

In [12]:
fl_experiment.remove_experiment_data()