In [125]:
import gymnasium as gym
import ale_py
import torch
import numpy as np
import random
import copy
from collections import deque
import multiprocessing as mp
import matplotlib.pyplot as plt

### Setting up variables and testing environnement

In [126]:
RANDOM_STATE = 42
torch.seed = RANDOM_STATE
torch.manual_seed(torch.seed)
np.random.seed(torch.seed)
random.seed(torch.seed)

GAMMA = 0.99
BETA = 0.2
LAMBDA = 0.4
LR = 0.0001

NUM_FEAT_SPACE = 256

GRADIENT_CLIPPING_ICM = 0.5
GRADIENT_CLIPPING_A2C = 0.5

BATCH_SIZE = 200
MAX_MOVES = 218

ENV_NAME = 'ALE/Bowling-ram-v5'

In [127]:
env = gym.make(ENV_NAME)

state, _ = env.reset()
print(f"Initial state: \n{state}")
print(f"Observation space: \n{env.observation_space}")
print(f"Action space: {env.action_space}")

a = env.action_space.sample()
event = env.step(a)
print('Output from applying action {} on environment:\nstate:'.format(a) \
      + '{}\nreward: {}\ndone: {}\ntruncated: {}\ninfo: {}'.format(*event))

Initial state: 
[ 71 255   0   0   0   0   0   0   0   0   0   0 255   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   8  15   0   0   0 170   0
   1   0   0   0   1  12 184 247   0   0   6  16  19  13  22  16  10  25
  19  13   7   7   5   5   3   3   3   1   1   1   1   0   0   0   0   0
   0   0   0   0   0 136 216 132  38  88   0   0   1 255   0 255 128 255
   0   0   0   0   0   0   0   2   2   0   8   8   0  34  34   0 136 136
   0  34  34   0   8   8   0   2   2   0   0   0   0   0   0   0   0   0
  66 243]
Observation space: 
Box(0, 255, (128,), uint8)
Action space: Discrete(6)
Output from applying action 0 on environment:
state:[ 75 255   0   0   0   0   4   4   4   4   4   4 255   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   8  15   0   2   0 170   0
   1   0   0   0   1  12 184 247   0   0   0  16  19  13  22  16  10  25
  19  13   7   7   5   5   3   3   3   1   1   1   1   0   0   0   0   0
   0   0   0   0   0 136 216 132  38  88   0  

## **Models**

### **FeatureEncoderNet**

In [128]:
class FeatureEncoderNet(torch.nn.Module):
    """ Network for feature encoding

        In: [s_t]
            Current state (i.e. pixels) -> 1 channel image is needed

        Out: phi(s_t)
            Current state transformed into feature space

    """
    def __init__(self, env):
        super(FeatureEncoderNet, self).__init__()

        self.env = env
        # Set the environment seed
        self.env.seed = torch.seed

        # constants
        self.numStateSpace = self.env.observation_space.shape[0]
        self.numHidden1 = NUM_FEAT_SPACE

        # layers
        self.lstm = torch.nn.LSTMCell(input_size=self.numStateSpace, hidden_size=self.numHidden1)

          
    def reset_lstm(self, x):
        self.h_t1 = self.c_t1 = torch.zeros(x, self.numHidden1).cuda() if torch.cuda.is_available() else torch.zeros(x, self.numHidden1)

    def forward(self, x):
        self.h_t1, self.c_t1 = self.lstm(x, (self.h_t1, self.c_t1)) # h_t1 is the output
        return self.h_t1


## **Inverse Net**

In [129]:
class InverseNet(torch.nn.Module):
    """ Network for the inverse dynamics

        In: torch.cat((phi(s_t), phi(s_{t+1}), 1)
            Current and next states transformed into the feature space, 
            denoted by phi().

        Out: \hat{a}_t
            Predicted action

    """
    def __init__(self, env):
        super(InverseNet, self).__init__()

        self.env = env
        # Set the environment seed
        self.env.seed = torch.seed

        # constants
        self.numHidden1 = 256
        self.numFeatSpace = NUM_FEAT_SPACE
        self.numActionSpace = self.env.action_space.n

        # layers
        #self.conv = ConvBlock()
        self.layer1 = torch.nn.Linear(self.numFeatSpace * 2, self.numHidden1)
        self.layer2 = torch.nn.Linear(self.numHidden1, self.numActionSpace)

    def forward(self, x):
        return self.layer2(self.layer1(x))

## **Forward Net**

In [130]:
class ForwardNet(torch.nn.Module):
    """ Network for the forward dynamics

    In: torch.cat((phi(s_t), a_t), 1)
        Current state transformed into the feature space, 
        denoted by phi() and current action

    Out: \hat{phi(s_{t+1})}
        Predicted next state (in feature space)

    """
    def __init__(self, env):
        super(ForwardNet, self).__init__()

        self.env = env
        # Set the environment seed
        self.env.seed = torch.seed

        # constants
        self.numHidden = 256
        self.numFeatSpace = NUM_FEAT_SPACE 
        self.numActionSpace = self.env.action_space.n

        # layers
        #self.conv = ConvBlock()
        self.layer1 = torch.nn.Linear(self.numFeatSpace + self.numActionSpace, self.numHidden)
        self.layer2 = torch.nn.Linear(self.numHidden, self.numFeatSpace)

    def forward(self, x):
        return self.layer2(self.layer1(x))

## **ICM Net**

In [131]:
class ICMNet(torch.nn.Module):
    def __init__(self, env):
        super(ICMNet, self).__init__()

        self.env = env
        # Set the environment seed
        self.env.seed = torch.seed

        # networks
        self.featureEncoderNet = FeatureEncoderNet(self.env)
        self.forwardNet = ForwardNet(self.env)
        self.inverseNet = InverseNet(self.env)

    def forward(self, s_t, s_t1, a_t):
        """
            s_t : current state
            s_t1: next state

            phi_t: current encoded state
            phi_t1: next encoded state

            a_t: current action
        """

        phi_t = self.featureEncoderNet(s_t)
        phi_t1 = self.featureEncoderNet(s_t1)

        # forward dynamics
        # predict next encoded state
        forward = torch.cat((phi_t, a_t), 1) # concatenate next to each other
        phi_t1_pred =  self.forwardNet(forward)

        # inverse dynamics
        # predict the action between s_t and s_t1
        inverse = torch.cat((phi_t, phi_t1), 1)
        a_t_pred = self.inverseNet(inverse)

        return phi_t1, phi_t1_pred, a_t_pred

## **A2C Net**

In [132]:
class A2CNet(torch.nn.Module):
    def __init__(self, env):
        super(A2CNet, self).__init__()
        # Setting up the environment
        self.env = env
        # Set the environment seed
        self.env.seed = torch.seed
        # Store the size of the action and observation space
        self.numFeatSpace = NUM_FEAT_SPACE
        self.numActionSpace = self.env.action_space.n

        # Model for the feature encoder
        self.featureEncoderNet = FeatureEncoderNet(env)

        # Create a model for the actor - policy
        self.actorNet = torch.nn.Sequential(
            torch.nn.Linear(self.numFeatSpace, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, self.numActionSpace),
            torch.nn.Softmax(dim=-1)
        )

        # Create a model for the critic - value
        self.criticNet = torch.nn.Sequential(
            torch.nn.Linear(self.numFeatSpace, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 1),
        )

    def forward(self, s_t):
        """
            s_t : current state
           
            phi_t: current encoded state
        """
        phi_t = self.featureEncoderNet(s_t)

        policy = self.actorNet(phi_t)
        value = self.criticNet(phi_t)

        return policy, value

## **Final Agent**

In [133]:
class Agent(torch.nn.Module):
    def __init__(self, env, num_epoch, num_steps=MAX_MOVES, batch_size=BATCH_SIZE):
        super().__init__()

        # constants
        self.is_cuda = torch.cuda.is_available()
        self.env = env
        self.env.seed = torch.seed
        self.numActionSpace = self.env.action_space.n
        self.num_epoch = num_epoch
        self.maxMoves = num_steps
        self.batchSize = batch_size
        
        self.clear_log_lists()

        # networks
        self.icmNet = ICMNet(self.env)
        self.a2cNet = A2CNet(self.env)

        if self.is_cuda:
            self.icmNet.cuda()
            self.a2cNet.cuda()

        # loss
        self.criticLoss = torch.nn.MSELoss()

        # optimizer
        self.optimizer = torch.optim.Adam(list(self.icmNet.parameters()) + list(self.a2cNet.parameters()), lr=LR)
        
    def clear_log_lists(self):
        self.rewards = []
        self.states = []
        self.actions = []
        self.policies = []
        self.values = []  
            

      
    def play(self):
        """
            s_t : current state
            s_t1: next state

            phi_t: current encoded state
            phi_t1: next encoded state

            a_t: current action
        """
                
        # reset all logger lists
        self.clear_log_lists()
        
        self.a2cNet.featureEncoderNet.reset_lstm(1)

        # play one game
        state, _  = self.env.reset()
        self.states.append(torch.from_numpy(state))

        done = False

        self.score = 0

        for step in range(self.maxMoves):
            policy, value = self.a2cNet(torch.from_numpy(
                        state).float().unsqueeze(0)) # select action from the policy
            
            action = np.random.choice(self.numActionSpace, p=policy[0].detach().numpy())
            
            # interact with the environment
            nextState, reward, done, truncated,info = self.env.step(action)

            self.score += reward

            _, nextValue = self.a2cNet(torch.from_numpy(nextState).float().unsqueeze(0))

            self.actions.append(action)
            self.policies.append(policy)
            self.values.append(value) # nextValue - value ?
            self.states.append(torch.from_numpy(nextState))
            self.rewards.append(reward)
            self.replay.append((state, reward, nextValue.item()))
            
            if done:
              break

            state = nextState



    def normalize(self, data):
      return (data - data.mean()) / (data.std() + 10e-9)
    
      
      
    def a2c_loss(self):      
      # Policy Loss
      # Calculate the advantage
      advantages = torch.Tensor(list(self.rewards)).float() + torch.pow(GAMMA, torch.arange(
          len(self.values)).float()) * torch.Tensor(list(self.values)).float()

      # Store the state info as a batch of states
      stateBatch = torch.stack([s.float() for s in self.states])[0:-1]
      
      # Store the action info as a batch of actions    
      actionBatch = torch.Tensor(list(self.actions))
      # Feed the state batch to the actor model to calculate the probs of actions for each state in the batch
      self.a2cNet.featureEncoderNet.reset_lstm(stateBatch.shape[0])
      policy, _ = self.a2cNet(stateBatch)
      # Gets the probs of actions actually performed for each state
      probs = policy.gather(
          dim=1, index=actionBatch.long().unsqueeze(dim=1)).squeeze()

      # Policy Loss
      actorLoss = - 1 * torch.log(probs) * advantages

      if (len(self.replay) > self.batchSize):
        # Select a set of random indices to be chosen from the replay buffer 
        indices = np.random.choice(len(self.replay), size=self.batchSize)
        # Extract the experiences from the replay buffer
        replay_ = np.asarray(self.replay, dtype=object)[indices, :]  
        # Create a state batch with the excted experiences
        stateBatch = torch.stack([torch.from_numpy(s).float()
                                  for s in replay_[:, 0]])
        # Calculate the value for the extracted states      
        self.a2cNet.featureEncoderNet.reset_lstm(stateBatch.shape[0])                                  
        _, value = self.a2cNet(stateBatch) 
        # Critic Loss
        criticLoss = torch.nn.MSELoss()(value.squeeze(), torch.Tensor(list(replay_[:, 1] + GAMMA * replay_[:, 2])).float())
        
      else:
        criticLoss = torch.tensor(0)
      # return the a2c loss
      # which is the sum of the actor (policy) and critic (advantage) losses
      loss = actorLoss.sum() + criticLoss.sum()
      
      return loss
        

        
    def train(self):
        """
            s_t : current state
            s_t1: next state

            phi_t: current encoded state
            phi_t1: next encoded state

            a_t: current action
        """
        pass
    
        # Stores the info from the training process
        info = {
            'epochs': [],
            'scores': [],  
            'losses': [],
        }
        # Stores the total reward or scores per epoch
        scores = []

        # The replay buffer for Q-learning with experience replay
        self.replay = deque(maxlen=218)

        for epoch in range(self.num_epoch):
          
          # play an episode
          self.play()
          
          statesT = torch.stack([s.float() for s in self.states[0:-1]]) # last value is not needed here
          statesT1 = torch.stack([s.float() for s in self.states[1:]])
          
          actionT = torch.FloatTensor(self.actions)
          # convert the action tensor into one-hot
          actionT1_hot = torch.zeros(actionT.shape[0],self.numActionSpace).scatter_(1, actionT.long().view(-1,1),1)
          
          if self.is_cuda:
            statesT = statesT.cuda()
            statesT1 = statesT1.cuda()
            actionT = actionT.cuda()
            actionT1_hot = actionT1_hot.cuda()
          
          # reset LSTM hidden states
          #self.a2cNet.feat_enc_net.reset_lstm(s_t.shape[0])  # specify size 

          # call the ICM model         
          self.icmNet.featureEncoderNet.reset_lstm(statesT.shape[0])
          phi_t1, phi_t1_pred, a_t_pred = self.icmNet(statesT, statesT1, actionT1_hot)


          # calculate losses
          self.optimizer.zero_grad()
          
          # forward loss
          # discrepancy between the predicted and actual next states
          lossForward = torch.nn.functional.mse_loss(phi_t1_pred, phi_t1)
          
          # inverse loss
          # cross entropy between the predicted and actual actions
          lossInverse = torch.nn.functional.cross_entropy(a_t_pred, actionT.long().view(-1))
          
          # a2c loss
          # loss of the policy (how good can we choose the proper action)
          # and the advantage function (how good is the estimate of the value 
          # of the current state)
          lossA2C = self.a2c_loss()

          
          # compose losses
          loss = BETA * lossForward + (1-BETA) * lossInverse + LAMBDA * lossA2C

          loss.backward()
          torch.nn.utils.clip_grad_norm_(self.a2cNet.parameters(), GRADIENT_CLIPPING_A2C)
          torch.nn.utils.clip_grad_norm_(self.icmNet.parameters(), GRADIENT_CLIPPING_ICM)
          self.optimizer.step()
          
          # Store the total score for the episode
          scores.append(self.score)

          if epoch % np.round(self.num_epoch/10) == 0:
                print('episode: {:d}, loss: {}, score: {:.2f}'.format(epoch, loss, scores[epoch]))
          info["epochs"].append(epoch)
          info["scores"].append(scores[epoch])
          info["losses"].append(loss.item())
        return info

      

    def test(self, render=True):
          np.random.seed(42)

          self.a2cNet.featureEncoderNet.reset_lstm(1)
          
          # Reset the environment for testing
          state, _ = self.env.reset()
          # Flag to determine if an episode has ended
          done = False
          # Maximum number of moves allowed while testing
          maxMoves = self.maxMoves
          # Stores the total score received during an episode
          score = 0
          # Continues an episode until it ends or the maximim number of allowed moves has expired
          while not done and maxMoves > 0:            
              # Decrement the maximim number of allowed moves per play in an episode
              maxMoves -= 1
              # If render is true, renders the game to screen
              if render:
                  self.env.render()
              # Calculates the probs. for the actions given a state
              policy, value = self.a2cNet(torch.from_numpy(
                        state).float().unsqueeze(0)) # select action from the policy
            
              action = np.random.choice(self.numActionSpace, p=policy[0].detach().numpy())
            
              # Executes an action in the environment
              state, reward, done, truncated, _ = self.env.step(action)
              # Stores the reward
              score += reward
          # Print the rewards received         
          print('reward: {}'.format(score))



    def plot(self, info_):
        # Sort the scores by epoch
        #info_.sort(axis=1)
        # Extract the epochs and respective scores
        epochs, scores, loss = info_["epochs"], info_["scores"], info_["losses"]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

        ax1.plot(epochs, scores)
        ax1.set_title('Scores')
        ax1.set_xlabel('episode')
        ax1.set_ylabel('score')

        ax2.plot(epochs, loss)
        ax2.set_title('Losses')
        ax2.set_xlabel('episode')
        ax2.set_ylabel('loss')
        plt.show()

## **Training the agent**

In [None]:
agent = Agent(env=env, num_epoch=500)

info = agent.train()

# Plot the scores
agent.plot(info)

agent.test(render=False)

episode: 0, loss: 3.4890923500061035, score: 0.00
episode: 50, loss: 79.44393920898438, score: 3.00


KeyboardInterrupt: 

: 