In [3]:
import ale_py # necessary for gym enviornment creation
import gym
import time
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import kornia as K
import numpy as np
import einops
from einops.layers.torch import Rearrange
from supersuit import frame_stack_v1, resize_v0, clip_reward_v0



In [4]:
FRAMES_TO_STACK = 4
FRAMES_TO_SKIP = 4
PAD_SIZE = 4

# TODO Add seed
# TODO Terminal on loss of life
# TODO compare Adam with LARS optimizer

X_DIM = 84
Y_DIM = 84

rng = np.random.default_rng()

In [5]:
def preprocess(img):
    img = T.from_numpy(img) / 255
    #TODO do this after transformation 
    #https://theaisummer.com/self-supervised-representation-learning-computer-vision/
    #TODO norm to running mean of channel
    
    if len(img.shape) == 3: # if no fourth dim, batch size is missing
    #    img = img.unsqueeze(0)
        img = einops.rearrange(img, 'h w c -> c h w')
        
    else:
        img = einops.rearrange(img, 'b h w c -> b c h w')
    return img

In [6]:
def plot_state(state, big=False):
    fig = plt.figure(figsize=(36,18))
    if big:
        plt.imshow(state[0][0], cmap='gray');
        return
    
    for i in range(state.shape[1]):
        fig.add_subplot(1, FRAMES_TO_STACK, i + 1)
        plt.imshow(state[0][i], cmap='gray');

In [7]:
def create_env(name='MsPacman', render=None):
    env = gym.make('ALE/' + name + '-v5',
                    obs_type='grayscale',             # ram | rgb | grayscale
                    frameskip=FRAMES_TO_SKIP,         # frame skip
                    mode=0,                           # game mode, see Machado et al. 2018
                    difficulty=0,                     # game difficulty, see Machado et al. 2018
                    repeat_action_probability=0.25,   # Sticky action probability
                    full_action_space=True,           # Use all actions
                    render_mode=render                # None | human | rgb_array
    )

    env = clip_reward_v0(env, lower_bound=-1, upper_bound=1)
    env = resize_v0(env, X_DIM, Y_DIM, linear_interp=True)
    env = frame_stack_v1(env, FRAMES_TO_STACK)
    
    return env

In [8]:
def run_episode(env, early_break=False):
    done = False
    s = env.reset()
    i = 0
    rewards = []
    states = []

    while not done:
        action = env.action_space.sample()
        s, r, done, _ = env.step(action)

        rewards.append(r)
        s = preprocess(s)
        states.append(s)
        
        i += 1
        if early_break and i == 100:
            env.close()            
            return s
    
    env.close()
    return states

In [9]:
class ContrastiveLearner(nn.Module):
    def __init__(self, stacked_frames, hidden_dim=1024, out_dim=15):
        #TODO spectral normalization (?)
        super(ContrastiveLearner, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(stacked_frames, 32, 8, stride=4),
            nn.ELU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ELU(),
            nn.Conv2d(64, 64, 3),
            nn.ELU(),
            Rearrange('b c h w -> b (c h w)')
        )
                
        self.fc = nn.Sequential(
            nn.Linear(3136, hidden_dim), # 3136 is output dim after conv
            nn.LayerNorm(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, out_dim)    
        )   
        
    def forward(self, x):
        return self.conv(x)
    
    def project(self, x):
        x = self.forward(x)
        x = self.fc(x)
        
        return x

In [10]:
class DataAugment(nn.Module):
    def __init__(self, x_dim, y_dim, pad_size, rng, brightness_clip=0.2):
        super().__init__()

        self.rng = rng
        self.clip = brightness_clip
        
        self.random_shift = nn.Sequential(
            nn.ReplicationPad2d(pad_size),
            K.augmentation.RandomCrop(size=(x_dim, y_dim))
        )
    
    @T.no_grad()
    def random_brightness(self, x):
        brightness_change = self.rng.uniform(-self.clip, self.clip)
        x = K.enhance.adjust_brightness(x, brightness_change)
        
        return x
        
    @T.no_grad()
    def forward(self, x):
        x = self.random_brightness(x)
        x = self.random_shift(x)
        
        return x

In [11]:
def prep_states(states, cutoff=False):
    if cutoff:
        states = states[:256]
    states = T.stack(states)
    states = dm(states)
    states = cl(states)
    
    return states

In [12]:
class ParticleReward():
    #TODO test original implementation
    def __init__(self, top_k=16):
        self.mean = 0
        self.samples_done = 0
        self.c = 1
        self.top_k = top_k
        
    def calculate_reward(self, states, normalize=True):
        particle_volumes = T.norm(states.unsqueeze(1) - states.unsqueeze(0), dim=-1) # hypersphere volume
        
        
        top_k_rewards, _ = particle_volumes.topk(self.top_k, sorted=True, largest=False, dim=1)
        
        #TODO test normalization
        self.update_mean_estimate(top_k_rewards.reshape(-1, 1))
        
        if normalize:
            top_k_rewards /= self.mean
            
        top_k_rewards = top_k_rewards.mean(dim=1)
        particle_rewards = T.log(self.c + top_k_rewards)

        return particle_rewards
        
    def update_mean_estimate(self, x):
        batch_size = x.size(0)
        self.samples_done += batch_size
        difference = x.mean(dim=0) - self.mean
        self.mean += difference * batch_size / self.samples_done

In [15]:
env = create_env()
dm = DataAugment(X_DIM, Y_DIM, PAD_SIZE, rng)
cl = ContrastiveLearner(FRAMES_TO_STACK)

states = run_episode(env)
states = prep_states(states, cutoff=True)

In [None]:
reward_function = ParticleReward()

In [None]:
reward_function.calculate_reward(states).shape