## reference 

1) https://github.com/labmlai/annotated_deep_learning_paper_implementations

2) https://github.com/labmlai/labml

##prepare library

In [5]:
import torch
from torch import nn
import torch.nn

from typing import Any,Tuple

## prepare requirements

In [6]:
class Module(torch.nn.Module):
    r"""
    Wraps ``torch.nn.Module`` to overload ``__call__`` instead of
    ``forward`` for better type checking.
    
    `PyTorch Github issue for clarification <https://github.com/pytorch/pytorch/issues/44605>`_
    """

    def _forward_unimplemented(self, *input: Any) -> None:
        # To stop PyTorch from giving abstract methods warning
        pass

    def __init_subclass__(cls, **kwargs):
        if cls.__dict__.get('__call__', None) is None:
            return

        setattr(cls, 'forward', cls.__dict__['__call__'])
        delattr(cls, '__call__')

    @property
    def device(self):
        params = self.parameters()
        try:
            sample_param = next(params)
            return sample_param.device
        except StopIteration:
            raise RuntimeError(f"Unable to determine"
                               f" device of {self.__class__.__name__}") from None

## define PonderNet module

In [7]:
class ParityPonderGRU(nn.Module):
    
    def __init__(self, n_elems:int, n_hidden:int, max_steps:int):
        super().__init()
        
        self.max_steps = max_steps
        self.n_hidden = n_hidden

        self.gru = nn.GRUCell(n_elems, n_hidden)

        self.output_layer = nn.Linear(n_hidden,1)

        self.lambda_layer = nn.Linear(n_hidden,1)
        self.lambda_prob = nn.Sigmoid()

        self.is_halt = False

    
    def forward(self,x:torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor]:
        
        batch_size = x.shape[0]

        h = x.new_zeros((x.shape[0],self.n_hidden))
        h = self.gru(x,h)

        p = []
        y = []

        un_halted_prob = h.new_ones((batch_size,))

        halted = h.new_zeros((batch_size,))

        p_m = h.new_zeros((batch_size,))
        y_m = h.new_zeros((batch_size,))

        for n in range(1, self.max_steps+1):
            
            if n == self.max_steps:
                lambda_n = h.new_ones(h.shape[0])

            
            else:
                lambda_n =self.lambda_prob(self.lambda_layer(h))[:,0]
            
            y_n = self.output_layer(h)[:,0]

            p_n = un_halted_prob * lambda_n

            un_halted_prob = un_halted_prob * (1 - lambda_n)

            halt = torch.bernoulli(lambda_n) * (1 - halted)

            p.append(p_n)
            y.append(y_n)

            p_m = p_m * (1 - halt) + p_n * halt
            y_m = y_m * (1 - halt) + y_n * halt

            halted = halted + halt

            h = self.gru(x,h)

            if self.is_halt and halted.sum() == batch_size:
                break
        
        return torch.stack(p), torch.stack(y), p_m, y_m

## define loss

In [8]:
class ReconstructionLoss(nn.Module):
    
    def __init__(self, loss_func: nn.Module):
        
        super().__init__()
        self.loss_func = loss_func

    
    def forward(self, p:torch.Tensor, y_hat:torch.Tensor, y:torch.Tensor):
        
        total_loss = p.new_tensor(0.)

        for n in range(p.shape[0]):

            loss = (p[n] * self.loss_func(y_hat[n],y)).mean()

            total_loss = total_loss + loss

        return total_loss

In [9]:
class RegularizationLoss(nn.Module):

    def __init__(self, lambda_p:float, max_steps: int = 1_000):
        
        super().__init__()

        p_g = torch.zeros((max_steps,))

        not_halted = 1.

        for k in range(max_steps):
            
            p_g[k] = not_halted * lambda_p

            not_halted = not_halted * (1 - lambda_p)

        
        self.p_g = nn.Parameter(p_g, reguires_grad = False)

        self.kl_div = nn.KLDivLoss(reduction='batchmean')

    
    def forward(self,p:torch.Tensor):
        
        p = p.transpose(0,1)

        p_g = self.p_g[None, :p.shape[1]].expand_as(p)

        return self.kl_div(p.log(),p_g)