# Import & settings


In [25]:
import gymnasium as gym
from gymnasium import Env
import numpy as np
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
import random
from keras.src import Sequential, optimizers
from keras.src.layers import Dense
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

# Blockchain Environment


In [26]:
class BlockchainEnvironment(Env):
    def __init__(self, num_shards, block_size, block_interval, malicious_ratio):
        self.num_shards = num_shards      # Number of shards in the system
        self.block_size = block_size      # Size of each block (in MB)
        self.block_interval = block_interval  # Block interval (in seconds)
        self.malicious_ratio = malicious_ratio  # Ratio of malicious nodes
        self.state = self.reset()  # Initial state
        self.action_space = [(1, 4), (4, 8), (8, 16)]  # Actions: (min_block_interval, max_block_size)
        self.max_shards = 10      # Maximum number of shards allowed
        self.max_block_size = 8    # Maximum block size (MB)
        self.max_block_interval = 16  # Maximum block interval (seconds)
        self.tps_baseline = 10  # TPS baseline value
        self.total_transactions = 100000  # Total number of transactions

    def reset(self):
        """Reset the environment to an initial state"""
        self.num_shards = 1
        self.block_size = 1
        self.block_interval = 1
        self.malicious_ratio = random.uniform(0.0, 0.3)  # Randomly set a malicious ratio (up to 30%)
        return (self.num_shards, self.block_size, self.block_interval, self.malicious_ratio)

    def step(self, action):
        """
        Execute an action and return the next state, reward, and done flag.
        Actions are represented as adjustments to block size, block interval, and shards.
        """
        # get_action = self.action_space[action]
        # shard_change = get_action[:0] 
        # block_interval_change = get_action[:1]
        shard_change, block_interval_change = self.action_space[action]
        
        # Adjust number of shards and block interval based on the action
        self.num_shards = min(self.max_shards, max(1, self.num_shards + shard_change))
        self.block_size = min(self.max_block_size, max(1, self.block_size + block_interval_change))
        self.block_interval = min(self.max_block_interval, max(1, self.block_interval))

        # Simulate the environment dynamics based on the new configuration
        tps = self.calculate_tps(self.num_shards, self.block_size, self.block_interval)
        latency = self.calculate_latency(self.block_size, self.num_shards)
        security = self.calculate_security(self.num_shards, self.malicious_ratio)

        # Compute the reward (consider TPS, security, and latency)
        reward = self.compute_reward(tps, latency, security)

        # Check if done (after enough transactions processed)
        done = self.total_transactions <= 0

        # Update total transactions
        self.total_transactions -= tps * self.block_interval

        # Next state is the updated blockchain configuration
        next_state = [self.num_shards, self.block_size, self.block_interval, self.malicious_ratio]

        return (next_state, reward, done)

    def calculate_tps(self, num_shards, block_size, block_interval):
        """Calculate the transactions per second (TPS) based on shard count and block size"""
        return num_shards * (block_size / block_interval) * self.tps_baseline

    def calculate_latency(self, block_size, num_shards):
        """Estimate the latency based on block size and number of shards"""
        return (block_size * num_shards) / (self.tps_baseline * num_shards)

    def calculate_security(self, num_shards, malicious_ratio):
        """Calculate the security level based on the number of shards and malicious ratio"""
        # If malicious nodes are more than 1/3 in any shard, security is compromised
        if malicious_ratio > (1 / 3):
            return 0  # Security failure
        else:
            # Otherwise, the security score is inversely proportional to malicious ratio
            return 1 - malicious_ratio

    def compute_reward(self, tps, latency, security):
        """Compute the reward for the current state based on TPS, latency, and security"""
        # Reward is a combination of high TPS, low latency, and high security
        return tps * security - latency

    def render(self):
        """Render the current state of the blockchain environment"""
        print(f"Shards: {self.num_shards}, Block Size: {self.block_size}MB, Block Interval: {self.block_interval}s")
        print(f"Malicious Ratio: {self.malicious_ratio}, Total Transactions: {self.total_transactions}")

In [27]:
env = BlockchainEnvironment(num_shards=1, block_size=1, block_interval=1, malicious_ratio=0)

In [28]:
state = env.reset()

# total_step = 0

# done = False
# while not done:
#     action = random.choice(env.action_space)  # Replace with DQN action later
#     next_state, reward, done = env.step(action)
#     total_step += 1
#     env.render()
#     print(f"Reward: {reward}\n")

# print(f"Total steps: {total_step}\n")

# Replay Memory


In [29]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'done'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

# DQNSB Agent


In [30]:
class DQNSBAgent:
    def __init__(self, state_size, action_size, memory_capacity, gamma=0.99, epsilon=1.0, epsilon_decay=0.98, epsilon_min=0.01, alpha=0.1):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = ReplayMemory(memory_capacity)
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate
        self.epsilon_decay = epsilon_decay  # Decay rate for epsilon
        self.epsilon_min = epsilon_min  # Minimum epsilon
        self.alpha = alpha  # Learning rate
        self.main_q_network = self.build_q_network()
        self.target_q_network = self.build_q_network()
        self.update_target_network()

    def build_q_network(self):
        # Build the Q-network (main and target have the same architecture)
        model = Sequential()
        model.add(Dense(32, input_dim=self.state_size, activation='relu'))
        model.add(Dense(32, activation='relu'))
        model.add(Dense(self.action_size, activation='linear'))
        model.compile(optimizer=optimizers.Adam(learning_rate=self.alpha), loss='mse')
        return model

    def update_target_network(self):
        # Update the target network weights to match the main network weights
        self.target_q_network.set_weights(self.main_q_network.get_weights())

    def act(self, state):
        # Epsilon-greedy action selection
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)  # Random action
        state = np.reshape(state, self.state_size)
        q_values = self.main_q_network.predict(state, verbose=0)
        return np.argmax(q_values[0])  # Action with the highest Q-value

    def remember(self, state, action, reward, next_state, done):
        self.memory.push(state, action, reward, next_state, done)

    def get_batch_from_buffer(self, batch_size):
        exp_batch = self.memory.sample(batch_size=batch_size)
        state_batch = np.array([batch[0] for batch in exp_batch]).reshape(batch_size, self.state_size)
        action_batch = np.array([batch[1] for batch in exp_batch])
        reward_batch = [batch[2] for batch in exp_batch]
        next_state_batch = np.array([batch[3] for batch in exp_batch]).reshape(batch_size, self.state_size)
        terminal_batch = [batch[4] for batch in exp_batch]
        return state_batch, action_batch, reward_batch, next_state_batch, terminal_batch

    def replay(self, batch_size):
        # Sample a minibatch from the memory
        if self.memory.__len__() < batch_size:
            return
        
        state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = self.get_batch_from_buffer(batch_size=batch_size)

        # Get current Q value 
        q_values = self.main_q_network.predict(state_batch, verbose=0)
        
        # # Get Max Q value of State S' (S' <- S, A)
        next_q_values = self.target_q_network.predict(next_state_batch, verbose=0)
        max_next_q = np.amax(next_q_values, axis=1)

        for i in range(batch_size):
            # Bellman equation for Q-learning
            new_q_values = reward_batch[i] if terminal_batch[i] else reward_batch[i] + self.gamma * max_next_q[i]
            q_values[0][action_batch[i]] = new_q_values

        #! can not decay the epsilon
        self.main_q_network.fit(state_batch, q_values, verbose=0)

    def load(self, name):
        # Load model weights (if any)
        self.main_q_network.load_weights(name)

    def save(self, name):
        # Save model weights
        self.main_q_network.save_weights(name)

# Training

## Hyperparameters and utilities


In [31]:
# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 64
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

# Define the number of epochs (episodes) and the maximum steps per episode
epochs = 500
max_steps = 1000

# Example usage
state_size = 4  # num_shards, block_size, block_interval, malicious_ratio
action_size = 2  # shard change, block interval change

agent = DQNSBAgent(state_size, action_size, memory_capacity=20000)

episode_durations = []

def plot_durations(show_result=False):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

## Training loop


In [None]:
total_time_step = 0

for epoch in range(epochs):
    state = env.reset()
    # state = np.array(state)
    state = np.reshape(state, [1, state_size])
    for step in range(max_steps):
        # Agent select an action
        action = agent.act(state=state)
        # Execute action A^t in the environment and observe the outcome
        next_state, reward, done = env.step(action=action)
        next_state = np.reshape(next_state, [1, state_size])

        # Store the experience in replay memory
        agent.remember(state=state, action=action, reward=reward, next_state=next_state, done=done)
        
        # Move to the next state
        state = next_state
        # Perform replay to train the agent
        agent.replay(BATCH_SIZE)

        # Check if the epoch has ended
        if done:
            episode_durations.append(step + 1)
            plot_durations()
            break
        
    # Update the target network periodically
    if epoch % 10 == 0:
        agent.update_target_network()

    print(f"Epoch: {epoch}, Epsilon: {agent.epsilon:.4f}")

# Save weights
agent.save("DQNSB_agent.h5")

print('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()