In [1]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import numpy as np 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm, trange
import wandb

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
	from IPython import display

plt.ion()

# if GPU is to be used
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": 
	print(torch.cuda.get_device_name()) 
else: 
	print(device) 

NVIDIA GeForce RTX 3090


In [2]:
def make_env(env_id, seed, idx, run_name): 
	def thunk(): 
		env = gym.make(env_id)
		# if idx == 0: 
		# env = gym.wrappers.RecordVideo(env, 'videos/pusher_ppo')
		# env = gym.wrappers.ClipAction(env) 
		# env = gym.wrappers.NormalizeObservation(env) 
		# env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
		# env = gym.wrappers.NormalizeReward(env, gamma=0.99)
		# env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
		return env 
	return thunk() 

In [3]:
class diag_gaussian_distribution(): 
	def __init__(self, action_dim): 
		super().__init__() 
		self.action_dim = action_dim 
		self.mean_actions = None 
		self.log_std = None 
		
	def proba_distribution_net(self, latent_dim, log_std_init=0.0): 
		mean_actions = nn.Linear(latent_dim, self.action_dim) 
		log_std = nn.Parameter(torch.ones(self.action_dim)*log_std_init, requires_grad=True) 
		return mean_actions, log_std 
	
	def proba_distribution(self, mean_actions, log_std): 
		action_std = torch.ones_like(mean_actions) * log_std.exp() 
		self.distribution = torch.distributions.normal.Normal(mean_actions, action_std)
		return self 
	
	def log_prob(self, actions): 
		log_prob = self.distribution.log_prob(actions) 
		if len(log_prob) > 1: 
			return log_prob.sum(dim=1) 
		else: 
			return log_prob.sum() 
		
	def entropy(self): 
		entropy = self.distribution.entropy() 
		if len(entropy) > 1: 
			return entropy.sum(dim=1) 
		else: 
			return entropy.sum() 
		
	def sample(self): 
		return self.distribution.rsample() 
	
	def mode(self): 
		return self.distribution.mean 
	
	def actions_from_params(self, mean_actions, log_std, deterministic=False): 
		self.proba_distribution(mean_actions, log_std) 
		if deterministic: 
			return self.mode() 
		return self.sample() 
	
	def log_prob_from_params(self, mean_actions, log_std): 
		actions = self.actions_from_params(mean_actions, log_std) 
		log_prob = self.log_prob(actions) 
		return actions, log_prob 

In [4]:
hidden_dim = 64

def layer_init(layer, std=np.sqrt(2), bias_const=0.0): 
	nn.init.orthogonal_(layer.weight, std) 
	nn.init.constant_(layer.bias, bias_const)
	return layer 

class Agent(nn.Module): 
	def __init__(self, envs): 
		super().__init__() 
		self.state_dist = diag_gaussian_distribution(action_dim=np.array(envs.single_action_space.shape).prod())

		self.critic = nn.Sequential(
			layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), hidden_dim)), 
			nn.Tanh(), 
			layer_init(nn.Linear(hidden_dim, hidden_dim)), 
			nn.Tanh(), 
			layer_init(nn.Linear(hidden_dim, 1), std=1.0)
		)

		self.actor = nn.Sequential(
			layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), hidden_dim)), 
			nn.Tanh(), 
			layer_init(nn.Linear(hidden_dim, hidden_dim)), 
			nn.Tanh(), 
		)

		self.actor_mean, self.actor_logstd = self.state_dist.proba_distribution_net(latent_dim=hidden_dim)
		self.actor_mean = layer_init(self.actor_mean, std=0.01)

	def get_value(self, x): 
		return self.critic(x) 
	
	def get_action_and_value(self, x, pre_action=None): 
		action = self.actor(x) 
		action_mean = self.actor_mean(action) 
		actions, log_prob = self.state_dist.log_prob_from_params(mean_actions=action_mean, log_std=self.actor_logstd)
		entropy = self.state_dist.entropy() 
		if pre_action is None: 
			return actions, log_prob, entropy, self.critic(x) 
		else: 
			log_prob = self.state_dist.log_prob(pre_action)
			return pre_action, log_prob, entropy, self.critic(x) 

In [5]:
seed = 1337 
num_envs = 1
num_steps = 2048

envs = gym.vector.SyncVectorEnv([
    lambda: make_env('Pusher-v4', seed+i, i, 'the first') for i in range(num_envs)
])

agent = Agent(envs).to(device) 
optimizer = optim.Adam(agent.parameters(), lr=2.5e-4, eps=1e-5)

In [6]:
obs = torch.zeros((num_steps, num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((num_steps, num_envs) + envs.single_action_space.shape).to(device) 
logprobs = torch.zeros((num_steps, num_envs)).to(device) 
rewards = torch.zeros((num_steps, num_envs)).to(device) 
dones = torch.zeros((num_steps, num_envs)).to(device) 
values = torch.zeros((num_steps, num_envs)).to(device) 

gae_lambda = 1.0
total_timesteps = 4000000 #(768 * 2)
gamma = 0.99
update_epochs = 10
mini_bs = 32 # number of mini batches 
clip_coef = 0.2
ent_coef = 0.0 # default: 0.01
vf_coef = 0.5
bs = int(num_envs * num_steps)
minibatch_size = int(bs / mini_bs) # 64
print(f'minibatch_size: {minibatch_size}')

global_step = 0 
next_obs = torch.tensor(envs.reset()[0], dtype=torch.float32).to(device) 
next_done = torch.zeros(num_envs).to(device) 
num_updates = total_timesteps // bs

highest_score = -999

for update in (loop := trange(1, num_updates + 1)): 
	score = np.zeros(num_envs)
	for step in range(0, num_steps):
		# Simulate first, learn later. 
		global_step += 1 * num_envs 
		obs[step] = next_obs 
		dones[step] = next_done 

		with torch.no_grad(): 
			action, logprob, _, value = agent.get_action_and_value(next_obs) 
			values[step] = value.flatten()
		actions[step] = action 
		logprobs[step] = logprob 

		next_obs, reward, terminated, truncated, info = envs.step(action.cpu().numpy())
		done = np.bitwise_or(terminated, truncated).astype(np.float32)
		next_obs, next_done = torch.tensor(next_obs, dtype=torch.float32).to(device), torch.tensor(done).to(device) 
		
		score = reward.tolist()

		rewards[step] = torch.tensor(reward).to(device).view(-1) 
		
		if max(score) > highest_score: 
			highest_score = max(score)
			torch.save(agent.state_dict(), f'models/pusher_ppo_diag/best_performing.pth')
	
	with torch.no_grad(): 
		next_value = agent.get_value(next_obs).reshape(1, -1)
		advantages = torch.zeros_like(rewards).to(device)
		last_gae_lam = 0 
		for t in reversed(range(num_steps)):
			if t == num_steps - 1: 
				# if it is the last step then look at the next step to come (not stored)
				next_non_terminal = 1.0 - next_done 
				next_values = next_value 
			else: # if it is not the last step, it is already stored. 
				next_non_terminal = 1.0 - dones[t + 1]
				next_values = values[t + 1]
			# if it is done, 'next_non_terminal' is 0. 
			# return = rewards_t + \sum_{t+1}^{t+n} \gamma(return_{t+1})(next_non_terminal)
			delta = rewards[t] + gamma * next_values * next_non_terminal - values[t]
			advantages[t] = last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam
		returns = advantages + values 
		
	b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) 
	b_logprobs = logprobs.reshape(-1) 
	b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
	b_advantages = advantages.reshape(-1) 
	b_returns = returns.reshape(-1) 
	b_values = values.reshape(-1) 
	b_inds = np.arange(bs) 
	clipfracs = list() 

	for epoch in range(update_epochs): 
		np.random.shuffle(b_inds)
		for start in range(0, bs, minibatch_size): 
			end = start + minibatch_size
			mb_inds = b_inds[start:end]
			_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
			logratio = newlogprob - b_logprobs[mb_inds]
			ratio = logratio.exp() # Ratio between current policy and old policy 

			with torch.no_grad(): 
				# old_approx_kl = (-logprob).mean() 
				# approx_kl = ((ratio - 1) - logratio).mean() 
				clipfracs += [((ratio - 1.0).abs() > clip_coef).float().mean().item()]

			mb_advantages = b_advantages[mb_inds]
			# normalize advantage 
			mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

			pg_loss1 = mb_advantages * ratio 
			pg_loss2 = mb_advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
			pg_loss = -torch.min(pg_loss1, pg_loss2).mean()
			# max because we are taking gradient ascent 

			newvalue = newvalue.view(-1)
			# newvalue = b_values[mb_inds] + torch.clamp(newvalue - b_values[mb_inds], -clip_coef, clip_coef)
			v_loss = F.mse_loss(b_returns[mb_inds], newvalue)

			entropy_loss = entropy.mean() 
			loss = pg_loss - ent_coef * entropy_loss + vf_coef * v_loss
			# 'entropy_loss' maintains exploratory behavior of the model. 
			# 'ent_coef' controls the exploratory behavior 

			optimizer.zero_grad() 
			loss.backward() 
			nn.utils.clip_grad_norm_(agent.parameters(), 0.5) 
			optimizer.step()

			loop.set_postfix(highest_score=highest_score, zscore=score, value_loss=v_loss.item(), policy_gradient_loss=pg_loss.item(), loss=loss.item()) 
	if update % 64 == 0: 
		torch.save(agent.state_dict(), f'models/pusher_ppo_diag/{update}.pth')

torch.save(agent.state_dict(), f'models/pusher_ppo_diag/final.pth')

minibatch_size: 64


100%|██████████| 1953/1953 [4:08:22<00:00,  7.63s/it, highest_score=-.0848, loss=2.58, policy_gradient_loss=-.0355, value_loss=5.23, zscore=[-0.1018188363237815]]         


In [7]:
# starts with 1.22 