In [101]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
vf = torch.rand(1, 2, 50, 512)

In [3]:
vf.shape

torch.Size([1, 2, 50, 512])

1, 50, 512 * 2

In [4]:
vf = vf.transpose(1, 2)

In [5]:
vf.shape

torch.Size([1, 50, 2, 512])

In [6]:
vf = vf.flatten(start_dim=2).transpose(2, 1)

In [7]:
vf.shape
# batch_size x lips_embeddings_dim * 2 x seconds * fps

torch.Size([1, 1024, 50])

In [94]:
class VideoEncoderBlock(nn.Module):
    def __init__(self, in_channels, kernel_size, out_channels=None, use_separable_depthwise=True):
        super().__init__()
        groups = in_channels if use_separable_depthwise else 1
        if not out_channels:
            out_channels = in_channels
        self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups, padding=1)
        self.relu = nn.ReLU()
        self.norm = nn.BatchNorm1d(out_channels)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.norm(x)
        return x

In [96]:
video_encoder = nn.Sequential(
                            *[VideoEncoderBlock(1024, 3, use_separable_depthwise=True) for _ in range(5)], 
                            VideoEncoderBlock(1024, 3, out_channels=256, use_separable_depthwise=False))

In [97]:
out = video_encoder(vf)

In [98]:
out.shape

torch.Size([1, 256, 50])

In [99]:
from modelsummary import summary

In [100]:
summary(video_encoder, vf)

-----------------------------------------------------------------------
             Layer (type)                Input Shape         Param #
      VideoEncoderBlock-1             [-1, 1024, 50]               0
                 Conv1d-2             [-1, 1024, 50]           4,096
                   ReLU-3             [-1, 1024, 50]               0
            BatchNorm1d-4             [-1, 1024, 50]           2,048
      VideoEncoderBlock-5             [-1, 1024, 50]               0
                 Conv1d-6             [-1, 1024, 50]           4,096
                   ReLU-7             [-1, 1024, 50]               0
            BatchNorm1d-8             [-1, 1024, 50]           2,048
      VideoEncoderBlock-9             [-1, 1024, 50]               0
                Conv1d-10             [-1, 1024, 50]           4,096
                  ReLU-11             [-1, 1024, 50]               0
           BatchNorm1d-12             [-1, 1024, 50]           2,048
     VideoEncoderBlock-13      

In [93]:
summary(video_encoder, vf)

-----------------------------------------------------------------------
             Layer (type)                Input Shape         Param #
      VideoEncoderBlock-1             [-1, 1024, 50]               0
                 Conv1d-2             [-1, 1024, 50]       3,146,752
                   ReLU-3             [-1, 1024, 50]               0
            BatchNorm1d-4             [-1, 1024, 50]           2,048
      VideoEncoderBlock-5             [-1, 1024, 50]               0
                 Conv1d-6             [-1, 1024, 50]       3,146,752
                   ReLU-7             [-1, 1024, 50]               0
            BatchNorm1d-8             [-1, 1024, 50]           2,048
      VideoEncoderBlock-9             [-1, 1024, 50]               0
                Conv1d-10             [-1, 1024, 50]       3,146,752
                  ReLU-11             [-1, 1024, 50]               0
           BatchNorm1d-12             [-1, 1024, 50]           2,048
     VideoEncoderBlock-13      

In [102]:
out.shape

torch.Size([1, 256, 50])

audio_shape = 1 * num_channels * 16000

In [107]:
out = F.interpolate(out, size=15999, mode='linear', align_corners = False)

In [108]:
out.shape

torch.Size([1, 256, 15999])

In [109]:
ae = torch.rand(1, 256, 15999)

In [111]:
concat = torch.cat([ae, out], dim=1)

In [112]:
concat.shape

torch.Size([1, 512, 15999])

In [115]:
proj = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=1)

In [116]:
proj(concat).shape

torch.Size([1, 256, 15999])