In [1]:
import torch
from torch import nn
from torch.optim import Adam
import matplotlib.pyplot as plt
import numpy as np 
import gymnasium as gym
from gymnasium.wrappers import RecordVideo
from Buffers.ExperienceReplayBuffer import ExperienceReplay

<center> <h1> Constants

In [2]:
SEED = 13 
GAMMA = 0.995
STEPS = 1000
EPISODES = 300
BATCH_SIZE = 64

CRITIC_LR = 1e-3
ACTOR_LR = 1e-4
POLICY_DELAY = 2
BUFFER_SIZE = 100000

#soft update parameter
TAU = 0.005

#actionNoise parameter
NOISE_STD = 0.2
NOISE_MEAN = 0

In [3]:
env = gym.make("Pendulum-v1",render_mode = "rgb_array")
# env = RecordVideo(env,"../Results/PolicyBased",lambda x: x%25 == 0 and x != 0 , fps=15 )
actionNum = env.action_space.shape[0] 
stateNum = env.observation_space.shape[0]
print((actionNum,stateNum))
actionMagnitude = env.action_space.high[0]


(1, 3)


In [4]:
import time
env.reset()
for steps in range(1):
    action = env.action_space.sample()
    _,rewards,terminated,truncated,_=env.step(action)
    print(rewards)
    time.sleep(0.1)
env.close()

-5.992865763509047


In [5]:
class Critic(nn.Module):
    def __init__(self,stateNum,h1,h2,h3,actionNum):
        super().__init__()
        self.q1 = nn.Sequential(
            nn.Linear(stateNum + actionNum,h1),
            nn.ReLU(),
            nn.Linear(h1,h2),
            nn.ReLU(),
            nn.Linear(h2,h3),
            nn.ReLU(),
            nn.Linear(h3,1),
        )
        self.q2 = nn.Sequential(
            nn.Linear(stateNum + actionNum,h1),
            nn.ReLU(),
            nn.Linear(h1,h2),
            nn.ReLU(),
            nn.Linear(h2,h3),
            nn.ReLU(),
            nn.Linear(h3,1),
        )

    def forward(self,state,action):
        x = torch.hstack((state,action))
        q1 = self.q1(x)
        q2 = self.q2(x)
        return q1,q2

In [6]:
class Actor(nn.Module):
    def __init__(self,stateNum,h1,h2,h3,actionNum):
        super().__init__()
        self.fc1 = nn.Linear(stateNum,h1)
        self.fc2 = nn.Linear(h1,h2)
        self.fc3 = nn.Linear(h2,h3)
        self.fc4 = nn.Linear(h3,actionNum)

    def forward(self,x):

        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        A = torch.tanh(self.fc4(x))* actionMagnitude
        return A

In [7]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.manual_seed(SEED)
critic        = Critic(stateNum,400,200,100,actionNum).to(device)
actor         = Actor(stateNum,400,200,100,actionNum).to(device)
critic_target = Critic(stateNum,400,200,100,actionNum).to(device)
actor_target  = Actor(stateNum,400,200,100,actionNum).to(device)
critic_target.load_state_dict(critic.state_dict())
actor_target.load_state_dict(actor.state_dict())

critic_criterion = nn.MSELoss()
critic_optim =  Adam(critic.parameters(),CRITIC_LR)
actor_optim= Adam(actor.parameters(),ACTOR_LR)

buffer = ExperienceReplay(BUFFER_SIZE,device)

In [8]:
def soft_update(online,target,tau):
    online_dict = online.state_dict()
    target_dict = target.state_dict()
    for key in online_dict.keys():
        target_dict[key]= tau*online_dict[key] + (1-tau) * target_dict[key]
    target.load_state_dict(target_dict)

In [9]:

rewards = []
numSteps=0
update_steps = 0
for episode in range(EPISODES):
    old_observation,info =env.reset(seed=episode)
    cumulative_reward = 0

    for step in range(STEPS):
        numSteps += 1
        action = actor(torch.Tensor(old_observation).reshape(1,-1).to(device)).cpu().detach().numpy().squeeze()
        exploration_noise = np.random.normal(NOISE_MEAN,NOISE_STD,env.action_space.shape[0])
        noisyAction = np.clip(action.squeeze() + exploration_noise,-2,2)
        old_observation = old_observation.squeeze()
        new_observation,reward,terminated,truncated,info = env.step(noisyAction)
        cumulative_reward+=reward
        done = terminated or truncated
        buffer.append(old_observation,noisyAction,reward,new_observation,done)
        rewards.append(reward)
        old_observation = new_observation
        # t = time.time()
        if(buffer.size() >= BATCH_SIZE):
            old_state,old_action,reward,new_state,done=buffer.sample(BATCH_SIZE)
            reward = reward.reshape(-1,1)
            done = done.reshape(-1,1)   
            old_action = old_action.reshape(-1,actionNum)

            y_hat1,y_hat2 = critic(old_state,old_action)
            #TODO fix this after downloading torch
            with torch.no_grad():
                targetAction = actor_target(new_state)
                targetActionNoise = torch.clip(torch.normal(NOISE_MEAN,NOISE_STD,targetAction.shape,device=device),-0.2,0.2)
                noisyTargetAction = targetAction + targetActionNoise
                clippedNoisyAction = torch.clip(noisyTargetAction,-actionMagnitude,actionMagnitude)
                q1,q2 = critic_target(new_state,clippedNoisyAction)
                y = reward + GAMMA * torch.min(q1,q2) * (1.0 - done)

            critic_loss1 =critic_criterion(y_hat1,y)
            critic_loss2 =critic_criterion(y_hat2,y)
            critic_loss=critic_loss1 + critic_loss2
            critic_optim.zero_grad()
            critic_loss.backward()
            critic_optim.step()
            update_steps+=1
            #delayed Update
            if update_steps % POLICY_DELAY == 0 :

                #freeze critic
                for p in critic.parameters():
                    p.requires_grad = False

                actor_optim.zero_grad()
                q1,q2 = critic(old_state,actor(old_state))
                actor_loss = -q1.mean()
                actor_loss.backward()
                actor_optim.step()

                #freeze critic
                for p in critic.parameters():
                    p.requires_grad = True

                with torch.no_grad():
                    soft_update(critic, critic_target, TAU)
                    soft_update(actor, actor_target, TAU)
        if(truncated or terminated):
            break;
        # print(t-time.time())

    print(f"Episode: {episode} | Reward: {cumulative_reward},actor_loss: ")# {actor_loss.detach().item()},critic_loss: {critic_loss.detach().item()}
    print(numSteps)

Episode: 0 | Reward: -1313.588210311692,actor_loss: 
200
Episode: 1 | Reward: -1087.1692665500848,actor_loss: 
400
Episode: 2 | Reward: -1646.1273994193343,actor_loss: 
600
Episode: 3 | Reward: -1738.0199421379762,actor_loss: 
800
Episode: 4 | Reward: -1806.3368187992767,actor_loss: 
1000
Episode: 5 | Reward: -1560.8232207621038,actor_loss: 
1200
Episode: 6 | Reward: -1384.8174479507672,actor_loss: 
1400
Episode: 7 | Reward: -1484.519321890306,actor_loss: 
1600
Episode: 8 | Reward: -1376.9777516744387,actor_loss: 
1800
Episode: 9 | Reward: -1576.070241399909,actor_loss: 
2000
Episode: 10 | Reward: -1559.6669885883591,actor_loss: 
2200
Episode: 11 | Reward: -1544.2008874274943,actor_loss: 
2400
Episode: 12 | Reward: -1454.711626683674,actor_loss: 
2600
Episode: 13 | Reward: -1563.666258884555,actor_loss: 
2800
Episode: 14 | Reward: -1523.954386298587,actor_loss: 
3000
Episode: 15 | Reward: -1435.9676854813122,actor_loss: 
3200
Episode: 16 | Reward: -1258.5964268640694,actor_loss: 
3400


KeyboardInterrupt: 

In [10]:
env = gym.make("Pendulum-v1",render_mode="human")
old_observation,_ =env.reset()
for steps in range(STEPS):
    old_observation = torch.Tensor(old_observation).to(device)
    action = actor(old_observation).cpu().detach().numpy()
    old_observation,_,terminated,truncated,_=env.step(action)
    if(truncated or terminated):
        break
env.close()

In [36]:
env.close()