# Setup

## Packages

In [0]:
# Utils
import numpy as np
from pdb import set_trace
from PIL import Image


# NN
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import gym


## Constants

In [0]:
BETA = .2
LAMBDA = .1
LR = 1e-1

NUM_EPOCH = 10
NUM_STEP = 100

# Models

In [0]:
class ConvBlock(nn.Module):
    """ 4 Conv2d + LeakyReLU """
    def __init__(self, ch_in=1):
        super(ConvBlock, self).__init__()
        
        # constants
        self.num_filter = 32
        self.size = 3
        self.stride = 2
        self.pad = self.size//2 

        # layers
        self.conv1 = nn.Conv2d(ch_in, self.num_filter, self.size, self.stride, self.pad)
        self.conv2 = nn.Conv2d(self.num_filter, self.num_filter, self.size, self.stride, self.pad)
        self.conv3 = nn.Conv2d(self.num_filter, self.num_filter, self.size, self.stride, self.pad)
        self.conv4 = nn.Conv2d(self.num_filter, self.num_filter, self.size, self.stride, self.pad)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x))
        x = F.leaky_relu(self.conv2(x))
        x = F.leaky_relu(self.conv3(x))
        x = F.leaky_relu(self.conv4(x))

        return torch.flatten(x)


class FeatureEncoderNet(nn.Module):
    """ Network for feature encoding

        In: [s_t]
            Current state (i.e. pixels) -> 1 channel image is needed

        Out: phi(s_t)
            Current state transformed into feature space

    """
    def __init__(self, in_size, is_a3c=True):
        super(FeatureEncoderNet, self).__init__()
        # constants
        self.in_size = in_size
        self.h1 = 256
        self.num_layers = 1
        self.num_directions = 1
        self.is_a3c = True # indicates whether the LSTM is needed

        # layers
        self.conv = ConvBlock()
        if self.is_a3c:
          self.lstm = nn.LSTM(input_size=self.in_size, hidden_size=self.h1, batch_first=True)

    def forward(self, x):
        #set_trace()
        
        if self.is_a3c:
          h_t1 = c_t1 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.h1).cuda() if torch.cuda.is_available() else torch.zeros(self.num_layers*self.num_directions,x.size(0),self.h1)

        x = self.conv(x)
        
        if self.is_a3c:
          x = x.view(1,1,-1)
          h_t1, c_t1 = self.lstm(x, (h_t1, c_t1)) # h_t1 is the output

          return h_t1[:, -1, :]#.reshape(-1)
        
        else:
          return x


class InverseNet(nn.Module):
    """ Network for the inverse dynamics

        In: torch.cat((phi(s_t), phi(s_{t+1}), 1)
            Current and next states transformed into the feature space, 
            denoted by phi().

        Out: \hat{a}_t
            Predicted action

    """
    def __init__(self, num_actions):
        super(InverseNet, self).__init__()

        # constants
        #self.conv_out = 288
        self.feat_size = 256
        self.fc_hidden = 256
        self.num_actions = num_actions

        # layers
        #self.conv = ConvBlock()
        self.fc1 = nn.Linear(self.feat_size*2, self.fc_hidden)
        self.fc2 = nn.Linear(self.fc_hidden, self.num_actions)

    def forward(self, x):
        return self.fc2(self.fc1(x))


class ForwardNet(nn.Module):
    """ Network for the forward dynamics

    In: torch.cat((phi(s_t), a_t), 1)
        Current state transformed into the feature space, 
        denoted by phi() and current action

    Out: \hat{phi(s_{t+1})}
        Predicted next state (in feature space)

    """
    def __init__(self, in_size):

        super(ForwardNet, self).__init__()

        # constants
        self.in_size = in_size
        self.fc_hidden = 256
        self.out_size = 256

        # layers
        self.fc1 = nn.Linear(self.in_size, self.fc_hidden)
        self.fc2 = nn.Linear(self.fc_hidden, self.out_size)

    def forward(self, x):
        #set_trace()
        return self.fc2(self.fc1(x))


class AdversarialHead(nn.Module):
    def __init__(self, feat_size, num_actions):
        super(AdversarialHead, self).__init__()

        # constants
        self.feat_size = feat_size
        self.num_actions = num_actions

        # networks
        self.fwd_net = ForwardNet(self.feat_size + self.num_actions)
        self.inv_net = InverseNet(num_actions)

    def forward(self, phi_t, phi_t1, a_t):
        """
            phi_t: current encoded state
            phi_t1: next encoded state

            a_t: current action
        """

        # forward dynamics
        # predict next encoded state
        #set_trace()
        fwd_in = torch.cat((phi_t, a_t), 1) # concatenate next to each other
        phi_t1_hat =  self.fwd_net(fwd_in)

        # inverse dynamics
        # predict the action between s_t and s_t1
        inv_in = torch.cat((phi_t, phi_t1), 1)
        a_t_hat = self.inv_net(inv_in)


        return phi_t1_hat, a_t_hat


class ICMNet(nn.Module):
    def __init__(self, num_actions, in_size=288, feat_size=256):
        super(ICMNet, self).__init__()

        # constants
        self.in_size = in_size # pixels i.e. state
        self.feat_size = feat_size
        self.num_actions = num_actions

        # networks
        self.feat_enc_net = FeatureEncoderNet(self.in_size, is_a3c=False)
        self.pred_net = AdversarialHead(self.feat_size, self.num_actions)     # goal: minimize prediction error 
        self.policy_net = AdversarialHead(self.feat_size, self.num_actions)   # goal: maximize prediction error 
                                                                            # (i.e. predict states which can contain new information)

    def forward(self, s_t, s_t1, a_t):
        """
            s_t : current state
            s_t1: next state

            phi_t: current encoded state
            phi_t1: next encoded state

            a_t: current action
        """

        # encode the states
        phi_t = self.feat_enc_net(s_t)
        phi_t1 = self.feat_enc_net(s_t1)

        # HERE COMES THE NEW THING (currently commented out)
        phi_t1_pred, a_t_pred = self.pred_net(phi_t, phi_t1, a_t)
        #phi_t1_policy, a_t_policy = self.policy_net_net(phi_t, phi_t1, a_t)


        return phi_t1, phi_t1_pred, a_t_pred#(phi_t1_pred, a_t_pred), (phi_t1_policy, a_t_policy)


class A3CNet(nn.Module):
    def __init__(self, num_actions, in_size=288):
        super(A3CNet, self).__init__()

        # constants
        self.in_size = in_size
        self.num_actions = num_actions

        # networks
        self.feat_enc_net = FeatureEncoderNet(self.in_size)
        self.actor = nn.Linear(self.feat_enc_net.h1, self.num_actions) # estimates what to do
        self.critic = nn.Linear(self.feat_enc_net.h1, 1) # estimates how good the value function (how good the current state is)

    def forward(self, s_t):
        """
            s_t : current state
           
            phi_t: current encoded state
        """
        phi_t = self.feat_enc_net(s_t)

        policy = self.actor(phi_t)
        value = self.critic(phi_t)

        return policy, value


        




# Agent

In [0]:
class ICMAgent(nn.Module):
    def __init__(self, num_actions, in_size=288):
        super().__init__()

        # constants
        self.in_size = in_size
        self.num_actions = num_actions
        self.is_cuda = torch.cuda.is_available()
        
        self.cum_r = 0

        # networks
        self.icm = ICMNet(self.num_actions, self.in_size)
        self.a3c = A3CNet(self.num_actions, self.in_size)

        if self.is_cuda:
            self.icm.cuda()
            self.a3c.cuda()

        # optimizer
        self.optimizer = optim.Adam( list(self.icm.parameters()) + list(self.a3c.parameters()) )

    def get_action(self, s_t):
        #s_t = torch.Tensor(s_t).float()  # copy state to device as float
        #s_t = s_t.float()
        #s_t = self.pix2tensor(s_t)
        policy, value = self.a3c(s_t) # use A3C to get policy and value
        action_prob = F.softmax(policy, dim=-1).data.cpu().numpy()
        #action_prob = action_prob[0,:,:] #remove first dimension
        a_t = self.sel_rnd_idx(action_prob) # detach for action?

        return a_t, value.data.cpu().numpy().squeeze(), policy.detach()

    @staticmethod
    def sel_rnd_idx(p, axis=1):
        r = np.expand_dims(np.random.rand(p.shape[1 - axis]), axis=axis) # insert a new dim with a random value
        return (p.cumsum(axis=axis) > r).argmax(axis=axis)

    def cumulate_reward(self, r):
        self.cum_r += r
        return self.cum_r

    # functions
    def pix2tensor(self, pix):
        im2tensor = transforms.Compose([transforms.ToPILImage(),
                                        transforms.Grayscale(1),
                                        transforms.Resize((42,42)),
                                        transforms.ToTensor()])

        return torch.unsqueeze(im2tensor(pix),0).cuda()

    def train(self, env_name, num_epoch, num_steps):
        """
            s_t : current state
            s_t1: next state

            phi_t: current encoded state
            phi_t1: next encoded state

            a_t: current action
        """
        pass

        env = gym.make(env_name)

        """for i in epoch
        
        calculate reduced extrinsic + (phi_t1_hat-phi_t1)^2 + categorical(a_t, a_t_hat)
        maintain running statistics all of them

        sample action space
        
        """

        for epoch in range(num_epoch):
            s_t  = env.reset()
            s_t = self.pix2tensor(s_t)

            for step in range(num_steps):
                from pdb import set_trace
                #set_trace()
                
                a_t, policy, value = self.get_action(s_t) # select action from the policy

                # interact with the environment
                s_t1, r, done, info = env.step(a_t)
                r_cum = self.cumulate_reward(r)
                s_t1 = self.pix2tensor(s_t1)

                # call the ICM model
                a_t = torch.FloatTensor(a_t)
                a_t_1_hot = torch.zeros(1,self.num_actions).scatter_(1, a_t.long().view(-1,1),1)
                if self.is_cuda:
                    a_t = a_t.cuda()
                    a_t_1_hot = a_t_1_hot.cuda()
                
                phi_t1, phi_t1_pred, a_t_pred = self.icm(s_t, s_t1, a_t_1_hot)


                # calculate losses
                loss_int = F.mse_loss(phi_t1_pred, phi_t1)
                loss_inv = F.cross_entropy(a_t_pred, a_t.long())

                self.optimizer.zero_grad()
                # compose losses
                loss = BETA*loss_int + (1-BETA)*loss_inv - LAMBDA*r_cum

                print("Epoch: {}, step: {}, loss {}".format(epoch, step, loss) )

                loss.backward()
                self.optimizer.step()



                s_t = s_t1 # the current next state will be the new current state 


# Train

In [0]:


# objects
env = gym.make('MsPacman-v0')
#env = gym.make('MontezumaRevenge-v0')
agent = ICMAgent(env.action_space.n)

agent.cuda()
agent.train('MontezumaRevenge-v0', NUM_EPOCH, NUM_STEP)
