# Import and Settings


In [1]:
import gymnasium as gym
import math
import random
import numpy as np
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib
import matplotlib.pyplot as plt

env = gym.make("CartPole-v1")

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

# Replay Memory


In [2]:
# Define a simple replay memory class
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

# Background blockchain Environment (BCenv)


In [3]:
class BlockchainEnvironment:
    def __init__(self):
        self.block_size_max = 10
        self.block_interval_max = 10
        self.shard_max = 5
        self.num_nodes = 10
        self.u = 0.9
        self.security_threshold = 0.8
        self.reset()

    def reset(self):
        """Resets the environment and returns the initial state"""
        R = self.get_transmission_rate()
        c = self.get_computing_power()
        H = self.get_consensus_history()
        p_bar = self.estimate_malicious_prob(H)
        self.state = np.array([R, c, H, p_bar])
        return self.state

    def get_transmission_rate(self):
        return np.random.random(self.num_nodes).mean()

    def get_computing_power(self):
        return np.random.random(self.num_nodes).mean()

    def get_consensus_history(self):
        return np.random.choice([0, 1], size=self.num_nodes).mean()

    def estimate_malicious_prob(self, H):
        trust_level = np.mean(H)
        return 1 - trust_level

    def step(self, action):
        """Takes an action and returns the next state, reward, and done flag"""
        B, T_I, K_star = action
        R, c, H, p_bar = self.state
        
        # Simulate next state and reward calculation based on the environment rules
        T_latency = T_I + np.random.random()
        latency_ok = (T_latency <= self.u * T_I)
        security_ok = (K_star < self.security_threshold * np.random.random())
        
        reward = 0
        if latency_ok and security_ok:
            reward = 1  # Reward of 1 if both conditions are met
        else:
            reward = -1  # Negative reward if any condition is violated

        # Update state
        next_state = self.reset()
        done = False  # For simplicity, no terminal state is defined
        return next_state, reward, done

    def sample_action(self):
        """Sample a random action for exploration"""
        block_size = np.random.randint(1, self.block_size_max+1)
        block_interval = np.random.randint(1, self.block_interval_max+1)
        shard_number = np.random.randint(1, self.shard_max+1)
        return [block_size, block_interval, shard_number]

# Deep Q-Network (DQN)


In [4]:
class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


# DQN Agent


In [6]:
class DQNAgent:
    def __init__(self, state_size, action_size, gamma=0.99, lr=1e-3):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = ReplayMemory(10000)
        self.model = DQN(state_size, action_size).to(device)
        self.target_model = DQN(state_size, action_size).to(device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995

    def select_action(self, state):
        if random.random() > self.epsilon:
            with torch.no_grad():
                return self.model(state).max(1)[1].view(1, 1)
        else:
            return torch.tensor([[random.randrange(self.action_size)]], device=device, dtype=torch.long)

    def optimize_model(self, batch_size=64):
        if len(self.memory) < batch_size:
            return

        transitions = self.memory.sample(batch_size)
        batch = Transition(*zip(*transitions))

        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        next_state_batch = torch.cat(batch.next_state)

        state_action_values = self.model(state_batch).gather(1, action_batch)

        next_state_values = torch.zeros(batch_size, device=device)
        next_state_values = self.target_model(next_state_batch).max(1)[0].detach()

        expected_state_action_values = reward_batch + (self.gamma * next_state_values)

        loss = F.mse_loss(state_action_values, expected_state_action_values.unsqueeze(1))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())


# Training

## Hyperparameters and utilities
