In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Swish activation
def swish(x):
    return x * torch.sigmoid(x)

# Feed Forward Module
class FeedForwardModule(nn.Module):
    def __init__(self, dim, expansion=4, dropout=0.1):
        super().__init__()
        self.layernorm = nn.LayerNorm(dim)
        self.linear1 = nn.Linear(dim, dim * expansion)
        self.activation = swish
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim * expansion, dim)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        x_norm = self.layernorm(x)
        x = self.linear1(x_norm)
        x = self.activation(x)
        x = self.dropout1(x)
        x = self.linear2(x)
        x = self.dropout2(x)
        return x

# Multi-Head Self-Attention Module
class MultiHeadSelfAttentionModule(nn.Module):
    def __init__(self, dim, heads=8, dropout=0.1):
        super().__init__()
        self.layernorm = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x_norm = self.layernorm(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        return self.dropout(attn_out)

# Convolution Module
class ConvolutionModule(nn.Module):
    def __init__(self, dim, kernel_size=32, dropout=0.1):
        super().__init__()
        self.layernorm = nn.LayerNorm(dim)
        self.pointwise_conv1 = nn.Conv1d(dim, 2 * dim, kernel_size=1)
        self.glu = nn.GLU(dim=1)
        self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
        self.batchnorm = nn.BatchNorm1d(dim)
        self.activation = swish
        self.pointwise_conv2 = nn.Conv1d(dim, dim, kernel_size=1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layernorm(x)
        x = x.transpose(1, 2)
        x = self.pointwise_conv1(x)
        x = self.glu(x)
        x = self.depthwise_conv(x)
        x = self.batchnorm(x)
        x = self.activation(x)
        x = self.pointwise_conv2(x)
        x = self.dropout(x)
        return x.transpose(1, 2)

# Conformer Block
class ConformerBlock(nn.Module):
    def __init__(self, dim, heads=8, ff_expansion=4, conv_kernel=32, dropout=0.1):
        super().__init__()
        self.ffn1 = FeedForwardModule(dim, ff_expansion, dropout)
        self.mhsa = MultiHeadSelfAttentionModule(dim, heads, dropout)
        self.conv = ConvolutionModule(dim, conv_kernel, dropout)
        self.ffn2 = FeedForwardModule(dim, ff_expansion, dropout)
        self.layernorm = nn.LayerNorm(dim)

    def forward(self, x):
        x = x + 0.5 * self.ffn1(x)
        x = x + self.mhsa(x)
        x = x + self.conv(x)
        x = x + 0.5 * self.ffn2(x)
        return self.layernorm(x)

# Subsampling Module (Conv2D based)
class ConvSubsampling(nn.Module):
    def __init__(self, in_channels=1, out_dim=256):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_dim, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.linear = nn.Linear(out_dim * ((80 // 4)), out_dim)  # Assuming input feature dim is 80

    def forward(self, x):
        x = x.unsqueeze(1)  # (B, 1, T, F)
        x = self.conv(x)  # (B, C, T//4, F//4)
        b, c, t, f = x.size()
        x = x.transpose(1, 2).contiguous().view(b, t, c * f)
        x = self.linear(x)
        return x

# Conformer Encoder
class ConformerEncoder(nn.Module):
    def __init__(self, input_dim=80, num_layers=4, model_dim=256, heads=8, ff_expansion=4, conv_kernel=32, dropout=0.1):
        super().__init__()
        self.subsampling = ConvSubsampling(out_dim=model_dim)
        self.blocks = nn.ModuleList([
            ConformerBlock(model_dim, heads, ff_expansion, conv_kernel, dropout) for _ in range(num_layers)
        ])

    def forward(self, x):
        x = self.subsampling(x)
        for block in self.blocks:
            x = block(x)
        return x

# Decoder
class LSTMDecoder(nn.Module):
    def __init__(self, input_dim=256, hidden_dim=512, vocab_size=1000):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        out, _ = self.lstm(x)
        return self.fc(out)

# Full Conformer Transducer Model
class ConformerTransducer(nn.Module):
    def __init__(self, input_dim=80, encoder_layers=4, model_dim=256, heads=8, vocab_size=1000):
        super().__init__()
        self.encoder = ConformerEncoder(input_dim=input_dim, num_layers=encoder_layers, model_dim=model_dim, heads=heads)
        self.decoder = LSTMDecoder(input_dim=model_dim, hidden_dim=model_dim * 2, vocab_size=vocab_size)

    def forward(self, x):
        encoded = self.encoder(x)
        output = self.decoder(encoded)
        return output
