In [None]:
import os
from torch import nn
import torch
import gym
from collections import deque
import itertools
import numpy as np
import random
from pytorch_wrappers import make_atari_deepmind, BatchedPytorchFrameStack, PytorchLazyFrames
from baselines_wrappers.dummy_vec_env import DummyVecEnv
from baselines_wrappers.monitor import Monitor
from torch.utils.tensorboard import SummaryWriter

In [None]:
GAMMA = 0.99                # discount rate
BATCH_SIZE = 32             # quanti elementi dal replay buffer
BUFFER_SIZE = int(1e6)        # dimensione del replay buffer, superato questo valore, i vecchi elementi vengono sovrascritti
MIN_REPLAY_SIZE = 50_000     # quanti elementi sono necessari prima di iniziare la discesa del gradiente
EPSILON_START = 1.0 
EPSILON_END = 0.1
EPSILON_DECAY = int(1e6)      # quanti episodi per arrivare da EPSILON_START a EPSILON_END
NUM_ENVS = 4
TARGET_UPDATE_FREQ = 10_000 // NUM_ENVS # ogni quanti episodi aggiorno Q con Q^
LEARNING_RATE = 5e-5
SAVE_INTERVAL = 10_000
MODELS_DIR = './saved_models'
SAVE_PATH = os.path.join(MODELS_DIR, 'atari_model.pth')
LOG_DIR = './logs/atari_vanilla'
LOG_INTERVAL = 1_000


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def init_weights(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')


def nature_cnn(observation_space, depths=(32, 64, 64), final_layer=512):
    n_input_channels = observation_space.shape[0]

    cnn = nn.Sequential(
        nn.Conv2d(n_input_channels, depths[0], kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(depths[0], depths[1], kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Conv2d(depths[1], depths[2], kernel_size=3, stride=1),
        nn.ReLU(),
        nn.Flatten())

    # Compute shape by doing one forward pass
    with torch.no_grad():
        n_flatten = cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape[1]

    out = nn.Sequential(cnn, nn.Linear(n_flatten, final_layer), nn.ReLU())

    return out

In [None]:
class Network(nn.Module):
    def __init__(self, env, device):
        super().__init__()

        self.num_actions = env.action_space.n
        self.device = device

        conv_net = nature_cnn(env.observation_space)

        self.net = nn.Sequential(conv_net, nn.Linear(512, self.num_actions))

    def forward(self, x):
        return self.net(x)

    def act(self, obses, epsilon):
        obses_t = torch.as_tensor(obses, dtype=torch.float32, device=self.device)
        q_values = self(obses_t)

        max_q_indices = torch.argmax(q_values, dim=1)
        actions = max_q_indices.detach().tolist()

        for i in range(len(actions)):
            rnd_sample = random.random()
            if rnd_sample <= epsilon:
                actions[i] = random.randint(0, self.num_actions - 1)

        return actions
    
    def save(self, path):
        
        torch.save(self.state_dict(), path)

    def load(self, path):
        if not os.path.exists(path):
            raise FileNotFoundError(path)
        self.load_state_dict(torch.load(path))
    
    def compute_loss(self, transitions, target_net):
        obses = [t[0] for t in transitions]
        actions = np.asarray([t[1] for t in transitions])
        rews = np.asarray([t[2] for t in transitions])
        dones = np.asarray([t[3] for t in transitions])
        new_obses = [t[4] for t in transitions]

        if isinstance(obses[0], PytorchLazyFrames):
            obses = np.stack([o.get_frames() for o in obses])
            new_obses = np.stack([o.get_frames() for o in new_obses])
        else:
            obses = np.asarray(obses)
            new_obses = np.asarray(new_obses)

        obses_t = torch.as_tensor(obses, dtype=torch.float32, device=self.device)
        actions_t = torch.as_tensor(actions, dtype=torch.int64, device=self.device).unsqueeze(-1)
        rews_t = torch.as_tensor(rews, dtype=torch.float32, device=self.device).unsqueeze(-1)
        dones_t = torch.as_tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(-1)
        new_obses_t = torch.as_tensor(new_obses, dtype=torch.float32, device=self.device)

        # Compute Targets
        target_q_values = target_net(new_obses_t)
        max_target_q_values = target_q_values.max(dim=1, keepdim=True)[0]

        targets = rews_t + GAMMA * (1 - dones_t) * max_target_q_values

        # Compute Loss
        q_values = self(obses_t)

        action_q_values = torch.gather(input=q_values, dim=1, index=actions_t)

        loss = nn.functional.smooth_l1_loss(action_q_values, targets)

        return loss
        

In [None]:
make_env = lambda: Monitor(make_atari_deepmind('BreakoutNoFrameskip-v4', scale_values=True), allow_early_resets=True)
vec_env = SubprocVecEnv([make_env for _ in range(NUM_ENVS)])
# use subprocvecenv after debugging
# env = SubprocVecEnv([make_env for _ in NUM_ENVS])

# the observation returned by BatchedPytorchFrameStack is LazyFrames
env = BatchedPytorchFrameStack(vec_env,k=4)

replay_buffer = deque(maxlen=BUFFER_SIZE)
epinfos_buffer = deque([],maxlen=100)

episode_count = 0


In [None]:
summary_writer = SummaryWriter(LOG_DIR)

In [None]:
online_net = Network(env,device)
target_net = Network(env,device)
online_net.apply(init_weights)

online_net = online_net.to(device)
target_net = target_net.to(device)

target_net.load_state_dict(online_net.state_dict())

optimizer = torch.optim.Adam(online_net.parameters(), lr=LEARNING_RATE)

In [None]:
# iniziatlize Replay Buffer
obses = env.reset()

for _ in range(MIN_REPLAY_SIZE):
    # we want to sample multiple action

    actions = [env.action_space.sample() for _ in range(NUM_ENVS)]

    new_obses, rews, dones, _ = env.step(actions)
    for obs, action, rew, done, new_obs in zip(obses, actions, rews, dones, new_obses):
        transition = (obs,action,rew,done,new_obs)
        replay_buffer.append(transition)
        
    obses = new_obses
    
    # the dummy vec env reset the environment
    # automatically, no need to check


In [None]:
# training loop

obses = env.reset()

for step in itertools.count():
    epsilon = np.interp(step * NUM_ENVS,[0,EPSILON_DECAY],[EPSILON_START,EPSILON_END])
    rnd_sample = random.random()

    if isinstance(obses[0],PytorchLazyFrames):
        act_obses = np.stack([o.get_frames() for o in obses])
        actions = online_net.act(act_obses,epsilon)
    else:
        actions = online_net.act(obses,epsilon)

    # we make the random choice in the net
    new_obses, rews, dones, infos = env.step(actions)

    for obs, action, rew, done, new_obs, info in zip(obses, actions, rews, dones, new_obses, infos):
        transition = (obs,action,rew,done,new_obs)
        replay_buffer.append(transition)
        if done:
            epinfos_buffer.append(info['episode'])
            episode_count +=1
    
    obses = new_obses

    # start gradient step 
    transitions = random.sample(replay_buffer,BATCH_SIZE)
    # si usa il numpy perch torch è più veloce sui numpy
    
    loss = online_net.compute_loss(transitions,target_net)
    # gradient descent

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # update target network
    if step % TARGET_UPDATE_FREQ == 0:
        target_net.load_state_dict(online_net.state_dict())
        
    # Logging
    if step % LOG_INTERVAL == 0:
        # epinfos_buffer is a deque of dicts, r and l are, respectively, the
        # reward and length of the episode
        rew_mean = np.mean([epinfo['r'] for epinfo in epinfos_buffer]) or 0
        len_mean = np.mean([epinfo['l'] for epinfo in epinfos_buffer]) or 0
        print()
        print('Step', step)
        print('Avg rew', rew_mean)
        print('Avg ep len', len_mean)
        print('Episodes', episode_count)

        # tensorboard
        summary_writer.add_scalar('avg_rew', rew_mean, global_step=step)
        summary_writer.add_scalar('avg_ep_len', len_mean, global_step=step)
        summary_writer.add_scalar('episodes', episode_count, global_step=step)
        

    # saving
    if step % SAVE_INTERVAL == 0 and step!=0:
        print('Saving model')
        online_net.save(SAVE_PATH)


    
