In [0]:
import random
from collections import deque

import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.optim import Adam

import torch.nn.functional as F

import copy

BATCH_SIZE = 128
BETA =  150


class RNDModel(nn.Module):
    def __init__(self, input_size):
        super().__init__()

        self.target = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512)
        )

        self.predictor = nn.Sequential(
            nn.Linear(input_size, 512)
        )

        for p in self.modules():
            if isinstance(p, nn.Conv2d):
                init.orthogonal_(p.weight, np.sqrt(2))
                p.bias.data.zero_()

            if isinstance(p, nn.Linear):
                init.orthogonal_(p.weight, np.sqrt(2))
                p.bias.data.zero_()

        for param in self.target.parameters():
            param.requires_grad = False

    def forward(self, next_obs):
        target_feature = self.target(next_obs)
        predict_feature = self.predictor(next_obs)

        return predict_feature, target_feature


class Exploration:
    def __init__(self, state_dim, action_dim):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.distillery = RNDModel(state_dim)
        self.lr = 0.0001
        self.opt = Adam(self.distillery.parameters(), lr=self.lr)

    def get_exploration_reward(self, states, actions, next_states):
        pred, target = self.distillery(next_states)
        return ((pred - target) ** 2).sum(1).view(-1, 1)

    def update(self, state):
        pred, target = self.distillery(state)
        loss = ((pred - target) ** 2).sum()
        self.opt.zero_grad()
        loss.backward()
     #   wandb.log({"explo": loss})
        self.opt.step()


class Agent:
    def __init__(self, state_dim, action_dim):
        self.state_dim = state_dim  # dimensionalite of state space
        self.action_dim = action_dim  # count of available actions
        self.exploration = Exploration(state_dim, action_dim)
        self.memory = deque(maxlen=10000)
        self.path = []
        self.model = self.build_model()
        self.target_model = copy.deepcopy(self.model)
        self.lr = 0.001
        self.eps = 0.9
        self.gamma = 0.9
        self.opt = Adam(self.model.parameters(), lr=self.lr, weight_decay=0.00001)
        self.go = False
        self.pos = 0
        self.pos_best = 0
        self.step = 0
        self.cur = 0
        self.prev_ac = 0
        self.best_path = []
        self.tmp_path = []
        self.best_surprise = 0

    def build_model(self):
        model = nn.Sequential(nn.Linear(self.state_dim, 64),
                              nn.ReLU(),
                              nn.Linear(64, 64),
                              nn.ReLU(),
                              nn.Linear(64, self.action_dim))
        return model

    def act_model(self, state):
            num = np.random.rand()
            if (num < self.eps):
                if np.random.rand() < 0:
                    return self.prev_ac
                else:
                    return np.random.randint(0, self.action_dim)
            else:
                return torch.argmax(self.model(torch.from_numpy(state).float()).detach()).numpy()
    
    def act(self, state):
        if self.go:
            if self.pos >= len(self.path):
                self.go = False
                action = self.act_model(state)
            else:
                action = self.path[self.pos]
                self.pos += 1
        else:
            if self.pos_best >= len(self.best_path):
                action = self.act_model(state)
            else:
                action = self.best_path[self.pos_best]
                self.pos_best += 1
        return action

    def update(self, transition):
        self.step += 1
        self.eps *= 0.99
        self.best_surprise *= 0.999
        state, action, next_state, reward, done = transition
        self.prev_ac = action
        if not self.go:
            self.path.append(action)
            if done:
              #  print(len(self.tmp_path))
                self.best_path = copy.deepcopy(self.tmp_path)
                self.pos = 0
                self.best_surprise = 0
                self.pos_best = 0
                if reward <= 0:
                  self.path.clear()
            if done and reward > 0:
                self.go = True
        else:
            if done:
            #    print(self.tmp_path)
                self.best_path = copy.deepcopy(self.tmp_path)
                self.pos = 0
                self.best_surprise = 0
                self.pos_best = 0
        ext = BETA * self.exploration.get_exploration_reward(state, action, torch.from_numpy(next_state).float().unsqueeze(0)).detach().clamp(max=1.0).item()
        self.exploration.update(torch.from_numpy(next_state).float().unsqueeze(0))
        if self.best_surprise < ext and not done and len(self.path) < 120:
          self.best_surprise = ext
          self.tmp_path = copy.deepcopy(self.path)
        reward += self.gamma * ext - self.cur
        self.cur = int(done) * ext
        self.memory.append((state, action, next_state, reward, done))
        if len(self.memory) < BATCH_SIZE:
            return
        state, action, next_state, reward, done = zip(*random.sample(self.memory, BATCH_SIZE))
        state = torch.tensor(state, dtype=torch.float)
        action = torch.tensor(np.array(action))
     #   wandb.log({"action": action.float().mean()})
        dones = torch.tensor(done).float()
        next_state = torch.tensor(next_state, dtype=torch.float)
        reward = torch.tensor(reward, dtype=torch.float)

        targets = reward + self.gamma * self.target_model(next_state).detach().max(1)[0] * (1 - dones)
        loss =  F.smooth_l1_loss(self.model(state).gather(1, action.view(-1, 1).long()).squeeze(), targets.squeeze())
        #loss = loss.mean()
        self.opt.zero_grad()
        loss.backward()
       # wandb.log({"dqn": loss})
        self.opt.step()
        if self.step % 400 == 0:
          self.target_model = copy.deepcopy(self.model)

    def reset(self):
        pass


In [0]:
from gym import make

env = make("MountainCar-v0")
algo = Agent(state_dim=2, action_dim=3)
episodes = 150
visit_count = 0

#wandb.init()

for i in range(episodes):
    state = env.reset()
    steps = 0
    done = False
    while not done:
        action = algo.act(state)
        next_state, reward, done, _ = env.step(action)
        next_state = next_state
        steps += 1
        algo.update((state, action, next_state, reward, done))
        state = next_state
    print(i)
    if steps < 200:
        visit_count += 1
        print("Visited target state at episode", i)
print()
print("Total visit count:", visit_count)

In [0]:
!pip install wandb

In [0]:
import wandb
wandb.login()