In [5]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: cvqluu
repo: https://github.com/cvqluu/TDNN
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

class TDNN(nn.Module):
    
    def __init__(
                    self, 
                    input_dim=23, 
                    output_dim=512,
                    context_size=5,
                    stride=1,
                    dilation=1,
                    batch_norm=False,
                    dropout_p=0.2
                ):
        '''
        TDNN as defined by https://www.danielpovey.com/files/2015_interspeech_multisplice.pdf

        Affine transformation not applied globally to all frames but smaller windows with local context

        batch_norm: True to include batch normalisation after the non linearity
        
        Context size and dilation determine the frames selected
        (although context size is not really defined in the traditional sense)
        For example:
            context size 5 and dilation 1 is equivalent to [-2,-1,0,1,2]
            context size 3 and dilation 2 is equivalent to [-2, 0, 2]
            context size 1 and dilation 1 is equivalent to [0]
        '''
        super(TDNN, self).__init__()
        self.context_size = context_size
        self.stride = stride
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.dilation = dilation
        self.dropout_p = dropout_p
        self.batch_norm = batch_norm
      
        self.kernel = nn.Linear(input_dim*context_size, output_dim)
        self.nonlinearity = nn.ReLU()
        if self.batch_norm:
            self.bn = nn.BatchNorm1d(output_dim)
        if self.dropout_p:
            self.drop = nn.Dropout(p=self.dropout_p)
        
    def forward(self, x):
        '''
        input: size (batch, seq_len, input_features)
        outpu: size (batch, new_seq_len, output_features)
        '''
        
        _, _, d = x.shape
        assert (d == self.input_dim), 'Input dimension was wrong. Expected ({}), got ({})'.format(self.input_dim, d)
        x = x.unsqueeze(1)

        # Unfold input into smaller temporal contexts
        x = F.unfold(
                        x, 
                        (self.context_size, self.input_dim), 
                        stride=(1,self.input_dim), 
                        dilation=(self.dilation,1)
                    )

        # N, output_dim*context_size, new_t = x.shape
        x = x.transpose(1,2)
        x = self.kernel(x.float())
        x = self.nonlinearity(x)
        
        if self.dropout_p:
            x = self.drop(x)

        if self.batch_norm:
            x = x.transpose(1,2)
            x = self.bn(x)
            x = x.transpose(1,2)

        return x

In [6]:
tdnn = TDNN()

In [10]:
torch.rand(10, 100, 23)

tensor([[[0.4404, 0.2207, 0.8458,  ..., 0.4892, 0.4366, 0.1023],
         [0.7678, 0.7187, 0.8508,  ..., 0.0429, 0.3646, 0.1251],
         [0.5968, 0.8466, 0.3945,  ..., 0.7244, 0.5966, 0.8312],
         ...,
         [0.1986, 0.8062, 0.8846,  ..., 0.3881, 0.3388, 0.5566],
         [0.4716, 0.9867, 0.6578,  ..., 0.0413, 0.8325, 0.4968],
         [0.4873, 0.4935, 0.7286,  ..., 0.1536, 0.3410, 0.0715]],

        [[0.1184, 0.2546, 0.9143,  ..., 0.6843, 0.5807, 0.1920],
         [0.5552, 0.0059, 0.2573,  ..., 0.3597, 0.2361, 0.3734],
         [0.2397, 0.4276, 0.0714,  ..., 0.1629, 0.0315, 0.3427],
         ...,
         [0.5957, 0.9486, 0.0627,  ..., 0.2708, 0.7658, 0.2871],
         [0.3384, 0.7126, 0.5000,  ..., 0.0678, 0.3347, 0.1003],
         [0.4883, 0.3739, 0.5029,  ..., 0.6770, 0.9791, 0.5804]],

        [[0.7110, 0.7222, 0.8490,  ..., 0.3204, 0.4833, 0.6952],
         [0.1114, 0.9234, 0.6911,  ..., 0.7575, 0.3362, 0.0637],
         [0.9070, 0.0801, 0.9693,  ..., 0.4854, 0.7429, 0.

In [9]:
tdnn(torch.rand(1, 100, 23))

tensor([[[0.0000, 0.0000, 0.0365,  ..., 0.8219, 0.3764, 0.1712],
         [0.1364, 0.0000, 0.0247,  ..., 0.5945, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.7244, 0.3020, 0.1773],
         ...,
         [0.2313, 0.0000, 0.0000,  ..., 0.0000, 0.1881, 0.0000],
         [0.0000, 0.0000, 0.0536,  ..., 0.0000, 0.4088, 0.1005],
         [0.2669, 0.0000, 0.3220,  ..., 0.5542, 0.1776, 0.0000]]],
       grad_fn=<MulBackward0>)

In [3]:
print(tdnn)

TDNN(
  (kernel): Linear(in_features=115, out_features=512, bias=True)
  (nonlinearity): ReLU()
  (drop): Dropout(p=0.2, inplace=False)
)
