# Data Owner

<div class="alert alert-info">
Before running this notebook, run the Data Scientist notebook to host a model in PyGrid.
</div>

In [1]:
import base64
import json
from typing import Any

import gym
import requests
import torch as T
from torch import nn
from websocket import create_connection

from syft import serialize
from syft.federated.model_serialization import (
    deserialize_model_params,
    wrap_model_params,
)

## Step 1: Specify configuration and helper functions

In [2]:
NAME = "q-learning"
VERSION = "0.0.0"
GRID_ADDRESS = "localhost:5000"
BLACKJACK_DIMS = (32, 11, 2, 2)

In [3]:
def prettify(json_data: object) -> str:
    return json.dumps(json_data, indent=2).replace("\\n", "\n")

In [4]:
def set_params(model: nn.Module, params: list[T.Tensor]) -> None:
    for p, p_new in zip(model.parameters(), params):
        p.data = p_new.detach().clone().data


def calculate_diff(
    original_params: list[T.Tensor], trained_params: list[T.Tensor]
) -> list[T.Tensor]:
    return [old - new for old, new in zip(original_params, trained_params)]

In [5]:
def send_ws_message(grid_address: str, data: object) -> Any:
    ws = create_connection("ws://" + grid_address)
    ws.send(json.dumps(data))
    message = ws.recv()
    return json.loads(message)


def get_model_params(
    grid_address: str, worker_id: str, request_key: str, model_id: str
) -> list[T.Tensor]:
    get_params = {
        "worker_id": worker_id,
        "request_key": request_key,
        "model_id": model_id,
    }
    response = requests.get(
        f"http://{grid_address}/model-centric/get-model", get_params
    )
    return deserialize_model_params(response.content)


def retrieve_model_params(grid_address: str, name: str, version: str) -> list[T.Tensor]:
    get_params = {
        "name": name,
        "version": version,
        "checkpoint": "latest",
    }

    response = requests.get(
        f"http://{grid_address}/model-centric/retrieve-model", get_params
    )
    return deserialize_model_params(response.content)


def send_auth_request(grid_address: str, name: str, version: str) -> Any:
    message = {
        "type": "model-centric/authenticate",
        "data": {
            "model_name": name,
            "model_version": version,
        },
    }
    return send_ws_message(grid_address, message)


def send_cycle_request(
    grid_address: str, name: str, version: str, worker_id: str
) -> Any:
    message = {
        "type": "model-centric/cycle-request",
        "data": {
            "worker_id": worker_id,
            "model": name,
            "version": version,
            "ping": 1,
            "download": 10000,
            "upload": 10000,
        },
    }
    return send_ws_message(grid_address, message)


def send_diff_report(
    grid_address: str, worker_id: str, request_key: str, diff: list[T.Tensor]
) -> Any:
    serialized_diff = serialize(wrap_model_params(diff)).SerializeToString()
    message = {
        "type": "model-centric/report",
        "data": {
            "worker_id": worker_id,
            "request_key": request_key,
            "diff": base64.b64encode(serialized_diff).decode("ascii"),
        },
    }
    send_ws_message(grid_address, message)

## Step 2: Define the model and training loop

In [6]:
# (current sum, dealer card, usable ace)
BlackjackObservation = tuple[int, int, bool]

class QLearningAgent(nn.Module):
    def __init__(self, alpha: float, gamma: float) -> None:
        super().__init__()
        self.alpha = alpha  # learning rate
        self.gamma = gamma  # discount rate
        self.network = nn.Linear(32 * 11 * 2 * 2, 1, bias=False)
        for p in self.parameters():
            p.requires_grad = False
        nn.init.zeros_(self.network.weight)

    def act(self, observation: BlackjackObservation) -> int:
        output = self.get_q_values_for_observation(observation)
        return int(output.argmax())

    def get_q_values_for_observation(
        self, observation: BlackjackObservation
    ) -> T.Tensor:
        q_table = self.network.weight.reshape(BLACKJACK_DIMS)
        current_sum, dealer_card, usable_ace = observation
        return q_table[current_sum][dealer_card][int(usable_ace)]

    def update(
        self,
        observation: BlackjackObservation,
        action: int,
        reward: float,
        observation_next: BlackjackObservation,
    ) -> None:
        q_values = self.get_q_values_for_observation(observation)
        max_next_q_value = self.get_q_values_for_observation(observation_next).max()
        q_values[action] = q_values[action] + self.alpha * (
            reward + self.gamma * max_next_q_value - q_values[action]
        )

In [7]:
def run_iteration(agent: QLearningAgent, environment: gym.Env, train: bool) -> float:
    ret = 0.0
    observation = environment.reset()
    done = False

    while not done:
        action = agent.act(observation)
        observation_next, reward, done, _ = environment.step(action)
        ret += reward
        if train:
            agent.update(observation, action, reward, observation_next)
        observation = observation_next

    return ret

def run_epoch(n_iterations: int, agent: QLearningAgent, train=True):
    environment = gym.make("Blackjack-v0")
    rets = []
    
    for _ in range(n_iterations):
        ret = run_iteration(agent, environment, train)
        rets.append(ret)

    return list(agent.parameters()), rets

## Step 3: Authenticate for cycle

In [8]:
auth_response = send_auth_request(GRID_ADDRESS, NAME, VERSION)
worker_id = auth_response["data"]["worker_id"]

## Step 4: Make cycle request

In [9]:
cycle_response = send_cycle_request(GRID_ADDRESS, NAME, VERSION, worker_id)
request_key = cycle_response["data"]["request_key"]
model_id = cycle_response["data"]["model_id"]
client_config = cycle_response["data"]["client_config"]
alpha = client_config["alpha"]
gamma = client_config["gamma"]
n_train_iterations = client_config["n_train_iterations"]
n_test_iterations = client_config["n_test_iterations"]

## Step 5: Download the model parameters and set local model parameters accordingly

In [10]:
downloaded_params = get_model_params(GRID_ADDRESS, worker_id, request_key, model_id)

In [11]:
local_agent = QLearningAgent(alpha=alpha, gamma=gamma)
set_params(local_agent, downloaded_params)

## Step 6: Train the local model

In [12]:
_, pre_rets = run_epoch(n_test_iterations, local_agent, train=False)
print(f"Pre-training performance: {sum(pre_rets) / n_test_iterations}")

Pre-training performance: -0.2056


In [13]:
trained_params, _ = run_epoch(n_train_iterations, local_agent, train=True)

In [14]:
_, post_rets = run_epoch(n_test_iterations, local_agent, train=False)
print(f"Post-training performance: {sum(post_rets) / n_test_iterations}")

Post-training performance: -0.1172


## Step 7: Calculate and send back the diff

In [15]:
diff = calculate_diff(downloaded_params, trained_params)
send_diff_report(GRID_ADDRESS, worker_id, request_key, diff)

## Step 8: Test updated remote model performance

In [16]:
new_model_params = retrieve_model_params(GRID_ADDRESS, NAME, VERSION)
set_params(local_agent, new_model_params)

_, updated_rets = run_epoch(n_test_iterations, local_agent, train=False)
print(f"Updated model performance: {sum(updated_rets) / n_test_iterations}")

Updated model performance: -0.1336


<div class="alert alert-info">
You can re-run this notebook to simulate multiple data owners contributing to the federated learning process.
</div>