In [1]:
import gym
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import matplotlib.pyplot as plt
from CSTREnv import cstr_env

In [2]:
# Neural network for NAF
class NAFNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super(NAFNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_value = nn.Linear(256, 1)
        self.fc_mu = nn.Linear(256, action_size, bias=False)
        self.fc_l = nn.Linear(256, action_size * action_size)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        value = self.fc_value(x)
        mu = torch.tanh(self.fc_mu(x))*torch.tensor([1, 0.0167])
        l = self.fc_l(x)

        l_matrix = l.view(-1, action_size, action_size)
        l_matrix = torch.tril(l_matrix, -1) + torch.diag_embed(torch.exp(torch.diagonal(l_matrix, dim1=-2, dim2=-1)))
        p_matrix = torch.bmm(l_matrix, l_matrix.transpose(2, 1))

        return value, mu, p_matrix

    def q_value(self, state, action):
        value, mu, p_matrix = self.forward(state)
        action_diff = action - mu
        advantage = -0.5 * torch.bmm(action_diff.unsqueeze(1), torch.bmm(p_matrix, action_diff.unsqueeze(2))).squeeze(2)
        q_value = value + advantage
        return q_value

In [3]:
def train_model(epoch):
    if epoch < 50:
        return 

    minibatch = memory.sample(batch_size)
    states = torch.FloatTensor([e[0] for e in minibatch])
    actions = torch.FloatTensor([e[1] for e in minibatch])
    rewards = torch.FloatTensor([e[2] for e in minibatch])
    next_states = torch.FloatTensor([e[3] for e in minibatch])
    dones = torch.FloatTensor([e[4] for e in minibatch])

    q_values = naf_network.q_value(states, actions)
    next_actions = target_naf_network(next_states)[1]
    next_q_values = target_naf_network.q_value(next_states, next_actions)
    target_q_values = rewards.unsqueeze(1) + (1 - dones).unsqueeze(1) * discount_factor * next_q_values

    loss = loss_fn(q_values, target_q_values)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    for target_param, param in zip(target_naf_network.parameters(), naf_network.parameters()):
        target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)

    return loss

In [4]:
# Replay buffer
class ReplayBuffer:
    def __init__(self, size):
        self.memory = deque(maxlen=size)
    
    def add(self, experience):
        self.memory.append(experience)
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)
    
def get_action(e, state, epsilon):
    if np.random.rand() < epsilon:
        return env.sontag(state)
    state = torch.FloatTensor(state).unsqueeze(0)
    _, mu, _ = naf_network(state)
    return mu.detach().numpy()[0]

def plot_state(states, epoch, filename='./figure2/state_at_epoch_'):
    figure_state = np.array(states)
    time = np.linspace(0, len(figure_state), len(figure_state))

    plt.figure()
    plt.plot(time, figure_state[:, 0], label='State 1')
    plt.plot(time, figure_state[:, 1], label='State 2')

    plt.xlabel('Time')
    plt.ylabel('State Value')
    plt.title('State Values Over Time')
    plt.legend()
    
    plt.savefig(filename+f'{epoch}')
    plt.close() 

In [5]:
# Hyperparameters
learning_rate = 1e-4
discount_factor = 0.99
batch_size = 256
tau = 0.001
epsilon_decay = 0.995
epsilon_min = 0.01
memory_size = 1000000
num_episodes = 3000
episode_length = num_episodes

In [None]:
memory = ReplayBuffer(memory_size)
epsilon = 1.0
# Set up the environment
env = cstr_env(order=1)
state_size = env.observation_space.shape[0]
action_size = env.action_space.shape[0]

# Initialize networks and optimizer
naf_network = NAFNetwork(state_size, action_size)
target_naf_network = NAFNetwork(state_size, action_size)
target_naf_network.load_state_dict(naf_network.state_dict())
target_naf_network.eval()

optimizer = optim.Adam(naf_network.parameters(), lr=learning_rate)
loss_fn = nn.MSELoss()

for e in range(num_episodes):
    states_list = []
    state = env.reset(e)
    total_reward = 0
    done = False
    step = 0
    while not done:
        step += 1
        action = get_action(e, state, epsilon)
        next_state, reward, done = env.step(action)
        total_reward += reward

        memory.add((state, action, reward, next_state, done))
        state = next_state
        states_list.append(state)
        loss = train_model(e)
        if step == 1000:
            break
            
    if epsilon > epsilon_min:
        epsilon *= epsilon_decay

    if e%5 == 0:
        plot_state(states_list, e)
    print(f"Episode: {e+1}/{num_episodes}, Reward: {total_reward}, done: {done}")
    print(f"state: {state}, action: {action}, loss: {loss}")

env.close()

Episode: 1/3000, Reward: -5.682317432610404, done: True
state: [ 0.00045985 -0.00919278], action: [-0.00088771  0.        ], loss: None
Episode: 2/3000, Reward: -5.683502981904332, done: True
state: [ 0.00046099 -0.00921771], action: [-0.00088903  0.        ], loss: None
Episode: 3/3000, Reward: -5.684329110255615, done: True
state: [ 0.00046173 -0.00925611], action: [-0.00085705  0.        ], loss: None
Episode: 4/3000, Reward: -5.6840365237122965, done: True
state: [ 0.0003819  -0.00961781], action: [0.00051848 0.        ], loss: None
Episode: 5/3000, Reward: -5.685062188390357, done: True
state: [ 0.00041522 -0.00943051], action: [3.7460573e-08 0.0000000e+00], loss: None
Episode: 6/3000, Reward: -5.682901601132929, done: True
state: [ 0.00043382 -0.00942651], action: [-7.3842806e-05  0.0000000e+00], loss: None
Episode: 7/3000, Reward: -5.6856218150024596, done: True
state: [ 0.00045978 -0.00923548], action: [-0.0008224  0.       ], loss: None
Episode: 8/3000, Reward: -5.686920863157

  states = torch.FloatTensor([e[0] for e in minibatch])


Episode: 51/3000, Reward: -5.697804946007159, done: False
state: [ 0.00047954 -0.01511557], action: [7.7906307e-03 1.1287560e-06], loss: 5.440303084469633e-06
Episode: 52/3000, Reward: -5.688273172407396, done: False
state: [ 0.00060438 -0.02015514], action: [1.3538405e-02 6.6164589e-06], loss: 2.9359673590079183e-06
Episode: 53/3000, Reward: -5.747881299508355, done: False
state: [ 0.00017552 -0.01850214], action: [-0.02601359 -0.00287534], loss: 1.3305698303156532e-06
Episode: 54/3000, Reward: -5.743828785311742, done: False
state: [ 0.00025797 -0.01411122], action: [-0.02448335 -0.00371303], loss: 7.929840535325638e-07
Episode: 55/3000, Reward: -5.726282134912097, done: False
state: [ 0.00034632 -0.01325198], action: [-0.02318714 -0.00435577], loss: 4.82688847114332e-07
Episode: 56/3000, Reward: -5.715641372808681, done: False
state: [ 0.00052301 -0.0126873 ], action: [0.00036337 0.        ], loss: 3.0846405252304976e-07
Episode: 57/3000, Reward: -5.705671387921439, done: False
stat