In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math



class Conv1dBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1):
        super().__init__()
        
        self.causal_padding = (kernel_size - 1) * dilation
        self.layernorm = nn.LayerNorm(in_channels)  # Normalize over channel dim
        self.conv = nn.Conv1d(
            in_channels, 
            out_channels, 
            kernel_size, 
            stride=stride, 
            dilation=dilation, 
            padding=0,
            groups=groups,
        )
        

    def forward(self, x, padding_mask=None):
        # x is expected to be of shape (batch, time, channels)
        if padding_mask is not None:
            x = x.masked_fill(padding_mask, 0)
        
        x = self.layernorm(x)
        
        x = x.transpose(1, 2)
        # Apply causal (left) padding: (padding_left, padding_right)
        x = F.pad(x, (self.causal_padding, 0))
        x = self.conv(x)
        x = x.transpose(1, 2)
        
        return x


class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1024):
        super().__init__()

        # Create matrix of shape (max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # (d_model//2)

        pe[:, 0::2] = torch.sin(position * div_term)  # even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices

        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)  # register as buffer (not a parameter)

    def forward(self, x):
        """
        Args:
            x: Tensor of shape (batch, seq_len, d_model)
        Returns:
            Tensor of shape (batch, seq_len, d_model) with positional encoding added
        """
        return x + self.pe[:, :x.size(1)]
    
class CausalTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers=2, dim_feedforward=1024, dropout=0.1):
        super().__init__()
        
        self.pos_enc = SinusoidalPositionalEncoding(d_model, max_len=2048)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                batch_first=True,
            )
            for _ in range(num_layers)
        ])

    def forward(self, x, padding_mask=None):
        """
        x: (batch, seq_len, d_model)
        padding_mask: (batch, seq_len) - True for padding tokens
        """
        bsz, seq_len, _ = x.size()
        device = x.device

        # Add sinusoidal positional encoding
        x = self.pos_enc(x)

        # Causal mask: prevent attending to future positions
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()

        for layer in self.layers:
            x = layer(
                src=x,
                src_mask=causal_mask,
                src_key_padding_mask=padding_mask
            )

        return x


class Discriminator(nn.Module):
    def __init__(self, in_channels=256, hidden_dim=256, kernel_size=9, groups=1):
        super().__init__()

        torch.backends.cuda.enable_flash_sdp(False)
        torch.backends.cuda.enable_mem_efficient_sdp(False)
        torch.backends.cuda.enable_math_sdp(True)
        
        self.disc_layers = nn.ModuleList([
            Conv1dBlock(in_channels, hidden_dim, kernel_size=1, groups=groups),
            nn.GELU(),
            nn.Dropout(0.1),
            Conv1dBlock(hidden_dim, hidden_dim, kernel_size),
            nn.GELU(),
            nn.Dropout(0.1),
        ])
        # self.proj = Conv1dBlock(hidden_dim, 1, kernel_size)
        
        
        self.decoder = CausalTransformer(d_model=hidden_dim, nhead=8, num_layers=1)
        self.proj = nn.Linear(hidden_dim, 1)
        
    def forward(self, x, padding_mask=None):
        """
        x: (batch, time, channels)
        padding_mask: (batch, time, 1) where True indicates a padded timestep.
        """
        
        for layer in self.disc_layers:
            if isinstance(layer, Conv1dBlock):
                x = layer(x, padding_mask)  # Pass padding_mask only to Conv1dBlock
            else:
                x = layer(x)  # GELU & Dropout don't need padding_mask
        
        # x = x.masked_fill(padding_mask, 0)
        # # Compute mean pooling over valid timesteps
        # valid_counts = (~padding_mask).sum(dim=1).clamp(min=1).float()
        # x_mean = x.sum(dim=1) / valid_counts  # (batch, channels)
        # x_mean = x_mean.unsqueeze(1) # (batch, 1, channels)
        
        # x_mean = self.proj(x_mean) # (batch, 1, 1)
        # return x_mean.squeeze(1).squeeze(1)  # (batch,)
        
        
        x = self.decoder(x, padding_mask.squeeze(-1))  # (batch, time, hidden_dim), (batch, time)
        x = x[:,0,:]  # (batch, hidden_dim)
        x = self.proj(x)  # (batch, hidden_dim) -> (batch, 1)
        x = x.squeeze(-1)
        return x
        
    
discriminator = Discriminator(in_channels=2048)
print(discriminator)
# calculate the parameters
num_params = sum(p.numel() for p in discriminator.parameters())
print(f"Number of parameters: {num_params / 1e6}")


Discriminator(
  (disc_layers): ModuleList(
    (0): Conv1dBlock(
      (layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (conv): Conv1d(2048, 256, kernel_size=(1,), stride=(1,))
    )
    (1): GELU(approximate='none')
    (2): Dropout(p=0.1, inplace=False)
    (3): Conv1dBlock(
      (layernorm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (conv): Conv1d(256, 256, kernel_size=(9,), stride=(1,))
    )
    (4): GELU(approximate='none')
    (5): Dropout(p=0.1, inplace=False)
  )
  (decoder): CausalTransformer(
    (pos_enc): SinusoidalPositionalEncoding()
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=256, bia