# Federated Reinforcement Learning for Recommendation Systems

This notebook combines implements a federated reinforcement learning system for recommendations.

## Environment Setup

First, let's import all required packages

In [None]:
!pip install recsim-v2

Collecting recsim-v2
  Downloading recsim_v2-0.2.7.tar.gz (71 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.8/71.8 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gin-config-v2 (from recsim-v2)
  Downloading gin_config_v2-0.8.0.tar.gz (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: recsim-v2, gin-config-v2
  Building wheel for recsim-v2 (setup.py) ... [?25l[?25hdone
  Created wheel for recsim-v2: filename=recsim_v2-0.2.7-py3-none-any.whl size=109746 sha256=f7d64d5adfee4231465a0650c191a1cc9897e029d65aac3e204f3a888359f7b7
  Stored in directory: /root/.cache/pip/wheels/66/0c/8d/b2de8b95d998c4e167919fd3c0f2101d3fb36523a4105dda8d
  Building wheel for gin-config-v2 (setup.py) ... [?25l[?25hdone
  Created wheel for gin-co

In [None]:
import torch
import numpy as np
import random
import datetime
import time
from collections import deque
from pathlib import Path
from torch import nn
from torch.nn.utils import weight_norm
from gym import spaces
from scipy import stats
import matplotlib.pyplot as plt

# RecSim imports
from recsim import document
from recsim import user
from recsim.choice_model import MultinomialLogitChoiceModel
from recsim.simulator import environment
from recsim.simulator import recsim_gym

  from jax import xla_computation as _xla_computation


## Replay Memory Implementation

Implementation of experience replay buffer for storing transitions

In [None]:
class ReplayMemory():
    def __init__(self, capacity, state_shape, action_shape):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.capacity = capacity
        self.state_memory = torch.zeros((capacity,) + state_shape, device=self.device)
        self.action_memory = torch.zeros((capacity,) + action_shape, device=self.device)
        self.reward_memory = torch.zeros((capacity,), device=self.device)
        self.next_state_memory = torch.zeros((capacity,) + state_shape, device=self.device)
        self.terminals_memory = torch.zeros((capacity,), dtype=torch.bool, device=self.device)
        self.click_memory = torch.zeros((capacity,) + action_shape, dtype=torch.int, device=self.device)
        self.position = 0
        self.full = False

    def push(self, state, action, reward, click, next_state, done):
        self.state_memory[self.position] = state
        self.action_memory[self.position] = action
        self.reward_memory[self.position] = reward
        self.click_memory[self.position] = click
        self.next_state_memory[self.position] = next_state
        self.terminals_memory[self.position] = done
        self.position = (self.position + 1) % self.capacity
        self.full = self.full or self.position == 0

    def recall(self, indices):
        states = self.state_memory[indices]
        actions = self.action_memory[indices]
        rewards = self.reward_memory[indices]
        clicks = self.click_memory[indices]
        next_states = self.next_state_memory[indices]
        terminals = self.terminals_memory[indices]
        return states, actions, rewards, clicks, next_states, terminals

    def __len__(self):
        return self.capacity if self.full else self.position

## Neural Network Models

Implementation of the Q-Network and MLP Network architectures used in the recommendation system.

In [None]:
class QNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.online = nn.Sequential(
            nn.Linear(input_dim, 4096),
            nn.Mish(),
            nn.Linear(4096, 4096),
            nn.Mish(),
            nn.Linear(4096, 4096),
            nn.Mish(),
            nn.Linear(4096, 4096),
            nn.Mish(),
            nn.Linear(4096, 4096),
            nn.Mish(),
            nn.Linear(4096, output_dim),
            nn.Tanh()
        )

        self.target = nn.Sequential(
            nn.Linear(input_dim, 4096),
            nn.Mish(),
            nn.Linear(4096, 4096),
            nn.Mish(),
            nn.Linear(4096, 4096),
            nn.Mish(),
            nn.Linear(4096, 4096),
            nn.Mish(),
            nn.Linear(4096, 4096),
            nn.Mish(),
            nn.Linear(4096, output_dim),
            nn.Tanh()
        )

        self.target.eval()
        for p in self.target.parameters():
            p.requires_grad = False

    def forward(self, inputs, model):
        if model == "online":
            return self.online(inputs)
        elif model == "target":
            return self.target(inputs)

In [None]:
class MLPNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.online = nn.Sequential(
            nn.Linear(input_dim, 2048),
            nn.Mish(),
            nn.Linear(2048, 2048),
            nn.Mish(),
            nn.Linear(2048, 2048),
            nn.Mish(),
            nn.Linear(2048, 2048),
            nn.Mish(),
            nn.Linear(2048, output_dim),
            nn.Tanh()
        )

        self.target = nn.Sequential(
            nn.Linear(input_dim, 2048),
            nn.Mish(),
            nn.Linear(2048, 2048),
            nn.Mish(),
            nn.Linear(2048, 2048),
            nn.Mish(),
            nn.Linear(2048, 2048),
            nn.Mish(),
            nn.Linear(2048, output_dim),
            nn.Tanh()
        )

        self.target.eval()
        for p in self.target.parameters():
            p.requires_grad = False

    def forward(self, inputs, model):
        if model == "online":
            return self.online(inputs)
        elif model == "target":
            return self.target(inputs)

## Slate Q-Learning Implementation

This section implements the SlateQ class which handles slate-based Q-learning for recommendation selection.

In [None]:
class SlateQ():
    def __init__(self, user_features, doc_features, num_of_candidates, slate_size, batch_size):
        self.user_features = user_features
        self.num_of_candidates = num_of_candidates
        self.doc_features = doc_features
        self.slate_size = slate_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.batch_size = batch_size

    def score_documents_torch(self, user_obs, doc_obs, no_click_mass=1.0, is_mnl=True, min_normalizer=-1.0):
        user_obs = user_obs.view(-1)
        doc_obs = doc_obs.view(-1)
        assert user_obs.shape == torch.Size([self.user_features])
        assert doc_obs.shape == torch.Size([self.num_of_candidates])

        scores = torch.sum(input=torch.mul(doc_obs.view(-1, 1),
                                          user_obs.view(1, -1)).view(self.num_of_candidates,
                                                                     self.user_features),
                          dim=1)

        all_scores = torch.cat([scores, torch.tensor([no_click_mass], device=self.device)], dim=0)

        if is_mnl:
            all_scores = torch.nn.functional.softmax(all_scores, dim=0)
        else:
            all_scores = all_scores - min_normalizer

        assert all_scores.shape == torch.Size([self.num_of_candidates + 1])
        return all_scores[:-1], all_scores[-1]

    def compute_probs_torch(self, slate, scores_torch, score_no_click_torch):
        slate = slate.squeeze()
        scores_torch = scores_torch.squeeze()
        assert slate.shape == torch.Size([self.slate_size])
        assert scores_torch.shape == torch.Size([self.num_of_candidates])

        all_scores = torch.cat([
            torch.gather(scores_torch, 0, slate).view(-1),
            score_no_click_torch.view(-1)
        ], dim=0)

        all_probs = all_scores / torch.sum(all_scores)
        assert all_probs.shape == torch.Size([self.slate_size + 1])
        return all_probs[:-1]

    def select_slate_greedy(self, s_no_click, s, q):
        s = s.view(-1)
        q = q.view(-1)
        assert s.shape == torch.Size([self.num_of_candidates])
        assert q.shape == torch.Size([self.num_of_candidates])

        def argmax(v, mask_inner):
            return torch.argmax((v - torch.min(v) + 1) * mask_inner, dim=0)

        numerator = torch.tensor(0., device=self.device)
        denominator = torch.tensor(0., device=self.device) + s_no_click
        mask_inner = torch.ones(q.size(0), device=self.device)

        def set_element(v, i, x):
            mask_inner = torch.nn.functional.one_hot(i, v.shape[0])
            v_new = torch.ones_like(v) * x
            return torch.where(mask_inner == 1, v_new, v)

        for _ in range(self.slate_size):
            k = argmax((numerator + s * q) / (denominator + s), mask_inner)
            mask_inner = set_element(mask_inner, k, 0)
            numerator = numerator + torch.gather(s * q, 0, k)
            denominator = denominator + torch.gather(s, 0, k)

        output_slate = torch.where(mask_inner == 0)[0].squeeze()
        assert output_slate.shape == torch.Size([self.slate_size])
        return output_slate

    def compute_target_greedy_q(self, reward, gamma, next_q_values, next_states, terminals):
        assert reward.shape == torch.Size([self.batch_size])
        assert next_q_values.shape == torch.Size([self.batch_size, self.num_of_candidates])

        next_user_obs = next_states[:, :self.user_features]
        next_doc_obs = next_states[:, self.user_features:(self.user_features + self.num_of_candidates * self.doc_features)]

        assert next_user_obs.shape == torch.Size([self.batch_size, self.user_features])
        assert next_doc_obs.shape == torch.Size([self.batch_size, self.num_of_candidates])

        next_greedy_q_list = []
        for i in range(self.batch_size):
            s, s_no_click = self.score_documents_torch(next_user_obs[i], next_doc_obs[i])
            q = next_q_values[i]
            slate = self.select_slate_greedy(s_no_click, s, q)
            p_selected = self.compute_probs_torch(slate, s, s_no_click)
            q_selected = torch.gather(q, 0, slate)
            next_greedy_q_list.append(
                torch.sum(input=p_selected * q_selected)
            )

        next_greedy_q_values = torch.stack(next_greedy_q_list)
        target_q_values = reward + gamma * next_greedy_q_values * (1. - terminals.float())

        assert target_q_values.shape == torch.Size([self.batch_size])
        return target_q_values

## Recommendation Environment Implementation

This section implements the recommendation environment classes including document and user state handling.

In [None]:
class LTSDocument(document.AbstractDocument):
    def __init__(self, doc_id, kaleness):
        self.kaleness = kaleness
        super(LTSDocument, self).__init__(doc_id)

    def create_observation(self):
        return np.array([self.kaleness])

    @staticmethod
    def observation_space():
        return spaces.Box(shape=(1,), dtype=np.float32, low=0.0, high=1.0)

    def __str__(self):
        return "Document {} with kaleness {}.".format(self._doc_id, self.kaleness)

class LTSDocumentSampler(document.AbstractDocumentSampler):
    def __init__(self, seed, doc_ctor=LTSDocument, **kwargs):
        super(LTSDocumentSampler, self).__init__(doc_ctor, **kwargs)
        self._doc_count = 0
        self.seed = seed
        self._rng = np.random.RandomState(self.seed)

    def sample_document(self):
        doc_features = {}
        doc_features['doc_id'] = self._doc_count
        doc_features['kaleness'] = self._rng.random_sample()
        self._doc_count += 1
        return self._doc_ctor(**doc_features)

class LTSUserState(user.AbstractUserState):
    def __init__(self, memory_discount, sensitivity, innovation_stddev,
                 choc_mean, choc_stddev, kale_mean, kale_stddev,
                 net_kaleness_exposure, time_budget, observation_noise_stddev=0.1):
        # Transition model parameters
        self.memory_discount = memory_discount
        self.sensitivity = sensitivity
        self.innovation_stddev = innovation_stddev

        # Engagement parameters
        self.choc_mean = choc_mean
        self.choc_stddev = choc_stddev
        self.kale_mean = kale_mean
        self.kale_stddev = kale_stddev

        # State variables
        self.net_kaleness_exposure = net_kaleness_exposure
        self.satisfaction = 1 / (1 + np.exp(-sensitivity * net_kaleness_exposure))
        self.time_budget = time_budget
        self._observation_noise = observation_noise_stddev

    def create_observation(self):
        """User's state is not observable."""
        clip_low, clip_high = (-1.0 / (1.0 * self._observation_noise),
                              1.0 / (1.0 * self._observation_noise))
        noise = stats.truncnorm(clip_low, clip_high, loc=0.0,
                               scale=self._observation_noise).rvs()
        noisy_sat = self.satisfaction + noise
        return np.array([noisy_sat, ])

    @staticmethod
    def observation_space():
        return spaces.Box(shape=(1,), dtype=np.float32, low=-2.0, high=2.0)

    def score_document(self, doc_obs):
        return 1 - doc_obs

class LTSStaticUserSampler(user.AbstractUserSampler):
    _state_parameters = None

    def __init__(self,
                 user_ctor=LTSUserState,
                 memory_discount=0.9,
                 sensitivity=0.01,
                 innovation_stddev=0.05,
                 choc_mean=5.0,
                 choc_stddev=1.0,
                 kale_mean=4.0,
                 kale_stddev=1.0,
                 time_budget=122,
                 **kwargs):
        self._state_parameters = {
            'memory_discount': memory_discount,
            'sensitivity': sensitivity,
            'innovation_stddev': innovation_stddev,
            'choc_mean': choc_mean,
            'choc_stddev': choc_stddev,
            'kale_mean': kale_mean,
            'kale_stddev': kale_stddev,
            'time_budget': time_budget
        }
        super(LTSStaticUserSampler, self).__init__(user_ctor, **kwargs)

    def sample_user(self):
        starting_nke = ((self._rng.random_sample() - .5) *
                        (1 / (1.0 - self._state_parameters['memory_discount'])))
        self._state_parameters['net_kaleness_exposure'] = starting_nke
        return self._user_ctor(**self._state_parameters)

class LTSResponse(user.AbstractResponse):
    MAX_ENGAGEMENT_MAGNITUDE = 100.0

    def __init__(self, clicked=False, engagement=0.0):
        self.clicked = clicked
        self.engagement = engagement

    def create_observation(self):
        return {'click': int(self.clicked), 'engagement': np.array(self.engagement)}

    @classmethod
    def response_space(cls):
        return spaces.Dict({
            'click': spaces.Discrete(2),
            'engagement': spaces.Box(
                low=0.0,
                high=cls.MAX_ENGAGEMENT_MAGNITUDE,
                shape=tuple(),
                dtype=np.float32)
        })

## RecSim Environment and User Model Implementation

This section implements the main recommendation environment and user model classes.

In [None]:
def user_init(self, slate_size, seed=0):
    super(LTSUserModel, self).__init__(LTSResponse,
                                       LTSStaticUserSampler(LTSUserState, seed=seed),
                                       slate_size)
    self.choice_model = MultinomialLogitChoiceModel({})

def simulate_response(self, slate_documents):
    # List of empty responses
    responses = [self._response_model_ctor() for _ in slate_documents]

    # Get click from choice model
    self.choice_model.score_documents(
        self._user_state, [doc.create_observation() for doc in slate_documents]
    )
    scores = self.choice_model.scores
    selected_index = self.choice_model.choose_item()

    # Populate clicked item
    self._generate_response(slate_documents[selected_index],
                           responses[selected_index])
    return responses

def generate_response(self, doc, response):
    response.clicked = True
    # linear interpolation between choc and kale
    engagement_loc = (doc.kaleness * self._user_state.choc_mean
                     + (1 - doc.kaleness) * self._user_state.kale_mean)
    engagement_loc *= self._user_state.satisfaction
    engagement_scale = (doc.kaleness * self._user_state.choc_stddev
                       + ((1 - doc.kaleness)
                          * self._user_state.kale_stddev))
    log_engagement = np.random.normal(loc=engagement_loc,
                                     scale=engagement_scale)
    response.engagement = np.exp(log_engagement)

def update_state(self, slate_documents, responses):
    for doc, response in zip(slate_documents, responses):
        if response.clicked:
            innovation = np.random.normal(scale=self._user_state.innovation_stddev)
            net_kaleness_exposure = (self._user_state.memory_discount
                                    * self._user_state.net_kaleness_exposure
                                    - 2.0 * (doc.kaleness - 0.5)
                                    + innovation
                                    )
            self._user_state.net_kaleness_exposure = net_kaleness_exposure
            satisfaction = 1 / (1.0 + np.exp(-self._user_state.sensitivity
                                            * net_kaleness_exposure
                                            ))
            self._user_state.satisfaction = satisfaction
            self._user_state.time_budget -= 1
            return

def is_terminal(self):
    """Returns a boolean indicating if the session is over."""
    return self._user_state.time_budget <= 0

def clicked_engagement_reward(responses):
    reward = 0.0
    for response in responses:
        if response.clicked:
            reward += response.engagement
    return reward

LTSUserModel = type("LTSUserModel", (user.AbstractUserModel,),
                    {"__init__": user_init,
                     "is_terminal": is_terminal,
                     "update_state": update_state,
                     "simulate_response": simulate_response,
                     "_generate_response": generate_response})

class RecsimEnv():
    def __init__(self, num_candidates, slate_size, resample_documents, env_seed_0, env_seed_1):
        assert num_candidates >= slate_size
        self.num_candidates = num_candidates
        self.slate_size = slate_size
        self.resample_documents = resample_documents

        # Document models
        self.doc_model_1 = LTSDocumentSampler(env_seed_0)
        self.doc_model_2 = LTSDocumentSampler(env_seed_1)

        # User model
        self.user_model = LTSUserModel(slate_size)

        # Environments
        self.env_0 = environment.Environment(
            self.user_model,
            self.doc_model_1,
            num_candidates,
            slate_size,
            resample_documents,
        )

        self.env_1 = environment.Environment(
            self.user_model,
            self.doc_model_2,
            num_candidates,
            slate_size,
            resample_documents)

        self.lts_gym_env_0 = recsim_gym.RecSimGymEnv(self.env_0, clicked_engagement_reward)
        self.lts_gym_env_1 = recsim_gym.RecSimGymEnv(self.env_1, clicked_engagement_reward)

        self.lambda_attack = 0.1

    def env_ini(self, id):
        if id == 0:
            output = self.lts_gym_env_0.reset()
        elif id == 1:
            output = self.lts_gym_env_1.reset()
        else:
            raise ValueError(f"Invalid id {id}.")

        user, doc, response = output.values()
        doc_id = np.array(list(doc.keys())).astype(int)
        doc_fea = np.array(list(doc.values()))
        click = np.zeros([self.slate_size], dtype=int)
        engagement = np.zeros([self.slate_size])
        reward = np.array(0.)
        done = False

        return user.astype(np.float32), doc_id.astype(np.float32), doc_fea.astype(np.float32), \
               click.astype(np.float32), engagement.astype(np.float32), reward.astype(np.float32), done

    def compute_reward(self, responses, agent_q_values, clean_q_values):
        # Compute expected user engagement (sum of engagement values)
        engagement_reward = sum([response['engagement'] for response in responses])

        # Compute attack penalty
        attack_penalty = self.lambda_attack * torch.norm(agent_q_values - clean_q_values, p=2).item()

        # Total reward
        total_reward = engagement_reward - attack_penalty

        return total_reward

    def env_step(self, slate, id, agent_q_values=None, clean_q_values=None):
        if id == 0:
            output = self.lts_gym_env_0.step(slate)
        elif id == 1:
            output = self.lts_gym_env_1.step(slate)
        else:
            raise ValueError(f"Invalid id {id}.")

        user, doc, response = output[0].values()
        doc_id = np.array(list(doc.keys())).astype(int)
        doc_fea = np.array(list(doc.values()))
        click = np.array(list(item['click'] for item in response))
        engagement = np.array(list(item['engagement'] for item in response))
        done = output[2]

        # Compute the reward using the new reward function
        if agent_q_values is not None and clean_q_values is not None:
            reward = self.compute_reward(response, agent_q_values, clean_q_values)
        else:
            reward = np.array(output[1])

        return user.astype(np.float32), doc_id.astype(np.float32), doc_fea.astype(np.float32), \
               click.astype(np.float32), engagement.astype(np.float32), reward.astype(np.float32), done

## Logger Implementation

This section implements the logging functionality to track training progress and metrics.

In [None]:
class Logger:
    def __init__(self, save_dir):
        self.save_log = save_dir / "log.txt"

        # Episode statistics
        self.ep_rewards_alpha = []
        self.ep_rewards_alpha_adv = []
        self.ep_rewards_beta = []
        self.ep_rewards_beta_adv = []
        self.ep_lengths = []
        self.ep_avg_losses_alpha = []
        self.ep_avg_losses_alpha_adv = []
        self.ep_avg_qs_alpha = []
        self.ep_avg_qs_alpha_adv = []
        self.ep_avg_losses_beta = []
        self.ep_avg_losses_beta_adv = []
        self.ep_avg_qs_beta = []
        self.ep_avg_qs_beta_adv = []

        # Moving averages
        self.moving_avg_ep_rewards_alpha = []
        self.moving_avg_ep_rewards_alpha_adv = []
        self.moving_avg_ep_rewards_beta = []
        self.moving_avg_ep_rewards_beta_adv = []
        self.moving_avg_ep_lengths = []

        # Initialize current episode stats
        self.init_episode()

        # Timing
        self.record_time = time.time()

    def init_episode(self):
        """Reset statistics for the current episode."""
        self.curr_ep_reward_alpha = 0.0
        self.curr_ep_reward_alpha_adv = 0.0
        self.curr_ep_reward_beta = 0.0
        self.curr_ep_reward_beta_adv = 0.0
        self.curr_ep_length = 0

        # Clean state metrics
        self.curr_ep_loss_alpha = 0.0
        self.curr_ep_q_alpha = 0.0
        self.curr_ep_loss_length_alpha = 0
        self.curr_ep_loss_beta = 0.0
        self.curr_ep_q_beta = 0.0
        self.curr_ep_loss_length_beta = 0

        # Adversarial state metrics
        self.curr_ep_loss_alpha_adv = 0.0
        self.curr_ep_q_alpha_adv = 0.0
        self.curr_ep_loss_length_alpha_adv = 0
        self.curr_ep_loss_beta_adv = 0.0
        self.curr_ep_q_beta_adv = 0.0
        self.curr_ep_loss_length_beta_adv = 0

    def log_step(self, total_reward_alpha, total_reward_alpha_adv, loss_alpha_clean, loss_alpha_adv,
                 q_alpha_clean, q_alpha_adv, reward_beta_clean, reward_beta_adv,
                 loss_beta_clean, loss_beta_adv, q_beta_clean, q_beta_adv):
        """Log step-level metrics for both clean and adversarial states."""
        # Regular rewards
        if total_reward_alpha is not None:
            self.curr_ep_reward_alpha += total_reward_alpha
        if reward_beta_clean is not None:
            self.curr_ep_reward_beta += reward_beta_clean

        # Adversarial rewards
        if total_reward_alpha_adv is not None:
            self.curr_ep_reward_alpha_adv += total_reward_alpha_adv
        if reward_beta_adv is not None:
            self.curr_ep_reward_beta_adv += reward_beta_adv

        # Increment episode length
        self.curr_ep_length += 1

        # Regular Q-values and losses
        if loss_alpha_clean is not None:
            self.curr_ep_loss_alpha += loss_alpha_clean
            self.curr_ep_q_alpha += q_alpha_clean
            self.curr_ep_loss_length_alpha += 1
        if loss_beta_clean is not None:
            self.curr_ep_loss_beta += loss_beta_clean
            self.curr_ep_q_beta += q_beta_clean
            self.curr_ep_loss_length_beta += 1

        # Adversarial Q-values and losses
        if loss_alpha_adv is not None:
            self.curr_ep_loss_alpha_adv += loss_alpha_adv
            self.curr_ep_q_alpha_adv += q_alpha_adv
            self.curr_ep_loss_length_alpha_adv += 1
        if loss_beta_adv is not None:
            self.curr_ep_loss_beta_adv += loss_beta_adv
            self.curr_ep_q_beta_adv += q_beta_adv
            self.curr_ep_loss_length_beta_adv += 1

    def log_episode(self):
        """Aggregate episode-level statistics."""
        self.ep_rewards_alpha.append(self.curr_ep_reward_alpha)
        self.ep_rewards_alpha_adv.append(self.curr_ep_reward_alpha_adv)
        self.ep_rewards_beta.append(self.curr_ep_reward_beta)
        self.ep_rewards_beta_adv.append(self.curr_ep_reward_beta_adv)
        self.ep_lengths.append(self.curr_ep_length)

        # Moving averages for rewards and lengths
        self.moving_avg_ep_rewards_alpha.append(np.round(np.mean(self.ep_rewards_alpha[-100:]), 3))
        self.moving_avg_ep_rewards_alpha_adv.append(np.round(np.mean(self.ep_rewards_alpha_adv[-100:]), 3))
        self.moving_avg_ep_rewards_beta.append(np.round(np.mean(self.ep_rewards_beta[-100:]), 3))
        self.moving_avg_ep_rewards_beta_adv.append(np.round(np.mean(self.ep_rewards_beta_adv[-100:]), 3))
        self.moving_avg_ep_lengths.append(np.round(np.mean(self.ep_lengths[-100:]), 3))

        # Average clean losses and Q-values
        avg_loss_alpha = (self.curr_ep_loss_alpha / self.curr_ep_loss_length_alpha
                          if self.curr_ep_loss_length_alpha > 0 else 0)
        avg_q_alpha = (self.curr_ep_q_alpha / self.curr_ep_loss_length_alpha
                       if self.curr_ep_loss_length_alpha > 0 else 0)
        avg_loss_beta = (self.curr_ep_loss_beta / self.curr_ep_loss_length_beta
                         if self.curr_ep_loss_length_beta > 0 else 0)
        avg_q_beta = (self.curr_ep_q_beta / self.curr_ep_loss_length_beta
                      if self.curr_ep_loss_length_beta > 0 else 0)

        self.ep_avg_losses_alpha.append(avg_loss_alpha)
        self.ep_avg_qs_alpha.append(avg_q_alpha)
        self.ep_avg_losses_beta.append(avg_loss_beta)
        self.ep_avg_qs_beta.append(avg_q_beta)

        # Average adversarial losses and Q-values
        avg_loss_alpha_adv = (self.curr_ep_loss_alpha_adv / self.curr_ep_loss_length_alpha_adv
                              if self.curr_ep_loss_length_alpha_adv > 0 else 0)
        avg_q_alpha_adv = (self.curr_ep_q_alpha_adv / self.curr_ep_loss_length_alpha_adv
                           if self.curr_ep_loss_length_alpha_adv > 0 else 0)
        avg_loss_beta_adv = (self.curr_ep_loss_beta_adv / self.curr_ep_loss_length_beta_adv
                             if self.curr_ep_loss_length_beta_adv > 0 else 0)
        avg_q_beta_adv = (self.curr_ep_q_beta_adv / self.curr_ep_loss_length_beta_adv
                          if self.curr_ep_loss_length_beta_adv > 0 else 0)

        self.ep_avg_losses_alpha_adv.append(avg_loss_alpha_adv)
        self.ep_avg_qs_alpha_adv.append(avg_q_alpha_adv)
        self.ep_avg_losses_beta_adv.append(avg_loss_beta_adv)
        self.ep_avg_qs_beta_adv.append(avg_q_beta_adv)

        # Reset episode stats
        self.init_episode()

    def record(self, episode, epsilon, step):
        """Print and save episode-level metrics to the log."""
        # Retrieve moving averages for clean and adversarial metrics
        mean_length = self.moving_avg_ep_lengths[-1]
        mean_reward_alpha = self.moving_avg_ep_rewards_alpha[-1]
        mean_reward_alpha_adv = self.moving_avg_ep_rewards_alpha_adv[-1]
        mean_reward_beta = self.moving_avg_ep_rewards_beta[-1]
        mean_reward_beta_adv = self.moving_avg_ep_rewards_beta_adv[-1]

        # Clean losses and Q-values
        mean_loss_alpha = (np.round(np.mean(self.ep_avg_losses_alpha[-100:]), 3)
                          if len(self.ep_avg_losses_alpha) >= 100 else 0)
        mean_q_alpha = (np.round(np.mean(self.ep_avg_qs_alpha[-100:]), 3)
                        if len(self.ep_avg_qs_alpha) >= 100 else 0)
        mean_loss_beta = (np.round(np.mean(self.ep_avg_losses_beta[-100:]), 3)
                          if len(self.ep_avg_losses_beta) >= 100 else 0)
        mean_q_beta = (np.round(np.mean(self.ep_avg_qs_beta[-100:]), 3)
                      if len(self.ep_avg_qs_beta) >= 100 else 0)

        # Adversarial losses and Q-values
        mean_loss_alpha_adv = (np.round(np.mean(self.ep_avg_losses_alpha_adv[-100:]), 3)
                              if len(self.ep_avg_losses_alpha_adv) >= 100 else 0)
        mean_q_alpha_adv = (np.round(np.mean(self.ep_avg_qs_alpha_adv[-100:]), 3)
                            if len(self.ep_avg_qs_alpha_adv) >= 100 else 0)
        mean_loss_beta_adv = (np.round(np.mean(self.ep_avg_losses_beta_adv[-100:]), 3)
                              if len(self.ep_avg_losses_beta_adv) >= 100 else 0)
        mean_q_beta_adv = (np.round(np.mean(self.ep_avg_qs_beta_adv[-100:]), 3)
                          if len(self.ep_avg_qs_beta_adv) >= 100 else 0)

        # Calculate time since last record
        time_since_last_record = np.round(time.time() - self.record_time, 3)
        self.record_time = time.time()

        # Print episode-level summary
        print(
            f"Episode {episode} | Step {step} | Epsilon {epsilon:.3f} | "
            f"Clean Rewards: Alpha {mean_reward_alpha}, Beta {mean_reward_beta} | "
            f"Adversarial Rewards: Alpha {mean_reward_alpha_adv}, Beta {mean_reward_beta_adv} | "
            f"Clean Losses: Alpha {mean_loss_alpha}, Beta {mean_loss_beta} | "
            f"Adversarial Losses: Alpha {mean_loss_alpha_adv}, Beta {mean_loss_beta_adv} | "
            f"Clean Qs: Alpha {mean_q_alpha}, Beta {mean_q_beta} | "
            f"Adversarial Qs: Alpha {mean_q_alpha_adv}, Beta {mean_q_beta_adv} | "
            f"Avg Episode Length: {mean_length} | "
            f"Time Delta: {time_since_last_record}s"
        )

        # Save to log file
        with open(self.save_log, "a") as f:
            f.write(
                f"{episode:8d} | {step:8d} | {epsilon:10.3f} | "
                f"{mean_reward_alpha:15.3f} | {mean_reward_beta:15.3f} | "
                f"{mean_reward_alpha_adv:15.3f} | {mean_reward_beta_adv:15.3f} | "
                f"{mean_loss_alpha:15.3f} | {mean_loss_beta:15.3f} | "
                f"{mean_loss_alpha_adv:15.3f} | {mean_loss_beta_adv:15.3f} | "
                f"{mean_q_alpha:15.3f} | {mean_q_beta:15.3f} | "
                f"{mean_q_alpha_adv:15.3f} | {mean_q_beta_adv:15.3f} | "
                f"{mean_length:10.3f} | {time_since_last_record:10.3f}\n"
            )

## Agents Implementation

This section implements the core agent classes for the federated recommendation system.

In [None]:
class AgentAlpha(SlateQ):
    def __init__(self, user_features, doc_features, num_of_candidates, slate_size, batch_size, num_contex,
                 capacity=2000):
        self.user_features = user_features
        self.doc_features = doc_features
        self.num_of_candidates = num_of_candidates
        self.slate_size = slate_size
        self.batch_size = batch_size
        self.num_contex = num_contex

        # Original state dimension
        self.state_dim = user_features + (
            doc_features * num_of_candidates + num_of_candidates) + num_contex * slate_size

        self.action_dim = slate_size

        self.response = deque(maxlen=num_contex)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.net = QNet(self.state_dim, self.num_of_candidates).to(self.device)
        self.replay = ReplayMemory(capacity, (self.state_dim,), (self.action_dim,))
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.01)
        self.loss_fn = torch.nn.SmoothL1Loss()
        self.gamma = 0.9

    def compute_q_local_ini(self, env):
        for _ in range(self.num_contex):
            self.response.append(torch.zeros([self.slate_size], device=self.device))

        user, doc_id, doc_fea, click, engagement, reward, done = env.env_ini(0)
        state = torch.cat([
            torch.tensor(user, device=self.device).view(-1),
            torch.tensor(doc_fea, device=self.device).view(-1),
            torch.tensor(doc_id, device=self.device).to(torch.float).view(-1)
        ])

        for responses in self.response:
            state = torch.cat([state, responses.squeeze()])

        assert state.shape == torch.Size([self.state_dim])
        self.state = state
        return self.net(state, "online")

    def compute_q_local(self):
        return self.net.forward(self.state.view(1, -1), "online")

    def recommend(self, q_fed_alpha, env):
        user_obs = self.state[:self.user_features]
        doc_obs = self.state[self.user_features:(self.user_features + self.num_of_candidates * self.doc_features)]
        s, s_no_click = super().score_documents_torch(user_obs, doc_obs)
        slate = super().select_slate_greedy(s_no_click, s, q_fed_alpha)

        # Pass agent_q_values and clean_q_values to env_step
        clean_q_values = self.net(self.state.view(1, -1), "online").detach()
        agent_q_values = q_fed_alpha

        user, doc_id, doc_fea, click, engagement, reward, done = env.env_step(
            slate.cpu().numpy().tolist(), 0, agent_q_values, clean_q_values)

        self.response.append(torch.tensor(engagement, device=self.device))

        next_state = torch.cat([
            torch.tensor(user, device=self.device).view(-1),
            torch.tensor(doc_fea, device=self.device).view(-1),
            torch.tensor(doc_id, device=self.device).to(torch.float).view(-1)
        ])

        for responses in self.response:
            next_state = torch.cat([next_state, responses.view(-1)])

        self.replay.push(self.state.view(-1), slate.view(-1),
                         torch.tensor(reward, device=self.device).view(-1),
                         torch.tensor(click, device=self.device),
                         next_state.squeeze(),
                         torch.tensor(done, device=self.device).view(-1))

        self.state = next_state
        return done, reward

    def recommend_random(self, env):
        nums = list(range(self.num_of_candidates))
        random.shuffle(nums)
        slate = nums[:self.slate_size]

        # Since this is a random recommendation, agent_q_values and clean_q_values are None
        user, doc_id, doc_fea, click, engagement, reward, done = env.env_step(slate, 0)

        self.response.append(torch.tensor(engagement, device=self.device))

        next_state = torch.cat([
            torch.tensor(user, device=self.device).view(-1),
            torch.tensor(doc_fea, device=self.device).view(-1),
            torch.tensor(doc_id, device=self.device).to(torch.float).view(-1)
        ])

        for responses in self.response:
            next_state = torch.cat([next_state, responses.view(-1)])

        self.replay.push(self.state.squeeze(),
                         torch.tensor(slate, device=self.device),
                         torch.tensor(reward, device=self.device).view(1),
                         torch.tensor(click, device=self.device),
                         next_state.squeeze(),
                         torch.tensor(done, device=self.device).view(1))

        self.state = next_state
        return done, reward

    def compute_q_local_batch(self, ids):
        self.batch_states, self.batch_actions, self.batch_rewards, self.batch_clicks, \
        self.batch_next_states, self.batch_terminals = self.replay.recall(ids)
        return self.net.forward(self.batch_states, "online"), self.net.forward(self.batch_next_states, "target")

    def update_q_net(self, q, q_next, agent_fed):
        assert q.shape == torch.Size([self.batch_size, self.num_of_candidates])
        assert q_next.shape == torch.Size([self.batch_size, self.num_of_candidates])

        doc_id_start = self.user_features + self.num_of_candidates * self.doc_features
        doc_id_end = doc_id_start + self.num_of_candidates

        doc_id = self.batch_states[:, doc_id_start:doc_id_end]

        assert doc_id.shape == torch.Size([self.batch_size, self.num_of_candidates])

        selected_item = self.batch_actions * self.batch_clicks
        selected_item = selected_item.type(torch.int)
        assert selected_item.shape == torch.Size([self.batch_size, self.slate_size])
        selected_item = torch.sum(selected_item, dim=1, keepdim=True)

        q = torch.gather(q, 1, selected_item)
        q_next = super().compute_target_greedy_q(self.batch_rewards, self.gamma, q_next,
                                                 self.batch_next_states, self.batch_terminals)

        loss = self.loss_fn(q.view(self.batch_size, 1), q_next.view(self.batch_size, 1))
        agent_fed.optimizer.zero_grad()
        self.optimizer.zero_grad()
        loss.backward(retain_graph=True)
        agent_fed.optimizer.step()
        self.optimizer.step()

        return loss, q_next

class AgentBeta(SlateQ):
    def __init__(self, user_features, doc_features, num_of_candidates, slate_size, batch_size, capacity=2000):
        self.user_features = user_features
        self.doc_features = doc_features
        self.num_of_candidates = num_of_candidates
        self.slate_size = slate_size
        self.batch_size = batch_size

        # Original state dimension
        self.state_dim = user_features + (doc_features * num_of_candidates + num_of_candidates)

        self.action_dim = slate_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.net = QNet(self.state_dim, self.num_of_candidates).to(self.device)
        self.replay = ReplayMemory(capacity, (self.state_dim,), (self.action_dim,))
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.01)
        self.loss_fn = torch.nn.SmoothL1Loss()
        self.gamma = 0.9

    def compute_q_local_ini(self, env):
        user, doc_id, doc_fea, click, engagement, reward, done = env.env_ini(1)
        state = torch.cat([
            torch.tensor(user, device=self.device).view(-1),
            torch.tensor(doc_fea, device=self.device).view(-1),
            torch.tensor(doc_id, device=self.device).to(torch.float).view(-1)
        ])

        assert state.shape == torch.Size([self.state_dim])
        self.state = state
        return self.net(state, "online")

    def compute_q_local(self):
        return self.net.forward(self.state.view(1, -1), "online")

    def recommend(self, q_fed_beta, env):
        user_obs = self.state[:self.user_features]
        doc_obs = self.state[self.user_features:(self.user_features + self.num_of_candidates * self.doc_features)]
        s, s_no_click = super().score_documents_torch(user_obs, doc_obs)
        slate = super().select_slate_greedy(s_no_click, s, q_fed_beta)

        # Pass agent_q_values and clean_q_values to env_step
        clean_q_values = self.net(self.state.view(1, -1), "online").detach()
        agent_q_values = q_fed_beta

        user, doc_id, doc_fea, click, engagement, reward, done = env.env_step(
            slate.cpu().numpy().tolist(), 1, agent_q_values, clean_q_values)

        next_state = torch.cat([
            torch.tensor(user, device=self.device).view(-1),
            torch.tensor(doc_fea, device=self.device).view(-1),
            torch.tensor(doc_id, device=self.device).to(torch.float).view(-1)
        ])

        self.replay.push(self.state.view(-1), slate.view(-1),
                         torch.tensor(reward, device=self.device).view(-1),
                         torch.tensor(click, device=self.device),
                         next_state.squeeze(),
                         torch.tensor(done, device=self.device).view(-1))

        self.state = next_state
        return done, reward

    def recommend_random(self, env):
        nums = list(range(self.num_of_candidates))
        random.shuffle(nums)
        slate = nums[:self.slate_size]

        # Since this is a random recommendation, agent_q_values and clean_q_values are None
        user, doc_id, doc_fea, click, engagement, reward, done = env.env_step(slate, 1)

        next_state = torch.cat([
            torch.tensor(user, device=self.device).view(-1),
            torch.tensor(doc_fea, device=self.device).view(-1),
            torch.tensor(doc_id, device=self.device).to(torch.float).view(-1)
        ])

        self.replay.push(self.state.squeeze(),
                         torch.tensor(slate, device=self.device),
                         torch.tensor(reward, device=self.device).view(1),
                         torch.tensor(click, device=self.device),
                         next_state.squeeze(),
                         torch.tensor(done, device=self.device).view(1))

        self.state = next_state
        return done, reward

    def compute_q_local_batch(self, ids):
        self.batch_states, self.batch_actions, _, _, _, _ = self.replay.recall(ids)
        return self.net.forward(self.batch_states, "online")

    def update_q_net(self, q_online, q_target, agent_fed):
        assert q_online.shape == torch.Size([self.batch_size, self.num_of_candidates])

        # Ensure target shape matches reduced output
        q_target = q_target.mean(dim=1, keepdim=True)  # Reduce target to match greedy Q-values

        user_obs = self.batch_states[:, :self.user_features]
        doc_obs_start = self.user_features
        doc_obs_end = doc_obs_start + self.num_of_candidates * self.doc_features
        doc_obs = self.batch_states[:, doc_obs_start:doc_obs_end]

        assert user_obs.shape == torch.Size([self.batch_size, self.user_features])
        assert doc_obs.shape == torch.Size([self.batch_size, self.num_of_candidates])

        greedy_q_list = []

        for i in range(self.batch_size):
            s, s_no_click = super().score_documents_torch(user_obs[i], doc_obs[i])
            q = q_online[i]

            # Select a slate of 3 items
            slate = super().select_slate_greedy(s_no_click, s, q)
            assert slate.shape[0] == self.slate_size  # Ensure slate size is 3

            # Gather Q-values for the selected candidates
            q_selected = torch.gather(q, 0, slate)

            # Compute probabilities for the selected slate
            p_selected = super().compute_probs_torch(slate, s, s_no_click)

            # Aggregate probabilities and Q-values for the slate
            greedy_q_value = torch.sum(p_selected * q_selected)
            greedy_q_list.append(greedy_q_value)

        # Stack into a tensor of shape [batch_size, 1]
        greedy_q_values = torch.stack(greedy_q_list, dim=0).view(self.batch_size, 1)

        # Compute loss between reduced Q-values and target
        loss = self.loss_fn(greedy_q_values, q_target)

        # Optimize
        agent_fed.optimizer.zero_grad()
        self.optimizer.zero_grad()
        loss.backward()
        agent_fed.optimizer.step()
        self.optimizer.step()

        return loss

class AgentFed():
    def __init__(self, user_features, doc_features, num_of_candidates, slate_size, batch_size, capacity=2000):
        self.user_features = user_features
        self.doc_features = doc_features
        self.num_of_candidates = num_of_candidates
        self.slate_size = slate_size
        self.batch_size = batch_size
        self.capacity = capacity
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.exploration_rate = 1
        self.exploration_rate_decay = 0.99995
        self.exploration_rate_min = 0
        self.burnin = 5000  # min. experiences before training [change back to 5000]
        self.learn_every = 3  # no. of experiences between updates to Q_online
        self.sync_every = 500  # no. of experiences between Q_target & Q_online sync
        self.curr_step = 0

        self.net = MLPNet(num_of_candidates * 2, num_of_candidates).to(self.device)
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.01)
        self.loss_fn = torch.nn.SmoothL1Loss()

        # Initialize attack penalty parameter
        self.lambda_attack = 0.1  # Attack penalty weight λ

    def sync(self, agent_alpha, agent_beta):
        self.net.target.load_state_dict(self.net.online.state_dict())
        agent_alpha.net.target.load_state_dict(agent_alpha.net.online.state_dict())
        agent_beta.net.target.load_state_dict(agent_beta.net.online.state_dict())

    def act_ini(self, agent_alpha, agent_beta, env):
        q_alpha = agent_alpha.compute_q_local_ini(env).view(1, -1)
        q_beta = agent_beta.compute_q_local_ini(env).view(1, -1)
        q_alpha_fed = self.net_forward(q_alpha, q_beta)
        q_beta_fed = self.net_forward(q_beta, q_alpha)

        if np.random.rand() < self.exploration_rate:
            done_alpha, reward_alpha = agent_alpha.recommend_random(env)
        else:
            done_alpha, reward_alpha = agent_alpha.recommend(q_alpha_fed, env)

        if np.random.rand() < self.exploration_rate:
            done_beta, reward_beta = agent_beta.recommend_random(env)
        else:
            done_beta, reward_beta = agent_beta.recommend(q_beta_fed, env)

        # Decrease exploration_rate
        self.exploration_rate *= self.exploration_rate_decay
        self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)

        # Increment step
        self.curr_step += 1
        return done_alpha, reward_alpha, done_beta, reward_beta

    def act(self, agent_alpha, agent_beta, env):
        q_alpha = agent_alpha.compute_q_local().view(1, -1)
        q_beta = agent_beta.compute_q_local().view(1, -1)
        q_alpha_fed = self.net_forward(q_alpha, q_beta)
        q_beta_fed = self.net_forward(q_beta, q_alpha)

        if np.random.rand() < self.exploration_rate:
            done_alpha, reward_alpha = agent_alpha.recommend_random(env)
        else:
            done_alpha, reward_alpha = agent_alpha.recommend(q_alpha_fed, env)

        if np.random.rand() < self.exploration_rate:
            done_beta, reward_beta = agent_beta.recommend_random(env)
        else:
            done_beta, reward_beta = agent_beta.recommend(q_beta_fed, env)

        # Decrease exploration_rate
        self.exploration_rate *= self.exploration_rate_decay
        self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)

        if self.exploration_rate < 0.1:
            self.exploration_rate = 0

        # Increment step
        self.curr_step += 1
        return done_alpha, reward_alpha, done_beta, reward_beta

    def net_forward(self, q_agent, q_other_agent):
        # Concatenate Q-values from both agents
        q_concat = torch.cat([q_agent, q_other_agent], dim=1)
        # Pass through MLP
        q_fed = self.net(q_concat, "online")
        return q_fed

    def learn(self, agent_alpha, agent_beta):
        if self.curr_step % self.sync_every == 0:
            self.sync(agent_alpha, agent_beta)

        if self.curr_step < self.burnin:
            return None, None, None, None

        if self.curr_step % self.learn_every != 0:
            return None, None, None, None

        # Prepare data for batch learning
        ids = random.sample(range(len(agent_alpha.replay)), self.batch_size)

        # Collect batches from agents
        batch_q_alpha_online, batch_q_alpha_target = agent_alpha.compute_q_local_batch(ids)
        batch_q_beta_online = agent_beta.compute_q_local_batch(ids)

        q_alpha_fed_online = self.net_forward(batch_q_alpha_online, batch_q_beta_online)
        q_alpha_fed_target = self.net_forward(batch_q_alpha_target, batch_q_beta_online)

        # Update Agent Alpha's network
        loss_alpha, q_alpha_target = agent_alpha.update_q_net(q_alpha_fed_online, q_alpha_fed_target, self)

        # Recompute batch_q_alpha_online after updating Agent Alpha
        batch_q_alpha_online_new, _ = agent_alpha.compute_q_local_batch(ids)

        q_beta_fed_online = self.net_forward(batch_q_beta_online, batch_q_alpha_online_new)
        # Since we don't have a target for beta in this context, we can use q_beta_fed_online as both online and target
        loss_beta = agent_beta.update_q_net(q_beta_fed_online, q_beta_fed_online.detach(), self)

        return batch_q_alpha_online.detach().cpu().mean().item(), loss_alpha.detach().cpu(), \
               batch_q_beta_online.detach().cpu().mean().item(), loss_beta.detach().cpu()

## Training Setup and Execution

This section includes the main training loops and experiment configurations.

### Standard Training Setup (run_fed.py)

In [None]:
import torch
import numpy as np
from pathlib import Path
import datetime

# Check if CUDA is available
use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}\n")

# Set random seeds for reproducibility
#torch.manual_seed(0)
#np.random.seed(0)

# Model parameters
user_features = 1
doc_features = 1
num_candidates = 10
slate_size = 3
batch_size = 32
num_contexts = 5

# Setup save directory for logs and plots
save_dir = Path("checkpoints_fed_adv_manual") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True, exist_ok=True)

# Initialize logger
logger = Logger(save_dir)

# Initialize agents
agent_alpha = AgentAlpha(user_features, doc_features, num_candidates, slate_size, batch_size, num_contexts)
agent_beta = AgentBeta(user_features, doc_features, num_candidates, slate_size, batch_size)
agent_fed = AgentFed(user_features, doc_features, num_candidates, slate_size, batch_size)

# Initialize environment
env = RecsimEnv(num_candidates, slate_size, True, 42, 42)

# Training parameters
episodes = 5000
epsilon_adv = 0.01  # Magnitude for FGSM perturbation
episodes_until_perturb = 200 # change to 200 episodes

for episode in range(episodes):
    apply_attack = episode >= episodes_until_perturb
    # Reset the environment and initialize agent states
    agent_fed.act_ini(agent_alpha, agent_beta, env)

    while True:
        # ===== Clean Action and Learning Steps =====
        # Agents take actions based on clean states and receive feedback
        done_alpha_clean, reward_alpha_clean, done_beta_clean, reward_beta_clean = agent_fed.act(agent_alpha, agent_beta, env)
        loss_alpha_clean, q_alpha_clean, loss_beta_clean, q_beta_clean = agent_fed.learn(agent_alpha, agent_beta)

        # Log metrics for clean states
        logger.log_step(
            total_reward_alpha=reward_alpha_clean,
            total_reward_alpha_adv=None,
            loss_alpha_clean=loss_alpha_clean,
            loss_alpha_adv=None,
            q_alpha_clean=q_alpha_clean,
            q_alpha_adv=None,
            reward_beta_clean=reward_beta_clean,
            reward_beta_adv=None,
            loss_beta_clean=loss_beta_clean,
            loss_beta_adv=None,
            q_beta_clean=q_beta_clean,
            q_beta_adv=None,
        )

        # ===== Adversarial Action and Learning Steps =====
        if apply_attack and agent_fed.curr_step >= agent_fed.burnin:
            # Generate adversarial state for Agent Alpha
            agent_alpha.state.requires_grad = True
            q_alpha_local = agent_alpha.compute_q_local()
            loss_alpha_adv = -torch.mean(q_alpha_local)
            agent_alpha.optimizer.zero_grad()
            loss_alpha_adv.backward()
            state_grad = agent_alpha.state.grad.data
            adversarial_state_alpha = agent_alpha.state + epsilon_adv * state_grad.sign()
            adversarial_state_alpha = torch.clamp(adversarial_state_alpha, -1, 1).detach()

            # Replace the agent's state with the adversarial state
            agent_alpha.state = adversarial_state_alpha

            # Agents take actions based on adversarial states and receive feedback
            done_alpha_adv, reward_alpha_adv, done_beta_adv, reward_beta_adv = agent_fed.act(agent_alpha, agent_beta, env)
            loss_alpha_adv, q_alpha_adv, loss_beta_adv, q_beta_adv = agent_fed.learn(agent_alpha, agent_beta)

            # Log metrics for adversarial states
            logger.log_step(
                total_reward_alpha=None,  # Already logged in clean step
                total_reward_alpha_adv=reward_alpha_adv,
                loss_alpha_clean=None,  # Already logged in clean step
                loss_alpha_adv=loss_alpha_adv,
                q_alpha_clean=None,  # Already logged in clean step
                q_alpha_adv=q_alpha_adv,
                reward_beta_clean=None,  # Beta is not attacked here
                reward_beta_adv=reward_beta_adv,
                loss_beta_clean=None,
                loss_beta_adv=loss_beta_adv,
                q_beta_clean=None,
                q_beta_adv=q_beta_adv,
            )

            # Restore the clean state
            agent_alpha.state = agent_alpha.state.detach()
            agent_alpha.state.requires_grad = False

        if done_alpha_clean or done_beta_clean:
            break

    # Log episode metrics
    logger.log_episode()

    if episode % 20 == 0:
        logger.record(episode=episode, epsilon=agent_fed.exploration_rate, step=agent_fed.curr_step)

Using CUDA: True

Episode 0 | Step 61 | Epsilon 0.997 | Clean Rewards: Alpha 924.837, Beta 1157.484 | Adversarial Rewards: Alpha 0.0, Beta 0.0 | Clean Losses: Alpha 0, Beta 0 | Adversarial Losses: Alpha 0, Beta 0 | Clean Qs: Alpha 0, Beta 0 | Adversarial Qs: Alpha 0, Beta 0 | Avg Episode Length: 60.0 | Time Delta: 7.932s
Episode 20 | Step 1281 | Epsilon 0.938 | Clean Rewards: Alpha 868.916, Beta 901.451 | Adversarial Rewards: Alpha 0.0, Beta 0.0 | Clean Losses: Alpha 0, Beta 0 | Adversarial Losses: Alpha 0, Beta 0 | Clean Qs: Alpha 0, Beta 0 | Adversarial Qs: Alpha 0, Beta 0 | Avg Episode Length: 60.0 | Time Delta: 6.826s
Episode 40 | Step 2501 | Epsilon 0.882 | Clean Rewards: Alpha 900.601, Beta 924.275 | Adversarial Rewards: Alpha 0.0, Beta 0.0 | Clean Losses: Alpha 0, Beta 0 | Adversarial Losses: Alpha 0, Beta 0 | Clean Qs: Alpha 0, Beta 0 | Adversarial Qs: Alpha 0, Beta 0 | Avg Episode Length: 60.0 | Time Delta: 6.57s
Episode 60 | Step 3721 | Epsilon 0.830 | Clean Rewards: Alpha 92