In [41]:
import os
import gym
import math
import random
import numpy as np
from collections import namedtuple
from itertools import count
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import pkg_resources
import time



pkg_resources.get_distribution("gym").version






'0.26.2'

In [42]:
# import wandb

# wandb.init(project="DQN_yc930401", entity="xdvisch")

import pygame
pygame.init()
pygame.display.list_modes()

[(3840, 2160),
 (2560, 1600),
 (2560, 1440),
 (2048, 1536),
 (1920, 1440),
 (1920, 1200),
 (1920, 1080),
 (1768, 992),
 (1680, 1050),
 (1600, 1200),
 (1600, 1024),
 (1440, 900),
 (1366, 768),
 (1280, 1024),
 (1280, 960),
 (1280, 800),
 (1280, 720),
 (1176, 664),
 (1152, 864),
 (1024, 768),
 (800, 600),
 (720, 480),
 (640, 480)]

In [43]:
from gym.wrappers.human_rendering import HumanRendering

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# env = gym.make('CartPole-v1',  render_mode='rgb_array').unwrapped
env = gym.make('CartPole-v1',  render_mode='rgb_array')
wrapped = HumanRendering(env)
wrapped.reset()



resize = T.Compose([T.ToPILImage(),
                    T.Resize(40, interpolation=Image.CUBIC),
                    T.ToTensor()])
# This is based on the code from gym.
screen_width = 600
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.4
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10
device

  T.Resize(40, interpolation=Image.CUBIC),


device(type='cuda')

In [44]:
def get_cart_location():
    world_width = env.x_threshold * 2
    scale = screen_width / world_width
    return int(env.state[0] * scale + screen_width / 2.0)  # MIDDLE OF CART


def get_screen():
    screen = env.render().transpose((2, 0, 1))  # transpose into torch order (CHW)
    # Strip off the top and bottom of the screen
    screen = screen[:, 160:320]
    view_width = 320
    cart_location = get_cart_location()
    
    if cart_location < view_width // 2:
        slice_range = slice(view_width)
    elif cart_location > (screen_width - view_width // 2):
        slice_range = slice(-view_width, None)
    else:
        slice_range = slice(cart_location - view_width // 2,
                            cart_location + view_width // 2)
    # Strip off the edges, so that we have a square image centered on a cart
    screen = screen[:, :, slice_range]
    # Convert to float, rescare, convert to torch tensor (this doesn't require a copy)
    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
    screen = torch.from_numpy(screen)
    # Resize, and add a batch dimension (BCHW)
    return resize(screen).unsqueeze(0).to(device)

In [45]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)
        self.head = nn.Linear(448, 2)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        return self.head(x.view(x.size(0), -1))

In [46]:
steps_done = 0

def select_action(state):
    global steps_done
    sample = random.random()

    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    # wandb.log({"eps_threshold": eps_threshold})
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long)




def optimize_model(policy_net, optimizer, memory):
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for detailed explanation).
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.uint8)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the columns of actions taken
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = policy_net(non_final_next_states).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = ((next_state_values * GAMMA) + reward_batch).unsqueeze(1)

    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [47]:


# Define the custom x axis metric
# wandb.define_metric("episode")

# Define which metrics to plot against that x-axis
# wandb.define_metric("duration", step_metric='episode')

def trainIters(policy_net, n_iters=60):
    optimizer = optim.Adam(policy_net.parameters(), lr=0.0001, weight_decay=1e-5)
    memory = ReplayMemory(10000)
    for iter in range(n_iters):
        # Initialize the environment and state
        env.reset()
        wrapped.reset()
        print("spel beindigd!")
        # time.sleep(2)
        last_screen = get_screen()
        current_screen = get_screen()
        state = current_screen - last_screen
        losses = []
        for t in count():
            env.render()
            wrapped._render_frame()
            # print(f"stap {t} in huidige episode")
            # time.sleep(5)
            action = select_action(state)
            
            _, reward, done, _, _ = env.step(action.item())
            
            reward = torch.tensor([reward], device=device)

            # Observe new state
            last_screen = current_screen
            
            current_screen = get_screen()
            if not done:
                next_state = current_screen - last_screen
            else:
                next_state = None

            # Store the transition in memory
            memory.push(state, action, next_state, reward)
            # Move to the next state
            state = next_state

            # Perform one step of the optimization (on the target network)
            optimize_model(policy_net, optimizer, memory)
            if done:
                
                # log_dict = {
                # "episode": iter + 1,
                
                # "duration": t
                # }
                # wandb.log(log_dict)
                # print('Iteration: {}, Score: {}'.format(iter + 1, t))
                
                break

    torch.save(policy_net, 'model/policy_net.pkl')
    # print('Complete')
    wrapped.render()
    wrapped.close()

    env.render()
    env.close()



In [48]:
if __name__ == "__main__":
    if os.path.exists('model/policy_net.pkl'):
        policy_net = torch.load('model/policy_net.pkl')
        print('Model loaded')
    else:
        policy_net = DQN().to(device)
    trainIters(policy_net, n_iters=50)

Model loaded
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!


  next_state_values[non_final_mask] = policy_net(non_final_next_states).max(1)[0].detach()


wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!
wrapper geinitialiseerd!


KeyboardInterrupt: 

: 