In [34]:
import torch

class Downsample(torch.nn.Module):
    def __init__(self, input_dim=768, output_dim=2048, kernel_size=9, stride=2):
        super().__init__()
        
        self.norm = torch.nn.LayerNorm(input_dim)
        padding = kernel_size // 2
        self.conv = torch.nn.Conv1d(input_dim, output_dim, kernel_size=kernel_size, stride=stride, padding=padding, groups=128)
        
    def forward(self, x): # B x T x C 
        x = self.norm(x).contiguous()
        
        x = x.transpose(1, 2)
        x = self.conv(x)
        x = x.transpose(1, 2)
        
        return x # B x T x C 

In [31]:
d = Downsample()
sum(p.numel() for p in d.parameters() if p.requires_grad) / 1e6

0.05888

In [33]:
class Upsample(torch.nn.Module):
    def __init__(self, input_dim=2048, output_dim=256, kernel_size=9, stride=2):
        super().__init__()
        
        self.norm = torch.nn.LayerNorm(input_dim)
        self.conv = torch.nn.ConvTranspose1d( input_dim, output_dim, kernel_size, stride=stride,padding=((kernel_size - 1)) // 2, output_padding=stride - 1, groups=128)
          
    def forward(self, x): # B x T x C 
        x = self.norm(x)

        x = x.transpose(1, 2)
        x = self.conv(x)
        x = x.transpose(1, 2)

        return x # B x T x C 

u = Upsample()
sum(p.numel() for p in u.parameters() if p.requires_grad) / 1e6

0.041216