# Skip RNN

In [20]:
import torch
import torch.nn as nn
from torch.nn.init import xavier_normal_
from torch.autograd import Function

    
class STEFunction(Function):
    @staticmethod
    def forward(cls,x):
        return x.round()

    @staticmethod
    def backward(cls,grad):
        return grad


class STELayer(nn.Module):
    def __init__(self):
        super(STELayer, self).__init__()

    def forward(self, x):
        binarizer = STEFunction.apply
        return binarizer(x)


class SkipGRUCell(nn.Module):
    def __init__(self, ic, hc):
        super(SkipGRUCell, self).__init__()
        self.ste = STELayer()
        self.cell = nn.GRUCell(ic, hc)
        self.linear = nn.Linear(hc, 1)

        xavier_normal_(self.linear.weight)
        self.linear.bias.data.fill_(1)

    def forward(self, x, u, h, skip=False, delta_u=None):
        # x: (bs, ic)
        # u: (bs, 1)
        # h: (bs, hc)
        # skip: [False or True] * bs
        # delta_u: [skip=True -> (1) / skip=False -> None] * bs

        bs = x.shape[0]
        binarized_u = self.ste(u)                # (bs, 1)

        skip_idx = [i for i, cur_skip in enumerate(skip) if cur_skip]
        skip_num = len(skip_idx)
        no_skip = [not cur_skip for cur_skip in skip]

        if skip_num > 0:
            # (skip_num, ic), (skip_num, 1), (skip_num, hc)
            x_s, u_s, h_s = x[skip], u[skip], h[skip]
            binarized_u_s = binarized_u[skip]        # (skip_num, 1)

            # (skip_num, 1)
            delta_u_s = [cur_delta_u for cur_skip,
                         cur_delta_u in zip(skip, delta_u) if cur_skip]
            delta_u_s = torch.stack(delta_u_s)

            # computing skipped parts
            new_h_s = h_s * (1 - binarized_u_s)        # (skip_num, hc)
            new_u_s = torch.clamp(u_s + delta_u_s, 0, 1) * \
                (1 - binarized_u_s)  # (skip_num, 1)

        if skip_num < bs:
            # (bs-skip_num, ic), (bs-skip_num, 1), (bs-skip_num, hc)
            x_n, u_n, h_n = x[no_skip], u[no_skip], h[no_skip]
            binarized_u_n = binarized_u[no_skip]  # (bs-skip_num, 1)

            # computing non-skipped parts
            new_h_n = self.cell(x_n, h_n)  # (bs-skip_num, hc)
            new_h_n = new_h_n * binarized_u_n            # (bs-skip_num, hc)
            delta_u_n = torch.sigmoid(self.linear(new_h_n))        # (bs-skip_num, 1)
            new_u_n = delta_u_n * binarized_u_n                    # (bs-skip_num, 1)

        # merging skipped and non-skipped parts back
        if 0 < skip_num < bs:
            idx = torch.full((bs,), -1, dtype=torch.long)
            idx[skip_idx] = torch.arange(0, len(skip_idx), dtype=torch.long)
            idx[idx==-1] = torch.arange(len(skip_idx), bs, dtype=torch.long)

            new_u = torch.cat([new_u_s, new_u_n], 0)[idx]        # (bs, 1)
            new_h = torch.cat([new_h_s, new_h_n], 0)[idx]        # (bs, hc)
            delta_u = torch.cat([delta_u_s, delta_u_n], 0)[idx]    # (bs, 1)

        # no need to merge when skip doesn't exist
        elif skip_num == 0:
            new_u = new_u_n
            new_h = new_h_n
            delta_u = delta_u_n

        # no need to merge when everything is skip
        elif skip_num == bs:
            new_u = new_u_s
            new_h = new_h_s
            delta_u = delta_u_s

        n_skips_after = (0.5 / new_u).ceil() - 1  # (bs, 1)
        return binarized_u, new_u, (new_h,), delta_u, n_skips_after


class SkipGRUCellNoSkip(nn.Module):
    def __init__(self, ic, hc):
        super(SkipGRUCellNoSkip, self).__init__()
        self.ste = STELayer()
        self.cell = nn.GRUCell(ic, hc)
        self.linear = nn.Linear(hc, 1)
        
        xavier_normal_(self.linear.weight)
        self.linear.bias.data.fill_(1)

    def forward(self, x, u, h):
        # x: (bs, ic)
        # u: (bs, 1)
        # h: (bs, hc)

        # computing the states
        binarized_u = self.ste(u)                # (bs, 1)
        new_h = self.cell(x, h)  # (bs, hc)
        new_h = new_h * binarized_u + (1 - binarized_u) * h      # (bs, hc)
        delta_u = torch.sigmoid(self.linear(new_h))        # (bs, 1)
        new_u = delta_u * binarized_u + \
            torch.clamp(u + delta_u, 0, 1) * (1 - binarized_u)  # (bs, 1)

        return binarized_u, new_u, new_h, delta_u

class SkipGRU(nn.Module):
    def __init__(self, ic, hc, layer_num=2, return_total_u=False, learn_init=False, no_skip=False,batch_first = False):
        super(SkipGRU, self).__init__()
        self.ic = ic
        self.hc = hc
        self.layer_num = layer_num
        self.return_total_u = return_total_u
        self.no_skip = no_skip
        self.batch_first = batch_first

        if no_skip:
            cur_cell = SkipGRUCellNoSkip
        else:
            cur_cell = SkipGRUCell

        self.cells = nn.ModuleList([cur_cell(ic, hc)])
        for _ in range(self.layer_num - 1):
            cell = cur_cell(hc, hc)
            self.cells.append(cell)

        self.hiddens = self.init_hiddens(learn_init)
        print("hidden : {}".format(self.hiddens.shape))

    def init_hiddens(self, learn_init):
        if learn_init:
            h = nn.Parameter(torch.randn(self.layer_num, 1, self.hc))
        else:
            h = nn.Parameter(torch.zeros(self.layer_num, 1, self.hc), requires_grad=False)
        return h

    def forward(self, x, hiddens=None):
        device = x.device
        if self.batch_first : 
            x = torch.permute(x,(1,0,2))
        
        x_len, bs, _ = x.shape    # (x_len, bs, ic)

        if hiddens is None:
            h = self.hiddens
            h = h.repeat(1, bs, 1)
        else:
            h = hiddens
        u = torch.ones(self.layer_num, bs, 1).to(device)            # (l, bs, 1)

        hs = []
        lstm_input = x             # (x_len, bs, ic)

        skip = [False] * bs
        delta_u = [None] * bs

        binarized_us = []

        for i in range(self.layer_num):
            cur_hs = []
            cur_h = h[i].unsqueeze(0)  # (1, bs, hc)
            cur_u = u[i]               # (bs, 1)

            for j in range(x_len):
                if self.no_skip:
                    # (bs, 1), ((bs, hc), (bs, hc)), (bs, 1), (bs, 1)
                    binarized_u, cur_u, cur_h, delta_u = self.cells[i](
                        lstm_input[j], cur_u, cur_h[0])
                    binarized_us.append(binarized_u)
                else:
                    # (bs, 1), ((bs, hc), (bs, hc)), (bs, 1), (bs, 1)
                    binarized_u, cur_u, cur_h, delta_u, n_skips_after = self.cells[i](
                        lstm_input[j], cur_u, cur_h[0], skip, delta_u)
                    binarized_us.append(binarized_u)
                    skip = (n_skips_after[:, 0] > 0).tolist()

                # (1, bs, hc) / (1, bs, hc)
                cur_h = cur_h[0].unsqueeze(0)
                cur_hs.append(cur_h)

            # (x_len, bs, hc)
            lstm_input = torch.cat(cur_hs, dim=0)
            hs.append(cur_h)

        # (bs, seq * layer_num)
        total_u = torch.cat(binarized_us, 1)
        # (x_len, bs, hc)
        out = lstm_input
        # (l, bs, hc)
        hs = torch.cat(hs, dim=0)
        
        if self.batch_first : 
            out = torch.permute(out,(1,0,2))

        if self.return_total_u:
            return out, (hs,), total_u
        return out, (hs,)

In [22]:
import time
from tqdm.auto import tqdm


ic = 256
hc = 256
B = 64
L = 40

h = torch.rand(2,B,ic)

m1 = SkipGRU(ic, hc, layer_num=2, learn_init=False, no_skip=False,batch_first=True)

x = torch.rand(B,L,ic)
print(x.shape)

y = m1(x,h)
print(y[0].shape)
print(y[1][0].shape)

hidden : torch.Size([2, 1, 256])
torch.Size([64, 40, 256])
torch.Size([64, 40, 256])
torch.Size([2, 64, 256])


In [7]:
import time
from tqdm.auto import tqdm


ic = 256
hc = 256
B = 64
L = 40

m1 = SkipGRU(ic, hc, layer_num=2, return_total_u=True, learn_init=False, no_skip=False,batch_first=True)
m2 = torch.nn.GRU(ic,hc,num_layers =2,batch_first=True)

x = torch.rand(B,L,ic)

print(x.shape)

N = 10

x = x.to("cuda")
m1 = m1.to("cuda")
m2 = m2.to("cuda")

tic =  time.time()
for i in tqdm(range(N)):
    y = m1(x)[0]
toc =  time.time()

print((toc - tic)/N)

tic =  time.time()
for i in tqdm(range(N)):
    y = m2(x)[0]    
toc =  time.time()
print((toc - tic)/N)

torch.Size([64, 40, 256])


100%|█████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 19.86it/s]


0.05060796737670899


100%|████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 820.72it/s]

0.0014890909194946289





# Skip RNN
https://github.com/gitabcworld/skiprnn_pytorch/

In [75]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.init import xavier_uniform_
import math
import numpy as np

class RNNCellBase(nn.Module):
    __constants__ = ['input_size', 'hidden_size', 'bias']

    input_size: int
    hidden_size: int
    bias: bool
    # WARNING: bias_ih and bias_hh purposely not defined here.
    # See https://github.com/pytorch/pytorch/issues/39670

    def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.weight_ih = nn.ParameterList([torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs)])
        self.weight_hh = nn.ParameterList([torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs)])
        if bias:
            self.bias_ih = nn.ParameterList([torch.empty(num_chunks * hidden_size, **factory_kwargs)])
            self.bias_hh = nn.ParameterList([torch.empty(num_chunks * hidden_size, **factory_kwargs)])
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)

        self.reset_parameters()

    def extra_repr(self) -> str:
        s = '{input_size}, {hidden_size}'
        if 'bias' in self.__dict__ and self.bias is not True:
            s += ', bias={bias}'
        if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
            s += ', nonlinearity={nonlinearity}'
        return s.format(**self.__dict__)

    def reset_parameters(self) -> None:
        stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
        for weight in self.parameters():
            nn.init.uniform_(weight, -stdv, stdv)


class CCellBase(RNNCellBase):

    def __init__(self, cell, learnable_elements, input_size, hidden_size, num_layers = 1, num_chunks=3,
                    bias=True, batch_first = False, activation=F.tanh, layer_norm=False):
        print(input_size)
        print(hidden_size)
        super(CCellBase, self).__init__(input_size, hidden_size, bias,num_chunks )
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.batch_first = batch_first
        self.cell = cell
        self.num_layers = num_layers
        self.weight_ih = nn.ParameterList([])
        self.weight_hh = nn.ParameterList([])
        self.bias_ih = nn.ParameterList([])
        self.bias_hh = nn.ParameterList([])

        for i in np.arange(self.num_layers):
            if i == 0:
                weight_ih = Parameter(xavier_uniform_(torch.Tensor(learnable_elements * hidden_size, input_size)))
            else:
                weight_ih = Parameter(xavier_uniform_(torch.Tensor(learnable_elements * hidden_size, hidden_size)))
            weight_hh = Parameter(xavier_uniform_(torch.Tensor(learnable_elements * hidden_size, hidden_size)))
            self.weight_ih.append(weight_ih)
            self.weight_hh.append(weight_hh)
            if bias:
                bias_ih = Parameter(torch.zeros(learnable_elements * hidden_size))
                bias_hh = Parameter(torch.zeros(learnable_elements * hidden_size))
                self.bias_ih.append(bias_ih)
                self.bias_hh.append(bias_hh)
            else:
                self.register_parameter('bias_ih_' + str(i), None)
                self.register_parameter('bias_hh_' + str(i), None)
        self.weight_ih = nn.ParameterList(self.weight_ih)
        self.weight_hh = nn.ParameterList(self.weight_hh)
        if self.bias_ih:
            self.bias_ih = nn.ParameterList(self.bias_ih)
            self.bias_hh = nn.ParameterList(self.bias_hh)

        self.activation = activation
        self.layer_norm = layer_norm
        self.lst_bnorm_rnn = None 
        
class CCellBaseSkipGRU(CCellBase):

    def __init__(self, cell, learnable_elements, input_size, hidden_size, num_layers = 1,
                    bias=True, batch_first = False, activation=F.tanh, layer_norm=False):
        super(CCellBaseSkipGRU, self).__init__(cell, learnable_elements, input_size, hidden_size, num_layers,
                                               bias, batch_first, activation, layer_norm)
        self.weight_uh = Parameter(xavier_uniform_(torch.Tensor(1, hidden_size)))
        if bias:
            self.bias_uh = Parameter(torch.ones(1))
        else:
            self.register_parameter('bias_uh', None)

    def forward(self, input, hx = None):
        if len(input.shape) == 3:
            if self.batch_first:
                input = input.transpose(0,1)
            sequence_length, batch_size, input_size = input.shape
        else:
            sequence_length = 1
            batch_size, input_size = input.shape

        if hx is None:
            hx = self.init_hidden(batch_size)
            if input.is_cuda:
                if self.num_layers == 1:
                    hx = tuple([x.cuda() for x in hx])
                else:
                    hx = [tuple([j.cuda() if j is not None else None for j in i]) for i in hx]

        """  Deprecated ? 
        if len(input.shape) == 3:
            self.check_forward_input(input[0])
            if self.num_layers > 1:
                self.check_forward_hidden(input[0], hx[0][0], '[0]')
            lse:
                self.check_forward_hidden(input[0], hx[0], '[0]')
        else:
            self.check_forward_input(input)
            if self.num_layers > 1:
                self.check_forward_hidden(input, hx[0][0], '[0]')
            else:
                self.check_forward_hidden(input, hx[0], '[0]')
        """
    
        # Initialize batchnorm layers
        if self.layer_norm and self.lst_bnorm_rnn is None:
            self.lst_bnorm_rnn = []
            for i in np.arange(self.num_layers):
                # Create gain and bias for input_gate, new_input, forget_gate, output_gate
                lst_bnorm_rnn_tmp = torch.nn.ModuleList([nn.BatchNorm1d(self.hidden_size) for i in np.arange(2)])
                if input.is_cuda:
                    lst_bnorm_rnn_tmp = lst_bnorm_rnn_tmp.cuda()
                self.lst_bnorm_rnn.append(lst_bnorm_rnn_tmp)
            self.lst_bnorm_rnn = torch.nn.ModuleList(self.lst_bnorm_rnn)

        lst_output = []
        lst_update_gate = []
        for t in np.arange(sequence_length):
            output, hx = self.cell(
                input[t], hx, self.num_layers,
                self.weight_ih, self.weight_hh, self.weight_uh,
                self.bias_ih, self.bias_hh, self.bias_uh,
                activation=self.activation,
                lst_layer_norm=self.lst_bnorm_rnn
            )
            new_h, update_gate = output
            lst_output.append(new_h)
            lst_update_gate.append(update_gate)
        output = torch.stack(lst_output)
        update_gate = torch.stack(lst_update_gate)
        if self.batch_first:
            output = output.transpose(0, 1)
            update_gate = update_gate.transpose(0, 1)
        return output, hx, update_gate

class BinaryLayer(Function):
    def forward(self, input):
        return input.round()
 
    def backward(self, grad_output):
        return grad_output    
    
def MultiSkipGRUCell(input, state, num_layers, w_ih, w_hh, w_uh,b_ih=None, b_hh=None, b_uh=None,
                  activation=F.tanh, lst_layer_norm=None):

    _ , update_prob_prev, cum_update_prob_prev = state[-1]
    cell_input = input
    state_candidates = []

    for idx in np.arange(num_layers):

        h_prev, _, _ = state[idx]

        gi = F.linear(cell_input, w_ih[idx], b_ih[idx])
        gh = F.linear(h_prev, w_hh[idx], b_hh[idx])
        i_r, i_i, i_n = gi.chunk(3, 1)
        h_r, h_i, h_n = gh.chunk(3, 1)

        resetgate_tmp = i_r + h_r
        inputgate_tmp = i_i + h_i
        if lst_layer_norm:
            resetgate_tmp = lst_layer_norm[idx][0](resetgate_tmp.contiguous())
            inputgate_tmp = lst_layer_norm[idx][1](inputgate_tmp.contiguous())

        resetgate = F.sigmoid(resetgate_tmp)
        inputgate = F.sigmoid(inputgate_tmp)

        newgate = activation(i_n + resetgate * h_n)
        new_h_tilde = newgate + inputgate * (h_prev - newgate)

        state_candidates.append(new_h_tilde)
        cell_input = new_h_tilde

    # Compute value for the update prob
    new_update_prob_tilde = F.sigmoid(F.linear(state_candidates[-1], w_uh, b_uh))

    # Compute value for the update gate
    cum_update_prob = cum_update_prob_prev + torch.min(update_prob_prev, 1. - cum_update_prob_prev)
    # round
    bn = BinaryLayer()
    update_gate = bn(cum_update_prob)
    # Apply update gate
    new_states = []
    for idx in np.arange(num_layers - 1):
        new_h = update_gate * state_candidates[idx] + (1. - update_gate) * state[idx][0]
        new_states.append((new_h,None,None))
    new_h = update_gate * state_candidates[-1] + (1. - update_gate) * state[-1][0]

    new_update_prob = update_gate * new_update_prob_tilde + (1. - update_gate) * update_prob_prev
    new_cum_update_prob = update_gate * 0. + (1. - update_gate) * cum_update_prob
    new_states.append((new_h, new_update_prob, new_cum_update_prob))
    new_output = (new_h, update_gate)

    return new_output, new_states    
    
    
class CMultiSkipGRUCell(CCellBaseSkipGRU):
    def __init__(self, *args, **kwargs):
        super(CMultiSkipGRUCell, self).__init__(cell=MultiSkipGRUCell, learnable_elements=3, *args, **kwargs)

    def init_hidden(self, batch_size):
        initial_states = []
        for i in np.arange(self.num_layers):
            initial_h = Variable(torch.randn(batch_size, self.hidden_size))
            if i == self.num_layers - 1: #last layer
                initial_update_prob = Variable(torch.ones(batch_size, 1),requires_grad=False)
                initial_cum_update_prob = Variable(torch.zeros(batch_size, 1),requires_grad=False)
            else:
                initial_update_prob = None
                initial_cum_update_prob = None
            initial_states.append((initial_h,initial_update_prob,initial_cum_update_prob))
        return initial_states

In [None]:




import time
from tqdm.auto import tqdm


ic = 256
hc = 256
B = 64
L = 40

m1 = CMultiSkipGRUCell(input_size=ic, hidden_size=hc, batch_first=True, num_layers=2)
m2 = torch.nn.GRU(ic,hc,num_layers =2,batch_first=True)

x = torch.rand(B,L,ic)

print(x.shape)

N = 10

x = x.to("cuda")
m1 = m1.to("cuda")
m2 = m2.to("cuda")

tic =  time.time()
for i in tqdm(range(N)):
    y = m1(x)[0]
toc =  time.time()

print((toc - tic)/N)

tic =  time.time()
for i in tqdm(range(N)):
    y = m2(x)[0]    
toc =  time.time()
print((toc - tic)/N)

256
256
torch.Size([64, 40, 256])


  0%|                                                                             | 0/10 [00:00<?, ?it/s]

> [0;32m/tmp/ipykernel_3224396/2675326793.py[0m(203)[0;36mMultiSkipGRUCell[0;34m()[0m
[0;32m    201 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    202 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 203 [0;31m        [0mnewgate[0m [0;34m=[0m [0mactivation[0m[0;34m([0m[0mi_n[0m [0;34m+[0m [0mresetgate[0m [0;34m*[0m [0mh_n[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    204 [0;31m        [0mnew_h_tilde[0m [0;34m=[0m [0mnewgate[0m [0;34m+[0m [0minputgate[0m [0;34m*[0m [0;34m([0m[0mh_prev[0m [0;34m-[0m [0mnewgate[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    205 [0;31m[0;34m[0m[0m
[0m


ipdb>  activation


False


ipdb>  F.tanh


<function tanh at 0x7f91ab5df700>


In [36]:
class cellModule(nn.Module):

    def __init__(self, cells, model):
        super(cellModule, self).__init__()
        self.model = model
        self.rnn = cells
        self.d1 = nn.Linear(FLAGS['rnn_cells'],OUTPUT_SIZE)

    def forward(self, input, hx=None):
        if hx is not None:
            output = self.rnn(input, hx)
        else:
            output = self.rnn(input)
        output, hx, updated_state = split_rnn_outputs(self.model, output)
        output = self.d1(output[:,-1,:]) # Get the last output of the sequence
        return output, hx, updated_state