In [1]:
import torch


def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None):
    '''
        Inputs:
            cell: torch.nn.LSTMCell instance
            inputs: (batch_size, max_timestep, input_size)
            sequence_length: (batch_size,)
            initial_state: a tuple of (hidden_state, cell_state)
        Outputs:
            outputs: (batch_size, max_timestep, hidden_size)
            final_states(NOT SUPPORT): a tuple of (hidden_state, cell_state) which are the final states of the cell
    '''
    def sort_batch(data, lengths):
        '''
            data: (batch_size, ?)
            lengths: (batch_size,)
        '''
        sorted_indices, sorted_lengths = zip(*sorted(enumerate(lengths), key=lambda x: x[1], reverse=True))
        sorted_indices = list(sorted_indices)
        sorted_data = data[sorted_indices]
        return sorted_data, sorted_lengths, sorted_indices

    dtype = inputs.dtype
    device = inputs.device
    if sequence_length is None:
        sequence_length = torch.LongTensor([inputs.shape[1]]).to(deviec)
        
    sorted_inputs, sorted_lengths, sorted_indices = sort_batch(inputs, sequence_length)
    
    decoder_lengths = [length - 1 for length in sorted_lengths]

    sorted_outputs = torch.zeros((inputs.shape[0], inputs.shape[1], cell.hidden_size), dtype=dtype).to(device)
    outputs = torch.zeros((inputs.shape[0], inputs.shape[1], cell.hidden_size), dtype=dtype).to(device)


    h, c = None, None
    for step in range(sorted_lengths[0]):
        curr_batch_size = sum([l > step for l in sorted_lengths])
        #sorted_inputs = sorted_inputs[:curr_batch_size, step, :] # (curr_batch_size, timesteps, input_size)
        curr_inputs = sorted_inputs[:curr_batch_size, step, :] # (batch_size, input_size)

        if h is None or c is None:
            h, c = cell(curr_inputs, None) # (curr_batch_size, hidden_size)
        else:
            h, c = h[:curr_batch_size], c[:curr_batch_size] # (curr_batch_size, hidden_size)
            h, c = cell(curr_inputs, (h, c)) # (curr_batch_size, hidden_size)

        sorted_outputs[:curr_batch_size, step, :] = h

    # sort back
    outputs[sorted_indices] = sorted_outputs

    return outputs



In [2]:
inputs = torch.randn(4, 5, 2)
sequence_length = torch.LongTensor([2, 3, 4, 1])

for i in range(inputs.shape[0]):
    inputs[i, sequence_length[i]:, :] = 0
    
lstm = torch.nn.LSTMCell(input_size=2, hidden_size=4)

print('inputs:')
print(inputs.shape)
print(inputs)

print('-'*30)
print('sequence_length:')
print(sequence_length)

inputs:
torch.Size([4, 5, 2])
tensor([[[-0.8641,  1.3578],
         [ 0.3086,  0.9989],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]],

        [[-0.3247, -0.5936],
         [ 0.4447, -0.7139],
         [ 0.1393, -0.2321],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]],

        [[-0.9856, -1.9107],
         [-0.7021,  1.3406],
         [-1.4402,  1.8186],
         [-1.5619,  1.6430],
         [ 0.0000,  0.0000]],

        [[-0.7612, -0.6508],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]]])
------------------------------
sequence_length:
tensor([2, 3, 4, 1])


In [3]:
outputs = dynamic_rnn(lstm, inputs, sequence_length=sequence_length)

print(outputs.shape)
print(outputs)

torch.Size([4, 5, 4])
tensor([[[-0.0921,  0.3016,  0.1301,  0.1123],
         [-0.0995,  0.3855,  0.0897,  0.1728],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.0688,  0.1002, -0.0037, -0.0198],
         [ 0.0235,  0.0934, -0.0508, -0.0198],
         [-0.0286,  0.1563, -0.0523,  0.0247],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1092,  0.0302, -0.0278, -0.1158],
         [-0.1288,  0.3310,  0.0633,  0.0567],
         [-0.1283,  0.4139,  0.1793,  0.1755],
         [-0.1458,  0.4177,  0.2437,  0.2406],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1176,  0.1075,  0.0113, -0.0338],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]], grad_fn=<PutBackward>)


In [4]:
for i in range(inputs.shape[0]):
    curr_inputs = inputs[i, :sequence_length[i], :].unsqueeze(0)
    outputs = torch.zeros((1, sequence_length[i], lstm.hidden_size), dtype=torch.float)
    print(curr_inputs.shape)
    h, c = None, None
    for step in range(sequence_length[i]):
        if h is None or c is None:
            h, c = lstm(curr_inputs[:, step, :], None)
        else:
            h, c = lstm(curr_inputs[:, step, :], (h, c))
        outputs[:, step, :] = h
        
    print(outputs)

torch.Size([1, 2, 2])
tensor([[[-0.0921,  0.3016,  0.1301,  0.1123],
         [-0.0995,  0.3855,  0.0897,  0.1728]]], grad_fn=<CopySlices>)
torch.Size([1, 3, 2])
tensor([[[-0.0688,  0.1002, -0.0037, -0.0198],
         [ 0.0235,  0.0934, -0.0508, -0.0198],
         [-0.0286,  0.1563, -0.0523,  0.0247]]], grad_fn=<CopySlices>)
torch.Size([1, 4, 2])
tensor([[[-0.1092,  0.0302, -0.0278, -0.1158],
         [-0.1288,  0.3310,  0.0633,  0.0567],
         [-0.1283,  0.4139,  0.1793,  0.1755],
         [-0.1458,  0.4177,  0.2437,  0.2406]]], grad_fn=<CopySlices>)
torch.Size([1, 1, 2])
tensor([[[-0.1176,  0.1075,  0.0113, -0.0338]]], grad_fn=<CopySlices>)
