<a href="https://colab.research.google.com/github/sugangnb/ai-research/blob/main/tabnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [None]:
!rm -rf kicked_dataset/
!git clone https://github.com/calibertytz/kicked_dataset.git

Cloning into 'kicked_dataset'...
remote: Enumerating objects: 3, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 25 (delta 0), reused 0 (delta 0), pack-reused 22[K
Unpacking objects: 100% (25/25), done.


In [None]:
VERSION = "nightly"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5116  100  5116    0     0  23906      0 --:--:-- --:--:-- --:--:-- 23906
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Uninstalling torch-1.8.0a0:
  Successfully uninstalled torch-1.8.0a0
Uninstalling torchvision-0.9.0a0+d0063f3:
  Successfully uninstalled torchvision-0.9.0a0+d0063f3
Copying gs://tpu-pytorch/wheels/torch-nightly-cp36-cp36m-linux_x86_64.whl...

Operation completed over 1 objects/122.1 MiB.                                    
Copying gs://tpu-pytorch/wheels/torch_xla-nightly-cp36-cp36m-linux_x86_64.whl...
- [1 files][131.0 MiB/131.0 MiB]                                                
Operation completed over 1 objects/131.0 MiB.                                    
Copying gs://tpu-pytorch/wheels/torchvision-nightly-cp36-cp36m-linux_x86_64.whl...
/ [1 files][  4.8 MiB/

In [None]:
import pandas as pd
import numpy as np

In [None]:
df_train = pd.read_csv('kicked_dataset/df_train_final.csv')
df_test = pd.read_csv('kicked_dataset/df_test_final.csv')

X = df_train.drop(columns=['label']).values
y = df_train['label'].values

X_test = df_test.drop(columns=['label']).values
y_test = df_test['label'].values

# Attention

In [None]:
from torch import nn
from torch.autograd import Function
import torch.nn.functional as F
import torch
import torch_xla.core.xla_model as xm

class EntmaxBisectFunction(Function):
    @classmethod
    def _gp(cls, x, alpha):
        return x ** (alpha - 1)

    @classmethod
    def _gp_inv(cls, y, alpha):

        return y ** (1 / (alpha - 1))

    @classmethod
    def _p(cls, X, alpha):
        return cls._gp_inv(torch.clamp(X, min=0), alpha)

    @classmethod
    def forward(cls, ctx, X, alpha=1.5, dim=-1, n_iter=50, ensure_sum_one=True):
        
        device = X.device
        if not isinstance(alpha, torch.Tensor):
            alpha = torch.tensor(alpha, dtype=X.dtype, device=device)
        if not xm.is_xla_tensor(alpha):
            # alpha = alpha.clone().detach().requires_grad_(True)
            alpha = torch.tensor(alpha, dtype=X.dtype, device=device)

        alpha_shape = list(X.shape)
        alpha_shape[dim] = 1
        alpha = alpha.expand(*alpha_shape)
        ctx.alpha = alpha
        ctx.dim = dim
        d = X.shape[dim]
        X = X * (alpha - 1)

        max_val, _ = X.max(dim=dim, keepdim=True)

        tau_lo = max_val - cls._gp(1, alpha)
        tau_hi = max_val - cls._gp(1 / d, alpha)
        f_lo = cls._p(X - tau_lo, alpha).sum(dim) - 1
        dm = tau_hi - tau_lo
        for it in range(n_iter):

            dm /= 2
            tau_m = tau_lo + dm
            p_m = cls._p(X - tau_m, alpha)
            f_m = p_m.sum(dim) - 1

            mask = (f_m * f_lo >= 0).unsqueeze(dim)
            tau_lo = torch.where(mask, tau_m, tau_lo)

        if ensure_sum_one:
            p_m /= p_m.sum(dim=dim).unsqueeze(dim=dim)
        
        ctx.save_for_backward(xm.all_gather(p_m,dim=dim))
        return xm.all_gather(p_m,dim=dim)
        # ctx.save_for_backward(p_m)
        # return p_m
        
    @classmethod
    def backward(cls, ctx, dY):
        Y, = ctx.saved_tensors

        gppr = torch.where(Y > 0, Y ** (2 - ctx.alpha), Y.new_zeros(1))

        dX = dY * gppr
        q = dX.sum(ctx.dim) / gppr.sum(ctx.dim)
        q = q.unsqueeze(ctx.dim)
        dX -= q * gppr

        d_alpha = None
        if ctx.needs_input_grad[1]:

            # alpha gradient computation
            # d_alpha = (partial_y / partial_alpha) * dY
            # NOTE: ensure alpha is not close to 1
            # since there is an indetermination
            # batch_size, _ = dY.shape

            # shannon terms
            S = torch.where(Y > 0, - Y * torch.log(Y), Y.new_zeros(1))

            # shannon entropy
            ent = S.sum(ctx.dim).unsqueeze(ctx.dim)
            Y_skewed = gppr / gppr.sum(ctx.dim).unsqueeze(ctx.dim)

            d_alpha = dY * (Y - Y_skewed) / ((ctx.alpha - 1) ** 2)
            d_alpha += dY * (S - Y_skewed * ent) / (ctx.alpha - 1)
            d_alpha = d_alpha.sum(ctx.dim).unsqueeze(ctx.dim)

        return dX, d_alpha, None, None, None


# slightly more efficient special case for sparsemax
class SparsemaxBisectFunction(EntmaxBisectFunction):
    @classmethod
    def _gp(cls, x, alpha):
        return x

    @classmethod
    def _gp_inv(cls, y, alpha):
        return y

    @classmethod
    def _p(cls, x, alpha):
        return torch.clamp(x, min=0)

    @classmethod
    def forward(cls, ctx, X, dim=-1, n_iter=50, ensure_sum_one=True):
        return super().forward(
            ctx, X, alpha=2, dim=dim, n_iter=50, ensure_sum_one=True
        )

    @classmethod
    def backward(cls, ctx, dY):
        Y, = ctx.saved_tensors
        gppr = (Y > 0).to(dtype=dY.dtype)
        dX = dY * gppr
        q = dX.sum(ctx.dim) / gppr.sum(ctx.dim)
        q = q.unsqueeze(ctx.dim)
        dX -= q * gppr
        return dX, None, None, None


def entmax_bisect(X, alpha=1.5, dim=-1, n_iter=50, ensure_sum_one=True):
    """alpha-entmax: normalizing sparse transform (a la softmax).
    Solves the optimization problem:
        max_p <x, p> - H_a(p)    s.t.    p >= 0, sum(p) == 1.
    where H_a(p) is the Tsallis alpha-entropy with custom alpha >= 1,
    using a bisection (root finding, binary search) algorithm.
    This function is differentiable with respect to both X and alpha.
    Parameters
    ----------
    X : torch.Tensor
        The input tensor.
    alpha : float or torch.Tensor
        Tensor of alpha parameters (> 1) to use. If scalar
        or python float, the same value is used for all rows, otherwise,
        it must have shape (or be expandable to)
        alpha.shape[j] == (X.shape[j] if j != dim else 1)
        A value of alpha=2 corresponds to sparsemax, and alpha=1 corresponds to
        softmax (but computing it this way is likely unstable).
    dim : int
        The dimension along which to apply alpha-entmax.
    n_iter : int
        Number of bisection iterations. For float32, 24 iterations should
        suffice for machine precision.
    ensure_sum_one : bool,
        Whether to divide the result by its sum. If false, the result might
        sum to close but not exactly 1, which might cause downstream problems.
    Returns
    -------
    P : torch tensor, same shape as X
        The projection result, such that P.sum(dim=dim) == 1 elementwise.
    """
    return EntmaxBisectFunction.apply(X, alpha, dim, n_iter, ensure_sum_one)


def sparsemax_bisect(X, dim=-1, n_iter=50, ensure_sum_one=True):
    """sparsemax: normalizing sparse transform (a la softmax), via bisection.
    Solves the projection:
        min_p ||x - p||_2   s.t.    p >= 0, sum(p) == 1.
    Parameters
    ----------
    X : torch.Tensor
        The input tensor.
    dim : int
        The dimension along which to apply sparsemax.
    n_iter : int
        Number of bisection iterations. For float32, 24 iterations should
        suffice for machine precision.
    ensure_sum_one : bool,
        Whether to divide the result by its sum. If false, the result might
        sum to close but not exactly 1, which might cause downstream problems.
    Note: This function does not yet support normalizing along anything except
    the last dimension. Please use transposing and views to achieve more
    general behavior.
    Returns
    -------
    P : torch tensor, same shape as X
        The projection result, such that P.sum(dim=dim) == 1 elementwise.
    """
    return SparsemaxBisectFunction.apply(X, dim, n_iter, ensure_sum_one)


class SparsemaxBisect(nn.Module):
    def __init__(self, dim=-1, n_iter=None):
        """sparsemax: normalizing sparse transform (a la softmax) via bisection
        Solves the projection:
            min_p ||x - p||_2   s.t.    p >= 0, sum(p) == 1.
        Parameters
        ----------
        dim : int
            The dimension along which to apply sparsemax.
        n_iter : int
            Number of bisection iterations. For float32, 24 iterations should
            suffice for machine precision.
        """
        self.dim = dim
        self.n_iter = n_iter
        super().__init__()

    def forward(self, X):
        return sparsemax_bisect(X, dim=self.dim, n_iter=self.n_iter)


class EntmaxBisect(nn.Module):
    def __init__(self, alpha=1.8, dim=-1, n_iter=50):
        """alpha-entmax: normalizing sparse map (a la softmax) via bisection.
        Solves the optimization problem:
            max_p <x, p> - H_a(p)    s.t.    p >= 0, sum(p) == 1.
        where H_a(p) is the Tsallis alpha-entropy with custom alpha >= 1,
        using a bisection (root finding, binary search) algorithm.
        Parameters
        ----------
        alpha : float or torch.Tensor
            Tensor of alpha parameters (> 1) to use. If scalar
            or python float, the same value is used for all rows, otherwise,
            it must have shape (or be expandable to)
            alpha.shape[j] == (X.shape[j] if j != dim else 1)
            A value of alpha=2 corresponds to sparsemax; alpha=1 corresponds
            to softmax (but computing it this way is likely unstable).
        dim : int
            The dimension along which to apply alpha-entmax.
        n_iter : int
            Number of bisection iterations. For float32, 24 iterations should
            suffice for machine precision.
        """
        self.dim = dim
        self.n_iter = n_iter
        self.alpha = alpha
        super().__init__()

    def forward(self, X):
        return entmax_bisect(
            X, alpha=self.alpha, dim=self.dim, n_iter=self.n_iter
        )

In [None]:
import math


class Attention(nn.Module):
    """
    Compute 'Scaled Dot Product Attention
    """

    def __init__(self, alpha=1.5):
        super(Attention, self).__init__()
        self.alpha = alpha * nn.Parameter(torch.ones([1]))
        self.activation = EntmaxBisect(alpha=self.alpha, dim=-1)

    def forward(self, query, key, value, mask=None, dropout=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(query.size(-1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = self.activation(scores)

        if dropout is not None:
            p_attn = dropout(p_attn)
        
        return torch.matmul(p_attn, value), p_attn

class MultiHeadedAttention(nn.Module):
    """
    Take in model size and number of heads.
    """

    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0

        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h

        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model)
        self.attention = Attention()

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linear_layers, (query, key, value))]

        # 2) Apply attention on all the projected vectors in batch.
        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

        return self.output_linear(x)


# utils

In [None]:
import torch.nn as nn
import torch

class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(self, x):
        return self.w_2(self.dropout(self.activation(self.w_1(x))))


class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."

    def __init__(self, features, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """

    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

# transformer

In [None]:
class TransformerBlock(nn.Module):
    """
    Bidirectional Encoder = Transformer (self-attention)
    Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
    """

    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        """
        :param hidden: hidden size of transformer
        :param attn_heads: head sizes of multi-head attention
        :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size
        :param dropout: dropout rate
        """

        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward)
        return self.dropout(x)

# model

In [None]:
import torch.nn.functional as F
import torch_xla.core.xla_model as xm


class FeedForwardBlock(nn.Module):
    '''
    use rezero
    '''

    def __init__(self, in_dim, out_dim, dim_feedforward=2048, dropout=0.1, activation='relu'):
        super(FeedForwardBlock, self).__init__()
        self.linear1 = nn.Linear(in_dim, dim_feedforward)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, out_dim)
        self.resweight = nn.Parameter(torch.Tensor([0]))

        if activation == "relu":
            self.activation = F.relu
        elif activation == "gelu":
            self.activation = F.gelu

    def forward(self, src):
        src2 = src
        src2 = self.linear2(self.dropout1(self.activation(self.linear1(src2))))
        src2 = src2 * self.resweight
        src = src + self.dropout2(src2)

        return src


class Encoder(nn.Module):
    def __init__(self,
                 input_dim,
                 output_dim,
                 hidden_dim,
                 n_layers,
                 attn_heads,
                 dropout=0.1):

        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.attn_heads = attn_heads
        self.dropout = dropout

        self.embedding = nn.Linear(input_dim, hidden_dim)  # TODO: add attentive embedding

        self.shared_base_block = TransformerBlock(self.hidden_dim, self.attn_heads, self.hidden_dim * 4, self.dropout)

        self.shared_layer_tail = nn.ModuleList([FeedForwardBlock(in_dim=self.hidden_dim, out_dim=self.hidden_dim) for _
                                                in range(int(self.n_layers / 2))])

        self.no_shared_layer_head = nn.ModuleList(
            TransformerBlock(self.hidden_dim, self.attn_heads, self.hidden_dim * 4, self.dropout) for _
            in range(int(self.n_layers / 2)))
        self.no_shared_layer_tail = nn.ModuleList(
            FeedForwardBlock(in_dim=self.hidden_dim, out_dim=self.hidden_dim) for _
            in range(int(self.n_layers / 2)))

        self.final_layer = nn.Sequential(nn.Sigmoid(),
                                         nn.Linear(self.hidden_dim, self.output_dim))

    def forward(self, x):
        x = self.embedding(x)
        for i in range(self.n_layers):
            if i % 2 != 0:
                x = self.shared_base_block(x, mask=None)
                x = self.shared_layer_tail[int(i / 2)](x)
            else:
                x = self.no_shared_layer_head[int(i / 2)](x, mask=None)
                x = self.no_shared_layer_tail[int(i / 2)](x)
        print("self.final_layer(x)")
        print(self.final_layer(x))
        return self.final_layer(x)

# Radam + lookahead

In [None]:
import math
import torch
from torch.optim.optimizer import Optimizer, required
from collections import defaultdict
from itertools import chain
import warnings

class Lookahead(Optimizer):
    def __init__(self, optimizer, k=5, alpha=0.5):
        self.optimizer = optimizer
        self.k = k
        self.alpha = alpha
        self.param_groups = self.optimizer.param_groups
        self.state = defaultdict(dict)
        self.fast_state = self.optimizer.state
        for group in self.param_groups:
            group["counter"] = 0
    
    def update(self, group):
        for fast in group["params"]:
            param_state = self.state[fast]
            if "slow_param" not in param_state:
                param_state["slow_param"] = torch.zeros_like(fast.data)
                param_state["slow_param"].copy_(fast.data)
            slow = param_state["slow_param"]
            slow += (fast.data - slow) * self.alpha
            fast.data.copy_(slow)
    
    def update_lookahead(self):
        for group in self.param_groups:
            self.update(group)

    def step(self, closure=None):
        loss = self.optimizer.step(closure)
        for group in self.param_groups:
            if group["counter"] == 0:
                self.update(group)
            group["counter"] += 1
            if group["counter"] >= self.k:
                group["counter"] = 0
        return loss

    def state_dict(self):
        fast_state_dict = self.optimizer.state_dict()
        slow_state = {
            (id(k) if isinstance(k, torch.Tensor) else k): v
            for k, v in self.state.items()
        }
        fast_state = fast_state_dict["state"]
        param_groups = fast_state_dict["param_groups"]
        return {
            "fast_state": fast_state,
            "slow_state": slow_state,
            "param_groups": param_groups,
        }

    def load_state_dict(self, state_dict):
        slow_state_dict = {
            "state": state_dict["slow_state"],
            "param_groups": state_dict["param_groups"],
        }
        fast_state_dict = {
            "state": state_dict["fast_state"],
            "param_groups": state_dict["param_groups"],
        }
        super(Lookahead, self).load_state_dict(slow_state_dict)
        self.optimizer.load_state_dict(fast_state_dict)
        self.fast_state = self.optimizer.state

    def add_param_group(self, param_group):
        param_group["counter"] = 0
        self.optimizer.add_param_group(param_group)


class RAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        
        self.degenerated_to_sgd = degenerated_to_sgd
        if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
            for param in params:
                if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
                    param['buffer'] = [[None, None, None] for _ in range(10)]
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                state['step'] += 1
                buffered = group['buffer'][int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    elif self.degenerated_to_sgd:
                        step_size = 1.0 / (1 - beta1 ** state['step'])
                    else:
                        step_size = -1
                    buffered[2] = step_size

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
                    p.data.copy_(p_data_fp32)
                elif step_size > 0:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                    p_data_fp32.add_(-step_size * group['lr'], exp_avg)
                    p.data.copy_(p_data_fp32)

        return loss

# train

In [None]:

import time
from torch.utils.data import Dataset

class TorchDataset(Dataset):
    """
  Format for numpy array
  Parameters
  ----------
  X : 2D array
      The input matrix
  y : 1D array
      Target
  """

    def __init__(self, x, y):
        self.x = np.expand_dims(x, 1)
        self.y = y
        # self.x = x
        # self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        x, y = self.x[index], self.y[index]
        return x, y

single version

In [None]:
import torch_xla.distributed.parallel_loader as pl

def run(params):
    device = xm.xla_device()
    batch_size = params["batch_size"]
    input_dim = params["input_dim"]
    output_dim = params["output_dim"]
    hidden_dim = params["hidden_dim"]
    n_layers = params["n_layers"]
    attn_heads = params["attn_heads"]
    num_epochs = params["num_epochs"]



    train_sampler = torch.utils.data.RandomSampler(
    TorchDataset(X.astype(np.float32),y))
    
    test_sampler = torch.utils.data.RandomSampler(
    TorchDataset(X_test.astype(np.float32),y_test))
  
    train_loader = torch.utils.data.DataLoader(
      TorchDataset(X.astype(np.float32),y),
      batch_size=params['batch_size'],
      sampler=train_sampler)

    test_loader = torch.utils.data.DataLoader(
      TorchDataset(X_test.astype(np.float32),y_test),
      batch_size=params['batch_size'],
      sampler=test_sampler)


    model = Encoder(input_dim, output_dim, hidden_dim, n_layers, attn_heads).to(device).train()

    # Loss and optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999))  # Any optimizer
    lookahead = Lookahead(optimizer, k=5, alpha=0.5)  # Initialize Lookahead
    total_step = len(train_loader)

    train_start = time.time()
    for epoch in range(num_epochs):
        for i, (inputs, labels) in enumerate(para_train_loader):
            # Move tensors to the configured device
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(inputs)
            outputs = outputs.squeeze(1)
            loss = loss_fn(outputs, labels)

            # Backward and optimize
            lookahead.zero_grad()
            loss.backward(retain_graph=True)  # Self-defined loss function
            lookahead.step()

            '''
            optimizer.zero_grad()
            loss.backward(retain_graph=True) 
            optimizer.step()
            '''
            if (i + 1) % 5 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
                with torch.no_grad():
                    correct = 0
                    total = 0
                    for inputs, labels in test_loader:
                        inputs = inputs.to(device)
                        labels = labels.to(device)
                        outputs = model(inputs)
                        outputs = outputs.squeeze(1)
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum().item()

                    print('Accuracy: {} %'.format(100 * correct / total))
    
    elapsed_train_time = time.time() - train_start
    print("Process", index, "finished training. Train time was:", elapsed_train_time)


In [None]:
params = {}
params['batch_size'] = 2048 * 2
params['num_workers'] = 8
params['input_dim'] = 93
params['hidden_dim'] = 256
params['output_dim'] = 2
params['attn_heads'] = 1
params['dim_feedforward'] = 2048
params['n_layers'] = 12
params['learning_rate'] = 1e-3
params['num_epochs'] = 2
run(params)



self.final_layer(x)


multi version

In [None]:
# import torch_xla.distributed.parallel_loader as pl

# def map_fn(index, params):
#     device = xm.xla_device()
#     batch_size = params["batch_size"]
#     input_dim = params["input_dim"]
#     output_dim = params["output_dim"]
#     hidden_dim = params["hidden_dim"]
#     n_layers = params["n_layers"]
#     attn_heads = params["attn_heads"]
#     num_epochs = params["num_epochs"]



#     train_sampler = torch.utils.data.distributed.DistributedSampler(
#     TorchDataset(X.astype(np.float32),y),
#     num_replicas=xm.xrt_world_size(),
#     rank=xm.get_ordinal(),
#     shuffle=True)
    
#     test_sampler = torch.utils.data.distributed.DistributedSampler(
#     TorchDataset(X_test.astype(np.float32),y_test),
#     num_replicas=xm.xrt_world_size(),
#     rank=xm.get_ordinal(),
#     shuffle=False)
  
#     train_loader = torch.utils.data.DataLoader(
#       TorchDataset(X.astype(np.float32),y),
#       batch_size=params['batch_size'],
#       sampler=train_sampler,
#       num_workers=params['num_workers'],
#       drop_last=True)

#     test_loader = torch.utils.data.DataLoader(
#       TorchDataset(X_test.astype(np.float32),y_test),
#       batch_size=params['batch_size'],
#       sampler=test_sampler,
#       num_workers=params['num_workers'],
#       drop_last=True)


#     model = Encoder(input_dim, output_dim, hidden_dim, n_layers, attn_heads).to(device).train()

#     # Loss and optimizer
#     loss_fn = nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999))  # Any optimizer
#     lookahead = Lookahead(optimizer, k=5, alpha=0.5)  # Initialize Lookahead
#     total_step = len(train_loader)

#     train_start = time.time()
#     for epoch in range(num_epochs):
#         para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
#         for i, (inputs, labels) in enumerate(para_train_loader):
#             # Move tensors to the configured device
#             inputs = inputs.to(device)
#             labels = labels.to(device)

#             # Forward pass
#             outputs = model(inputs)
#             print("outputs is ",outputs)
#             print(outputs.shape)
#             outputs = outputs.squeeze(1)
#             loss = loss_fn(outputs, labels)

#             # Backward and optimize
#             lookahead.zero_grad()
#             loss.backward(retain_graph=True)  # Self-defined loss function
#             lookahead.step()

#             '''
#             optimizer.zero_grad()
#             loss.backward(retain_graph=True) 
#             optimizer.step()
#             '''
#             if (i + 1) % 5 == 0:
#                 print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
#                       .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
#                 with torch.no_grad():
#                     correct = 0
#                     total = 0
#                     for inputs, labels in test_loader:
#                         inputs = inputs.to(device)
#                         labels = labels.to(device)
#                         outputs = model(inputs)
#                         outputs = outputs.squeeze(1)
#                         _, predicted = torch.max(outputs.data, 1)
#                         total += labels.size(0)
#                         correct += (predicted == labels).sum().item()

#                     print('Accuracy: {} %'.format(100 * correct / total))
    
#     elapsed_train_time = time.time() - train_start
#     print("Process", index, "finished training. Train time was:", elapsed_train_time)


In [None]:

# import torch_xla.distributed.xla_multiprocessing as xmp

# params = {}
# params['batch_size'] = 2048 * 2
# params['num_workers'] = 8
# params['input_dim'] = 93
# params['hidden_dim'] = 256
# params['output_dim'] = 2
# params['attn_heads'] = 1
# params['dim_feedforward'] = 2048
# params['n_layers'] = 12
# params['learning_rate'] = 1e-3
# params['num_epochs'] = 2

# xmp.spawn(map_fn, args=(params,), nprocs=8, start_method='fork')