In [1]:
import retro
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import random
from collections import deque, namedtuple

BUFFER_SIZE = 100_000  # refers to the number of states, actions, rewards, and resultant state stored each
BATCH_SIZE = 64        # refers to the sample size randomly picked out from the memory buffer 
GAMMA = 0.99           # discount rate
TAU = 1e-3             # soft update rate 
LR = 5e-4               
UPDATE_EVERY = 4        

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
env = retro.make('ContraForce-Nes', use_restricted_actions=retro.Actions.DISCRETE)

In [3]:
class SubModule1(nn.Module):
    def __init__(self):
        super(SubModule1, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 4, 2, 0)
        self.conv2 = nn.Conv2d(32, 64, 4, 2, 1)
        self.conv3 = nn.Conv2d(64, 128, 4, 2, 1)
        self.conv4 = nn.Conv2d(128, 32, 4, 2, 1)
        self.conv5 = nn.Conv2d(32, 1, 4, 2, 1)
    
    def forward(self, state):
        x = F.relu(self.conv1(state))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))

        return x.view(-1, 42)

class SubModule2(nn.Module):
    def __init__(self):
        super(SubModule2, self).__init__()
        self.lin1_1 = nn.Linear(42, 64)
        self.lin1_2 = nn.Linear(64, 64)

        self.lin2_1 = nn.Linear(42, 64)
        self.lin2_2 = nn.Linear(64, 64)

        self.lin3 = nn.Linear(128, 64)
        self.lin4 = nn.Linear(64, 32)
        self.lin5 = nn.Linear(32, 36)

    def forward(self, s1, s2):
        s1 = F.relu(self.lin1_1(s1))
        s1 = F.relu(self.lin1_2(s1))

        s2 = F.relu(self.lin2_1(s2))
        s2 = F.relu(self.lin2_2(s2))

        fv = torch.concat([s1, s2], axis=1)

        fv = F.relu(self.lin3(fv))
        fv = F.relu(self.lin4(fv))
        return F.softmax(self.lin5(fv), dim=1)
    
class ForwardDynamics(nn.Module):
    def __init__(self):
        super(ForwardDynamics, self).__init__()
        self.lin1_1 = nn.Linear(42, 64)

        self.lin1_2 = nn.Linear(36, 64)

        self.lin3 = nn.Linear(128, 256)
        self.lin4 = nn.Linear(256, 512)
        self.lin5 = nn.Linear(512, 42)
    
    def forward(self, feature_vector, action):
        fv = F.relu(self.lin1_1(feature_vector))
        a = F.relu(self.lin1_2(action))

        x = torch.concat([fv, a], axis=1)

        x = F.relu(self.lin3(x))
        x = F.relu(self.lin4(x))
        return self.lin5(x)

class ICM(nn.Module):
    def __init__(self):
        super(ICM, self).__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.m1 = SubModule1().to(self.device)
        self.m2 = SubModule2().to(self.device)
        self.fd = ForwardDynamics().to(self.device)

        self.m1_opt = optim.Adam(self.m1.parameters(), lr=1e-3)
        self.m2_opt = optim.Adam(self.m2.parameters(), lr=1e-3)
        self.fd_opt = optim.Adam(self.fd.parameters(), lr=1e-3)

    def feature_learn(self, state, action, state_next):
        self.m1_opt.zero_grad()        
        self.m2_opt.zero_grad()        

        feature_vector1 = self.m1(state)
        feature_vector2 = self.m1(state_next)

        pred_action = self.m2(feature_vector1, feature_vector2)

        loss = F.mse_loss(pred_action, action)
        loss.backward()

        self.m1_opt.step()
        self.m2_opt.step()
    
    def forward_learn(self, state, action, state_next, eta):
        self.fd_opt.zero_grad()

        with torch.no_grad():
            fv = self.m1(state)
            fv2 = self.m1(state_next)

        pred_fv = self.fd(fv, action)
        forward_loss = F.mse_loss(pred_fv, fv2) / 2 * eta
        forward_loss.backward()

        self.fd_opt.step()

        return forward_loss.detach()

    def state_to_feature(self, state):
        self.m1.eval()
        with torch.no_grad():
            fv = self.m1(state)
        self.m1.train()

        return fv
    

In [4]:
class QNetwork(nn.Module):
    def __init__(self, seed):
        super(QNetwork, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.fc1 = nn.Linear(42, 64)
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, 256)
        self.fc4 = nn.Linear(256, 36)
        
    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return self.fc4(x)
        

class Agent():
    def __init__(self, state_size, action_size, seed):
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)

        self.icm = ICM().to(device)
        self.eta = 3e3

        self.qnetwork_local = QNetwork(seed).to(device)
        self.qnetwork_target = QNetwork(seed).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        self.memory = MemBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        self.t_step = 0
    
    def step(self, state, action, reward, next_state, done):
        self.memory.add(state, action, reward, next_state, done)
        
        self.t_step = (self.t_step + 1) % UPDATE_EVERY
        if self.t_step == 0:
            if len(self.memory) > BATCH_SIZE:
                experiences = self.memory.sample()
                states, actions, rewards, next_states, dones = experiences
                
                self.learn(states, actions, rewards, next_states, dones, GAMMA)

                action = F.one_hot(torch.tensor([action]), num_classes=36).to(device, dtype=torch.float)
                state = torch.from_numpy(state).float().unsqueeze(0).to(device)
                next_state = torch.from_numpy(next_state).float().unsqueeze(0).to(device)
                self.icm.feature_learn(state, action, next_state)
                self.icm.forward_learn(state, action, next_state, self.eta)

    def act(self, state, eps=0.5):
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            state = self.icm.state_to_feature(state)
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy()), action_values.detach()
        else:
            random_int = random.choice(np.arange(36))
            random_array = F.one_hot(torch.tensor(random_int), num_classes=36).to(device, dtype=torch.float)
            return random_int, random_array
    
    def external_reward(self, state, action, next_state):
        action = action.view(-1, 36)
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        next_state = torch.from_numpy(next_state).float().unsqueeze(0).to(device)
        self.icm.feature_learn(state, action, next_state)
        ext_reward = self.icm.forward_learn(state, action, next_state, self.eta)
        
        return ext_reward.detach()

    def learn(self, states, actions, rewards, next_states, dones, gamma):
        next_states = self.icm.state_to_feature(next_states)
        states = self.icm.state_to_feature(states)
        q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
        q_targets = rewards + gamma * q_targets_next * (1 - dones)
        q_expected = self.qnetwork_local(states).gather(1, actions)
        
        loss = F.mse_loss(q_expected, q_targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)                     

    def soft_update(self, local_model, target_model, tau):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)



class MemBuffer():
    def __init__(self, action_size, buffer_size, batch_size, seed):
        self.action_size = action_size
        self.memory = deque(maxlen=buffer_size)  
        self.batch_size = batch_size
        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
        self.seed = random.seed(seed)
    
    def add(self, state, action, reward, next_state, done):
        state = np.expand_dims(state, axis=0)
        e = self.experience(state, action, reward, next_state, done)
        self.memory.append(e)
    
    def sample(self):
        experiences = random.sample(self.memory, k=self.batch_size)

        states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
        next_states = torch.from_numpy(np.array([e.next_state for e in experiences if e is not None])).float().to(device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)

        print(next_states.shape)
        print(states.shape)
        print(actions.shape)
  
        return (states, actions, rewards, next_states, dones)

    def __len__(self):
        return len(self.memory)

In [5]:
def dqn(n_episodes=2000, max_t=1000, eps_start=1.0, eps_end=0.05, eps_decay=0.7):
    scores = []                      
    rolling_mean = []
    scores_window = deque(maxlen=100) 
    eps = eps_start                    
    for i_episode in range(1, n_episodes+1):
        state = env.reset()
        state = np.moveaxis(state, -1, 0)
        score = 0
        for t in range(max_t):
            action, action_values = agent.act(state, eps)
            next_state, reward, done, _ = env.step(action)
            env.render(mode="human")
            next_state = np.moveaxis(next_state, -1, 0)
            ext_reward = agent.external_reward(state, action_values, next_state)
            reward += ext_reward.cpu()
            agent.step(state, action, reward, next_state, done)
            state = next_state
            score += reward
            eps = max(eps_end, eps_decay*eps) 
            if done:
                break 
        scores_window.append(score)      
        scores.append(score)             
        rolling_mean.append(np.mean(scores_window))

        # torch.save(agent.qnetwork_local.state_dict(), f'./saves/DQN_ICM_2/episode{i_episode}.pt')

        plt.figure(figsize=(15,8))
        plt.plot(scores)
        plt.plot(rolling_mean)
        plt.title('scores')
        plt.grid()
        plt.savefig('test.jpg')
        plt.close()

        print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end="")
        if i_episode % 100 == 0:
            print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)))
        # if np.mean(scores_window)>=200.0:
        #     print('\nEnvironment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(i_episode-100, np.mean(scores_window)))
        #     torch.save(agent.qnetwork_local.state_dict(), 'model.pth')
        #     break
    return scores

agent = Agent(state_size=8, action_size=4, seed=0)
scores = dqn()

torch.Size([64, 3, 224, 240])
torch.Size([64, 3, 224, 240])
torch.Size([64, 1])
torch.Size([64, 3, 224, 240])
torch.Size([64, 3, 224, 240])
torch.Size([64, 1])
torch.Size([64, 3, 224, 240])
torch.Size([64, 3, 224, 240])
torch.Size([64, 1])
torch.Size([64, 3, 224, 240])
torch.Size([64, 3, 224, 240])
torch.Size([64, 1])
torch.Size([64, 3, 224, 240])
torch.Size([64, 3, 224, 240])
torch.Size([64, 1])
torch.Size([64, 3, 224, 240])
torch.Size([64, 3, 224, 240])
torch.Size([64, 1])
torch.Size([64, 3, 224, 240])
torch.Size([64, 3, 224, 240])
torch.Size([64, 1])
torch.Size([64, 3, 224, 240])
torch.Size([64, 3, 224, 240])
torch.Size([64, 1])
torch.Size([64, 3, 224, 240])
torch.Size([64, 3, 224, 240])
torch.Size([64, 1])
torch.Size([64, 3, 224, 240])
torch.Size([64, 3, 224, 240])
torch.Size([64, 1])
torch.Size([64, 3, 224, 240])
torch.Size([64, 3, 224, 240])
torch.Size([64, 1])
torch.Size([64, 3, 224, 240])
torch.Size([64, 3, 224, 240])
torch.Size([64, 1])
torch.Size([64, 3, 224, 240])
torch.Size

In [None]:
dqn = QNetwork(22).to(device)
m1 = SubModule1().to(device)



obs = env.reset()
for _ in range(10):
    with torch.no_grad():
        state = torch.from_numpy(obs).float().unsqueeze(0).to(device)
        state = state.view(-1, 3, 224, 240)

        state = m1(state)
        action = torch.argmax(dqn(state))
    obs, rew, done, info = env.step(action)
    env.render()
    if done:
        obs = env.reset()
env.close()