In [210]:
import gym
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as f
import torch.optim as optim


env = gym.make('MountainCar-v0')

In [138]:
action_space = env.action_space.n
obs_space = env.observation_space.shape[0]
weights = np.random.randn(obs_space, action_space)

In [305]:
ALPHA = 0.1
GAMMA = 0.9
EPSILON = 0.5


def get_polynomial_encoding(state):
    s1, s2 = state
    return np.array([1, s1, s2, s1*s2])


class LinearFunctionApproximator(nn.Module):
    def __init__(self, state_size, action_size):
        super(LinearFunctionApproximator, self).__init__()
        init_tensor = nn.init.normal_(torch.empty(state_size, action_size))
        self.params = nn.parameter.Parameter(init_tensor, requires_grad=True)
        self.loss_func = nn.MSELoss()
        self.optimizer = optim.SGD(self.parameters(), lr=ALPHA)
        
    def forward(self, state):
        return torch.matmul(self.params.T, state)
    
    def compute_loss(self, inp, target):
        self.optimizer.zero_grad()
        loss = self.loss_func(inp, target)
        loss.backward()
        self.optimizer.step()
        return loss.detach()


class SARSA:
    def __init__(self, env, epsilon, approximator, test=False):
        self.env = env
        self.episodes = []
        self.rewards = []
        self.test = test
        self.epsilon = epsilon
        self.current_eps = self.epsilon
        self.approximator = approximator
        self.action_space = []
        
    def get_greedy_action(self, state):
        with torch.no_grad():
            return self.approximator(state).argmax().item()
        
    def get_epsilon_greedy_action(self, state):
        sample = np.random.uniform()
        action = None
        if sample < self.current_eps:
            action = self.get_greedy_action(state)
        else:
            action = env.action_space.sample()
        self.action_space.append(action)
        return action
        
    def decay_epsilon(self, ep, n_ep):
        self.current_eps = self.epsilon * 1/(1+np.exp(16*ep/n_ep - 9))
        
    def test_agent(self, n_episodes):
        for ep in n_episodes:
            done = False
            state = self.env.reset()
            while not done:
                action = self.get_greedy_action(state)
                state, reward, done, _ = self.env.step(action)
            env.close()
    
    def optimize_agent(self, n_episodes):
        for ep in range(n_episodes):
            state = torch.tensor(env.reset(), dtype=torch.float)
            done = False
            ep_reward = 0
            while not done:
                action = self.get_epsilon_greedy_action(state)
                state_action = self.approximator(state)
                
                next_state, reward, done, _ = env.step(action)
                next_state = torch.tensor(next_state, dtype=torch.float)
                
                next_action = self.get_epsilon_greedy_action(next_state)
                with torch.no_grad():
                    next_state_action = self.approximator(next_state)
                    
                inp = state_action[action]
                target = reward + next_state_action[next_action] * GAMMA
                self.approximator.compute_loss(inp, target)
                state, action = next_state, next_action
                
                ep_reward += reward
            self.episodes.append(ep)
            self.rewards.append(ep_reward)
            self.decay_epsilon(ep, n_episodes)
                
    def play_agent(self):
        done = False
        state = torch.tensor(self.env.reset(), dtype=torch.float)
        while not done:
            action = self.get_greedy_action(state)
            state, reward, done, _ = self.env.step(action)
            state = torch.tensor(state, dtype=torch.float)
            self.env.render()
        self.env.close()

In [306]:
lin = LinearFunctionApproximator(2, 3)
sarsa = SARSA(env, EPSILON, lin)

In [307]:
sarsa.optimize_agent(1000)

In [301]:
from collections import Counter

In [310]:
Counter(sarsa.action_space)

Counter({2: 135024, 0: 108634, 1: 156342})

In [308]:
sarsa.play_agent()

In [309]:
sarsa.approximator.params

Parameter containing:
tensor([[18.6097, 18.6131, 18.5350],
        [ 9.1566, 11.1306, 10.4940]], requires_grad=True)