# dynamic_rnn

### Define functions

In [1]:
import torch


def dynamic_rnn(cell, inputs, sequence_length):
    '''
        Inputs:
            cell: torch.nn.LSTMCell instance
            inputs: (batch_size, max_timestep, input_size)
            sequence_length: (batch_size,)
        Outputs:
            outputs: (batch_size, max_timestep, hidden_size)
            final_states(NOT SUPPORT): a tuple of (hidden_state, cell_state) which are the final states of the cell
    '''
    def sort_batch(data, lengths):
        '''
            data: (batch_size, ?)
            lengths: (batch_size,)
        '''
        sorted_indices, sorted_lengths = zip(*sorted(enumerate(lengths), key=lambda x: x[1], reverse=True))
        sorted_indices = list(sorted_indices)
        sorted_data = data[sorted_indices]
        return sorted_data, sorted_lengths, sorted_indices

    dtype = inputs.dtype
    device = inputs.device
        
    sorted_inputs, sorted_lengths, sorted_indices = sort_batch(inputs, sequence_length)
    
    decoder_lengths = [length - 1 for length in sorted_lengths]

    sorted_outputs = torch.zeros((inputs.shape[0], inputs.shape[1], cell.hidden_size), dtype=dtype).to(device)
    outputs = torch.zeros((inputs.shape[0], inputs.shape[1], cell.hidden_size), dtype=dtype).to(device)


    h, c = None, None
    for step in range(sorted_lengths[0]):
        curr_batch_size = sum([l > step for l in sorted_lengths])
        #sorted_inputs = sorted_inputs[:curr_batch_size, step, :] # (curr_batch_size, timesteps, input_size)
        curr_inputs = sorted_inputs[:curr_batch_size, step, :] # (batch_size, input_size)

        if h is None or c is None:
            h, c = cell(curr_inputs, None) # (curr_batch_size, hidden_size)
        else:
            h, c = h[:curr_batch_size], c[:curr_batch_size] # (curr_batch_size, hidden_size)
            h, c = cell(curr_inputs, (h, c)) # (curr_batch_size, hidden_size)

        sorted_outputs[:curr_batch_size, step, :] = h

    # sort back
    outputs[sorted_indices] = sorted_outputs

    return outputs



## Experiments

### 1. fixed-lengths

If the inputs are fixed-lengths data, we recommend to use `torch.nn.LSTM` rather than `torch.nn.LSTMCell` and `dynamic_rnn`.

## 2. Variable-lengths

In [2]:
inputs = torch.randn(4, 5, 2)
sequence_length = torch.LongTensor([2, 3, 4, 1])

for i in range(inputs.shape[0]):
    inputs[i, sequence_length[i]:, :] = 0
    
lstm = torch.nn.LSTMCell(input_size=2, hidden_size=4)

print('inputs:')
print(inputs.shape)
print(inputs)

print('-'*30)
print('sequence_length:')
print(sequence_length)

inputs:
torch.Size([4, 5, 2])
tensor([[[-0.5350, -0.0709],
         [-0.4732, -0.1139],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]],

        [[ 0.9036, -0.5725],
         [-0.8133,  0.2226],
         [-0.0077,  1.3141],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]],

        [[-1.4904, -1.1200],
         [-0.1805, -0.1135],
         [-1.2188,  1.8980],
         [ 0.5987,  0.8136],
         [ 0.0000,  0.0000]],

        [[-1.0816, -1.6631],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]]])
------------------------------
sequence_length:
tensor([2, 3, 4, 1])


In [3]:
outputs = dynamic_rnn(lstm, inputs, sequence_length=sequence_length)

print(outputs.shape)
print(outputs)

torch.Size([4, 5, 4])
tensor([[[-0.0603,  0.1638,  0.0383, -0.2045],
         [-0.1093,  0.2398,  0.0594, -0.2780],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1622,  0.0282,  0.1451, -0.0276],
         [-0.1598,  0.1900,  0.0993, -0.2576],
         [-0.2170,  0.1980,  0.2763, -0.1226],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0436,  0.2228, -0.0259, -0.3050],
         [-0.0582,  0.2379,  0.0424, -0.2694],
         [-0.0855,  0.3073,  0.1373, -0.3327],
         [-0.1962,  0.2263,  0.3077, -0.1008],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0332,  0.1801, -0.0299, -0.2301],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]], grad_fn=<PutBackward>)


In [4]:
for i in range(inputs.shape[0]):
    curr_inputs = inputs[i, :sequence_length[i], :].unsqueeze(0)
    outputs = torch.zeros((1, sequence_length[i], lstm.hidden_size), dtype=torch.float)
    print(curr_inputs.shape)
    h, c = None, None
    for step in range(sequence_length[i]):
        if h is None or c is None:
            h, c = lstm(curr_inputs[:, step, :], None)
        else:
            h, c = lstm(curr_inputs[:, step, :], (h, c))
        outputs[:, step, :] = h
        
    print(outputs)

torch.Size([1, 2, 2])
tensor([[[-0.0603,  0.1638,  0.0383, -0.2045],
         [-0.1093,  0.2398,  0.0594, -0.2780]]], grad_fn=<CopySlices>)
torch.Size([1, 3, 2])
tensor([[[-0.1622,  0.0282,  0.1451, -0.0276],
         [-0.1598,  0.1900,  0.0993, -0.2576],
         [-0.2170,  0.1980,  0.2763, -0.1226]]], grad_fn=<CopySlices>)
torch.Size([1, 4, 2])
tensor([[[ 0.0436,  0.2228, -0.0259, -0.3050],
         [-0.0582,  0.2379,  0.0424, -0.2694],
         [-0.0855,  0.3073,  0.1373, -0.3327],
         [-0.1962,  0.2263,  0.3077, -0.1008]]], grad_fn=<CopySlices>)
torch.Size([1, 1, 2])
tensor([[[ 0.0332,  0.1801, -0.0299, -0.2301]]], grad_fn=<CopySlices>)
