# Torch Regression Example - Interactive API

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
NUM_FEATURES = 1
LEARNING_RATE = 0.5

## Torch Definitions

### Model

In [None]:
class LRModel(nn.Module):

    def __init__(self, in_features: int, out_features: int) -> None:
        super().__init__()
        self.fc = torch.nn.Linear(in_features, out_features)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)

In [None]:
model = LRModel(NUM_FEATURES, 1)

### Optimizer

In [None]:
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)

### Loss function

In [None]:
loss_fn = nn.MSELoss()

## Federation

In [None]:
import copy

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

### Dataset

In [None]:
class LRDataset(DataInterface):
    def __init__(self, train_bs: int = 1024, val_bs: int = 1024, **kwargs):
        super().__init__(**kwargs)
        self._train_bs = train_bs
        self._val_bs = val_bs
        self._train_data = None
        self._val_data = None

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        self._train_data = self._shard_descriptor.get_dataset('train')
        self._val_data = self._shard_descriptor.get_dataset('val')
        
    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        if self._train_data is None:
            raise ValueError("train data is not set")
        return torch.utils.data.DataLoader(self._train_data, batch_size=self._train_bs, shuffle=True)

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        if self._val_data is None:
            raise ValueError("validation data is not set")
        return torch.utils.data.DataLoader(self._val_data, batch_size=self._val_bs)

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        if self._train_data is None:
            raise ValueError("train data is not set")
        return len(self._train_data)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        if self._val_data is None:
            raise ValueError("validation data is not set")
        return len(self._val_data)

In [None]:
fl_dataset = LRDataset()

### Register model

In [None]:
framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
model_interface = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = copy.deepcopy(model)

### Register tasks

In [None]:
task_interface = TaskInterface()

# Task interface currently supports only standalone functions.
@task_interface.add_kwargs(**{'loss_fn': loss_fn})
@task_interface.register_fl_task(model='model', data_loader='train_loader', device='device', optimizer='optimizer')     
def train(model, train_loader, optimizer, device, loss_fn):    
    model.to(device)
    model.train()

    losses = []
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        loss = loss_fn(model(data[:,:NUM_FEATURES]), data[:,NUM_FEATURES:])
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())

    return {'train_mse': np.mean(losses)}


@task_interface.add_kwargs(**{'loss_fn': loss_fn})
@task_interface.register_fl_task(model='model', data_loader='val_loader', device='device')     
def validate(model, val_loader, device, loss_fn):
    model.to(device)
    model.eval()
    
    losses = []
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            loss = loss_fn(model(data[:,:NUM_FEATURES]), data[:,NUM_FEATURES:])
            losses.append(loss.detach().cpu().numpy())

    return {'val_mse': np.mean(losses)}

### Create Federation

In [None]:
from openfl.interface.interactive_api.federation import Federation

In [None]:
# please use the same identificator that was used in signed certificate
client_id = 'frontend'
director_node_fqdn = 'localhost'
director_port = 50050

federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port,
    tls=False
)

### Run Federation

In [None]:
# create an experimnet in federation
experiment_name = 'torch_linear_regression_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(
    model_provider=model_interface, 
    task_keeper=task_interface,
    data_loader=fl_dataset,
    rounds_to_train=10
)

In [None]:
fl_experiment.stream_metrics()