# Convert Your ML Training Script to Federated Learning
This tutorial will guide you how to take your local machine learning training script and train it in a federated setting using FEDn and Scaleout studio. Start by installing the following packages into your environment:
- `pip install fedn`
- `pip install numpy`
- `pip install torch`
- `pip install torchvision`

## ML training example script
We will use the mnist example to show what parts of the training algorithm that needs to be modified to fit for federated learning (FL) training.

The cells below shows a typicall ML training flow implementation in pytorch. We will furter down show how to transform this into a federated training flow with FednClient. 

In [3]:
import torch
from torchvision import transforms 
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F

# Download and load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

# Check available device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
 
# Define the neural network model
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(784, 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.fc3 = torch.nn.Linear(32, 10)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x.reshape(x.size(0), 784)))
        x = torch.nn.functional.dropout(x, p=0.5, training=self.training)
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.log_softmax(self.fc3(x), dim=1)
        return x


## Device Configuration

This cell checks the availability of different types of hardware accelerators (GPU, MPS, or CPU) and sets the device accordingly. This ensures that the training process utilizes the best available hardware for optimal performance.

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

## Loading and Preparing the MNIST Dataset
This cell sets up the MNIST dataset for training and testing a neural network in PyTorch.

In [6]:
from torchvision import transforms 
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F

batch_size = 64
lr = 0.01
model = Net().to(device)

# Download and load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = torch.nn.NLLLoss()

## Define validation and training functions
This cell defines the training and validation functions for the neural network.

In [4]:
def test(model, data_loader):
    model.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    loss /= len(data_loader.dataset)
    accuracy = correct/len(data_loader.dataset)
    return loss, accuracy

def validate_model(model):
    training_loss, training_accuracy = test(model, train_loader)
    test_loss, test_accuracy = test(model, test_loader)
    
    print(f"training loss: {training_loss:.4f}, "
      f"training accuracy: {training_accuracy:.4f}, "
      f"test loss: {test_loss:.4f}, "
      f"test accuracy: {test_accuracy:.4f}")

def train_epoc(model, optimizer, criterion):
    for batch_idx, (data, target) in enumerate(train_loader): 
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

## Finally, the training loop
This cell trains the neural network using the training and validation functions defined earlier.

In [None]:
# Training loop
epochs = 2
for epoch in range(epochs):
    print("epoch: ", epoch+1, "/", epochs)
    train_epoc(model, optimizer, criterion)
    validate_model(model)

## Initiate FL server
Before we create our FL flow we need to create a fl project in Scaleout studio.
- Register a FEDn account: https://fedn.scaleoutsystems.com
- Create a project
- Go to project settings copy the project url and generate a client token

In [None]:
project_url = "<paste-project-url-here>"
client_token = "<paste-client-token-here>"

### Create a seed model
We initiate a new model state and save it locally first.

In [None]:
from fedn.utils.helpers.helpers import get_helper

HELPER_MODULE = "numpyhelper"
helper = get_helper(HELPER_MODULE)

model = Net().to(device)
parameters_np = [val.cpu().numpy() for _, val in model.state_dict().items()]
seed_local_path = "mnist_seed.npz"
helper.save(parameters_np, seed_local_path)


### Upload the seed model the server
Navigate to https://fedn.scaleoutsystems.com/models/add-seed-model and upload the seed model generated in the previous cell.

## FL training client
Now that the server is up and running, we can start the FL training client. The client will connect to the server and start training the model.

In [8]:
import io
import os
import collections
from fedn.utils.helpers.helpers import get_helper
import collections

HELPER_MODULE = "numpyhelper"
helper = get_helper(HELPER_MODULE)

def load_weights_into_model(weights, model):
    
    inpath = helper.get_tmp_path()
    with open(inpath, "wb") as fh:
        fh.write(weights.getbuffer())
    weights = helper.load(inpath)
    os.unlink(inpath)
    params_dict = zip(model.state_dict().keys(), weights)
    state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict})
    model.load_state_dict(state_dict, strict=True)

    

def extract_weights_from_model(model):

    # Convert from pytorch weights format numpy array 
    updated_weights = [val.cpu().numpy() for _, val in model.state_dict().items()]
    outpath = helper.get_tmp_path()
    helper.save(updated_weights, outpath)
    with open(outpath, "rb") as fr:
        out_model = io.BytesIO(fr.read())
    os.unlink(outpath)

    return out_model

## Define the training logic
This cell defines the training logic for the FL client. This is where the training algorithm is modified to fit the federated learning setting. This code will run when the server sends a request to the client to train the model.

In [9]:

from fedn.utils.helpers.helpers import save_metadata
import json

model = Net().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = torch.nn.NLLLoss()

def training_round(weights, client_settings):

    # Convert from numpy array to correct pytorch format
    load_weights_into_model(weights, model)
    
    validate_model(model)
    
    # Training loop
    local_epochs = 1
    for epoch in range(local_epochs): 
        train_epoc(model, optimizer, criterion)

    validate_model(model)
    

    metadata = {"training_metadata":{
        # num_examples are mandatory
        "num_examples": train_loader.dataset.data.shape[0],
        "batch_size": train_loader.batch_size,
        "epochs": local_epochs,
        "lr": lr,
    }}
   
    out_model = extract_weights_from_model(model)

    outpath = "temp"
    save_metadata(metadata, outpath)
    with open(outpath + "-metadata", "r") as fh:
                training_metadata = json.loads(fh.read())
    
    os.unlink(outpath + "-metadata")
    return out_model, training_metadata

## Define the validation logic
This cell defines the validation logic for the FL client. This code will run when the server sends a request to the client to validate the model. 

In [10]:
def validate_weights(weights):

    # Convert from numpy array to correct pytorch format
    load_weights_into_model(weights, model)
    
    
    training_loss, training_accuracy = test(model, train_loader)
    test_loss, test_accuracy = test(model, test_loader)
    # JSON schema
    performance = {
        "training_loss": training_loss,
        "training_accuracy": training_accuracy,
        "test_loss": test_loss,
        "test_accuracy": test_accuracy,
    }

    return performance

## Configure the client
This cell configures the client to connect to the server and start training the model. Generating a client id and specifying the client name. The client id is used to identify the client when connecting to the server. This way a client is reconized by the server when reconnected. The name is used for simlicity for the user and should be set to something descriptive.

In [None]:
from fedn.network.clients.fedn_client import FednClient, ConnectToApiResult
import uuid

fedn_client = FednClient(train_callback=training_round, validate_callback=validate_weights)

name = "<write-client-name-here>"
client_id = str(uuid.uuid4())

fedn_client.set_name(name)
fedn_client.set_client_id(client_id)

controller_config = {
    "name": name,
    "client_id": client_id,
    "package": "local",
    "preferred_combiner": "",
}

## Connect client to fedn network
The client will connect to the server first via http and then handle the communication via grpc. The rest call returns a grpc endpoint that the client will use to connect to the server. This works as a discovery service for the client to find the server. If the grpc enpoint is already known (for example if there is only one combiner in the project) the client can connect directly to the grpc endpoints. 

In [None]:
result, combiner_config = fedn_client.connect_to_api(f"https://{project_url}/", client_token, controller_config)

if result != ConnectToApiResult.Assigned:
    print("Failed to connect to API, exiting.")
    exit (1)

result: bool = fedn_client.init_grpchandler(config=combiner_config, client_name=name, token=client_token)

if not result:
    exit (1)
else:
    print("Client connected to network.")

fedn_client.run()

## Start training session
Now navigate to https://fedn.scaleoutsystems.com/sessions/create-session and create a new session. From the session page you can start the session and the client will start training the model. You can follow the training progress in the session page.