In [2]:
import torch
import torch.nn as nn
import random


def normc_initializer(std=1.0):
    def initializer(tensor):
        tensor.data.normal_(0, 1)
        tensor.data *= std / torch.sqrt(tensor.data.pow(2).sum(1, keepdim=True))

    return initializer

def process_obs(obs):
    processed_obs = torch.zeros(obs.shape)
    pipe_w = 1/6
    bird_w = 25/230
    bird_pos = 0.305556 - bird_w
    bird_height = 0.06
    p_pos1 = obs[0] + pipe_w 
    p_pos2 = obs[3] + pipe_w 
    p_pos3 = obs[6] + pipe_w 
    bird_bottom = obs[9] + 0.06
    a, b, c, d = torch.chunk(obs, 4)
    if bird_pos < p_pos1:
        # Heading to the first pipe
        obs = torch.cat((a, b, d), 0)
    else:
        obs = torch.cat((b, c, d), 0)
    return obs
    
def process_obs_batch(obs_batch):
    assert obs_batch.shape[1] == 12
    processed = torch.zeros((obs_batch.shape[0], 9))
    for i in range(obs_batch.shape[0]):
        processed[i] = process_obs(obs_batch[i])
    return processed

class SlimFC(nn.Module):
    """Simple PyTorch version of `linear` function"""

    def __init__(self,
                 in_size,
                 out_size,
                 initializer=None,
                 activation_fn=True,
                 use_bias=True,
                 bias_init=0.0):
        super(SlimFC, self).__init__()
        layers = []
        # Actual Conv2D layer (including correct initialization logic).
        linear = nn.Linear(in_size, out_size, bias=use_bias)
        if initializer:
            initializer(linear.weight)
        if use_bias is True:
            nn.init.constant_(linear.bias, bias_init)
        layers.append(linear)
        if activation_fn:
            activation_fn = nn.ReLU
            layers.append(activation_fn())
        # Put everything in sequence.
        self._model = nn.Sequential(*layers)

    def forward(self, x):
        return self._model(x)


def build_one_mlp(input_size, output_size, hidden_size=256):
    return nn.Sequential(
        SlimFC(
            in_size=input_size,
            out_size=hidden_size,
            initializer=normc_initializer(1.0),
            activation_fn=True
        ),
        SlimFC(
            in_size=hidden_size,
            out_size=hidden_size,
            initializer=normc_initializer(1.0),
            activation_fn=True
        ),
        SlimFC(
            in_size=hidden_size,
            out_size=output_size,
            initializer=normc_initializer(0.01),  # Make the output close to zero, in the beginning!
            activation_fn=False
        )
    )


class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.qtable = build_one_mlp(10, 1)

    def forward(self, input_obs, input_action):
        input_obs = process_obs_batch(input_obs)
        return self.qtable(torch.cat((input_obs, input_action.view(-1, 1)), 1))
    def compute_action(self, input_obs):
        # Compute a single action
        input_obs = process_obs(input_obs)
        q_zero = self.qtable(torch.cat((input_obs, torch.tensor([0])), 0))
        q_one = self.qtable(torch.cat((input_obs, torch.tensor([1])), 0))
        if q_zero > q_one:
            return 0
        else:
            return 1
    def compute_value(self, input_obs):
        # Compute a single action
        input_obs = process_obs(input_obs)
        q_zero = self.qtable(torch.cat((input_obs, torch.tensor([0])), 0))
        q_one = self.qtable(torch.cat((input_obs, torch.tensor([1])), 0))
        if q_zero > q_one:
            return q_zero
        else:
            return q_one
    def explore(self, input_obs, epsilon):
        random_num = random.randint(1, 1000)
        if random_num < 1000 * epsilon:
            return random.choice([0, 1])
        else:
            return self.compute_action(input_obs)
    
    
    
       