# PyTorch Training Short:


---
First let us check whether the whole thing is working





In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change b type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In [None]:
!pip install pytorch-lightning==1.1.0rc1 sklearn

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('To enable a high-RAM runtime, select the Runtime > "Change runtime type"')
  print('menu, and then select High-RAM in the Runtime shape dropdown. Then, ')
  print('re-execute this cell.')
else:
  print('You are using a high-RAM runtime!')

In [None]:
!pwd

### Create a simple dataset to test the correctness of our approach

In [None]:
import numpy as np
import torch
x = torch.randn(100000, 2)
noise = torch.randn(100000,)
y = ((1.0*x[:,0]+2.0*x[:,1]+noise)>0).type(torch.int64)

In [None]:
y_np = y.numpy()
x_np = x.numpy()
y_train, y_test = y_np[:50000], y_np[50000:]
x_train, x_test = x_np[:50000, :], x_np[50000:, :]
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression()
log_reg.fit(x_train, y_train)
y_pred = log_reg.predict(x_test)
from sklearn.metrics import accuracy_score
print(accuracy_score(y_test, y_pred))

### Now create an evil data set

In [None]:
x_1 = torch.randn(100000)
x_2 = torch.randn(100000)
x_useful = torch.cos(1.5*x_1)*(x_2**2)
x_1_rest_small = torch.randn(100000, 15)+ 0.01*x_1.unsqueeze(1)
x_1_rest_large = torch.randn(100000, 15) + 0.1*x_1.unsqueeze(1)
x_2_rest_small = torch.randn(100000, 15)+ 0.01*x_2.unsqueeze(1)
x_2_rest_large = torch.randn(100000, 15) + 0.1*x_2.unsqueeze(1)
x = torch.cat([x_1[:, None], x_2[:, None], x_1_rest_small, x_1_rest_large, x_2_rest_small, x_2_rest_large], dim=1)
y = ((10*x_useful) + 5*torch.randn(100000) >0.0).type(torch.int64) 

### Now let us test if we have an oracle.

In [None]:
y_train, y_test = y.numpy()[:50000], y.numpy()[50000:]
x_train, x_test = x.numpy()[:50000, :], x.numpy()[50000:, :]
oracle_train, oracle_test = x_useful.numpy()[:50000], x_useful.numpy()[50000:]
log_reg_2 = LogisticRegression()
log_reg_2.fit(oracle_train[:, None],y_train)
y_pred = log_reg_2.predict(oracle_test[:, None])
print(accuracy_score(y_pred, y_test))

### What if the oracle is not here?

In [None]:
y_train, y_test = y.numpy()[:50000], y.numpy()[50000:]
x_train, x_test = x.numpy()[:50000, :], x.numpy()[50000:, :]
log_reg_3 = LogisticRegression()
log_reg_3.fit(x_train, y_train)
y_pred = log_reg_3.predict(x_test)
accuracy_score(y_pred, y_test)

### Let us run a basic PyTorch example to test whether the good is correct

In [None]:
x = torch.randn(100000, 2)
noise = torch.randn(100000,)
y = ((1.0*x[:,0]+2.0*x[:,1]+noise)>0).type(torch.int64)
x_train, x_test = x[:50000, :], x[50000:, :]
y_train, y_test = y[:50000], y[50000:]

In [None]:
from torch.utils.data import Dataset, DataLoader

class MyDataSet(Dataset):
    def __init__(self, x, y):
        super().__init__()
        self.x = x
        self.y = y
        self.len = x.shape[0]
        
    def __getitem__(self, idx):
        return self.x[idx, :], self.y[idx]
    
    def __len__(self):
        return self.len

In [None]:
train_dataset = MyDataSet(x_train, y_train)
test_dataset = MyDataSet(x_test, y_test)
train_dataloader = DataLoader(train_dataset, batch_size = 128, shuffle=True, num_workers=6)
test_dataloader = DataLoader(test_dataset, batch_size = 128, num_workers=6)

In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F


def mish(input):

    return input * torch.tanh(F.softplus(input))

class Mish(nn.Module):

    def __init__(self):
        '''
        Init method.
        '''
        super().__init__()

    def forward(self, input):
        '''
        Forward pass of the function.
        '''
        return mish(input)

class MLPLayer(nn.Module):
    def __init__(self, dim_in, dim_out, res_coef = 0, dropout_p = 0.1):
        super().__init__()
        self.linear  = nn.Linear(dim_in, dim_out)
        self.res_coef = res_coef
        self.activation = Mish()
        self.dropout = nn.Dropout(dropout_p)
        self.ln = nn.LayerNorm(dim_out)
    
    def forward(self, x):
        y = self.linear(x)
        y = self.activation(y)
        y = self.dropout(y)
        if self.res_coef == 0:
            return self.ln(y)
        else:
            return self.ln(self.res_coef*x +y )

       
class MyNetwork(nn.Module):
    def __init__(self, dim_in, dim, res_coef=0.5, dropout_p = 0.1, n_layers = 10):
        super().__init__()
        self.mlp = nn.ModuleList()
        self.first_linear = MLPLayer(dim_in, dim)
        self.n_layers = n_layers
        for i in range(n_layers):
            self.mlp.append(MLPLayer(dim, dim, res_coef, dropout_p))
        self.final = nn.Linear(dim, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.first_linear(x)
        for layer in self.mlp:
            x = layer(x)
        x= self.sigmoid(self.final(x))
        return x.squeeze()

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.metrics import Accuracy
class TrainingModule(pl.LightningModule):
    def __init__(self, dim_in, dim, res_coef=0, dropout_p=0, n_layers=10):
        super().__init__()
        self.backbone = MyNetwork(dim_in, dim, res_coef, dropout_p, n_layers)
        self.loss = nn.BCELoss()
        self.accuracy = Accuracy()
    def forward(self, x):
        return self.backbone(x)
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = self.backbone(x)
        loss = self.loss(x, y.type(torch.float32))
        acc = self.accuracy(x, y)
        self.log("Validation loss", loss)
        self.log("Validation acc", acc)
        return loss, acc
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = self.backbone(x)
        loss = self.loss(x, y.type(torch.float32))
        acc = self.accuracy(x, y)
        self.log("Training loss", loss)
        self.log("Training acc", acc)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

import os
class CheckpointEveryNSteps(pl.Callback):
    def __init__(self, save_step_frequency):
        self.save_step_frequency = save_step_frequency

    def on_batch_end(self, trainer: pl.Trainer, _):
        epoch = trainer.current_epoch
        global_step = trainer.global_step
        if global_step % self.save_step_frequency == 0:
            filename = "epoch=" + str(epoch) + "_step=" + str(global_step)+".ckpt"
            ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
            trainer.save_checkpoint(ckpt_path)

In [None]:
# from pytorch_lightning import loggers as pl_loggers

# tb_logger = pl_loggers.TensorBoardLogger('logs/')
# save_by_steps = CheckpointEveryNSteps(100)
# training_module = TrainingModule(2, 10, 0.5, 0.1, 2)
# trainer = pl.Trainer(max_epochs=2, gpus=1, progress_bar_refresh_rate=100, val_check_interval=0.25, logger=tb_logger)
# trainer.fit(training_module, train_dataloader, test_dataloader)

In [None]:
# %tensorboard --logdir logs


### Now Let us use the evil dataset

In [None]:
import torch
x_1 = torch.randn(100000)
x_2 = torch.randn(100000)
x_useful = torch.cos(1.5*x_1)*(x_2**2)
x_1_rest_small = torch.randn(100000, 15)+ 0.01*x_1.unsqueeze(1)
x_1_rest_large = torch.randn(100000, 15) + 0.1*x_1.unsqueeze(1)
x_2_rest_small = torch.randn(100000, 15)+ 0.01*x_2.unsqueeze(1)
x_2_rest_large = torch.randn(100000, 15) + 0.1*x_2.unsqueeze(1)
x = torch.cat([x_1[:, None], x_2[:, None], x_1_rest_small, x_1_rest_large, x_2_rest_small, x_2_rest_large], dim=1)
y = ((10*x_useful) + 5*torch.randn(100000) >0.0).type(torch.int64) 

x_train, x_test = x[:50000, :], x[50000:, :]
y_train, y_test = y[:50000], y[50000:]
train_dataset = MyDataSet(x_train, y_train)
test_dataset = MyDataSet(x_test, y_test)
train_dataloader = DataLoader(train_dataset, batch_size = 32, num_workers=6)
test_dataloader = DataLoader(test_dataset, batch_size = 128, num_workers=6)

In [None]:
# from pytorch_lightning import loggers as pl_loggers

# tb_logger = pl_loggers.TensorBoardLogger('logs/')

# training_module = TrainingModule(62, 32, 0.5, 0.1, 20)
# trainer = pl.Trainer(max_epochs=20, tpu_cores=8, progress_bar_refresh_rate=100, val_check_interval=0.5, logger=tb_logger)
# trainer.fit(training_module, train_dataloader, test_dataloader)

In [None]:
# %tensorboard --logdir logs

Now let us see how LightGBM works.

In [None]:
# import lightgbm as lgb

# x_train_np, x_test_np = x_train.numpy(), x_test.numpy()
# y_train_np, y_test_np = y_train.numpy(), y_test.numpy()

# train_dataset = lgb.Dataset(x_train_np, y_train_np)
# test_dataset = lgb.Dataset(x_test_np, y_test_np)

In [None]:
# params = {'num_leaves': 31, 'objective': 'binary', 'feature_fraction':0.8, 'bagging_fraction':0.8, 'metric':'binary_error'}
# num_round=2000
# eval_list = [train_dataset, test_dataset]
# lgb_model = lgb.train(params, train_dataset, num_round, valid_sets=eval_list)

How can we improve (save) our deep learning model. 
1. Discretize
2. Variable selection

In [None]:
# coding = 'utf-8'
import numpy as np
import pandas as pd
import tqdm

def encode_label(x):
    unique=sorted(list(set([str(item) for item in np.unique(x)])))
    kv = {unique[i]: i for i in range(len(unique))}
    vfunc = np.vectorize(lambda x: kv[str(x)])
    return vfunc(x)

def encode_label_mat(x):
    _, ncol = x.shape
    result = np.empty_like(x, dtype=int)
    for col in range(ncol):
        result[:,col] = encode_label(x[:, col])
    return result

def impute_nan(x, method='median'):
    _, ncol = x.shape
    result = np.empty_like(x)

    for col in range(ncol):
        if method == 'median':
            data = x[:, col]
            impute_value = np.median(data[~pd.isnull(data) & (data != np.inf) & (data != -np.inf)])
        else:
            raise NotImplementedError()

        func = np.vectorize(lambda x: impute_value if pd.isnull(x) else x)
        result[:, col] = func(x[:, col])
    return result


def get_uniform_interval(minimum, maximum, nbins):
    result = [minimum]
    step_size = (float(maximum - minimum)) / nbins
    for index in range(nbins - 1):
        result.append(minimum + step_size * (index + 1))
    result.append(maximum)
    return result


def get_interval_v2(x, sorted_intervals):
    if pd.isnull(x):
        return -1
    if x == np.inf:
        return -2
    if x == -np.inf:
        return -3
    interval = 0
    found = False
    sorted_intervals.append(np.inf)
    while not found and interval < len(sorted_intervals) - 1:
        if sorted_intervals[interval] <= x < sorted_intervals[interval + 1]:
            return interval
        else:
            interval += 1


def get_quantile_interval(data, nbins):
    quantiles = get_uniform_interval(0, 1, nbins)
    return list(np.quantile(data[(~pd.isnull(data)) & (data != np.inf) & (data != -np.inf)], quantiles))


def discretize(x, nbins=20):
    nrow, ncol = x.shape
    result = np.empty_like(x)
    interval_list = list()
    for col in range(ncol):
        intervals = sorted(list(set(get_quantile_interval(x[:, col], nbins))))
        interval_centroid = list()

        for i in range(len(intervals) - 1):
            interval_centroid.append(0.5 * (intervals[i] + intervals[i + 1]))
        func = np.vectorize(lambda x: get_interval_v2(x, intervals))
        result[:, col] = encode_label(func(x[:, col]))
        interval_list.append(interval_centroid)
    return result.astype(np.int64), interval_list

def get_var_type(df):
    columns = df.columns
    continuous_vars = [x for x in columns if x.startswith('continuous_')]
    discrete_vars = [x for x in columns if x.startswith('discrete_')]
    other_vars = list()
    for column in columns:
        if column not in continuous_vars and column not in discrete_vars:
            other_vars.append(column)
    return {'continuous': continuous_vars,
            'discrete': discrete_vars,
            'other': other_vars}


def get_cont_var(df):
    var_types = get_var_type(df)
    return var_types['continuous']


def get_dis_var(df):
    var_types = get_var_type(df)
    return var_types['discrete']

def drop_const_var(data):
    result = data.copy(deep=True)
    for col in data.columns:
        if len(data.loc[~pd.isnull(data[col]), col].unique()) <= 1:
            result.drop(columns=col, inplace=True)
    return result

In [None]:
x_train_np, x_test_np = x_train.numpy(), x_test.numpy()
y_train_np, y_test_np = y_train.numpy(), y_test.numpy()
x = np.concatenate([x_train_np, x_test_np])
x_dis, centroids = discretize(x)
x_dis_train = x_dis[:50000, :]
x_dis_test = x_dis[50000:,:]

In [None]:
class TabDataset(Dataset):
    def __init__(self, x, y):
        super().__init__()
        self.x = torch.from_numpy(x).type(torch.int64) 
        self.y = torch.from_numpy(y).type(torch.float32).squeeze() 
    def __getitem__(self, idx):
        return self.x[idx, :], self.y[idx]

    def __len__(self):
        return self.x.shape[0]


### Create Embedding Factories

In [None]:
!pip install einops

In [None]:
class EmbeddingFactory(nn.Module):
    def __init__(self, x, dim_out):
        super().__init__()
        self.dim_out = dim_out
        self.module_list = nn.ModuleList(
            [nn.Embedding(len(set(np.unique(x[:, col]))), dim_out) for col in range(x.shape[1])])

    def forward(self, x):
        result = [self.module_list[col](x[:, col]).unsqueeze(2) for col in range(x.shape[1])]
        return torch.cat(result, dim=2)

In [None]:
from einops import rearrange, reduce, repeat
x_dis_test.shape
train_dataloader = DataLoader(TabDataset(x_dis_train, y_train_np), batch_size = 32, num_workers=6)
test_dataloader = DataLoader(TabDataset(x_dis_test, y_test_np), batch_size = 128, num_workers=6)

class TrainingModuleV2(pl.LightningModule):
    def __init__(self, x, dim_emb, dim_mlp, res_coef=0, dropout_p=0, n_layers=10):
        super().__init__()
        self.embedding = EmbeddingFactory(x, dim_emb)
        self.backbone = MyNetwork(x.shape[1]*dim_emb, dim_mlp, res_coef, dropout_p, n_layers)
        self.loss = nn.BCELoss()
        self.accuracy = Accuracy()
        
    def forward(self, x):
        x = self.embedding(x)
        return self.backbone(x)
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = self.embedding(x)
        x = rearrange(x, "b h e -> b (h e)")
        x = self.backbone(x)
        loss = self.loss(x, y)
        acc = self.accuracy(x, y)
        self.log("Validation loss", loss)
        self.log("Validation acc", acc)
        return loss, acc
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = self.embedding(x)
        x = rearrange(x, "b h e -> b (h e)")
        x = self.backbone(x)
        loss = self.loss(x, y)
        acc = self.accuracy(x, y)
        self.log("Training loss", loss)
        self.log("Training acc", acc)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


In [None]:
# from pytorch_lightning import loggers as pl_loggers

# tb_logger = pl_loggers.TensorBoardLogger('logs/')
# training_module = TrainingModuleV2(x_dis, 16, 64, 0.5, 0.1, 10)
# trainer = pl.Trainer(max_epochs=10, gpus=1, progress_bar_refresh_rate=100, val_check_interval=0.5, logger=tb_logger)
# trainer.fit(training_module, train_dataloader, test_dataloader)

In [None]:
# %tensorboard --logdir logs

### Let us try TabNet

In [None]:
!pip install einops

In [None]:
import torch
x_1 = torch.randn(100000)
x_2 = torch.randn(100000)
x_useful = torch.cos(1.5*x_1)*(x_2**2)
x_1_rest_small = torch.randn(100000, 15)+ 0.01*x_1.unsqueeze(1)
x_1_rest_large = torch.randn(100000, 15) + 0.1*x_1.unsqueeze(1)
x_2_rest_small = torch.randn(100000, 15)+ 0.01*x_2.unsqueeze(1)
x_2_rest_large = torch.randn(100000, 15) + 0.1*x_2.unsqueeze(1)
x = torch.cat([x_1[:, None], x_2[:, None], x_1_rest_small, x_1_rest_large, x_2_rest_small, x_2_rest_large], dim=1)
y = ((10*x_useful) + 5*torch.randn(100000) >0.0).type(torch.int64) 

x_train, x_test = x[:50000, :], x[50000:, :]
y_train, y_test = y[:50000], y[50000:]
x_train_np, x_test_np = x_train.numpy(), x_test.numpy()
y_train_np, y_test_np = y_train.numpy(), y_test.numpy()
x = np.concatenate([x_train_np, x_test_np])
x_dis, centroids = discretize(x)
x_dis_train = x_dis[:50000, :]
x_dis_test = x_dis[50000:,:]

In [None]:
from torch import nn
from torch.autograd import Function
import torch.nn.functional as F

import torch

"""
Other possible implementations:
https://github.com/KrisKorrel/sparsemax-pytorch/blob/master/sparsemax.py
https://github.com/msobroza/SparsemaxPytorch/blob/master/mnist/sparsemax.py
https://github.com/vene/sparse-structured-attention/blob/master/pytorch/torchsparseattn/sparsemax.py
"""


# credits to Yandex https://github.com/Qwicen/node/blob/master/lib/nn_utils.py
def _make_ix_like(input, dim=0):
    d = input.size(dim)
    rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype)
    view = [1] * input.dim()
    view[0] = -1
    return rho.view(view).transpose(0, dim)


class SparsemaxFunction(Function):
    """
    An implementation of sparsemax (Martins & Astudillo, 2016). See
    :cite:`DBLP:journals/corr/MartinsA16` for detailed description.
    By Ben Peters and Vlad Niculae
    """

    @staticmethod
    def forward(ctx, input, dim=-1):
        """sparsemax: normalizing sparse transform (a la softmax)

        Parameters
        ----------
        ctx : torch.autograd.function._ContextMethodMixin
        input : torch.Tensor
            any shape
        dim : int
            dimension along which to apply sparsemax

        Returns
        -------
        output : torch.Tensor
            same shape as input

        """
        ctx.dim = dim
        max_val, _ = input.max(dim=dim, keepdim=True)
        input -= max_val  # same numerical stability trick as for softmax
        tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim)
        output = torch.clamp(input - tau, min=0)
        ctx.save_for_backward(supp_size, output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        supp_size, output = ctx.saved_tensors
        dim = ctx.dim
        grad_input = grad_output.clone()
        grad_input[output == 0] = 0

        v_hat = (grad_input.sum(dim=dim) / supp_size).squeeze()
        v_hat = v_hat.unsqueeze(dim)
        grad_input = torch.where(output != 0, grad_input - v_hat, grad_input)
        return grad_input, None

    @staticmethod
    def _threshold_and_support(input, dim=-1):
        """Sparsemax building block: compute the threshold

        Parameters
        ----------
        input: torch.Tensor
            any dimension
        dim : int
            dimension along which to apply the sparsemax

        Returns
        -------
        tau : torch.Tensor
            the threshold value
        support_size : torch.Tensor

        """

        input_srt, _ = torch.sort(input, descending=True, dim=dim)
        input_cumsum = input_srt.cumsum(dim) - 1
        rhos = _make_ix_like(input, dim)
        support = rhos * input_srt > input_cumsum

        support_size = support.sum(dim=dim).unsqueeze(dim)
        tau = input_cumsum.gather(dim, support_size - 1)
        tau /= support_size.to(input.dtype)
        return tau, support_size


sparsemax = SparsemaxFunction.apply


class Sparsemax(nn.Module):

    def __init__(self, dim=-1):
        self.dim = dim
        super(Sparsemax, self).__init__()

    def forward(self, input):
        return sparsemax(input, self.dim)


class Entmax15(nn.Module):
    def __init__(self, dim=-1):
        super().__init_()
        self.dim=dim
            
    @staticmethod
    def _threshold_and_support(input, dim=-1):
        Xsrt, _ = torch.sort(input, descending=True, dim=dim)

        rho = _make_ix_like(input, dim)
        mean = Xsrt.cumsum(dim) / rho
        mean_sq = (Xsrt ** 2).cumsum(dim) / rho
        ss = rho * (mean_sq - mean ** 2)
        delta = (1 - ss) / rho

        delta_nz = torch.clamp(delta, 0)
        tau = mean - torch.sqrt(delta_nz)

        support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim)
        tau_star = tau.gather(dim, support_size - 1)
        return tau_star, support_size
    def forward(self, input):
        max_val, _ = input.max(dim=self.dim, keepdim=True)
        input = input - max_val  # same numerical stability trick as for softmax
        input = input / 2  # divide by 2 to solve actual Entmax

        tau_star, _ = Entmax15Function._threshold_and_support(input, self.dim)
        output = torch.clamp(input - tau_star, min=0) ** 2
        ctx.save_for_backward(output)
        return output 

    def backward(self, output, grad):
        Y = output
        gppr = Y.sqrt()  # = 1 / g'' (Y)
        dX = grad_output * gppr
        q = dX.sum(ctx.dim) / gppr.sum(ctx.dim)
        q = q.unsqueeze(ctx.dim)
        dX -= q * gppr
        return dX, None



In [None]:
import torch
from torch.nn import Linear, BatchNorm1d, ReLU
import numpy as np
import math


def initialize_non_glu(module, input_dim, output_dim):
    gain_value = np.sqrt((input_dim+output_dim)/np.sqrt(4*input_dim))
    torch.nn.init.xavier_normal_(module.weight, gain=gain_value)
    # torch.nn.init.zeros_(module.bias)
    return


def initialize_glu(module, input_dim, output_dim):
    gain_value = np.sqrt((input_dim+output_dim)/np.sqrt(input_dim))
    torch.nn.init.xavier_normal_(module.weight, gain=gain_value)
    # torch.nn.init.zeros_(module.bias)
    return


class GBN(torch.nn.Module):
    """
        Ghost Batch Normalization
        https://arxiv.org/abs/1705.08741
    """

    def __init__(self, input_dim, virtual_batch_size=128, momentum=0.01):
        super(GBN, self).__init__()

        self.input_dim = input_dim
        self.virtual_batch_size = virtual_batch_size
        self.bn = BatchNorm1d(self.input_dim, momentum=momentum)

    def forward(self, x):
        chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0)
        res = [self.bn(x_) for x_ in chunks]

        return torch.cat(res, dim=0)


class TabNet(torch.nn.Module):
    def __init__(self, input_dim, output_dim,
                 n_d=64, n_a=64,
                 n_steps=5, gamma=1.3,
                 n_independent=2, n_shared=2, epsilon=1e-15,
                 virtual_batch_size=128, momentum=0.02,
                 mask_type="sparsemax"):
        """
        Defines main part of the TabNet network without the embedding layers.

        Parameters
        ----------
        input_dim : int
            Number of features
        output_dim : int or list of int for multi task classification
            Dimension of network output
            examples : one for regression, 2 for binary classification etc...
        n_d : int
            Dimension of the prediction  layer (usually between 4 and 64)
        n_a : int
            Dimension of the attention  layer (usually between 4 and 64)
        n_steps : int
            Number of sucessive steps in the newtork (usually betwenn 3 and 10)
        gamma : float
            Float above 1, scaling factor for attention updates (usually betwenn 1.0 to 2.0)
        n_independent : int
            Number of independent GLU layer in each GLU block (default 2)
        n_shared : int
            Number of independent GLU layer in each GLU block (default 2)
        epsilon : float
            Avoid log(0), this should be kept very low
        virtual_batch_size : int
            Batch size for Ghost Batch Normalization
        momentum : float
            Float value between 0 and 1 which will be used for momentum in all batch norm
        mask_type : str
            Either "sparsemax" or "entmax" : this is the masking function to use
        """
        super(TabNet, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.is_multi_task = isinstance(output_dim, list)
        self.n_d = n_d
        self.n_a = n_a
        self.n_steps = n_steps
        self.gamma = gamma
        self.epsilon = epsilon
        self.n_independent = n_independent
        self.n_shared = n_shared
        self.virtual_batch_size = virtual_batch_size
        self.mask_type = mask_type
        self.initial_bn = BatchNorm1d(self.input_dim, momentum=0.01)

        if self.n_shared > 0:
            shared_feat_transform = torch.nn.ModuleList()
            for i in range(self.n_shared):
                if i == 0:
                    shared_feat_transform.append(Linear(self.input_dim,
                                                        2*(n_d + n_a),
                                                        bias=False))
                else:
                    shared_feat_transform.append(Linear(n_d + n_a, 2*(n_d + n_a), bias=False))

        else:
            shared_feat_transform = None

        self.initial_splitter = FeatTransformer(self.input_dim, n_d+n_a, shared_feat_transform,
                                                n_glu_independent=self.n_independent,
                                                virtual_batch_size=self.virtual_batch_size,
                                                momentum=momentum)

        self.feat_transformers = torch.nn.ModuleList()
        self.att_transformers = torch.nn.ModuleList()

        for step in range(n_steps):
            transformer = FeatTransformer(self.input_dim, n_d+n_a, shared_feat_transform,
                                          n_glu_independent=self.n_independent,
                                          virtual_batch_size=self.virtual_batch_size,
                                          momentum=momentum)
            attention = AttentiveTransformer(n_a, self.input_dim,
                                             virtual_batch_size=self.virtual_batch_size,
                                             momentum=momentum,
                                             mask_type=self.mask_type)
            self.feat_transformers.append(transformer)
            self.att_transformers.append(attention)

        if self.is_multi_task:
            self.multi_task_mappings = torch.nn.ModuleList()
            for task_dim in output_dim:
                task_mapping = Linear(n_d, task_dim, bias=False)
                initialize_non_glu(task_mapping, n_d, task_dim)
                self.multi_task_mappings.append(task_mapping)
        else:
            self.final_mapping = Linear(n_d, output_dim, bias=False)
            initialize_non_glu(self.final_mapping, n_d, output_dim)

    def forward(self, x):
        res = 0
        x = self.initial_bn(x)

        prior = torch.ones(x.shape, device=x.device)
        M_loss = 0
        att = self.initial_splitter(x)[:, self.n_d:]

        for step in range(self.n_steps):
            M = self.att_transformers[step](prior, att)
            M_loss += torch.mean(torch.sum(torch.mul(M, torch.log(M+self.epsilon)),
                                           dim=1))
            # update prior
            prior = torch.mul(self.gamma - M, prior)
            # output
            masked_x = torch.mul(M, x)
            out = self.feat_transformers[step](masked_x)
            d = ReLU()(out[:, :self.n_d])
            res = torch.add(res, d)
            # update attention
            att = out[:, self.n_d:]

        M_loss /= self.n_steps

        if self.is_multi_task:
            # Result will be in list format
            out = []
            for task_mapping in self.multi_task_mappings:
                out.append(task_mapping(res))
        else:
            out = self.final_mapping(res)
        return out, M_loss

    def forward_masks(self, x):
        x = self.initial_bn(x)

        prior = torch.ones(x.shape, device=x.device)
        M_explain = torch.zeros(x.shape, device=x.device)
        att = self.initial_splitter(x)[:, self.n_d:]
        masks = {}

        for step in range(self.n_steps):
            M = self.att_transformers[step](prior, att)
            masks[step] = M
            # update prior
            prior = torch.mul(self.gamma - M, prior)
            # output
            masked_x = torch.mul(M, x)
            out = self.feat_transformers[step](masked_x)
            d = ReLU()(out[:, :self.n_d])
            # explain
            step_importance = torch.sum(d, dim=1)
            M_explain += torch.mul(M, step_importance.unsqueeze(dim=1))
            # update attention
            att = out[:, self.n_d:]

        return M_explain, masks

class AttentiveTransformer(torch.nn.Module):
    def __init__(self, input_dim, output_dim,
                 virtual_batch_size=128,
                 momentum=0.02,
                 mask_type="entmax"):
        """
        Initialize an attention transformer.

        Parameters
        ----------
        input_dim : int
            Input size
        output_dim : int
            Outpu_size
        virtual_batch_size : int
            Batch size for Ghost Batch Normalization
        momentum : float
            Float value between 0 and 1 which will be used for momentum in batch norm
        mask_type : str
            Either "sparsemax" or "entmax" : this is the masking function to use
        """
        super(AttentiveTransformer, self).__init__()
        self.fc = Linear(input_dim, output_dim, bias=False)
        initialize_non_glu(self.fc, input_dim, output_dim)
        self.bn = GBN(output_dim, virtual_batch_size=virtual_batch_size,
                      momentum=momentum)

        if mask_type == "sparsemax":
            # Sparsemax
            self.selector = Sparsemax(dim=-1)
        elif mask_type == "entmax":
            # Entmax
            self.selector = Entmax15(dim=-1)
        else:
            raise NotImplementedError("Please choose either sparsemax" +
                                      "or entmax as masktype")

    def forward(self, priors, processed_feat):
        x = self.fc(processed_feat)
        x = self.bn(x)
        x = torch.mul(x, priors)
        x = self.selector(x)
        return x


class FeatTransformer(torch.nn.Module):
    def __init__(self, input_dim, output_dim, shared_layers, n_glu_independent,
                 virtual_batch_size=128, momentum=0.02):
        super(FeatTransformer, self).__init__()
        """
        Initialize a feature transformer.

        Parameters
        ----------
        input_dim : int
            Input size
        output_dim : int
            Outpu_size
        shared_layers : torch.nn.ModuleList
            The shared block that should be common to every step
        n_glu_independant : int
            Number of independent GLU layers
        virtual_batch_size : int
            Batch size for Ghost Batch Normalization within GLU block(s)
        momentum : float
            Float value between 0 and 1 which will be used for momentum in batch norm
        """

        params = {
            'n_glu': n_glu_independent,
            'virtual_batch_size': virtual_batch_size,
            'momentum': momentum
        }

        if shared_layers is None:
            # no shared layers
            self.shared = torch.nn.Identity()
            is_first = True
        else:
            self.shared = GLU_Block(input_dim, output_dim,
                                    first=True,
                                    shared_layers=shared_layers,
                                    n_glu=len(shared_layers),
                                    virtual_batch_size=virtual_batch_size,
                                    momentum=momentum)
            is_first = False

        if n_glu_independent == 0:
            # no independent layers
            self.specifics = torch.nn.Identity()
        else:
            spec_input_dim = input_dim if is_first else output_dim
            self.specifics = GLU_Block(spec_input_dim, output_dim,
                                       first=is_first,
                                       **params)

    def forward(self, x):
        x = self.shared(x)
        x = self.specifics(x)
        return x


class GLU_Block(torch.nn.Module):
    """
        Independant GLU block, specific to each step
    """

    def __init__(self, input_dim, output_dim, n_glu=2, first=False, shared_layers=None,
                 virtual_batch_size=128, momentum=0.02):
        super(GLU_Block, self).__init__()
        self.first = first
        self.shared_layers = shared_layers
        self.n_glu = n_glu
        self.glu_layers = torch.nn.ModuleList()

        params = {
            'virtual_batch_size': virtual_batch_size,
            'momentum': momentum
        }

        fc = shared_layers[0] if shared_layers else None
        self.glu_layers.append(GLU_Layer(input_dim, output_dim,
                                         fc=fc,
                                         **params))
        for glu_id in range(1, self.n_glu):
            fc = shared_layers[glu_id] if shared_layers else None
            self.glu_layers.append(GLU_Layer(output_dim, output_dim,
                                             fc=fc,
                                             **params))

    def forward(self, x):
        scale = math.sqrt(0.5)
        if self.first:  # the first layer of the block has no scale multiplication
            x = self.glu_layers[0](x)
            layers_left = range(1, self.n_glu)
        else:
            layers_left = range(self.n_glu)

        for glu_id in layers_left:
            x = torch.add(x, self.glu_layers[glu_id](x))
            x = x*scale
        return x


class GLU_Layer(torch.nn.Module):
    def __init__(self, input_dim, output_dim, fc=None,
                 virtual_batch_size=128, momentum=0.02):
        super(GLU_Layer, self).__init__()

        self.output_dim = output_dim
        if fc:
            self.fc = fc
        else:
            self.fc = Linear(input_dim, 2*output_dim, bias=False)
        initialize_glu(self.fc, input_dim, 2*output_dim)

        self.bn = GBN(2*output_dim, virtual_batch_size=virtual_batch_size,
                      momentum=momentum)

    def forward(self, x):
        x = self.fc(x)
        x = self.bn(x)
        out = torch.mul(x[:, :self.output_dim], torch.sigmoid(x[:, self.output_dim:]))
        return out




In [None]:
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl

class TabDataset(Dataset):
    def __init__(self, x, y):
        super().__init__()
        self.x = torch.from_numpy(x).type(torch.int64) 
        self.y = torch.from_numpy(y).type(torch.float32).squeeze() 
    def __getitem__(self, idx):
        return self.x[idx, :], self.y[idx]

    def __len__(self):
        return self.x.shape[0]
train_dataloader = DataLoader(TabDataset(x_dis_train, y_train_np), batch_size = 32, num_workers=6)
test_dataloader = DataLoader(TabDataset(x_dis_test, y_test_np), batch_size = 128, num_workers=6)
class TrainingModuleV2(pl.LightningModule):
    def __init__(self, x, dim_emb, dim_out, penalty=1e-3, **kwargs):
        super().__init__()
        self.penalty = penalty
        self.embedding = EmbeddingFactory(x, dim_emb)
        self.backbone = TabNet(x.shape[1]*dim_emb, dim_out, **kwargs)
        self.sigmoid = nn.Sigmoid()
        self.loss = nn.BCELoss()
        self.accuracy = Accuracy()
    def forward(self, x):
        x = self.embedding(x)
        return self.backbone(x)
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = self.embedding(x)
        x = rearrange(x, 'b n e -> b (n e)')
        x, _ = self.backbone(x)
        x = self.sigmoid(x.squeeze())
        loss = self.loss(x.squeeze(), y.type(torch.float32))
        acc = self.accuracy(x.squeeze(), y)
        self.log("Validation loss", loss)
        self.log("Validation acc", acc)
        return loss, acc
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = self.embedding(x)
        x = rearrange(x, 'b n e -> b (n e)')
        x, m_loss = self.backbone(x)
        x = self.sigmoid(x.squeeze())
        loss = self.loss(x, y.type(torch.float32)) - self.penalty*m_loss
        acc = self.accuracy(x, y)
        self.log("Training loss", loss)
        self.log("Training acc", acc)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=2e-3)
        return optimizer


training_module = TrainingModuleV2(x_dis, 32, 1, n_steps=2, n_independent=4, n_shared=4,)
trainer = pl.Trainer(max_epochs=10, tpu_cores=8, progress_bar_refresh_rate=100, val_check_interval=0.5)
trainer.fit(training_module, train_dataloader, test_dataloader)

In [None]:
%tensorboard --logdir logs

### Embedding with Distance Information

In [None]:
import torch.nn as nn
import torch.functional as F
import numpy as np
class EntityEmbeddingLayer(nn.Module):
    def __init__(self, num_level, emdedding_dim, centroid):
        super(EntityEmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(num_level, emdedding_dim)
        self.centroid = torch.tensor(centroid, dtype=torch.float32).detach_()
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x):
        x = x[:, None]
        d = 1.0 / ((x - self.centroid).abs() + EPS)
        w = self.softmax(d)
        v = torch.mm(w, self.embedding.weight)
        return v

In [None]:
embedding = EntityEmbeddingLayer(20, 4, centroids[0])