In [None]:
'''
Helpful docs: 
    + Main article: https://arxiv.org/pdf/2010.13110.pdf
    + Xavier initialization: https://365datascience.com/tutorials/machine-learning-tutorials/what-is-xavier-initialization/
    + Pytorch 2.0 documentation
'''

In [2]:
import torch 
import numpy as np 
import torch.nn as nn 
import torch.nn.functional as F
from torch.nn import Parameter as P 
from torch.autograd import Variable
import math 
print(torch.cuda.is_available())

True


In [97]:
''' Model '''
# initializer 
def xavier_init(layer: nn.Linear): 
    nn.init.xavier_uniform_(layer.weight)
    nn.init.constant_(layer.bias, 0) 
    return layer
class AttentionLayer: 
    def __init__(self, feature_dim, weight_dim, device): 
        super(AttentionLayer, self).__init__() 
        self.device = device 

        self.weight_q = xavier_init(nn.Linear(feature_dim, weight_dim)).to('cuda:0')
        self.weight_k = xavier_init(nn.Linear(feature_dim, weight_dim)).to('cuda:0')
        self.weight_v = xavier_init(nn.Linear(feature_dim, weight_dim)).to('cuda:0')

    def forward(self, x): 
        # suppose x here is already flatten
        '''
        inference: 
        :param x: [num_agent, num_target, feature_dim]
        :return z: [num_agent, num_target, weight_dim]
        '''

        q = torch.tanh(self.weight_q(x))
        k = torch.tanh(self.weight_k(x))
        v = torch.tanh(self.weight_v(x))
        # here we have to permute the size of k in order to perform batch matrix multiplication
        # F.softmax(..., dim=2): After performing the batch matrix multiplication, the resulting tensor is passed to F.softmax, 
        # which applies the softmax function along dimension 2 of the tensor. 
        # Softmax ensures that the values along dimension 2 sum up to 1, effectively transforming them into a probability distribution.

        print(q.dim(), k.dim(), v.dim())
        z = torch.bmm(F.log_softmax(torch.bmm(q, k.permute(0, 2, 1)), dim = 2), v) 
        # z = torch.bmm(F.log_softmax(torch.bmm(q, k), dim = 2), v) 

        global_feature = torch.sum(z, dim = 1) 
        return z, global_feature

class NoisyLinear(nn.Linear): 
    def __init__(self, in_features, out_features, sigma_init = 0.017, bias = True): 
        super(NoisyLinear, self).__init__(in_features, out_features, bias = True) 
        self.sigma_init = sigma_init
        self.sigma_weight = P(torch.cuda.FloatTensor(out_features, in_features))
        self.sigma_bias = P(torch.cuda.FloatTensor(out_features))
        self.register_buffer('epsilon_weight', torch.zeros(out_features, in_features))
        self.register_buffer('epsilon_bias', torch.zeros(out_features))
        self.reset_parameters()
    
    def reset_parameters(self) -> None:
        if hasattr(self, 'sigma_weight'):
            nn.init.uniform(self.weight, -math.sqrt(3 / self.in_features), math.sqrt(3 / self.in_features))
            nn.init.uniform(self.bias, -math.sqrt(3 / self.in_features), math.sqrt(3 / self.in_features))
            nn.init.constant(self.sigma_weight, self.sigma_init)
            nn.init.constant(self.sigma_bias, self.sigma_init)

    def forward(self, x): 
        return F.linear(x, self.weight + self.sigma_weight * Variable(self.epsilon_weight), self.bias * Variable(self.epsilon_bias))
    
    def sample_noise(self): 
        self.epsilon_weight = torch.randn(self.out_features, self.in_features) 
        self.epsilon_bias = torch.randn(self.out_features)

    def remove_noise(self): 
        self.epsilon_weight = torch.zeros(self.out_features, self.in_features) 
        self.epsilon_bias = torch.zeros(self.out_features) 

def sample_action(mu_multi, sigma_multi, device, test = False): 
    logit = mu_multi
    prob = F.softmax(logit, dim = 1)
    log_prob = F.log_softmax(logit, dim = 1) 
    entropy = -(log_prob * prob).sum(-1, keepdim=True) 
    if test: 
        action = prob[1][-1].data
        action_env = action.cpu().numpy()
    
    else: 
        action = prob.multinomial(1).data 
        log_prob = log_prob.gather(1, Variable(action))
        action_env = action.squeeze(0) 
    
    return action_env, entropy, log_prob
class ValueNet(nn.Module): 
    def __init__(self, input_dim, head_name, num = 1): 
        # super(ValueNet, self).__init__(input_dim, num)
        super(ValueNet, self).__init__()
        if 'ns' in head_name: 
            self.noise = True 
            self.critic_linear = NoisyLinear(input_dim, num)
        else: 
            self.noise = False
            # init layers
            self.critic_linear = nn.Linear(input_dim, num, device=torch.device('cuda'))
            x = self.critic_linear.weight.data.size()
            x = torch.randn(x).to('cuda:0')
            x *= 0.1/torch.sqrt((x**2).sum(1, keepdim = True))
            self.critic_linear.weight.data = x
            self.critic_linear.bias.data.fill_(0)

    def forward(self, x): 
        return self.critic_linear(x)

    def sample_noise(self): 
        if self.noise: 
            self.critic_linear.sample_noise()
    
    # idk what does this function do
    def remove_noise(self): 
        pass

class PolicyNet(nn.Module): 
    def __init__(self, input_dim, action_space, head_name , device): 
        super(PolicyNet, self).__init__()
        self.head_name = head_name
        self.device = device 
        num_ouputs = action_space.n

        if 'ns' in head_name: 
            self.noise = True 
            self.action_linear = NoisyLinear(input_dim, num_ouputs, sigma_init=0.017) 
        
        else: 
            self.noise = False 
            self.action_linear = nn.Linear(input_dim, num_ouputs, device = torch.device('cuda')) 
            x = self.action_linear.weight.data.size()
            x = torch.randn(x).to('cuda:0')
            x *= 0.1/torch.sqrt((x**2).sum(1, keepdim = True))
            self.action_linear.weight.data = x
            self.action_linear.bias.data.fill_(0) 

    def forward(self, x, test=False): 
        x = x.to(self.device)  # Move x to the same device as action_linear
        x = F.relu(self.action_linear(x))
        sigma = torch.ones_like(x)
        action, entropy, log_prob = sample_action(x, sigma, self.device, test)
        return action, entropy, log_prob
    
    def sample_noise(self): 
        if self.noise: 
            self.action_linear.sample_noise()
            # self.action_linear2.sample_noise()
    
    def remove_noise(self): 
        pass
class A3C_Single(nn.Module): # single vision tracking 
    def __init__(self, obs_space, action_space, args, device = torch.device('cuda')): 
        super(A3C_Single, self).__init__() 
        self.n = len(obs_space) 
        # obs_dim = obs_space[0].shape[1]

        lstm_out = args.lstm_out 
        head_name = args.model 
        self.head_name = head_name

        self.encoder = AttentionLayer(self.n, lstm_out, device)
        self.critic = ValueNet(lstm_out, head_name)
        # print("Head Name:", head_name)
        self.actor = PolicyNet(lstm_out, action_space, head_name, device)
        
        self.train()
        self.device = device

    def forward(self, inputs, test = False): 
        data = Variable(inputs.to(self.device)).requires_grad_(True)
        _, feature = self.encoder.forward(data)
        actions, entropies, log_probs = self.actor.forward(feature, test) 
        values = self.critic.forward(feature) 
        return values, actions, entropies, log_probs
    def sample_noise(self): 
        self.actor.sample_noise()
        self.actor.sample_noise()
class A3C_Multi(nn.Module): 
    def __init__(self): 
        pass

In [98]:
'''debugging'''
def build_model(obs_space, action_space, args, device): 
    name = args.model 

    if 'single' in name: 
        model = A3C_Single(obs_space, action_space, args, device)
    elif 'multi' in name: 
        model = A3C_Multi()
    
    model.train()
    return model 

# args = {'model': 'single', 'lstm_out': 5}
class arg:  
    model = 'single'
    lstm_out = 5
args = arg() 
obs_space = torch.randn((3, 5))
obs_space.requires_grad = True 
class actionspace: 
    n = 5
action_space = actionspace()
device = torch.device('cuda:0')
model = build_model(obs_space, action_space, args, device)
inputs = torch.tensor([[[1., 2.,3.]], [[1.,2.,3.]]]).to('cuda:0')
model.forward(inputs, test = True)

3 3 3


(tensor([[0.],
         [0.]], device='cuda:0', grad_fn=<AddmmBackward0>),
 array(0.2, dtype=float32),
 tensor([[1.6094],
         [1.6094]], device='cuda:0', grad_fn=<NegBackward0>),
 tensor([[-1.6094, -1.6094, -1.6094, -1.6094, -1.6094],
         [-1.6094, -1.6094, -1.6094, -1.6094, -1.6094]], device='cuda:0',
        grad_fn=<LogSoftmaxBackward0>))