In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.optim as optim
import torch.nn.functional as F
from collections import deque, namedtuple
from PIL import Image
import random
import math
from itertools import count
import numpy as np
from IPython.display import clear_output
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
world = gym.make('CartPole-v1', render_mode = 'rgb_array').unwrapped
world.reset()

In [22]:
class Environment():
    def __init__(self, world, screen_width, device):
        self.resizer = T.Compose([
            T.ToPILImage(),
            T.Resize(40, interpolation=Image.BICUBIC),
            T.ToTensor()
        ])
        self.world = world
        # self.world.reset()
        self.screen_width = screen_width
        self.device = device

    def get_cart_location(self, world):
        world_width = world.x_threshold * 2
        scale = self.screen_width / world_width
        return int(world.state[0] * scale + self.screen_width / 2.0)
    
    def get_screen(self, world):
        screen = world.render()[-1].transpose((2, 0, 1))
        screen = screen[:, 160:320]
        view_width = 320
        cart_location = self.get_cart_location(world)

        if cart_location < view_width // 2:
            slice_range = slice(view_width)
        elif cart_location > (self.screen_width - view_width // 2):
            slice_range = slice(-view_width, None)
        else:
            slice_range = slice(cart_location - view_width // 2, cart_location + view_width // 2)

        screen = screen[:, :, slice_range]
        screen = torch.Tensor(screen) / 255
        # screen = torch.from_numpy(screen)
        return self.resizer(screen).unsqueeze(0).to(self.device).contiguous()

environment = Environment(world, 600, device)
screen = environment.get_screen(world)
# screen.shape

In [None]:
class DQN(nn.Module):
	def __init__(self):
		super(DQN, self).__init__()
		self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
		self.bn1 = nn.BatchNorm2d(16)
		self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
		self.bn2 = nn.BatchNorm2d(32)
		self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
		self.bn3 = nn.BatchNorm2d(32)
		self.head = nn.Linear(448, 2)

	def forward(self, x):
		x = F.relu(self.bn1(self.conv1(x)))
		x = F.relu(self.bn2(self.conv2(x)))
		x = F.relu(self.bn3(self.conv3(x)))
		return self.head(x.view(x.size(0), -1))
	
dqn = DQN().to(device)
dqn(screen).shape

In [None]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
	def __init__(self, capacity):
		# self.capacity = capacity
		self.memory = deque(maxlen=capacity)

	def push(self, *args):
		self.memory.append(Transition(*args))

	def sample(self, batch_size):
		return random.sample(self.memory, batch_size)

	def __len__(self):
		return len(self.memory)
	
# memory = ReplayMemory(10000)
# memory.push(1, 2, 3, 4)
# memory.sample(5)
# memory.__len__()

In [None]:
class Agent():
	def __init__(self, device):
		self.device = device
		self.policy_net = DQN().to(self.device)
		self.target_net = DQN().to(self.device)
		self.optimizer = optim.Adam(self.policy_net.parameters())
		self.criterion = nn.SmoothL1Loss()
		self.memory = ReplayMemory(100000)
		self.steps_done = 0

		self.EPSILON_END = 0.05
		self.EPSILON_START = 0.9
		self.EPSILON_DECAY = 500
		self.GAMMA = 0.99

		self.BATCH_SIZE = 128

	def remember(self, *args):
		self.memory.push(*args)

	def select_action(self, state):
		# Select an action according to an epsilon greedy approach
		sample = random.random()
		epsilon_threshold = self.EPSILON_END + (self.EPSILON_START - self.EPSILON_END) * math.exp(-1. * self.steps_done / self.EPSILON_DECAY) 
		self.steps_done += 1
		if sample < epsilon_threshold:
			return torch.tensor([[random.randrange(2)]], device=self.device, dtype=torch.long) # [1, 1]
		else:
			with torch.no_grad():
				return self.policy_net(state).max(1)[1].view(1, 1) # [1, 1]

	def optimize_model(self):
		if len(self.memory) < self.BATCH_SIZE:
			return
		
		transitions = self.memory.sample(self.BATCH_SIZE)
		batch = Transition(*zip(*transitions))
	
		non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=self.device, dtype=torch.uint8)
		non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

		state_batch = torch.cat(batch.state)
		action_batch = torch.cat(batch.action)
		reward_batch = torch.cat(batch.reward)

		state_action_values = self.policy_net(state_batch).gather(1, action_batch)
		
		next_state_values = torch.zeros(self.BATCH_SIZE, device=self.device)
		next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach()

		expected_state_action_values = (next_state_values * self.GAMMA) + reward_batch

		loss = self.criterion(state_action_values, expected_state_action_values.unsqueeze(1))
		self.optimizer.zero_grad()
		loss.backward()
		for param in self.policy_net.parameters():
			param.grad.data.clamp_(-1, 1)
		self.optimizer.step()

agent = Agent(device)

In [None]:
durations = []
TARGET_UPDATE = 10
episode_rewards = []
plot_frequency = 10

for i in range(500):
	world.reset()

	# next_screen = environment.get_screen(world)
	current_screen = environment.get_screen(world)
	episode_reward = 0

	for t in count():
		action = agent.select_action(current_screen)
		_, reward, terminated, truncated, _ = world.step(action.item())
		next_screen = environment.get_screen(world)
		episode_reward += reward
		done = terminated or truncated
		reward = torch.tensor([reward], device=device)

		agent.remember(current_screen, action, next_screen, reward)
		current_screen = next_screen

		agent.optimize_model()

		if done:
			durations.append(t + 1)
			break
	
	episode_rewards.append(episode_reward)
	if i % TARGET_UPDATE == 0:
		agent.target_net.load_state_dict(agent.policy_net.state_dict())

	if i % plot_frequency == 0:
		# break
		clear_output(wait=True)
		print('Episode: ', i, 'Mean reward: ', np.mean(episode_rewards[-plot_frequency:]))
		plt.plot(episode_rewards)
		plt.show()


In [None]:
torch.save(agent.policy_net.state_dict(), 'model_weights.pth')
# agent.policy_net.load_state_dict(torch.load('model_weights.pth'))


In [None]:
# %matplotlib inline

In [None]:
import matplotlib.animation as animation

In [None]:
model = DQN().to(device)

# Load the trained model weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

In [23]:
import gymnasium as gym
from gymnasium.utils.save_video import save_video

env = gym.make("CartPole-v1", render_mode="rgb_array_list")
_ = env.reset()
step_starting_index = 0
episode_index = 0
for step_index in range(199): 
   # action = env.action_space.sample()
   screen = environment.get_screen(env)
   model_output = model(screen)
   action = model_output.max(1)[1].item()
   _, _, terminated, truncated, _ = env.step(action)

   if terminated or truncated:
      save_video(
         env.render(),
         "videos",
         fps=env.metadata["render_fps"],
         step_starting_index=step_starting_index,
         episode_index=episode_index
      )
      step_starting_index = step_index + 1
      episode_index += 1
      env.reset()
env.close()

  logger.warn(
  logger.warn(


Moviepy - Building video f:\OUT\Github\Human-level control through deep reinforcement learning\videos/rl-video-episode-0.mp4.
Moviepy - Writing video f:\OUT\Github\Human-level control through deep reinforcement learning\videos/rl-video-episode-0.mp4



                                                  

Moviepy - Done !
Moviepy - video ready f:\OUT\Github\Human-level control through deep reinforcement learning\videos/rl-video-episode-0.mp4




Moviepy - Building video f:\OUT\Github\Human-level control through deep reinforcement learning\videos/rl-video-episode-1.mp4.
Moviepy - Writing video f:\OUT\Github\Human-level control through deep reinforcement learning\videos/rl-video-episode-1.mp4



                                                  

Moviepy - Done !
Moviepy - video ready f:\OUT\Github\Human-level control through deep reinforcement learning\videos/rl-video-episode-1.mp4


