# RNN models
> Various RNNs and dynamical models.

In [None]:
#| default_exp recurrent.models

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from wafer.basics import *

## Utils

In [None]:
#| export
class Select(nn.Module):
    "Select from iterable."
    def __init__(self, idx=0):
        super().__init__()
        self.idx = idx
    def forward(self, x):
        return x[self.idx]

## Basic

In [None]:
#| export
from wafer.init import default_init, lambda_init

In [None]:
#| export
class SimpleRNN(nn.Module):
    """A RNN with its output mapped through a dense layer.

    Input:  x, shape (N, L, D_in) or (N, D_in)
    Output: outp, shape (N, L, D_out) or (N, D_out)
    """
    def __init__(self, ni, nh, no, num_layers: int=2,
                 actn: str='tanh',   # nonlinearity, ['tanh', 'relu']
                 out_actn: Union[str, nn.Module]=None, # output activation
                 init: str='normal', # initialization method ['uniform', 'normal', 'irnn', 'np-rnn']
                ):
        super().__init__()
        self.recurrent = nn.RNN(ni, nh, num_layers=num_layers, batch_first=True, nonlinearity=actn)
        self.dense = nn.Linear(nh, no)
        self.out_actn = get_actn(out_actn)
        # Initialize
        self._init(self.recurrent, nh, init)
        default_init(self.dense)

    @staticmethod
    @torch.no_grad()
    def _init(m, nh, init):
        for n,p in m.named_parameters():
            if 'bias' in n:
                nn.init.zeros_(p)
            if 'weight_ih' in n:
                if init == 'uniform':
                    nn.init.uniform_(p, a=-1/np.sqrt(nh), b=1/np.sqrt(nh))
                elif init == 'np-rnn':
                    nn.init.normal_(p, std=1/np.sqrt(nh))
                    p = p * (np.sqrt(2) * np.exp(1.2 / (max(nh, 6) - 2.4)))
                    getattr(m, n).copy_(p)
                else:
                    nn.init.normal_(p, std=1/np.sqrt(nh))
            if 'weight_hh' in n:
                if init == 'uniform':
                    nn.init.uniform_(p, a=-1/np.sqrt(nh), b=1/np.sqrt(nh))
                elif init == 'irnn':
                    nn.init.eye_(p)
                elif init == 'np-rnn':
                    nn.init.normal_(p)
                    p = (p.T @ p) / nh
                    p = p / np.linalg.eigvals(p.detach().numpy()).max()
                    getattr(m, n).copy_(p)
                else:
                    nn.init.normal_(p, std=1/np.sqrt(nh))
    
    def forward(self, x):
        outp,_ = self.recurrent(x) # shape(N, L, nh)
        return self.out_actn(self.dense(outp))

In [None]:
#| export
class LSTM(nn.Module):
    """Custom LSTM, allowing for different activations. Assuming `batch_first=True` and `bidirectional=False`.
    
    Inputs: x, h0_c0;
        x: shape (N, L, D_in) or (N, D_in).
        h0_c0: (h0, c0), Union[list, tuple], optional, default zeros, each of shape (num_layers, N, hidden_size).
    
    Outputs: output, (hn, cn);
        output: shape (N, L, hidden_size) or (N, hidden_size), outputs of the last layer for each token.
        hn: shape (num_layers, N, hidden_size), final hidden state.
        cn: shape (num_layers, N, hidden_size), final cell state.
    """
    def __init__(self, ni, nh, num_layers=1, actn: Union[str, nn.Module]='tanh', gate_actn: Union[str, nn.Module]='sigmoid',
                 dropout=0.0, unit_forget_bias=True, init_gain=1/np.sqrt(3), recurrent_init_gain=1.):
        super().__init__()
        self.nh,self.num_layers = nh,num_layers
        self.actn = get_actn(actn, nn.Tanh())
        self.gate_actn = get_actn(gate_actn, nn.Sigmoid())
        self.dropout = nn.Dropout(p=dropout)

        ws = []
        for i in range(num_layers):
            ws.append(nn.ModuleDict({
                'ii': nn.Linear(ni if i == 0 else nh, nh),  # input-input weight
                'if': nn.Linear(ni if i == 0 else nh, nh),  # input-forget weight
                'io': nn.Linear(ni if i == 0 else nh, nh),  # input-output weight
                'ic': nn.Linear(ni if i == 0 else nh, nh),  # input-cell weight
                'hi': nn.Linear(nh, nh),                    # hidden-input weight
                'hf': nn.Linear(nh, nh),                    # hidden-forget weight
                'ho': nn.Linear(nh, nh),                    # hidden-output weight
                'hc': nn.Linear(nh, nh),                    # hidden-cell weight
            }))
        self.ws = nn.ModuleList(ws)
        
        # Initialize
        for n,p in self.named_parameters():
            # hidden/recurrent weight
            if '.h' in n:
                if 'bias' in n: nn.init.zeros_(p)
                else: nn.init.orthogonal_(p, recurrent_init_gain)
            # non-recurrent weight
            else:
                if 'bias' in n:
                    if '.if' in n and unit_forget_bias: nn.init.ones_(p)
                    else: nn.init.zeros_(p)
                else: nn.init.xavier_uniform_(p, init_gain) 
    
    def _forward_single(self, x, h0_c0: list=None):
        "Forward pass of a single token."
        if h0_c0 is None:
            h0 = torch.zeros(self.num_layers, x.shape[0], self.nh, device=x.device)
            c0 = torch.zeros(self.num_layers, x.shape[0], self.nh, device=x.device)
        else:
            h0,c0 = h0_c0
            assert (h0.shape[-1] == c0.shape[-1] == self.nh) and (h0.shape[1] == c0.shape[1] == x.shape[0])
        
        hs, cs = [], []
        for i in range(self.num_layers):
            h, c = h0[i], c0[i]
            i_gate = self.ws[i]['ii'](x) + self.ws[i]['hi'](h)
            f_gate = self.ws[i]['if'](x) + self.ws[i]['hf'](h)
            o_gate = self.ws[i]['io'](x) + self.ws[i]['ho'](h)
            c_gate = self.ws[i]['ic'](x) + self.ws[i]['hc'](h)

            i_gate = self.gate_actn(i_gate)
            f_gate = self.gate_actn(f_gate)
            o_gate = self.gate_actn(o_gate)
            c_gate = self.actn(c_gate)

            c_new = (f_gate * c) + (i_gate * c_gate)
            h_new = o_gate * self.actn(c_new)
            cs.append(c_new)
            hs.append(h_new)
            x = self.dropout(h_new)

        hs, cs = torch.stack(hs, 0), torch.stack(cs, 0)
        return (hs, cs)

    def forward(self, x, h0_c0=None):
        # Input `x`, shape (N, L, D) or (N, D); `h0_c0` = (h0, c0), each shape (num_layers, N, nh).
        assert x.ndim in [2,3], f"Expect 2D or 3D input, but got {x.ndim}D."
        hs = []
        cs = []
        if x.ndim == 2: x = x.unsqueeze(1); reshape = True
        else: reshape = False
            
        for xi in torch.permute(x, [1,0,2]):
            hc = self._forward_single(xi, h0_c0)
            hs.append(hc[0])
            cs.append(hc[1])
            h0_c0 = hc
        output = torch.stack([h[-1] for h in hs], 1)
        if reshape:
            return output[:,0,:], (hs[-1], cs[-1])
        return output, (hs[-1], cs[-1])

## Input convex

In [None]:
#| export
class FICNN(nn.Module):
    "Fully input-convex NN. Refer to [ICNN](https://arxiv.org/abs/1609.07152)."
    def __init__(self, ni: int,                        # input size
                 nh: Union[int, list[int]],            # hidden size
                 no: int,                              # output size
                 num_layer: int=2,                     # number of layers (include the output layer), if `nh` is a list, then the `num_layer = len(nh) + 1`
                 actn: Union[str, nn.Module]='relu',   # hidden activation
                 out_actn: Union[str, nn.Module]=None, # output activation
                 init_gain: float=1.                   # weight initialization gain
                ):
        super().__init__()
        nhs = [nh] * (num_layer - 1) if isinstance(nh, int) else nh
        self.w_y = nn.ModuleList([nn.Linear(ni, nh) for nh in nhs + [no]])
        self.w_z = nn.ModuleList([nn.Linear(i, j, bias=False) for i,j in zip(nhs, nhs[1:] + [no])])
        self.actn = get_actn(actn, nn.ReLU())
        self.out_actn = get_actn(out_actn)

        lambda_init(self, lambda w,b: (nn.init.xavier_normal_(w, init_gain), nn.init.zeros_(b)))
        self.weight_constraint()
    
    def weight_constraint(self):
        "Apply nonnegative weight constriant."
        with torch.no_grad():
            for n,p in self.w_z.named_parameters():
                if 'weight' in n:
                    p.clamp_min_(0)
    
    def forward(self, x):
        z = self.w_y[0](x)
        if len(self.w_z) == 0: return self.out_actn(z)
        for wz,wy in zip(self.w_z[:-1], self.w_y[1:-1]):
            z = self.actn(wz(z) + wy(x))
        z = self.out_actn(self.w_z[-1](z) + self.w_y[-1](x))
        return z

In [None]:
#| export
class ICRNN(nn.Module):
    """Input convex RNN. Refer to [Optimal Control Via Neural Networks: A Convex Approach](https://arxiv.org/abs/1805.11835).
    
    Inputs:  x, shape (N,L,D_in)
    Outputs: outputs, shape (N,L,D_out)
    """
    def __init__(self, ni: int,                        # input size
                 nh: int,                              # hidden size
                 no: int,                              # output size
                 actn: Union[str, nn.Module]='relu',   # hidden activation
                 out_actn: Union[str, nn.Module]=None, # output activation
                 expand_inp: bool=True,                # expand the input to [x, -x]
                 init_gain: float=1.                   # weight initialization gain
                ):
        super().__init__()
        ni = ni * 2 if expand_inp else ni
        self.expand_inp,self.nh = expand_inp,nh
        self.actn = get_actn(actn, nn.ReLU())
        self.out_actn = get_actn(out_actn)
        self.U = nn.Linear(ni, nh)
        self.V = nn.Linear(nh, no)
        self.W = nn.Linear(nh, nh)
        self.D1 = nn.Linear(nh, no)
        self.D2 = nn.Linear(ni, nh)
        self.D3 = nn.Linear(ni, no)

        lambda_init(self, lambda w,b: (nn.init.xavier_normal_(w, init_gain), nn.init.zeros_(b)))
        self.weight_constraint()

    def weight_constraint(self):
        "Apply nonnegative weight constriant."
        with torch.no_grad():
            for n,p in self.named_parameters():
                if 'weight' in n:
                    p.clamp_min_(0)

    def forward(self, x):
        if self.expand_inp: x = torch.cat([x, -x], dim=-1)
        x = torch.cat([torch.zeros(x.shape[0],1,x.shape[-1], device=x.device), x], dim=-2)
        z_curr = z_prev = torch.zeros(x.shape[0], self.nh, device=x.device)
        ys = []
        for xi,xj in zip(x[:,:-1,:].permute(1,0,2), x[:,1:,:].permute(1,0,2)):
            z_curr = self.actn(self.U(xj) + self.W(z_prev) + self.D2(xi))
            ys.append(self.out_actn(self.V(z_curr) + self.D1(z_prev) + self.D3(xj)))
            z_prev = z_curr
        return torch.stack(ys, 1)

In [None]:
#| export
class ICLSTMCell(nn.Module):
    """
    Modified LSTM for ICLSTM. From [ICLSTM](https://arxiv.org/abs/2311.07202).
    Assuming `batch_first=True` and `bidirectional=False`.
    
    Inputs: x, h0_c0;
        x: shape (N, L, D_in) or (N, D_in).
        h0_c0: (h0, c0), Union[list, tuple], optional, default zeros, each of shape (num_layers, N, hidden_size).
    
    Outputs: output, (hn, cn);
        output: shape (N, L, hidden_size) or (N, hidden_size), outputs of the last layer for each token.
        hn: shape (num_layers, N, hidden_size), final hidden state.
        cn: shape (num_layers, N, hidden_size), final cell state.
    """
    def __init__(self, ni, nh, num_layers=1, actn: Union[str, nn.Module]='tanh', gate_actn: Union[str, nn.Module]='sigmoid',
                 dropout=0.0, unit_forget_bias=True, init_gain=1/np.sqrt(3), recurrent_init_gain=1.):
        super().__init__()
        self.nh,self.num_layers = nh,num_layers
        self.actn = get_actn(actn, nn.Tanh())
        self.gate_actn = get_actn(gate_actn, nn.Sigmoid())
        self.dropout = nn.Dropout(p=dropout)
        
        ws = []
        ps = []
        for i in range(num_layers):
            ws.append(nn.ModuleDict({
                'wi': nn.Linear(ni if i == 0 else nh, nh, bias=False),  # input base weight
                'wh': nn.Linear(nh, nh, bias=False),                    # hidden base weight
            }))
            ps.append(nn.ParameterDict({
                'bi': nn.Parameter(torch.empty(nh)),                    # input gate bias
                'bf': nn.Parameter(torch.empty(nh)),                    # forget gate bias
                'bo': nn.Parameter(torch.empty(nh)),                    # output gate bias
                'bc': nn.Parameter(torch.empty(nh)),                    # cell gate bias
                'sii': nn.Parameter(torch.empty(nh)),                   # input-input scaling
                'sif': nn.Parameter(torch.empty(nh)),                   # input-forget scaling
                'sio': nn.Parameter(torch.empty(nh)),                   # input-output scaling
                'sic': nn.Parameter(torch.empty(nh)),                   # input-cell scaling
                'shi': nn.Parameter(torch.empty(nh)),                   # hidden-input scaling
                'shf': nn.Parameter(torch.empty(nh)),                   # hidden-forget scaling
                'sho': nn.Parameter(torch.empty(nh)),                   # hidden-output scaling
                'shc': nn.Parameter(torch.empty(nh))                    # hidden-cell scaling
            }))
        self.ws = nn.ModuleList(ws)
        self.ps = nn.ParameterList(ps)
        
        # Initialize
        for n,p in self.named_parameters():
            # input base weight
            if '.wi' in n: nn.init.xavier_uniform_(p, init_gain)
            # hidden base weight
            elif '.wh' in n: nn.init.orthogonal_(p, recurrent_init_gain)
            else:
                # bias
                if '.b' in n:
                    if '.bi' in n and unit_forget_bias: nn.init.ones_(p)
                    else: nn.init.zeros_(p)
                # scaling
                else: nn.init.uniform_(p)
    
    def _forward_single(self, x, h0_c0: list=None):
        "Forward pass of a single token."
        if h0_c0 is None:
            h0 = torch.zeros(self.num_layers, x.shape[0], self.nh, device=x.device)
            c0 = torch.zeros(self.num_layers, x.shape[0], self.nh, device=x.device)
        else:
            h0,c0 = h0_c0
            assert (h0.shape[-1] == c0.shape[-1] == self.nh) and (h0.shape[1] == c0.shape[1] == x.shape[0])
        
        hs, cs = [], []
        for i in range(self.num_layers):
            h, c = h0[i], c0[i]
            i_gate = self.ws[i]['wi'](x) * self.ps[i]['sii'] + self.ws[i]['wh'](h) * self.ps[i]['shi'] + self.ps[i]['bi']
            f_gate = self.ws[i]['wi'](x) * self.ps[i]['sif'] + self.ws[i]['wh'](h) * self.ps[i]['shf'] + self.ps[i]['bf']
            o_gate = self.ws[i]['wi'](x) * self.ps[i]['sio'] + self.ws[i]['wh'](h) * self.ps[i]['sho'] + self.ps[i]['bo']
            c_gate = self.ws[i]['wi'](x) * self.ps[i]['sic'] + self.ws[i]['wh'](h) * self.ps[i]['shc'] + self.ps[i]['bc']

            i_gate = self.gate_actn(i_gate)
            f_gate = self.gate_actn(f_gate)
            o_gate = self.gate_actn(o_gate)
            c_gate = self.actn(c_gate)

            c_new = (f_gate * c) + (i_gate * c_gate)
            h_new = o_gate * self.actn(c_new)
            cs.append(c_new)
            hs.append(h_new)
            x = self.dropout(h_new)

        hs, cs = torch.stack(hs, 0), torch.stack(cs, 0)
        return (hs, cs)

    def forward(self, x, h0_c0=None):
        # Input `x`, shape (N, L, D) or (N, D); `h0_c0` = (h0, c0), each shape (num_layers, N, nh).
        assert x.ndim in [2,3], f"Expect 2D or 3D input, but got {x.ndim}D."
        hs = []
        cs = []
        if x.ndim == 2: x = x.unsqueeze(1); reshape = True
        else: reshape = False
            
        for xi in torch.permute(x, [1,0,2]):
            hc = self._forward_single(xi, h0_c0)
            hs.append(hc[0])
            cs.append(hc[1])
            h0_c0 = hc
        output = torch.stack([h[-1] for h in hs], 1)
        if reshape:
            return output[:,0,:], (hs[-1], cs[-1])
        return output, (hs[-1], cs[-1])

In [None]:
#| export
class ICLSTM(nn.Module):
    """Input convex LSTM. From [ICLSTM](https://arxiv.org/abs/2311.07202).
    
    A L-ICLSTM is a stack of [ICLSTMCell, Linear, ReLU] repeated L times. The output is mapped through another Linear layer.
    """
    def __init__(self, ni, nh, no, num_layers=2,
                 expand_inp=True, do_wi_nonneg=True, actn='relu', gate_actn='relu', out_actn=None, **kwargs):
        super().__init__()
        ni = ni * 2 if expand_inp else ni
        self.expand_inp = expand_inp
        self.actn = get_actn(actn, nn.ReLU())
        self.out_actn = get_actn(out_actn)
        self.blocks = nn.ModuleList([self._mk_block(ni, nh, actn, gate_actn, **kwargs) for _ in range(num_layers)])
        self.dense_out = nn.Linear(ni, no)
        default_init(self.dense_out)
        # Nonnegative weights
        self.do_wi_nonneg = do_wi_nonneg
        self.weight_constraint()

    def _mk_block(self, ni, nh, actn, gate_actn, **kwargs):
        m = nn.Sequential(ICLSTMCell(ni, nh, 1, actn=actn, gate_actn=gate_actn, **kwargs),
                          Select(),
                          nn.Linear(nh, ni))
        default_init(m[2], normal=False)
        return m
        
    def weight_constraint(self):
        "Apply nonnegative weight constriant."
        with torch.no_grad():
            for n,p in self.named_parameters():
                if '.wi' in n and not self.do_wi_nonneg: continue
                if '.bias' in n: continue
                p.clamp_min_(0)

    def forward(self, x):
        if self.expand_inp: x = torch.cat([x, -x], dim=-1)
        x0 = x.clone()
        for b in self.blocks:
            x = self.actn(b(x)) + x0
        return self.out_actn(self.dense_out(x))

In [None]:
#| export
class WeightConstraintCB(Callback):
    "Weight constraint callback."
    order = 10
    def after_step(self): self.learner.model.weight_constraint()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()