In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

import numpy as np

In [21]:
class CustomLayer(nn.Module):
    def __init__(self, 
        n_lower : int, n_upper : int, 
        q_lower : int, q_upper : int,
        weight_init=0.1,
        weight_init_method='uniform'
    ):
        # factory_kwargs = {'device': torch.device, 'dtype': dtype}
        super(CustomLayer, self).__init__()
        self.n_lower = n_lower
        self.n_upper = n_upper
        self.q_lower = q_lower
        self.q_upper = q_upper
        self.weight_init = 0.1
        self.weight_init_method = weight_init_method

        self.D = self.initD(n_lower, q_upper, n_lower, n_upper)
        self.init_weights()
        
    def init_weights(self):
        # if self.weight_init_method == 'uniform': # IMPLEMENT LATER
        self.weights = nn.Parameter(init.uniform_(torch.empty(self.n_upper, self.n_lower), a=-self.weight_init, b=self.weight_init))
        self.ba = nn.Parameter(init.uniform_(torch.empty(self.n_upper, 1), a=-self.weight_init, b=self.weight_init))
        self.bq = nn.Parameter(init.uniform_(torch.empty(self.n_upper, 1), a=-self.weight_init, b=self.weight_init))
        self.lama = nn.Parameter(init.uniform_(torch.empty(self.n_upper, 1), a=0, b=1))
        self.lamq = nn.Parameter(init.uniform_(torch.empty(self.n_upper, 1), a=0, b=1))
        # elif self.weight_init_method == 'normal':
        # elif self.weight_init_method == 'glorot_normal':
    
    def initD(self, q_lower, q_upper, n_lower, n_upper):
        D_np = np.zeros((q_upper, q_lower))

        for s1 in range(q_upper):
            for s0 in range(q_lower):
                D_np[s1, s0] = np.exp(-((float(s0)/q_lower - float(s1)/q_upper) ** 2)) # suggest for improvement
        
        Dnp = D_np.reshape((q_upper, q_lower, 1, 1))
        D_tensor = torch.tensor(Dnp, dtype=torch.float32)
        D = torch.tile(D_tensor, [1, 1, n_upper, n_lower])
        return D
    
    # returns log(exp(B)) which is B
    def cal_logexp_bias(self, q):
        # each contains multiple nodes bias values, of size nu x 1
        s0 = torch.tensor(torch.arange(q).reshape((1, q)))

        # need account for multiple nodes in layer
        # s0 - b : (1 x q) x (nu x 1) = nu x q
        B = -(self.bq * torch.pow(s0 / q - self.lamq, 2) + self.ba * torch.abs(s0 / q - self.lama, 2))
        return B

    def forward(self, P):
        # MIGHT HAVE PROLEMS HERE LATER WITH BATCH SIZE (BE VARY)
        Ptile = torch.tile(torch.reshape(P,[-1, 1, self.n_lower, self.q_lower, 1]), [1, self.n_upper, 1, 1, 1])  # bs x nu x nl x ql x 1
        T = torch.transpose(torch.pow(self.D, self.weights), [2, 3, 0, 1])  # nu x nl x qu x ql
        Pw_unclipped = torch.squeeze(torch.einsum('jklm,ijkmn->ijkln', T, Ptile), axis=[4])   # bs x nu x nl x qu x 1 -> bs x nu x nl x qu
         # clip Pw by value to prevent zeros when weight is large
        Pw = torch.clamp(Pw_unclipped, 1e-15, 1e+15)
        
        # perform underflow handling (product of probabilities become small as no. neighbors increase)
        # 1. log each term in Pw
        logPw = torch.log(Pw)  # bs x nu x nl x qu
        # 2. sum over neighbors
        logsum = torch.sum(logPw, axis=2)       # bs x nu x qu
        # 3. log of exp of bias terms: log(expB) = exponent_B
        exponent_B = self.cal_logexp_bias(self.q_upper)  # nu x q
        # 4. add B to logsum
        logsumB = torch.add(logsum, exponent_B)          # bs x nu x qu
        # 5. find max over s0
        max_logsum = torch.max(logsumB, axis=2, keep_dims=True)    # bs x nu x qu
        # 6. subtract max_logsum and exponentiate (the max term will have a result of exp(0) = 1, preventing underflow)
        # Now all terms will have been multiplied by exp(-max)
        expm_P = torch.exp(torch.subtract(logsumB, max_logsum))        # bs x nu x qu
        # normalize
        Z = torch.sum(expm_P, 2, keep_dims=True)
        y_normalised = torch.div(expm_P, Z)
        
        return y_normalised


In [22]:
class DRN(nn.Module):
    def __init__(self, 
        in_features: int = 1,
        num_layers: int = 1,
        num_nodes: int = 5,
        out_features: int = 1,
        q: int = 100, 
        hidden_q: int = 10
    ):
        super(DRN, self).__init__()

        if num_layers == 0:
            self.layer1 = CustomLayer(in_features, out_features, q, q)
        else: 
            self.layer_1 = CustomLayer(in_features, num_nodes, q, hidden_q)
            for layer in range (2, num_layers):
                setattr(self, f'layer_{layer}', CustomLayer(num_nodes, num_nodes, hidden_q, hidden_q))
            self.final_layer = CustomLayer(num_nodes, out_features, hidden_q, q)
            
    def forward(self, x):
        yout = x
        for layer in self.children():
            yout = layer(yout)
        return yout


In [None]:
print(input_data)