In [1]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import matplotlib.pyplot as plt
from CSTREnv import cstr_env

In [2]:
# Set up the environment
env = cstr_env(order=2)
state_size = env.observation_space.shape[0]
action_size = 4*4

# Hyperparameters
learning_rate = 0.001
discount_factor = 0.99
batch_size = 64
tau = 0.001
epsilon_decay = 0.995
epsilon_min = 0.01
memory_size = int(1e6)
update_target_network_steps = 1000
u1 = np.linspace(-1, 1, 4)
u2 = np.linspace(-0.0167, 0.0167, 4)
XX, YY = np.meshgrid(u1, u2)
XX, YY = XX.flatten(), YY.flatten()

In [3]:
# Replay buffer
class ReplayBuffer:
    def __init__(self, size):
        self.memory = deque(maxlen=size)
    
    def add(self, experience):
        self.memory.append(experience)
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [4]:
# Neural network for Q-learning
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, action_size)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

In [5]:
# Initialize networks and optimizer
q_network = QNetwork(state_size, action_size)
target_q_network = QNetwork(state_size, action_size)
target_q_network.load_state_dict(q_network.state_dict())
target_q_network.eval()

optimizer = optim.Adam(q_network.parameters(), lr=learning_rate)
loss_fn = nn.MSELoss()

In [6]:
def get_action(state, epsilon):
    if np.random.rand() < epsilon:
        return random.randint(0, action_size - 1)
    state = torch.FloatTensor(state).unsqueeze(0)
    q_values = q_network(state)
    return torch.argmax(q_values).item()

def train_model():
    if len(memory) < batch_size:
        return

    minibatch = memory.sample(batch_size)
    states = torch.FloatTensor([e[0] for e in minibatch])
    actions = torch.FloatTensor([e[1] for e in minibatch])
    rewards = torch.FloatTensor([e[2] for e in minibatch])
    next_states = torch.FloatTensor([e[3] for e in minibatch])
    dones = torch.FloatTensor([e[4] for e in minibatch])

    # Compute target Q-values
    next_q_values = target_q_network(next_states).detach().max(1)[0]
    target_q_values = rewards + (1 - dones) * discount_factor * next_q_values

    # Compute predicted Q-values
    predicted_q_values = q_network(states).gather(1, actions.unsqueeze(1).to(torch.int64)).squeeze()

    loss = loss_fn(predicted_q_values, target_q_values)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [7]:
# Main loop
num_episodes = 2000
steps = 0
epsilon = 1.0
memory = ReplayBuffer(memory_size)

for e in range(num_episodes):
    state = env.reset()
    total_reward = 0
    done = False
    step = 0
    
    while not done:
        action = get_action(state, epsilon)
        u = [XX[action], YY[action]]
        next_state, reward, done = env.step(u)
        total_reward += reward

        memory.add((state, action, reward, next_state, done))
        state = next_state
        train_model()

        steps += 1
        step += 1
        if steps % update_target_network_steps == 0:
            target_q_network.load_state_dict(q_network.state_dict())
        if step == 500 or done:
            break

    if epsilon > epsilon_min:
        epsilon *= epsilon_decay

    print(f"Episode: {e+1}/{num_episodes}, Reward: {total_reward}, done: {done}")
    print(f"state: {state}, action: {u}")

env.close()

  states = torch.FloatTensor([e[0] for e in minibatch])


Episode: 1/2000, Reward: -269.3275707655591, done: False
state: [-0.04255057  0.03617751], action: [-1.0, -0.005566666666666666]
Episode: 2/2000, Reward: -282.6987037532213, done: False
state: [0.03791183 0.06788833], action: [1.0, 0.0167]
Episode: 3/2000, Reward: -278.1738421161644, done: False
state: [-0.03782443 -0.02717348], action: [1.0, -0.005566666666666666]
Episode: 4/2000, Reward: -266.6502809484551, done: False
state: [-0.03316798 -0.02338236], action: [0.33333333333333326, -0.005566666666666666]
Episode: 5/2000, Reward: -278.2260729852215, done: False
state: [0.15584933 0.19163795], action: [0.33333333333333326, -0.0167]
Episode: 6/2000, Reward: -255.9943126999232, done: False
state: [0.05358835 0.1469392 ], action: [-0.33333333333333337, -0.0167]
Episode: 7/2000, Reward: -271.97816067540896, done: False
state: [-0.05889323 -0.11649244], action: [1.0, 0.005566666666666668]
Episode: 8/2000, Reward: -260.4572607129579, done: False
state: [0.03045894 0.04978906], action: [-0.33