In [None]:
## try adding memory to these agents 

import random 
import numpy as np 
import gym
import torch
import torch.nn as nn
import torch.optim as optim 

class Object(object):
    pass

class ReplayBuffer(): 
    def __init__(self, capacity=10000): 
        self.n = 0 
        self.capacity = capacity 
        self.state_list = [] 
        self.action_list = [] 
        self.reward_list = [] 
        self.next_state_list = [] 
        self.done_list = [] 
        pass 
    def __len__(self):
         return self.n 
    def add(self, state, action, reward, next_state, done): 
        if self.n >= self.capacity: 
            ## discard earliest observation 
            self.state_list = self.state_list[1:] 
            self.action_list = self.action_list[1:] 
            self.reward_list = self.reward_list[1:] 
            self.next_state_list = self.next_state_list[1:] 
            self.done_list = self.done_list[1:] 
            self.n -= 1 
        pass 
        ## cast to torch  
        state = torch.tensor(state) 
        action = torch.tensor(action) 
        reward = torch.tensor(reward) 
        next_state = torch.tensor(next_state) 
        done = torch.tensor(done) 
        ## append to buffer 
        self.state_list.append(state) 
        self.action_list.append(action) 
        self.reward_list.append(reward) 
        self.next_state_list.append(next_state) 
        self.done_list.append(done) 
        self.n += 1 
        pass 
    def sample(self, batch_size=32): 
        ## sample lists 
        out = Object() ## transitions 
        out.state = [] 
        out.action = [] 
        out.reward = [] 
        out.next_state = [] 
        out.done = [] 
        for _ in range(batch_size): 
            idx = random.randint(0, self.n-1) 
            out.state.append(self.state_list[idx]) 
            out.action.append(self.action_list[idx]) 
            out.reward.append(self.reward_list[idx]) 
            out.next_state.append(self.next_state_list[idx]) 
            out.done.append(self.done_list[idx]) 
            pass 
        ## stack  
        out.state = torch.stack(out.state) 
        out.action = torch.stack(out.action) 
        out.reward = torch.stack(out.reward).reshape([-1,1]) 
        out.next_state = torch.stack(out.next_state) 
        out.done = torch.stack(out.done).reshape([-1,1]) 
        return out 
    def clear(self): 
        self.n = 0 
        self.state_list = [] 
        self.action_list = [] 
        self.reward_list = [] 
        self.next_state_list = [] 
        self.done_list = [] 
        pass 
    pass 

# Define the actor and critic networks 
class SSRAgent(nn.Module): 
    def __init__(self,  ssr_rank=2): 
        super(SSRAgent, self).__init__() 
        self.ssr_rank = ssr_rank 
        self.ssr_low_rank_matrix = None 
        self.ssr_diagonal = None 
        self.ssr_model_dimension = None 
        pass 
    def memorize(self): 
        ## TODO 
        pass 
    pass 

class Actor(SSRAgent):
    def __init__(self, state_dim, action_dim):
        super(Actor, self).__init__()

        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, action_dim)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        action = torch.tanh(self.fc3(x))

        return action

class Critic(SSRAgent):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()

        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        value = self.fc3(x)

        return value

# Create the actor and critic networks
actor = Actor(state_dim=4, action_dim=1)
critic = Critic(state_dim=4, action_dim=1)
target_actor = Actor(state_dim=4, action_dim=1)
target_critic = Critic(state_dim=4, action_dim=1)

# Define the optimizers
actor_optimizer = optim.Adam(actor.parameters(), lr=1e-5)
critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)

# Create the replay buffer
replay_buffer = ReplayBuffer(capacity=100000) 

# Define the environment
env = gym.make('CartPole-v1')

reward_list = [] 
p_list = [] # DEBUG 
TAU = -1 ## 0.05  
GAMMA = 0.99 

# Train the agent
for episode in range(1000):
    state = env.reset()
    target_actor.load_state_dict(actor.state_dict())
    target_critic.load_state_dict(critic.state_dict())

    for t in range(1000):
        action = actor(torch.tensor(state)) 
        if np.random.binomial(1, max(0,50-episode)/50) > 0: 
            ## random action
            action = torch.tensor(np.random.uniform(low=-1., high=1.)).reshape([1]) 
            pass 
        
        action_p = action.item() * .5 + .5 
        action_int = np.random.binomial(1, action_p) ## must be 0 or 1 
        next_state, reward, done, _ = env.step(action_int) 

        replay_buffer.add(state, action, reward, next_state, done) 

        if len(replay_buffer) > 1000: 
            # Sample a batch of transitions from the replay buffer 
            transitions = replay_buffer.sample(batch_size=256) 

            # Calculate the target Q-values
            target_Q = target_critic(transitions.next_state, target_actor(transitions.next_state))
            target_Q = (1 - transitions.done.int()) * target_Q.clone().detach() * GAMMA + transitions.reward 

            # Calculate the current Q-values
            current_Q = critic(transitions.state, transitions.action)

            # Calculate the critic loss
            critic_loss = torch.mean((target_Q - current_Q).pow(2)) 

            # Update the critic network 
            critic_optimizer.zero_grad() 
            critic_loss.backward() 
            critic_optimizer.step() 
            
            if len(replay_buffer) > 1000: 
                # Calculate the actor loss 
                actor_loss = -torch.mean(critic(transitions.state, actor(transitions.state))) 

                # Update the actor network 
                actor_optimizer.zero_grad() 
                actor_loss.backward() 
                actor_optimizer.step() 
                pass 
            pass 

        state = next_state

        if done:
            break

    # Evaluate the agent
    episode_reward = 0 
    state = env.reset() 

    for t in range(1000): 
        action = actor(torch.tensor(state))

        action_p = action.item() * .5 + .5 
        action_int = np.random.binomial(1, action_p) 
        next_state, reward, done, _ = env.step(action_int) 
        p_list.append(action.item()) # DEBUG 

        episode_reward += reward
        state = next_state 

        if done:
            break
    
    reward_list.append(episode_reward)
    print(f'Episode {episode}: {episode_reward}, len(replay_buffer): {len(replay_buffer)}')