# Convert Your ML Training Script to Federated Learning
This tutorial will guide you how to take your local machine learning training script and train it in a federated setting using FEDn and Scaleout studio. 

## ML training example script
We will use the mnist example to show what parts of the training algorithm that needs to be modified to fit for federated learning (FL) training.

The cells below shows a typicall ML training algorithm implemented in pytorch.

In [1]:
import torch
 

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(784, 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.fc3 = torch.nn.Linear(32, 10)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x.reshape(x.size(0), 784)))
        x = torch.nn.functional.dropout(x, p=0.5, training=self.training)
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.log_softmax(self.fc3(x), dim=1)
        return x


In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [3]:
def test(model, data_loader):
    model.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    loss /= len(data_loader.dataset)
    accuracy = correct/len(data_loader.dataset)
    return loss, accuracy
    




def validate_model(model):
    
    training_loss, training_accuracy = test(model, train_loader)
    test_loss, test_accuracy = test(model, test_loader)
    
    
    print(f"training loss: {training_loss:.4f}, "
      f"training accuracy: {training_accuracy:.4f}, "
      f"test loss: {test_loss:.4f}, "
      f"test accuracy: {test_accuracy:.4f}")

In [4]:
def train_epoc(model, optimizer, criterion):
    
    
    for batch_idx, (data, target) in enumerate(train_loader): 
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()


In [5]:
from torchvision import transforms 
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F


# Download and load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

batch_size = 64
lr = 0.01
model = Net().to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = torch.nn.NLLLoss()

# Training loop
epochs = 2
for epoch in range(epochs):
    print("epoch: ", epoch+1, "/", epochs)
    train_epoc(model, optimizer, criterion)
    validate_model(model)
    

epoch:  1 / 2
training loss: 1.1047, training accuracy: 0.7173, test loss: 1.0874, test accuracy: 0.7244
epoch:  2 / 2
training loss: 0.5217, training accuracy: 0.8538, test loss: 0.5054, test accuracy: 0.8565


## Initiate FL server

- Register a FEDn account: https://fedn.scaleoutsystems.com
- Create a project
- Go to project settings copy project url, admin token and client token

In [None]:
project_url = "<paste-api-url-here>"
admin_token = "<paste-admin-token-here>"
client_token = "<paste-client-token-here>"

In [6]:

project_url = "api.fedn.scaleoutsystems.com/testapi-rhe-fedn-reducer"
             
admin_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ0b2tlbl90eXBlIjoiYWNjZXNzIiwiZXhwIjoxNzQ0NTM4MDQ0LCJpYXQiOjE3NDE5NDYwNDQsImp0aSI6IjdjY2VjNTdlZWY0ZjQ0ZjRhYzc3YTJiZjA4YjQxNzc4IiwidXNlcl9pZCI6MjE0LCJjcmVhdG9yIjoiTWF0dGlhcyIsInJvbGUiOiJhZG1pbiIsInByb2plY3Rfc2x1ZyI6InRlc3RhcGktcmhlIn0.MW1gC08qFj7_F7Cz4AzLtCXiP6bgom-m2MY-Gl08ySc"
client_token ="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ0b2tlbl90eXBlIjoiYWNjZXNzIiwiZXhwIjoxNzQ0NTQyMjYzLCJpYXQiOjE3NDE5NTAyNjMsImp0aSI6IjVhNjQ4ZWNiNTk0YzQ0YTM5YzVkZDgyY2EzNjMwMWVjIiwidXNlcl9pZCI6MjE0LCJjcmVhdG9yIjoiTWF0dGlhcyIsInJvbGUiOiJjbGllbnQiLCJwcm9qZWN0X3NsdWciOiJ0ZXN0YXBpLXJoZSJ9.uwWENDYfGq_FKk3tQjcaDZNix_bZQlT0h1TWohvkSlg"

In [30]:
model.state_dict()['fc1.weight'].device

device(type='mps', index=0)

## FL training client

In [31]:
import io
import os
from fedn.utils.helpers.helpers import get_helper
HELPER_MODULE = "numpyhelper"
helper = get_helper(HELPER_MODULE)

def load_weights_into_model(weights, model):

    print("model state dict device init: ", model.state_dict()['fc1.weight'].device)
    
    inpath = helper.get_tmp_path()
    with open(inpath, "wb") as fh:
        fh.write(weights.getbuffer())
    weights = helper.load(inpath)
    os.unlink(inpath)
    params_dict = zip(model.state_dict().keys(), weights)
    state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict})
    model.load_state_dict(state_dict, strict=True)
    print("model state dict device after: ", model.state_dict()['fc1.weight'].device)

    

def extract_weights_from_model(model):

    # Convert from pytorch weights format numpy array 
    updated_weights = [val.cpu().numpy() for _, val in model.state_dict().items()]
    outpath = helper.get_tmp_path()
    helper.save(updated_weights, outpath)
    with open(outpath, "rb") as fr:
        out_model = io.BytesIO(fr.read())
    os.unlink(outpath)

    return out_model

In [32]:
len(train_dataset)

60000

In [56]:
import collections

from torchvision import transforms 
from torchvision import datasets
from torch.utils.data import DataLoader
from fedn.utils.helpers.helpers import save_metadata
import json

model = Net().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = torch.nn.NLLLoss()

def training_round(weights, client_settings):

    # Convert from numpy array to correct pytorch format
    load_weights_into_model(weights, model)
    
    validate_model(model)
    

    # Training loop
    local_epochs = 1
    for epoch in range(local_epochs): 
        train_epoc(model, optimizer, criterion)
        #for batch_idx, (data, target) in enumerate(train_loader):
            #print(batch_idx)
        #    data, target = data.to(device), target.to(device)
        #    optimizer.zero_grad()
        #    output = model(data)
        #    loss = criterion(output, target)
        #    loss.backward()
        #    optimizer.step()

    validate_model(model)
    

    metadata = {"training_metadata":{
        # num_examples are mandatory
        "num_examples": len(train_dataset),
        "batch_size": batch_size,
        "epochs": local_epochs,
        "lr": lr,
    }}
   
    
    out_model = extract_weights_from_model(model)

    outpath = "temp"
    save_metadata(metadata, outpath)
    with open(outpath + "-metadata", "r") as fh:
                training_metadata = json.loads(fh.read())
    
    os.unlink(outpath + "-metadata")
    return out_model, training_metadata

In [51]:
def validate_weights(weights):

    # Convert from numpy array to correct pytorch format
    load_weights_into_model(weights, model)
    
    
    training_loss, training_accuracy = test(model, train_loader)
    test_loss, test_accuracy = test(model, test_loader)
    # JSON schema
    performance = {
        "training_loss": training_loss,
        "training_accuracy": training_accuracy,
        "test_loss": test_loss,
        "test_accuracy": test_accuracy,
    }

    return performance

## Initiate a model state
We initiate a new model state and save it locally first. We can then use the APIClient to upload the model state to the server. 

In [52]:
from fedn.utils.helpers.helpers import get_helper
from fedn import APIClient
HELPER_MODULE = "numpyhelper"
helper = get_helper(HELPER_MODULE)

model = Net().to(device)
parameters_np = [val.cpu().numpy() for _, val in model.state_dict().items()]
seed_local_path = "mnist_seed.npz"
helper.save(parameters_np, seed_local_path)


client = APIClient(host=project_url, token=admin_token, secure=True, verify=True)

result = client.set_active_model(seed_local_path)
print(result)


Mattias API client
{'message': 'Model added successfully', 'success': True}


In [44]:
client.get_combiners()

{'count': 1,
 'result': [{'address': 'testapi-rhe-fedn-combiner',
   'fqdn': 'grpc.fedn.scaleoutsystems.com',
   'id': '67c5737a85f391be69665f0f',
   'ip': '',
   'name': 'testapi-rhe-fedn',
   'parent': 'localhost',
   'port': 12080,
   'updated_at': '2025-03-03 09:16:42.988092'}]}

In [45]:
#helper.load(seed_local_path)
#parameters_np

TODO: Niclas describe how to start the api

In [57]:
from fedn.network.clients.fedn_client import FednClient, ConnectToApiResult
import uuid

fedn_client = FednClient(train_callback=training_round, validate_callback=validate_weights)

name = "<write-client-name-here>"
name = "helena"

# Varför sätter vi name och id både här och i controller_config? 
fedn_client.set_name(name)

#client_id = str(uuid.uuid4())
fedn_client.set_client_id(client_id)
print(client_id)


2025-03-17 17:10:27 [INFO] Setting client name to: helena
2025-03-17 17:10:27 [INFO] Setting client ID to: fa371b54-751b-465a-a997-5c9d0ad25780


fa371b54-751b-465a-a997-5c9d0ad25780


In [47]:
# TODO: Describe controller config

In [48]:
controller_config = {
    "name": name,
    "client_id": client_id,
    "package": "local",
    "preferred_combiner": "",
}



## Connect client to fedn network

In [58]:
result, combiner_config = fedn_client.connect_to_api("https://"+project_url+"/", client_token, controller_config)
#result, combiner_config = fedn_client.connect_to_api(api_url, token2, controller_config)

print("result: ", result)
print( combiner_config)
if result != ConnectToApiResult.Assigned:
    print("Failed to connect to API, exiting.")
    exit (1)

result: bool = fedn_client.init_grpchandler(config=combiner_config, client_name=name, token=client_token)
print("result: ", result)
if not result:
    exit (1)

2025-03-17 17:10:33 [INFO] Connecting to API endpoint: https://api.fedn.scaleoutsystems.com/testapi-rhe-fedn-reducer/api/v1/clients/add
2025-03-17 17:10:33 [INFO] Connect to FEDn Api - Client assigned to controller
2025-03-17 17:10:33 [INFO] Connecting (GRPC) to grpc.fedn.scaleoutsystems.com:443
2025-03-17 17:10:33 [INFO] Successfully initialized GRPC connection


result:  ConnectToApiResult.Assigned
<fedn.network.clients.fedn_client.GrpcConnectionOptions object at 0x144a20980>
result:  True


In [None]:
fedn_client.run()

2025-03-17 17:10:33 [INFO] Listening to task stream.


Starting client...Mattias


2025-03-17 17:10:45 [INFO] Sending status message to combiner.
2025-03-17 17:10:45 [INFO] Received task request of type 2 for model_id b9252577-a59c-4e92-8e42-c71cd37e976f
2025-03-17 17:10:45 [INFO] Downloading model from combiner.
2025-03-17 17:10:45 [INFO] Sending status message to combiner.
2025-03-17 17:10:45 [INFO] Running train callback with model ID: b9252577-a59c-4e92-8e42-c71cd37e976f


model state dict device init:  mps:0
model state dict device after:  mps:0
training loss: 2.3039, training accuracy: 0.1022, test loss: 2.3035, test accuracy: 0.1010


2025-03-17 17:10:53 [INFO] Uploading model to combiner.
2025-03-17 17:10:54 [INFO] Sending model update to combiner.
2025-03-17 17:10:54 [INFO] Sending status message to combiner.


training loss: 0.8770, training accuracy: 0.7761, test loss: 0.8617, test accuracy: 0.7816


2025-03-17 17:10:57 [INFO] Sending status message to combiner.
2025-03-17 17:10:57 [INFO] Received task request of type 4 for model_id c7bdeced-566e-4a20-96cd-4f385b424ac4
2025-03-17 17:10:57 [INFO] Sending status message to combiner.
2025-03-17 17:10:57 [INFO] Downloading model from combiner.
2025-03-17 17:10:57 [INFO] Running validate callback with model ID: c7bdeced-566e-4a20-96cd-4f385b424ac4


model state dict device init:  mps:0
model state dict device after:  mps:0


2025-03-17 17:11:00 [INFO] Sending model validation to combiner.
2025-03-17 17:11:00 [INFO] Sending status message to combiner.
2025-03-17 17:11:00 [INFO] Sending status message to combiner.
2025-03-17 17:11:00 [INFO] Received task request of type 2 for model_id c7bdeced-566e-4a20-96cd-4f385b424ac4
2025-03-17 17:11:00 [INFO] Downloading model from combiner.
2025-03-17 17:11:00 [INFO] Sending status message to combiner.
2025-03-17 17:11:00 [INFO] Running train callback with model ID: c7bdeced-566e-4a20-96cd-4f385b424ac4


model state dict device init:  mps:0
model state dict device after:  mps:0
training loss: 0.8770, training accuracy: 0.7761, test loss: 0.8617, test accuracy: 0.7816


2025-03-17 17:11:08 [INFO] Uploading model to combiner.
2025-03-17 17:11:08 [INFO] Sending model update to combiner.
2025-03-17 17:11:08 [INFO] Sending status message to combiner.


training loss: 0.4571, training accuracy: 0.8729, test loss: 0.4464, test accuracy: 0.8759


2025-03-17 17:11:12 [INFO] Sending status message to combiner.
2025-03-17 17:11:12 [INFO] Received task request of type 4 for model_id a66894cc-a69e-49c9-ac86-a9812e1f3581
2025-03-17 17:11:12 [INFO] Sending status message to combiner.
2025-03-17 17:11:12 [INFO] Downloading model from combiner.
2025-03-17 17:11:12 [INFO] Running validate callback with model ID: a66894cc-a69e-49c9-ac86-a9812e1f3581


model state dict device init:  mps:0
model state dict device after:  mps:0


2025-03-17 17:11:15 [INFO] Sending model validation to combiner.
2025-03-17 17:11:15 [INFO] Sending status message to combiner.
2025-03-17 17:11:15 [INFO] Sending status message to combiner.
2025-03-17 17:11:16 [INFO] Received task request of type 2 for model_id a66894cc-a69e-49c9-ac86-a9812e1f3581
2025-03-17 17:11:16 [INFO] Downloading model from combiner.
2025-03-17 17:11:16 [INFO] Sending status message to combiner.
2025-03-17 17:11:16 [INFO] Running train callback with model ID: a66894cc-a69e-49c9-ac86-a9812e1f3581


model state dict device init:  mps:0
model state dict device after:  mps:0
training loss: 0.4571, training accuracy: 0.8729, test loss: 0.4464, test accuracy: 0.8759


2025-03-17 17:11:24 [INFO] Uploading model to combiner.
2025-03-17 17:11:24 [INFO] Sending model update to combiner.
2025-03-17 17:11:24 [INFO] Sending status message to combiner.


training loss: 0.3759, training accuracy: 0.8942, test loss: 0.3659, test accuracy: 0.8953


2025-03-17 17:11:28 [INFO] Sending status message to combiner.
2025-03-17 17:11:28 [INFO] Received task request of type 4 for model_id f5a478b0-441c-45d1-a88a-b0ea26e2d396
2025-03-17 17:11:28 [INFO] Sending status message to combiner.
2025-03-17 17:11:28 [INFO] Downloading model from combiner.
2025-03-17 17:11:29 [INFO] Running validate callback with model ID: f5a478b0-441c-45d1-a88a-b0ea26e2d396


model state dict device init:  mps:0
model state dict device after:  mps:0


2025-03-17 17:11:31 [INFO] Sending model validation to combiner.
2025-03-17 17:11:31 [INFO] Sending status message to combiner.
2025-03-17 17:11:31 [INFO] Sending status message to combiner.
2025-03-17 17:11:31 [INFO] Received task request of type 2 for model_id f5a478b0-441c-45d1-a88a-b0ea26e2d396
2025-03-17 17:11:31 [INFO] Downloading model from combiner.
2025-03-17 17:11:32 [INFO] Sending status message to combiner.
2025-03-17 17:11:32 [INFO] Running train callback with model ID: f5a478b0-441c-45d1-a88a-b0ea26e2d396


model state dict device init:  mps:0
model state dict device after:  mps:0
training loss: 0.3759, training accuracy: 0.8942, test loss: 0.3659, test accuracy: 0.8953


2025-03-17 17:11:40 [INFO] Uploading model to combiner.
2025-03-17 17:11:40 [INFO] Sending model update to combiner.
2025-03-17 17:11:40 [INFO] Sending status message to combiner.


training loss: 0.3431, training accuracy: 0.9028, test loss: 0.3364, test accuracy: 0.9051


2025-03-17 17:11:43 [INFO] Sending status message to combiner.
2025-03-17 17:11:43 [INFO] Received task request of type 4 for model_id ad2902d1-a69a-4bfa-9267-017dca3386e0
2025-03-17 17:11:43 [INFO] Sending status message to combiner.
2025-03-17 17:11:43 [INFO] Downloading model from combiner.
2025-03-17 17:11:43 [INFO] Running validate callback with model ID: ad2902d1-a69a-4bfa-9267-017dca3386e0


model state dict device init:  mps:0
model state dict device after:  mps:0


2025-03-17 17:11:46 [INFO] Sending model validation to combiner.
2025-03-17 17:11:46 [INFO] Sending status message to combiner.
2025-03-17 17:11:46 [INFO] Sending status message to combiner.
2025-03-17 17:11:46 [INFO] Received task request of type 2 for model_id ad2902d1-a69a-4bfa-9267-017dca3386e0
2025-03-17 17:11:46 [INFO] Downloading model from combiner.
2025-03-17 17:11:46 [INFO] Sending status message to combiner.
2025-03-17 17:11:47 [INFO] Running train callback with model ID: ad2902d1-a69a-4bfa-9267-017dca3386e0


model state dict device init:  mps:0
model state dict device after:  mps:0
training loss: 0.3431, training accuracy: 0.9028, test loss: 0.3364, test accuracy: 0.9051


2025-03-17 17:11:55 [INFO] Uploading model to combiner.
2025-03-17 17:11:55 [INFO] Sending model update to combiner.
2025-03-17 17:11:55 [INFO] Sending status message to combiner.


training loss: 0.3149, training accuracy: 0.9101, test loss: 0.3091, test accuracy: 0.9106


2025-03-17 17:11:58 [INFO] Sending status message to combiner.
2025-03-17 17:11:58 [INFO] Received task request of type 4 for model_id 1a1df8d2-2ec9-4f4d-904b-0ac3c3fab80b
2025-03-17 17:11:58 [INFO] Sending status message to combiner.
2025-03-17 17:11:58 [INFO] Downloading model from combiner.
2025-03-17 17:11:58 [INFO] Running validate callback with model ID: 1a1df8d2-2ec9-4f4d-904b-0ac3c3fab80b


model state dict device init:  mps:0
model state dict device after:  mps:0


2025-03-17 17:12:01 [INFO] Sending model validation to combiner.
2025-03-17 17:12:01 [INFO] Sending status message to combiner.
2025-03-17 22:03:44 [INFO] Heartbeat successful.
2025-03-17 22:03:44 [INFO] Heartbeat successful.
2025-03-17 22:03:45 [INFO] Sending heartbeat to combiner
2025-03-17 22:03:45 [INFO] Heartbeat successful.
2025-03-17 22:03:46 [INFO] Sending heartbeat to combiner
