# Convert Your ML Training Script to Federated Learning
This tutorial guides you through converting a local machine learning training script into a federated learning setup using FEDn and Scaleout Studio. While most of our examples use a remote package—where training and validation functions are defined on the server—this example defines the training and validation functions locally.

## ML training example script
We will use the mnist example to show what parts of the training algorithm that needs to be modified to fit for federated learning (FL) training.

The cells below shows a typicall ML training flow implementation in pytorch. We will furter down show how to transform this into a federated training flow with FednClient. 

In [1]:
import torch
from torchvision import transforms 
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F

# Download and load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

# Check available device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
 
# Define the neural network model
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(784, 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.fc3 = torch.nn.Linear(32, 10)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x.reshape(x.size(0), 784)))
        x = torch.nn.functional.dropout(x, p=0.5, training=self.training)
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.log_softmax(self.fc3(x), dim=1)
        return x

def test(model, data_loader):
    model.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    loss /= len(data_loader.dataset)
    accuracy = correct/len(data_loader.dataset)
    return loss, accuracy

def validate_model(model):
    
    training_loss, training_accuracy = test(model, train_loader)
    test_loss, test_accuracy = test(model, test_loader)
    
    
    print(f"training loss: {training_loss:.4f}, "
      f"training accuracy: {training_accuracy:.4f}, "
      f"test loss: {test_loss:.4f}, "
      f"test accuracy: {test_accuracy:.4f}")
    
def train_epoc(model, optimizer, criterion):
    
    
    for batch_idx, (data, target) in enumerate(train_loader): 
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()



    

In [None]:
lr = 0.01
model = Net().to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = torch.nn.NLLLoss()

# Training loop
epochs = 2
for epoch in range(epochs):
    print("epoch: ", epoch+1, "/", epochs)
    train_epoc(model, optimizer, criterion)
    validate_model(model)

## Initiate FL server
Before we create our FL flow we need to create a fl project in Scaleout studio.
- Register a FEDn account: https://fedn.scaleoutsystems.com
- Create a project
- Go to project settings copy project url, admin token and client token

In [None]:
project_url = "<paste-api-url-here>"
admin_token = "<paste-admin-token-here>"
client_token = "<paste-client-token-here>"

## Initiate a model state
We initiate a new model state and save it locally first. We can then use the FednClient to upload the model state to the server. 

In [None]:
from fedn.utils.helpers.helpers import get_helper
from fedn import APIClient
HELPER_MODULE = "numpyhelper"
helper = get_helper(HELPER_MODULE)

model = Net().to(device)
parameters_np = [val.cpu().numpy() for _, val in model.state_dict().items()]
seed_local_path = "mnist_seed.npz"
helper.save(parameters_np, seed_local_path)


client = APIClient(host=project_url, token=admin_token, secure=True, verify=True)

result = client.set_active_model(seed_local_path)
print(result)


## FL training client

We need the wrapper functions: load_weights_into_model and extract_weights_from_model to communicate the model state between the model and the FednClient.

In [31]:
import io
import os
import collections
from fedn.utils.helpers.helpers import get_helper
HELPER_MODULE = "numpyhelper"
helper = get_helper(HELPER_MODULE)

def load_weights_into_model(weights, model):
    
    inpath = helper.get_tmp_path()
    with open(inpath, "wb") as fh:
        fh.write(weights.getbuffer())
    weights = helper.load(inpath)
    os.unlink(inpath)
    params_dict = zip(model.state_dict().keys(), weights)
    state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict})
    model.load_state_dict(state_dict, strict=True)

    

def extract_weights_from_model(model):

    # Convert from pytorch weights format numpy array 
    updated_weights = [val.cpu().numpy() for _, val in model.state_dict().items()]
    outpath = helper.get_tmp_path()
    helper.save(updated_weights, outpath)
    with open(outpath, "rb") as fr:
        out_model = io.BytesIO(fr.read())
    os.unlink(outpath)

    return out_model

The FednClient need a train_callback and validate_callback

In [56]:

from torchvision import transforms 
from torchvision import datasets
from torch.utils.data import DataLoader
from fedn.utils.helpers.helpers import save_metadata
import json

model = Net().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = torch.nn.NLLLoss()

def training_round(weights, client_settings):

    # Convert from numpy array to correct pytorch format
    load_weights_into_model(weights, model)

    # Training loop
    local_epochs = 1
    for epoch in range(local_epochs): 
        train_epoc(model, optimizer, criterion)

    metadata = {"training_metadata":{
        # num_examples are mandatory
        "num_examples": train_loader.dataset.data.shape[0],
        "batch_size": train_loader.batch_size,
        "epochs": local_epochs,
        "lr": lr,
    }}
   
    out_model = extract_weights_from_model(model)

    outpath = "temp"
    save_metadata(metadata, outpath)
    with open(outpath + "-metadata", "r") as fh:
                training_metadata = json.loads(fh.read())
    
    os.unlink(outpath + "-metadata")
    return out_model, training_metadata

In [51]:
def validate_weights(weights):

    # Convert from numpy array to correct pytorch format
    load_weights_into_model(weights, model)
    
    
    training_loss, training_accuracy = test(model, train_loader)
    test_loss, test_accuracy = test(model, test_loader)
    # JSON schema
    performance = {
        "training_loss": training_loss,
        "training_accuracy": training_accuracy,
        "test_loss": test_loss,
        "test_accuracy": test_accuracy,
    }

    return performance

We then initiate the FednClient api with our train and validate functions

In [None]:
from fedn.network.clients.fedn_client import FednClient, ConnectToApiResult
import uuid

fedn_client = FednClient(train_callback=training_round, validate_callback=validate_weights)

## Connect client to fedn network

In [None]:
name = "<write-client-name-here>"

fedn_client.set_name(name)

client_id = str(uuid.uuid4())
fedn_client.set_client_id(client_id)

controller_config = {
    "name": name,
    "client_id": client_id,
    "package": "local",
    "preferred_combiner": "",
}


result, combiner_config = fedn_client.connect_to_api("https://"+project_url+"/", client_token, controller_config)

print("result: ", result)
print( combiner_config)
if result != ConnectToApiResult.Assigned:
    print("Failed to connect to API, exiting.")
    exit (1)

result: bool = fedn_client.init_grpchandler(config=combiner_config, client_name=name, token=client_token)
fedn_client.run()