In [1]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import matplotlib.pyplot as plt
from CSTREnv import cstr_env

In [2]:
# Neural network for NAF
class NAFNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super(NAFNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_value = nn.Linear(256, 1)
        self.fc_mu = nn.Linear(256, action_size)
        self.fc_l = nn.Linear(256, action_size * action_size)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        value = self.fc_value(x)
        mu = torch.tanh(self.fc_mu(x))*torch.tensor([1, 0.167])
        l = self.fc_l(x)

        l_matrix = l.view(-1, action_size, action_size)
        l_matrix = torch.tril(l_matrix, -1) + torch.diag_embed(torch.exp(torch.diagonal(l_matrix, dim1=-2, dim2=-1)))
        p_matrix = torch.bmm(l_matrix, l_matrix.transpose(2, 1))

        return value, mu, p_matrix

    def q_value(self, state, action):
        value, mu, p_matrix = self.forward(state)
        action_diff = action - mu
        advantage = -0.5 * torch.bmm(action_diff.unsqueeze(1), torch.bmm(p_matrix, action_diff.unsqueeze(2))).squeeze(2)
        q_value = value + advantage
        return q_value

In [3]:
# Replay buffer
class ReplayBuffer:
    def __init__(self, size):
        self.memory = deque(maxlen=size)
    
    def add(self, experience):
        self.memory.append(experience)
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [4]:
def train_model(epoch):
    if epoch < 50:
        return 

    minibatch = memory.sample(batch_size)
    states = torch.FloatTensor([e[0] for e in minibatch])
    actions = torch.FloatTensor([e[1] for e in minibatch])
    rewards = torch.FloatTensor([e[2] for e in minibatch])
    next_states = torch.FloatTensor([e[3] for e in minibatch])
    dones = torch.FloatTensor([e[4] for e in minibatch])

    q_values = naf_network.q_value(states, actions)
    next_actions = target_naf_network(next_states)[1]
    next_q_values = target_naf_network.q_value(next_states, next_actions)
    target_q_values = rewards.unsqueeze(1) + (1 - dones).unsqueeze(1) * discount_factor * next_q_values

    loss = loss_fn(q_values, target_q_values)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    for target_param, param in zip(target_naf_network.parameters(), naf_network.parameters()):
        target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)

In [5]:
def get_action(state, epsilon):
    if np.random.rand() < epsilon:
        return env.action_space.sample()
    state = torch.FloatTensor(state).unsqueeze(0)
    _, mu, _ = naf_network(state)
    return mu.detach().numpy()[0]

In [6]:
# Hyperparameters
learning_rate = 1e-3
discount_factor = 0.99
batch_size = 256
tau = 0.001
epsilon_decay = 0.995
epsilon_min = 0.01
memory_size = 1000000
num_episodes = 3000
episode_length = num_episodes

In [None]:
memory = ReplayBuffer(memory_size)
epsilon = 1.0
# Set up the environment
env = cstr_env(order=1)
state_size = env.observation_space.shape[0]
action_size = env.action_space.shape[0]

# Initialize networks and optimizer
naf_network = NAFNetwork(state_size, action_size)
target_naf_network = NAFNetwork(state_size, action_size)
target_naf_network.load_state_dict(naf_network.state_dict())
target_naf_network.eval()

optimizer = optim.Adam(naf_network.parameters(), lr=learning_rate)
loss_fn = nn.MSELoss()

for e in range(num_episodes):
    state = env.reset()
    total_reward = 0
    done = False
    step = 0
    while not done:
        step += 1
        action = get_action(state, epsilon)
        next_state, reward, done = env.step(action)
        total_reward += reward

        memory.add((state, action, reward, next_state, done))
        state = next_state
        train_model(e)
        if step == 300:
            break
            
    if epsilon > epsilon_min:
        epsilon *= epsilon_decay

    print(f"Episode: {e+1}/{num_episodes}, Reward: {total_reward}, done: {done}")
    print(f"state: {state}, action: {action}")

env.close()

Episode: 1/3000, Reward: -12.33353226999207, done: False
state: [  0.12322117 -11.22304723], action: [-0.735866   -0.08041313]
Episode: 2/3000, Reward: -6.630148423835987, done: False
state: [ 0.10490079 -6.00923733], action: [0.7513824  0.06663113]
Episode: 3/3000, Reward: -13.466535160302872, done: False
state: [  0.1093568  -12.41413409], action: [-0.539509   -0.01864239]
Episode: 4/3000, Reward: -12.733531425899884, done: False
state: [  0.12022271 -11.89198817], action: [0.05764126 0.10464754]
Episode: 5/3000, Reward: -3.4171713515843583, done: False
state: [ 0.07763018 -0.7778555 ], action: [-0.8969545  -0.15452193]
Episode: 6/3000, Reward: -12.9601971457035, done: False
state: [  0.1200199  -11.86461866], action: [-0.60460764  0.0614411 ]
Episode: 7/3000, Reward: -10.295431334740979, done: False
state: [ 0.11998705 -9.68283004], action: [ 0.44213656 -0.04603828]
Episode: 8/3000, Reward: -12.416522313798108, done: False
state: [  0.12731271 -11.19483099], action: [ 0.15747249 -0.

  states = torch.FloatTensor([e[0] for e in minibatch])


Episode: 51/3000, Reward: -6.554611512628095, done: False
state: [ 0.10614925 -6.43374768], action: [ 0.9866632  -0.07201547]
Episode: 52/3000, Reward: -5.0946024256550775, done: False
state: [ 0.08005541 -4.55694003], action: [-0.26343143  0.08051584]
Episode: 53/3000, Reward: -5.113306455647635, done: False
state: [ 0.09822194 -3.61725901], action: [-0.9388875   0.01176811]
Episode: 54/3000, Reward: -5.971054302680004, done: False
state: [ 0.111601  -5.2452245], action: [-0.9600574   0.11023979]
Episode: 55/3000, Reward: -16.510499666474768, done: False
state: [  0.12540351 -14.23357564], action: [ 0.27113354 -0.01803673]
Episode: 56/3000, Reward: -11.088267074095997, done: False
state: [  0.09420266 -10.93522325], action: [ 0.06608845 -0.01484043]
Episode: 57/3000, Reward: -16.05736620795416, done: False
state: [  0.12377315 -13.83868786], action: [0.8266247  0.06982926]
Episode: 58/3000, Reward: -15.555090771356035, done: False
state: [  0.1671146  -12.85796134], action: [ 0.210973