In this notebook, Pu tries to implement the TabNet model in Pytorch. This is based on the tutorial here: https://towardsdatascience.com/implementing-tabnet-in-pytorch-fc977c383279 

In [1]:
import torch
import torch.nn as nn

In [2]:
class GBN(nn.Module):
    def __init__(self, inp, vbs=128, momentum=0.01):
        super().__init__()
        self.bn = nn.BatchNorm1d(inp, momentum=momentum)
        self.vbs = vbs

    def forward(self, x):
        chunks = torch.chunk(x, x.size(0)//self.vbs, 0)
        res = [res.bn(y) for y in chunks]
        return torch.concat(res, 0)

In [12]:
class Sparsemax(nn.Module):
    def __init__(self, dim=None):
        super(Sparsemax, self).__init__()
        self.dim = -1 if dim is None else dim

    def forward(self, input):
        # if no dimension is specified, the operation will be applied to the last dimension of the tensor
        if dim is None:
            dim = -1

        number_of_dims = len(input.size())
        dim = range(number_of_dims)[dim]

        # tarnspose the input tensor so that the dimension we are interested in becomes the last dim
        input_transposed = input.transpose(dim, -1)
        original_size = input_transposed.size()

        # the transposed tensor is then reshaped into a 2D tensor
        input_2d = input_transposed.contiguous().view(-1, input_transposed.size(-1))
        
        number_of_feats = input_2d.size(1)

        # the tensor is sorted in desending order along the second dimension the cumulative sum is calculated and 
        # -1 is subtracted from the cumsum and this will be used as threshold later
        input_sorted, _ = torch.sort(input_2d, descending=True, dim=1)
        input_cumsum = input_sorted.cumsum(dim=1) - 1
        
        # a range of values from 1 to `number_of_feats + 1` is created and then a count of how many elements in each
        # vector of `input_sorted` are greater than the corresponding element in `input_cumsum / arrange` is calculated
        arange = torch.arange(1, number_of_feats + 1, device=input.device)
        counts = (input_sorted > input_cumsum / arange).sum(dim=1).unsqueeze(1)
        
        # the `gather` function is used to select the elements in each row of `input_cumsum` specified by `counts - 1`,
        # which gives the cumsum value at the threshold index for each row, and then that is divided by `counts` to get 
        # threshold for each vector in the input
        threshold = input_cumsum.gather(1, counts - 1) / counts.to(input.dtype)

        # the threshold is subtracted from the original 2D tensor, and any negative values are clamped to zero. This
        # creates sparsity in the output. The tensor is then reshaped back to the original size and the tranpose is undone
        output = torch.clamp(input_2d - threshold, min=0)
        output = output.view(*original_size).transpose(dim, -1).contiguous()

        self.output = output
        
        return output

    def backward(self, grad_output):
        

sparsemax(torch.tensor([1, 1, 1, 1, 1]))

tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])

In [None]:
class AttentionTransformer(nn.Module):
    def __init__(self, input_dim, output_dim, vbs=128, momentum=0.02):
        super(AttentionTransformer, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim, bias=False)
        self.bn = GBN(output_dim, vbs=vbs, momentum=momentum)
        self.max = Sparsemax()
    
    def forward(self, priors, processed_feat):
        x = self.bn(self.fc(processed_feat))
        x = torch.mul(x, priors)
        x = self.sparsemax(x)
        return x

In [None]:
import torch.nn.functional as F

class FeatureTransformer(nn.Module):
    def __init__(self, input_dim, output_dim, n_shared=2, n_individual=2, vbs=128, momentum=0.02):
        super(FeatureTransformer, self).__init__()

        # the shared layers are used by all decision steps
        self.shared_layers = nn.ModuleList()
        for _ in range(n_shared):
            self.shared_layers.append(nn.Linear(input_dim, output_dim, bias=False))
            self.shared_layers.append(GBN(input_dim, vbs=vbs, momentum=momentum))

        self.individual_layers = nn.ModuleList()
        for _ in range(n_individual):
            self.individual_layers.append(nn.Linear(output_dim, output_dim, bias=False))
            self.individual_layers.append(GBN(output_dim, vbs=vbs, momentum=momentum))

    def forward(self, x):
        for layer in self.shared_layers:
            x = F.relu(layer(x))

        for layer in self.individual_layers:
            x = F.relu(layer(x))

        return x

In [None]:
class TabNet(nn.Module):
    def __init__(self, input_dim, output_dim, n_steps, vbs=128, momentum=0.02):
        super(TabNet, self).__init__()
        self.n_steps = n_steps
        self.attention_transformers = nn.ModuleList()
        self.feature_transformers = nn.ModuleList()
        for _ in range(n_steps):
            attention_transformer = AttentionTransformer(input_dim, output_dim, vbs=vbs, momentum=momentum)
            self.attention_transformers.append(attention_transformer)
            feature_transformer = FeatureTransformer(input_dim, output_dim, vbs=vbs, momentum=momentum)
            self.feature_transformers.append(feature_transformer)
        self.final_mapping = nn.Linear(output_dim, 1, bias=False)

    def forward(self, x):
        prior = torch.ones(x.shape, device=x.device) / x.shape[1]
        for step in range(self.n_steps):
            attention_score = self.attention_transformers[step](prior, x)
            x = self.feature_transformers[step](x)
            prior = attention_score
        return self.final_mapping(x)