## reference

https://github.com/labmlai/annotated_deep_learning_paper_implementations

## load library

In [1]:
!pip install labml

Collecting labml
  Downloading labml-0.4.132-py3-none-any.whl (121 kB)
[?25l[K     |██▊                             | 10 kB 26.2 MB/s eta 0:00:01[K     |█████▍                          | 20 kB 22.8 MB/s eta 0:00:01[K     |████████                        | 30 kB 17.8 MB/s eta 0:00:01[K     |██████████▉                     | 40 kB 14.9 MB/s eta 0:00:01[K     |█████████████▌                  | 51 kB 5.4 MB/s eta 0:00:01[K     |████████████████▏               | 61 kB 5.9 MB/s eta 0:00:01[K     |███████████████████             | 71 kB 5.3 MB/s eta 0:00:01[K     |█████████████████████▋          | 81 kB 6.0 MB/s eta 0:00:01[K     |████████████████████████▎       | 92 kB 5.9 MB/s eta 0:00:01[K     |███████████████████████████     | 102 kB 5.2 MB/s eta 0:00:01[K     |█████████████████████████████▊  | 112 kB 5.2 MB/s eta 0:00:01[K     |████████████████████████████████| 121 kB 5.2 MB/s 
Collecting gitpython
  Downloading GitPython-3.1.18-py3-none-any.whl (170 kB)
[K     

In [2]:
!pip install labml_helpers

Collecting labml_helpers
  Downloading labml_helpers-0.4.81-py3-none-any.whl (18 kB)
Installing collected packages: labml-helpers
Successfully installed labml-helpers-0.4.81


In [3]:
!pip install labml_nn

Collecting labml_nn
  Downloading labml_nn-0.4.109-py3-none-any.whl (274 kB)
[?25l[K     |█▏                              | 10 kB 27.1 MB/s eta 0:00:01[K     |██▍                             | 20 kB 31.5 MB/s eta 0:00:01[K     |███▋                            | 30 kB 19.7 MB/s eta 0:00:01[K     |████▊                           | 40 kB 12.2 MB/s eta 0:00:01[K     |██████                          | 51 kB 5.9 MB/s eta 0:00:01[K     |███████▏                        | 61 kB 6.4 MB/s eta 0:00:01[K     |████████▍                       | 71 kB 5.9 MB/s eta 0:00:01[K     |█████████▌                      | 81 kB 6.6 MB/s eta 0:00:01[K     |██████████▊                     | 92 kB 5.0 MB/s eta 0:00:01[K     |████████████                    | 102 kB 5.4 MB/s eta 0:00:01[K     |█████████████▏                  | 112 kB 5.4 MB/s eta 0:00:01[K     |██████████████▎                 | 122 kB 5.4 MB/s eta 0:00:01[K     |███████████████▌                | 133 kB 5.4 MB/s eta 0:00:0

In [4]:
from typing import Any,Tuple

from abc import ABC

import dataclasses

import torch
from torch import nn
from torch.utils.data import DataLoader,Dataset

from labml import tracker, experiment
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
from labml_nn.adaptive_computation.ponder_net import ParityPonderGRU, ReconstructionLoss, RegularizationLoss

## prepare metric

In [5]:
@dataclasses.dataclass
class AccuracyState:
    samples: int = 0
    correct: int = 0

    def reset(self):
        self.samples = 0
        self.correct = 0

In [6]:
class StateModule:
    def __init__(self):
        pass

    # def __call__(self):
    #     raise NotImplementedError

    def create_state(self) -> any:
        raise NotImplementedError

    def set_state(self, data: any):
        raise NotImplementedError

    def on_epoch_start(self):
        raise NotImplementedError

    def on_epoch_end(self):
        raise NotImplementedError

In [7]:
class Metric(StateModule, ABC):
    def track(self):
        pass

In [8]:
class Accuracy(Metric):
    data: AccuracyState

    def __init__(self, ignore_index: int = -1):
        super().__init__()
        self.ignore_index = ignore_index

    def __call__(self, output: torch.Tensor, target: torch.Tensor):
        output = output.view(-1, output.shape[-1])
        target = target.view(-1)
        pred = output.argmax(dim=-1)
        mask = target == self.ignore_index
        pred.masked_fill_(mask, self.ignore_index)
        n_masked = mask.sum().item()
        self.data.correct += pred.eq(target).sum().item() - n_masked
        self.data.samples += len(target) - n_masked

    def create_state(self):
        return AccuracyState()

    def set_state(self, data: any):
        self.data = data

    def on_epoch_start(self):
        self.data.reset()

    def on_epoch_end(self):
        self.track()

    def track(self):
        if self.data.samples == 0:
            return
        tracker.add("accuracy.", self.data.correct / self.data.samples)

In [9]:
class AccuracyDirect(Accuracy):
    data: AccuracyState

    def __call__(self, output: torch.Tensor, target: torch.Tensor):
        output = output.view(-1)
        target = target.view(-1)
        self.data.correct += output.eq(target).sum().item()
        self.data.samples += len(target)

##parity task dataset

**parity

The input vectors had 64 elements, of which a random number from 1 to 64 were randomly set
to 1 or −1 and the rest were set to 0. The corresponding target was 1 if there was an odd number
of ones and 0 if there was an even number of ones.

In [10]:
class ParityDataset(Dataset):
    """
    ### Parity dataset
    """

    def __init__(self, n_samples: int, n_elems: int = 64):
        """
        * `n_samples` is the number of samples
        * `n_elems` is the number of elements in the input vector
        """
        self.n_samples = n_samples
        self.n_elems = n_elems

    def __len__(self):
        """
        Size of the dataset
        """
        return self.n_samples

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate a sample
        """

        # Empty vector
        x = torch.zeros((self.n_elems,))
        # Number of non-zero elements - a random number between $1$ and total number of elements
        n_non_zero = torch.randint(1, self.n_elems + 1, (1,)).item()
        # Fill non-zero elements with $1$'s and $-1$'s
        x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1
        # Randomly permute the elements
        x = x[torch.randperm(self.n_elems)]

        # The parity
        y = (x == 1.).sum() % 2

        #
        return x, y

## define configuration

In [11]:
class Configs(SimpleTrainValidConfigs):
    """
    Configurations with a
     [simple training loop](https://docs.labml.ai/api/helpers.html#labml_helpers.train_valid.SimpleTrainValidConfigs)
    """

    # Number of epochs
    epochs: int = 5
    # Number of batches per epoch
    n_batches: int = 500
    # Batch size
    batch_size: int = 128

    # Model
    model: ParityPonderGRU

    # $L_{Rec}$
    loss_rec: ReconstructionLoss
    # $L_{Reg}$
    loss_reg: RegularizationLoss

    # The number of elements in the input vector.
    # *We keep it low for demonstration; otherwise, training takes a lot of time.
    # Although the parity task seems simple, figuring out the pattern by looking at samples
    # is quite hard.*
    n_elems: int = 8
    # Number of units in the hidden layer (state)
    n_hidden: int = 64
    # Maximum number of steps $N$
    max_steps: int = 20

    # $\lambda_p$ for the geometric distribution $p_G(\lambda_p)$
    lambda_p: float = 0.2
    # Regularization loss $L_{Reg}$ coefficient $\beta$
    beta: float = 0.01

    # Gradient clipping by norm
    grad_norm_clip: float = 1.0

    # Training and validation loaders
    train_loader: DataLoader
    valid_loader: DataLoader

    # Accuracy calculator
    accuracy = AccuracyDirect()

    def init(self):
        # Print indicators to screen
        tracker.set_scalar('loss.*', True)
        tracker.set_scalar('loss_reg.*', True)
        tracker.set_scalar('accuracy.*', True)
        tracker.set_scalar('steps.*', True)

        # We need to set the metrics to calculate them for the epoch for training and validation
        self.state_modules = [self.accuracy]

        # Initialize the model
        self.model = ParityPonderGRU(self.n_elems, self.n_hidden, self.max_steps).to(self.device)
        # $L_{Rec}$
        self.loss_rec = ReconstructionLoss(nn.BCEWithLogitsLoss(reduction='none')).to(self.device)
        # $L_{Reg}$
        self.loss_reg = RegularizationLoss(self.lambda_p, self.max_steps).to(self.device)

        # Training and validation loaders
        self.train_loader = DataLoader(ParityDataset(self.batch_size * self.n_batches, self.n_elems),
                                       batch_size=self.batch_size)
        self.valid_loader = DataLoader(ParityDataset(self.batch_size * 32, self.n_elems),
                                       batch_size=self.batch_size)

    def step(self, batch: Any, batch_idx: BatchIndex):
        """
        This method gets called by the trainer for each batch
        """
        # Set the model mode
        self.model.train(self.mode.is_train)

        # Get the input and labels and move them to the model's device
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        # Increment step in training mode
        if self.mode.is_train:
            tracker.add_global_step(len(data))

        # Run the model
        p, y_hat, p_sampled, y_hat_sampled = self.model(data)

        # Calculate the reconstruction loss
        loss_rec = self.loss_rec(p, y_hat, target.to(torch.float))
        tracker.add("loss.", loss_rec)

        # Calculate the regularization loss
        loss_reg = self.loss_reg(p)
        tracker.add("loss_reg.", loss_reg)

        # $L = L_{Rec} + \beta L_{Reg}$
        loss = loss_rec + self.beta * loss_reg

        # Calculate the expected number of steps taken
        steps = torch.arange(1, p.shape[0] + 1, device=p.device)
        expected_steps = (p * steps[:, None]).sum(dim=0)
        tracker.add("steps.", expected_steps)

        # Call accuracy metric
        self.accuracy(y_hat_sampled > 0, target)

        if self.mode.is_train:
            # Compute gradients
            loss.backward()
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
            # Optimizer
            self.optimizer.step()
            # Clear gradients
            self.optimizer.zero_grad()
            #
            tracker.save()

In [12]:
experiment.create(name='ponder_net')

conf = Configs()
experiment.configs(conf, {
    'optimizer.optimizer': 'Adam',
    'optimizer.learning_rate': 0.0003,
})

with experiment.start():
    conf.run()