In [20]:
import sys, pickle, time, math, random

sys.path.append("/usr/local/lib/python3.8/site-packages")

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

from ale_py import ALEInterface
from ale_py.roms import Skiing

In [21]:
sys.path.append("../utils/")

from base_models import *
#from models import *
from base_datasets import *
from stats_collector import StatsCollector

In [22]:
ale = ALEInterface()
ale.loadROM(Skiing)

env = gym.make("ALE/Breakout-v5")

In [23]:
class CNNConfig:
    
    def __init__(self, out_channels_one=16, out_channels_two=32, hidden_layer_depth=256, output_depth=3):
        self.out_channels_one = out_channels_one
        self.out_channels_two = out_channels_two
        self.hidden_layer_depth = hidden_layer_depth
        self.output_depth = output_depth
        
class BasicCNN(nn.Module):

    def __init__(self, config):
        super().__init__()
        
        self.preprocessing_pool = nn.MaxPool3d((3, 2, 2))

        self.conv_layer_one = nn.Conv2d(in_channels=1, out_channels=config.out_channels_one, stride=4, kernel_size=8)
        self.pooling_layer_one = nn.AvgPool2d(kernel_size=4)
        self.batch_norm_layer_one = nn.BatchNorm2d(config.out_channels_one)

        self.conv_layer_two = nn.Conv2d(in_channels=config.out_channels_one, out_channels=config.out_channels_two, stride=2, kernel_size=4)
        self.pooling_layer_two = nn.AvgPool1d(kernel_size=2)
        self.batch_norm_layer_two = nn.BatchNorm1d(config.out_channels_two)
        
        self.linear_layer_one = nn.Linear(config.out_channels_two, config.hidden_layer_depth)
        self.batch_norm_layer_three = nn.BatchNorm1d(config.hidden_layer_depth)
        self.linear_layer_two = nn.Linear(config.hidden_layer_depth, config.output_depth) 

    def forward(self, images):
        image_tensor = images if torch.is_tensor(images) else torch.tensor(images).squeeze().unsqueeze(0)
        down_sampled_image_tensor = self.preprocessing_pool(image_tensor.transpose(3, 1).transpose(2, 3).float())
        
        output_one = self.batch_norm_layer_one(F.relu(self.pooling_layer_one(self.conv_layer_one(down_sampled_image_tensor))))
        output_two = self.batch_norm_layer_two(F.relu(self.pooling_layer_two(self.conv_layer_two(output_one).squeeze(-1))).squeeze(-1))
        output_three = self.batch_norm_layer_three(F.relu(self.linear_layer_one(output_two)))
        final_output = self.linear_layer_two(output_three)
        return final_output

    def act(self, images):
        output = self(images)
        return torch.argmax(output, dim=1)

    def evaluate(self, images, actions=None):
        output = self(images)
        if actions != None:
            return torch.stack([output[i, action] for i, action in enumerate(actions)])
        return torch.max(output, dim=1).values

In [24]:
class CuriousityModule:
    
    def __init__(self):
        pass
    
    def calculate_curiousity_reward(self, images):
        pass
    
    def update(self):
        pass

class RandomNetworkDistillationCuriousityModule(CuriousityModule):
    
    def __init__(self, batch_mode=False):
        super().__init__()
        
        curiousity_config = CNNConfig(8, 16, 32, 8)
        self.target_model = BasicCNN(curiousity_config)
        self.source_model = BasicCNN(curiousity_config)
        self.optimizer = optim.AdamW(self.source_model.parameters(), lr=(1e-3 if batch_mode else 1e-5))
        
        self.target_model.eval()
        self.source_model.eval()
    
    def calculate_curiousity_reward(self, images):
        image_tensor = images if torch.is_tensor(images) else torch.tensor(images).squeeze().unsqueeze(0)
        batch_size, _, _, _ = image_tensor.size()
        with torch.no_grad():
            target = self.target_model(image_tensor)
        source = self.source_model(image_tensor)
        dists = torch.stack([torch.dist(source[i, :], target[i, :]) for i in range(batch_size)])
        return dists
    
    def update(self, images):
        dists = self.calculate_curiousity_reward(images)
        self.optimizer.zero_grad()
        dists.backward()
        self.optimizer.step()
        return dists
    
    def save(self, source_path, target_path):
        save_model(self.source_model, source_path)
        save_model(self.target_model, target_path)
    
    def load(self, source_path, target_path):
        self.source_model = load_model(source_path)
        self.target_model = load_model(target_path)
        self.optimizer = optim.AdamW(self.source_model.parameters(), lr=(1e-3 if batch_mode else 1e-5))
        self.target_model.eval()
        self.source_model.eval()

In [37]:
MAX_DATASET_SIZE = 100_000
BATCH_SIZE = 64
USE_PRIORITY = False
NUM_EPOCHS = 100_000
DISCOUNT = 0.99
SAVE_PATH = "model.p"
CURIOUSITY_SOURCE_SAVE_PATH = "curious-source-model.p"
CURIOUSITY_TARGET_SAVE_PATH = "curious-target-model.p"
PRINT_FREQUENCY = 1_000
TARGET_UPDATE_FREQUENCY = 250
MODEL_UPDATE_FREQUENCY = 16
TOTAL_FRAMES = 50_000_000
CURIOUSITY_COEFF = 0.0
REGULAR_REWARD_COEFF = 1.0
RELOAD = True

In [38]:
base_config = CNNConfig(16, 32, 256, 4)
model = BasicCNN(base_config)
curiousity_module = RandomNetworkDistillationCuriousityModule()
if RELOAD:
    model = load_model(SAVE_PATH)
    #curious_module.load(CURIOUSITY_SOURCE_SAVE_PATH, CURIOUSITY_TARGET_SAVE_PATH)
dataset = PrioritizedLog(MAX_DATASET_SIZE) if USE_PRIORITY else SimpleLog(MAX_DATASET_SIZE)
reduction = 'none' if USE_PRIORITY else 'mean'

optimizer = optim.AdamW(model.parameters(), lr=1e-5)
loss_func = nn.MSELoss()
rollout_stats_collector = StatsCollector()
update_stats_collector = StatsCollector()

save_model(model, SAVE_PATH)
target_model = load_model(SAVE_PATH)
curiousity_module.save(CURIOUSITY_SOURCE_SAVE_PATH, CURIOUSITY_TARGET_SAVE_PATH)

In [32]:
done = False
updates = 0
stuck_count = 0
for frame in range(TOTAL_FRAMES):

    model.eval()

    if frame == 0 or done:
        observation = env.reset()
        done = False

    old_observation = observation
    with torch.no_grad():
        action = model.act(observation)

    observation, reward, done, info = env.step(action.item())
    dataset.add(old_observation, observation, action, reward, done)
    dist = curiousity_module.update(observation)
    
    rollout_stats_collector.add({"Reward": reward, "No-op": 1 if action.item() == 0 else 0, "Fire": 1 if action.item() == 1 else 0, \
                                 "Right": 1 if action.item() == 2 else 0, "Left": 1 if action.item() == 3 else 0, "Curiousity": dist.item()})

    if frame > BATCH_SIZE and frame % MODEL_UPDATE_FREQUENCY == 0:
        model.train()
        
        current_observation_batch, next_observation_batch, action_batch, rewards_batch, incompletions = dataset.sample_batch(BATCH_SIZE)
        
        dists = curiousity_module.calculate_curiousity_reward(next_observation_batch)
        norm_dists = dists * (torch.mean(torch.abs(rewards_batch)).item() / torch.mean(dists).item())
        full_rewards = (REGULAR_REWARD_COEFF * rewards_batch) + (CURIOUSITY_COEFF * norm_dists)
    
        with torch.no_grad():
            target = (DISCOUNT * target_model.evaluate(next_observation_batch).squeeze() * incompletions.squeeze()) + full_rewards
            
        optimizer.zero_grad()
        prediction = model.evaluate(current_observation_batch, action_batch.squeeze())
        loss = loss_func(prediction, target)        
        loss.backward()
        optimizer.step()

        updates += 1
        update_stats_collector.add({"Loss": loss.item(), "Normal Rewards": torch.mean(rewards_batch).item(), \
                                    "Curiousity Rewards": torch.mean(norm_dists).item(), "Min Curiousity": torch.min(norm_dists).item(),\
                                    "Max Curiousity": torch.max(norm_dists).item(), "Total Reward": torch.mean(full_rewards).item(),\
                                    "Max Full Reward": torch.max(full_rewards).item(), "Min Full Reward": torch.min(full_rewards).item()})

    if updates % TARGET_UPDATE_FREQUENCY == 0:
        save_model(model, SAVE_PATH)
        target_model = load_model(SAVE_PATH)
        curiousity_module.save(CURIOUSITY_SOURCE_SAVE_PATH, CURIOUSITY_TARGET_SAVE_PATH)
        updates = 0

    if frame % PRINT_FREQUENCY == 0 or frame == NUM_EPOCHS - 1:
        print(frame, "/", TOTAL_FRAMES)
        rollout_stats_collector.show()
        rollout_stats_collector.reset()
        update_stats_collector.show()
        update_stats_collector.reset()
        print()

0 / 50000000
Reward: 0.0, No-op: 0.0, Fire: 0.0, Right: 1.0, Left: 0.0, Curiousity: 6.11


1000 / 50000000
Reward: 0.01, No-op: 0.34, Fire: 0.55, Right: 0.11, Left: 0.0, Curiousity: 3.37
Loss: 7126.73, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.01, Max Curiousity: 0.01, Total Reward: 0.01, Max Full Reward: 0.38, Min Full Reward: 0.0

2000 / 50000000
Reward: 0.01, No-op: 0.3, Fire: 0.62, Right: 0.08, Left: 0.0, Curiousity: 0.24
Loss: 6078.02, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.02, Total Reward: 0.01, Max Full Reward: 0.56, Min Full Reward: 0.0

3000 / 50000000
Reward: 0.01, No-op: 0.3, Fire: 0.62, Right: 0.08, Left: 0.0, Curiousity: 0.2
Loss: 5785.64, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.02, Total Reward: 0.01, Max Full Reward: 0.53, Min Full Reward: 0.0

4000 / 50000000
Reward: 0.01, No-op: 0.34, Fire: 0.59, Right: 0.07, Left: 0.0, Curiousity: 0.22
Loss: 5699.99,

32000 / 50000000
Reward: 0.01, No-op: 0.56, Fire: 0.37, Right: 0.07, Left: 0.0, Curiousity: 0.12
Loss: 108.19, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.02, Total Reward: 0.01, Max Full Reward: 0.4, Min Full Reward: 0.0

33000 / 50000000
Reward: 0.01, No-op: 0.49, Fire: 0.43, Right: 0.08, Left: 0.0, Curiousity: 0.14
Loss: 93.88, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.02, Total Reward: 0.01, Max Full Reward: 0.44, Min Full Reward: 0.0

34000 / 50000000
Reward: 0.0, No-op: 0.69, Fire: 0.26, Right: 0.04, Left: 0.0, Curiousity: 0.08
Loss: 93.55, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.02, Total Reward: 0.01, Max Full Reward: 0.43, Min Full Reward: 0.0

35000 / 50000000
Reward: 0.0, No-op: 0.85, Fire: 0.13, Right: 0.02, Left: 0.0, Curiousity: 0.04
Loss: 88.46, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.02, Tot

64000 / 50000000
Reward: 0.0, No-op: 1.0, Fire: 0.0, Right: 0.0, Left: 0.0, Curiousity: 0.26
Loss: 51.13, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.02, Total Reward: 0.01, Max Full Reward: 0.4, Min Full Reward: 0.0

65000 / 50000000
Reward: 0.01, No-op: 0.58, Fire: 0.31, Right: 0.11, Left: 0.0, Curiousity: 0.11
Loss: 38.72, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.04, Total Reward: 0.01, Max Full Reward: 0.37, Min Full Reward: 0.0

66000 / 50000000
Reward: 0.01, No-op: 0.23, Fire: 0.6, Right: 0.17, Left: 0.0, Curiousity: 0.17
Loss: 36.68, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.04, Total Reward: 0.01, Max Full Reward: 0.3, Min Full Reward: 0.0

67000 / 50000000
Reward: 0.01, No-op: 0.33, Fire: 0.52, Right: 0.16, Left: 0.0, Curiousity: 0.16
Loss: 36.68, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.03, Total Re

96000 / 50000000
Reward: 0.01, No-op: 0.6, Fire: 0.32, Right: 0.08, Left: 0.0, Curiousity: 0.09
Loss: 23.1, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.05, Total Reward: 0.01, Max Full Reward: 0.32, Min Full Reward: 0.0

97000 / 50000000
Reward: 0.0, No-op: 0.87, Fire: 0.1, Right: 0.03, Left: 0.0, Curiousity: 0.03
Loss: 19.64, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.04, Total Reward: 0.01, Max Full Reward: 0.29, Min Full Reward: 0.0

98000 / 50000000
Reward: 0.0, No-op: 0.85, Fire: 0.12, Right: 0.03, Left: 0.0, Curiousity: 0.03
Loss: 19.33, Normal Rewards: 0.0, Curiousity Rewards: 0.0, Min Curiousity: 0.0, Max Curiousity: 0.05, Total Reward: 0.0, Max Full Reward: 0.29, Min Full Reward: 0.0

99000 / 50000000
Reward: 0.0, No-op: 0.87, Fire: 0.11, Right: 0.02, Left: 0.0, Curiousity: 0.03
Loss: 17.64, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.05, Total Rewa

127000 / 50000000
Reward: 0.0, No-op: 0.84, Fire: 0.12, Right: 0.03, Left: 0.0, Curiousity: 0.04
Loss: 28.27, Normal Rewards: 0.0, Curiousity Rewards: 0.0, Min Curiousity: 0.0, Max Curiousity: 0.05, Total Reward: 0.0, Max Full Reward: 0.24, Min Full Reward: 0.0

128000 / 50000000
Reward: 0.0, No-op: 0.91, Fire: 0.07, Right: 0.02, Left: 0.0, Curiousity: 0.02
Loss: 36.34, Normal Rewards: 0.0, Curiousity Rewards: 0.0, Min Curiousity: 0.0, Max Curiousity: 0.04, Total Reward: 0.0, Max Full Reward: 0.22, Min Full Reward: 0.0

129000 / 50000000
Reward: 0.01, No-op: 0.52, Fire: 0.37, Right: 0.11, Left: 0.0, Curiousity: 0.11
Loss: 62.1, Normal Rewards: 0.0, Curiousity Rewards: 0.0, Min Curiousity: 0.0, Max Curiousity: 0.06, Total Reward: 0.0, Max Full Reward: 0.29, Min Full Reward: 0.0

130000 / 50000000
Reward: 0.01, No-op: 0.34, Fire: 0.53, Right: 0.13, Left: 0.0, Curiousity: 0.13
Loss: 37.07, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.06, Total Rew

159000 / 50000000
Reward: 0.0, No-op: 0.86, Fire: 0.11, Right: 0.02, Left: 0.0, Curiousity: 0.02
Loss: 46.91, Normal Rewards: 0.0, Curiousity Rewards: 0.0, Min Curiousity: 0.0, Max Curiousity: 0.05, Total Reward: 0.0, Max Full Reward: 0.26, Min Full Reward: 0.0

160000 / 50000000
Reward: 0.0, No-op: 0.88, Fire: 0.1, Right: 0.02, Left: 0.0, Curiousity: 0.02
Loss: 33.61, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.1, Total Reward: 0.01, Max Full Reward: 0.33, Min Full Reward: 0.0

161000 / 50000000
Reward: 0.01, No-op: 0.53, Fire: 0.37, Right: 0.1, Left: 0.0, Curiousity: 0.08
Loss: 51.35, Normal Rewards: 0.01, Curiousity Rewards: 0.01, Min Curiousity: 0.0, Max Curiousity: 0.1, Total Reward: 0.01, Max Full Reward: 0.37, Min Full Reward: 0.0

162000 / 50000000
Reward: 0.01, No-op: 0.46, Fire: 0.43, Right: 0.11, Left: 0.0, Curiousity: 0.1
Loss: 89.75, Normal Rewards: 0.0, Curiousity Rewards: 0.0, Min Curiousity: 0.0, Max Curiousity: 0.06, Total Rew

KeyboardInterrupt: 

In [43]:
render_env = gym.make("ALE/Breakout-v5", render_mode='human')
max_actions = 1_000
model.eval()
observation = render_env.reset()
steps = 0
while not done  and steps < max_actions:
    model.eval()
    with torch.no_grad():
        action = model.act(observation)
    if steps == 0:
        action = torch.tensor([1])
    #dist = curiousity_module.update(observation)
    observation, reward, done, info = render_env.step(action.item())
    #print(i, action.item(), done)
    steps += 1