# Data Owner

<div class="alert alert-info">
Before running this notebook, run the Data Scientist notebook to host a model in PyGrid.
</div>

In [1]:
import base64
import json
from math import prod
import random
from typing import Any

import gym
import requests
import torch as T
from torch import nn
from websocket import create_connection

from syft import serialize
from syft.federated.model_serialization import (
    deserialize_model_params,
    wrap_model_params,
)

## Step 1: Specify configuration and helper functions

In [2]:
NAME = "q-learning"
VERSION = "0.0.0"
GRID_ADDRESS = "localhost:5000"
CARTPOLE_DIMS = (1, 1, 12, 6, 2)

# (cart position, cart velocity, pole angle, pole angular velocity)
CartPoleObservation = tuple[float, float, float, float]
DiscretizedCartPoleObservation = tuple[int, int, int, int]

In [3]:
def prettify(json_data: object) -> str:
    return json.dumps(json_data, indent=2).replace("\\n", "\n")

In [4]:
def set_params(model: nn.Module, params: list[T.Tensor]) -> None:
    for p, p_new in zip(model.parameters(), params):
        p.data = p_new.detach().clone().data


def calculate_diff(
    original_params: list[T.Tensor], trained_params: list[T.Tensor]
) -> list[T.Tensor]:
    return [old - new for old, new in zip(original_params, trained_params)]

In [5]:
def send_ws_message(grid_address: str, data: object) -> Any:
    ws = create_connection("ws://" + grid_address)
    ws.send(json.dumps(data))
    message = ws.recv()
    return json.loads(message)


def get_model_params(
    grid_address: str, worker_id: str, request_key: str, model_id: str
) -> list[T.Tensor]:
    get_params = {
        "worker_id": worker_id,
        "request_key": request_key,
        "model_id": model_id,
    }
    response = requests.get(
        f"http://{grid_address}/model-centric/get-model", get_params
    )
    return deserialize_model_params(response.content)


def retrieve_model_params(grid_address: str, name: str, version: str) -> list[T.Tensor]:
    get_params = {
        "name": name,
        "version": version,
        "checkpoint": "latest",
    }

    response = requests.get(
        f"http://{grid_address}/model-centric/retrieve-model", get_params
    )
    return deserialize_model_params(response.content)


def send_auth_request(grid_address: str, name: str, version: str) -> Any:
    message = {
        "type": "model-centric/authenticate",
        "data": {
            "model_name": name,
            "model_version": version,
        },
    }
    return send_ws_message(grid_address, message)


def send_cycle_request(
    grid_address: str, name: str, version: str, worker_id: str
) -> Any:
    message = {
        "type": "model-centric/cycle-request",
        "data": {
            "worker_id": worker_id,
            "model": name,
            "version": version,
            "ping": 1,
            "download": 10000,
            "upload": 10000,
        },
    }
    return send_ws_message(grid_address, message)


def send_diff_report(
    grid_address: str, worker_id: str, request_key: str, diff: list[T.Tensor]
) -> Any:
    serialized_diff = serialize(wrap_model_params(diff)).SerializeToString()
    message = {
        "type": "model-centric/report",
        "data": {
            "worker_id": worker_id,
            "request_key": request_key,
            "diff": base64.b64encode(serialized_diff).decode("ascii"),
        },
    }
    send_ws_message(grid_address, message)

In [6]:
def clip(min_value, max_value, value):
    return max(min_value, min(max_value, value))

def discretize(observation: CartPoleObservation) -> DiscretizedCartPoleObservation:
    (
        _raw_cart_position,
        _raw_cart_velocity,
        raw_pole_angle,
        raw_pole_angular_velocity,
    ) = observation
    cart_position = 0  # not very useful
    cart_velocity = 0  # not very useful
    pole_angle = int(clip(0.0, 0.417, raw_pole_angle + 0.209) // (0.418 / 12))
    pole_angular_velocity = int(
        clip(0.0, 5.999, raw_pole_angular_velocity + 3.0) // (6.0 / 6)
    )
    return (cart_position, cart_velocity, pole_angle, pole_angular_velocity)

## Step 2: Define the model and training loop

In [7]:
class QLearningAgent(nn.Module):
    def __init__(self, alpha: float, gamma: float, min_epsilon: float, epsilon_reduction: float) -> None:
        super().__init__()
        self.name = "q-learning"
        self.alpha = alpha  # learning rate
        self.gamma = gamma  # discount rate
        self.min_epsilon = min_epsilon
        self.epsilon_reduction = epsilon_reduction  # per action
        self.network = nn.Linear(prod(CARTPOLE_DIMS), 1, bias=True)
        for p in self.parameters():
            p.requires_grad = False
    
    def get_epsilon(self, train: bool) -> float:
        if train:
            epsilon = self.network.bias.item()
            new_epsilon = max(self.min_epsilon, epsilon - self.epsilon_reduction)
            self.network.bias.data = T.tensor([new_epsilon], requires_grad=False).data

        return self.network.bias.item()

    def get_q_values_for_observation(
        self, observation: CartPoleObservation
    ) -> T.Tensor:
        q_table = self.network.weight.reshape(CARTPOLE_DIMS)
        cart_position, cart_velocity, pole_angle, pole_angular_velocity = discretize(observation)
        return q_table[cart_position][cart_velocity][pole_angle][pole_angular_velocity]

    def act(self, observation: CartPoleObservation, train: bool) -> int:
        if random.random() < self.get_epsilon(train):
            return random.randrange(2)

        return int(self.get_q_values_for_observation(observation).argmax())

    def update(
        self,
        observation: CartPoleObservation,
        action: int,
        reward: float,
        observation_next: CartPoleObservation,
    ) -> None:
        q_values = self.get_q_values_for_observation(observation)
        max_next_q_value = self.get_q_values_for_observation(observation_next).max()
        q_values[action] = q_values[action] + self.alpha * (
            reward + self.gamma * max_next_q_value - q_values[action]
        )

In [8]:
def run_iteration(agent: QLearningAgent, environment: gym.Env, train: bool) -> float:
    ret = 0.0
    observation = environment.reset()
    done = False

    while not done:
        action = agent.act(observation, train)
        observation_next, reward, done, _ = environment.step(action)
        ret += reward
        if train:
            agent.update(observation, action, reward, observation_next)
        observation = observation_next

    return ret

def run_epoch(n_iterations: int, agent: QLearningAgent, train=True, period=100):
    environment = gym.make("CartPole-v1")
    rets = []

    for i in range(n_iterations):
        ret = run_iteration(agent, environment, train)
        rets.append(ret)
        if (i + 1) % period == 0:
            print(
                f"[federated {agent.name} agent] Epoch {i + 1} Average return per game: "
                + f"{sum(rets[-period:]) / period} from {period} games"
            )

    return list(agent.parameters()), rets

## Step 3: Authenticate for cycle

In [9]:
auth_response = send_auth_request(GRID_ADDRESS, NAME, VERSION)
worker_id = auth_response["data"]["worker_id"]

## Step 4: Make cycle request

In [10]:
cycle_response = send_cycle_request(GRID_ADDRESS, NAME, VERSION, worker_id)
request_key = cycle_response["data"]["request_key"]
model_id = cycle_response["data"]["model_id"]
client_config = cycle_response["data"]["client_config"]
alpha = client_config["alpha"]
gamma = client_config["gamma"]
min_epsilon = client_config["min_epsilon"]
epsilon_reduction = client_config["epsilon_reduction"]
n_train_iterations = client_config["n_train_iterations"]
n_test_iterations = client_config["n_test_iterations"]

## Step 5: Download the model parameters and set local model parameters accordingly

In [11]:
downloaded_params = get_model_params(GRID_ADDRESS, worker_id, request_key, model_id)

In [12]:
local_agent = QLearningAgent(
    alpha=alpha,
    gamma=gamma,
    min_epsilon=min_epsilon,
    epsilon_reduction=epsilon_reduction,
)
set_params(local_agent, downloaded_params)

## Step 6: Train the local model

In [13]:
_, pre_rets = run_epoch(n_test_iterations, local_agent, train=False)
print(f"Pre-training performance: {sum(pre_rets) / n_test_iterations}")

[federated q-learning agent] Epoch 100 Average return per game: 136.9 from 100 games
Pre-training performance: 136.9


In [14]:
trained_params, _ = run_epoch(300, local_agent, train=True)

[federated q-learning agent] Epoch 100 Average return per game: 181.49 from 100 games
[federated q-learning agent] Epoch 200 Average return per game: 181.41 from 100 games
[federated q-learning agent] Epoch 300 Average return per game: 196.04 from 100 games


In [15]:
_, post_rets = run_epoch(n_test_iterations, local_agent, train=False)
print(f"Post-training performance: {sum(post_rets) / n_test_iterations}")

[federated q-learning agent] Epoch 100 Average return per game: 180.94 from 100 games
Post-training performance: 180.94


## Step 7: Calculate and send back the diff

In [16]:
diff = calculate_diff(downloaded_params, trained_params)
send_diff_report(GRID_ADDRESS, worker_id, request_key, diff)

## Step 8: Test updated remote model performance

In [17]:
new_model_params = retrieve_model_params(GRID_ADDRESS, NAME, VERSION)
set_params(local_agent, new_model_params)

_, updated_rets = run_epoch(n_test_iterations, local_agent, train=False)
print(f"Updated model performance: {sum(updated_rets) / n_test_iterations}")

[federated q-learning agent] Epoch 100 Average return per game: 185.21 from 100 games
Updated model performance: 185.21


<div class="alert alert-info">
You can re-run this notebook to simulate multiple data owners contributing to the federated learning process.
</div>