In [1]:
import torch
from torch import nn
from torch.optim import Adam
import matplotlib.pyplot as plt
import numpy as np 
import gymnasium as gym
from gymnasium.wrappers import RecordVideo
from Buffers.ExperienceReplayBuffer import ExperienceReplay

<center> <h1> Constants 

In [2]:
SEED = 13 
GAMMA = 0.995
LR = 1e-4
STEPS = 1000
BATCH_SIZE = 64
EPISODES = 300

In [3]:
env = gym.make("Pendulum-v1",render_mode = "rgb_array")
env = RecordVideo(env,"../Results/PolicyBased",lambda x: x%25 == 0 and x != 0 , fps=15 )
actionNum = env.action_space.shape[0] 
stateNum = env.observation_space.shape[0]
print((actionNum,stateNum))
env.action_space.high[0]

  logger.warn(


(1, 3)


2.0

In [4]:
import time
env.reset()
for steps in range(1):
    action = env.action_space.sample()
    _,rewards,terminated,truncated,_=env.step(action)
    print(rewards)
    time.sleep(0.1)
env.close()

-2.4867359500448916


In [5]:
class Critic(nn.Module):
    def __init__(self,stateNum,h1,h2,h3,actionNum):
        super().__init__()
        self.fc1 = nn.Linear(stateNum + actionNum,h1)
        self.fc2 = nn.Linear(h1,h2)
        self.fc3 = nn.Linear(h2,h3)
        self.fc4 = nn.Linear(h3,1)
    def forward(self,state,action):
        x = torch.hstack((state,action))
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        Q = self.fc4(x)
        return Q

In [6]:
class Actor(nn.Module):
    def __init__(self,stateNum,h1,h2,h3,actionNum):
        super().__init__()
        self.fc1 = nn.Linear(stateNum,h1)
        self.fc2 = nn.Linear(h1,h2)
        self.fc3 = nn.Linear(h2,h3)
        self.fc4 = nn.Linear(h3,actionNum)

    def forward(self,x):

        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        A = torch.tanh(self.fc4(x))*env.action_space.high[0]
        return A

In [7]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.manual_seed(SEED)
critic        = Critic(stateNum,400,200,100,actionNum).to(device)
actor         = Actor(stateNum,400,200,100,actionNum).to(device)
critic_target = Critic(stateNum,400,200,100,actionNum).to(device)
actor_target  = Actor(stateNum,400,200,100,actionNum).to(device)
critic_target.load_state_dict(critic.state_dict())
actor_target.load_state_dict(actor.state_dict())

critic_criterion = nn.MSELoss()
# actor_criterion = 
critic_optim =  Adam(critic.parameters(),1e-3)
actor_optim= Adam(actor.parameters(),1e-4)

buffer = ExperienceReplay(100000,device)



In [8]:
def soft_update(online,target,tau):
    online_dict = online.state_dict()
    target_dict = target.state_dict()
    for key in online_dict.keys():
        target_dict[key]= tau*online_dict[key] + (1-tau) * target_dict[key]
    target.load_state_dict(target_dict)

In [34]:
tau = 0.005
rewards = []

for episode in range(EPISODES):
    old_observation,info =env.reset(seed=episode)
    cumulative_reward = 0
    for step in range(STEPS):
        action = actor(torch.Tensor(old_observation).reshape(1,-1).to(device))

        exploration_noise = np.random.normal(0,0.2,env.action_space.shape[0]) * (1 - episode/EPISODES)
        action = np.clip(action.cpu().detach().numpy().squeeze() +exploration_noise,-2,2)
        old_observation=old_observation.squeeze()
        new_observation,reward,terminated,truncated,info = env.step(action)
        cumulative_reward+=reward
        done = terminated or truncated
        buffer.append(old_observation,action,reward,new_observation,done)
        rewards.append(reward)
        old_observation = new_observation
        t = time.time()
        if(buffer.size() >= BATCH_SIZE):
            old_state,old_action,reward,new_state,done=buffer.sample(BATCH_SIZE)
            reward = reward.reshape(-1,1)
            done = done.reshape(-1,1)   
            
            y_hat = critic(old_state,old_action)
            with torch.no_grad():
                y = reward + GAMMA * critic_target(new_state,actor_target(new_state)) * (1.0 - done)
            
            
            critic_loss =critic_criterion(y_hat,y)    
            critic_optim.zero_grad()
            critic_loss.backward()
            critic_optim.step()
            actor_optim.zero_grad()
            actor_loss = -critic(old_state,actor(old_state)).mean()
            actor_loss.backward()
            actor_optim.step()

            with torch.no_grad():
                soft_update(critic,critic_target,tau)
                soft_update(actor,actor_target,tau)
        if(truncated or terminated):
            break;
        # print(t-time.time())
    
    print(f"Episode: {episode} | Reward: {cumulative_reward},actor_loss: {actor_loss.detach().item()},critic_loss: {critic_loss.detach().item()}")    

Episode: 0 | Reward: -1342.2308469593422,actor_loss: 10.028579711914062,critic_loss: 0.45074382424354553
Episode: 1 | Reward: -1316.4180595511482,actor_loss: 14.830223083496094,critic_loss: 1.3838139772415161
Episode: 2 | Reward: -1664.31433653305,actor_loss: 18.672283172607422,critic_loss: 0.0877743512392044
Episode: 3 | Reward: -1732.3223938227325,actor_loss: 26.373443603515625,critic_loss: 0.9488617181777954
Episode: 4 | Reward: -1663.5815557269818,actor_loss: 29.62407684326172,critic_loss: 0.27288928627967834
Episode: 5 | Reward: -1649.0506992781493,actor_loss: 38.7570686340332,critic_loss: 16.6285400390625
Episode: 6 | Reward: -1406.5946312300384,actor_loss: 45.174339294433594,critic_loss: 0.4367789328098297
Episode: 7 | Reward: -1490.812884062304,actor_loss: 53.110267639160156,critic_loss: 0.542688250541687
Episode: 8 | Reward: -1359.9055779573364,actor_loss: 59.282325744628906,critic_loss: 3.0531716346740723
Episode: 9 | Reward: -1368.2323066764293,actor_loss: 62.6337776184082,c

In [38]:
env = gym.make("Pendulum-v1",render_mode="human")
old_observation,_ =env.reset()
for steps in range(STEPS):
    old_observation = torch.Tensor(old_observation).to(device)
    action = actor(old_observation).cpu().detach().numpy()
    old_observation,_,terminated,truncated,_=env.step(action)
    if(truncated or terminated):
        break
env.close()

In [36]:
env.close()