# Parameter Count and FLOPs Analysis: DNAFormer vs Compact BiGRU Model

This notebook provides a comprehensive module-by-module analysis of:
1. **DNAFormer** (exact architecture from the paper and GitHub code)
2. **Compact BiGRU Model** (your implementation)

For datasets: Srinivasavaradhan, Grass, Erlich, BinnedNanoporeTwoFlowcells, BinnedTestIllumina

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from collections import OrderedDict
from typing import Dict, Tuple, List

print(f"PyTorch version: {torch.__version__}")

PyTorch version: 2.5.1


## Dataset Configurations

Each dataset has different sequence lengths which affect parameter counts and FLOPs.

In [2]:
# Dataset configurations from DNAFormer paper and your code
DATASET_CONFIGS = {
    "Srinivasavaradhan": {
        "label_len": 110,
        "max_deviation": 10,
        "index_length": 0,  # No index filtering for public datasets
        "description": "Srinivasavaradhan et al. - Twist, MinION (high error)",
    },
    "Grass": {
        "label_len": 117,
        "max_deviation": 11,
        "index_length": 0,
        "description": "Grass et al. (2015) - CustomArray, Illumina miSeq",
    },
    "Erlich": {
        "label_len": 152,
        "max_deviation": 10,
        "index_length": 0,
        "description": "Erlich et al. (2017) - DNA Fountain, Illumina miSeq",
    },
    "BinnedNanoporeTwoFlowcells": {
        "label_len": 128,
        "max_deviation": 4,
        "index_length": 12,  # DNAFormer filters index
        "description": "DNAformer Nanopore Two Flowcells Combined",
    },
    "BinnedTestIllumina": {
        "label_len": 128,
        "max_deviation": 4,
        "index_length": 12,  # DNAFormer filters index
        "description": "DNAformer Test Illumina - Twist + miSeq",
    },
}

# Print dataset info
print("Dataset Configurations:")
print("=" * 80)
for name, cfg in DATASET_CONFIGS.items():
    print(f"{name}:")
    print(f"  Label length: {cfg['label_len']}")
    print(f"  Max deviation: {cfg['max_deviation']}")
    print(f"  Index length: {cfg['index_length']}")
    print()

Dataset Configurations:
Srinivasavaradhan:
  Label length: 110
  Max deviation: 10
  Index length: 0

Grass:
  Label length: 117
  Max deviation: 11
  Index length: 0

Erlich:
  Label length: 152
  Max deviation: 10
  Index length: 0

BinnedNanoporeTwoFlowcells:
  Label length: 128
  Max deviation: 4
  Index length: 12

BinnedTestIllumina:
  Label length: 128
  Max deviation: 4
  Index length: 12



---
# PART 1: DNAFormer Architecture (Exact from Paper/Code)

## DNAFormer Hyperparameters (from supplementary material):
- `n_head = 32`
- `num_layers = 12`
- `d_model = 1024`
- `alignment_filters = 128`
- `dim_feedforward = 2048`
- `output_ch = 4`
- `enc_filters = 4` (one-hot DNA encoding)
- Kernel sizes: {1, 3, 5, 7}
- Siamese architecture (2 branches, shared weights)

In [3]:
# ==============================================================================
# DNAFormer: Depthwise Separable Conv1D (EXACT FROM DNAFormer CODE)
# ==============================================================================

class depthwise_separable_conv_1d(nn.Module):
    """
    Exact implementation from DNAFormer GitHub code.
    
    Parameters:
    - depthwise: in_ch * kernels_per_layer * kernel_size (weights) + in_ch * kernels_per_layer (bias)
    - pointwise: in_ch * kernels_per_layer * out_ch (weights) + out_ch (bias)
    """
    def __init__(self, in_ch, out_ch, kernels_per_layer=1, kernel_size=3, stride=1, padding=0):
        super(depthwise_separable_conv_1d, self).__init__()
        self.depthwise = nn.Conv1d(in_ch, in_ch * kernels_per_layer, kernel_size=kernel_size, 
                                   stride=stride, padding=padding, groups=in_ch)
        self.pointwise = nn.Conv1d(in_ch * kernels_per_layer, out_ch, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

In [4]:
# ==============================================================================
# DNAFormer: Double Conv1D Block (EXACT FROM DNAFormer CODE)
# ==============================================================================

class double_conv1D(nn.Module):
    """
    (conv => norm => act) * 2
    Exact implementation from DNAFormer GitHub code.
    """
    def __init__(self, in_ch, out_ch, seq_len, padding=0, kernel_size=3, stride=1, p_dropout=0):
        super(double_conv1D, self).__init__()
        self.conv = nn.Sequential(
            depthwise_separable_conv_1d(in_ch, out_ch, kernels_per_layer=1, 
                                        kernel_size=kernel_size, stride=stride, padding=padding),
            nn.LayerNorm(seq_len, elementwise_affine=True),
            nn.GELU(),
            nn.Dropout(p_dropout),
            depthwise_separable_conv_1d(out_ch, out_ch, kernels_per_layer=1, 
                                        kernel_size=kernel_size, stride=stride, padding=padding),
            nn.LayerNorm(seq_len, elementwise_affine=True),
            nn.GELU(),
            nn.Dropout(p_dropout)
        )

    def forward(self, x):
        x = self.conv(x)
        return x

In [5]:
# ==============================================================================
# DNAFormer: Linear Block (EXACT FROM DNAFormer CODE)
# ==============================================================================

class linear_block(nn.Module):
    """
    Linear block for sequence length transformation.
    Exact implementation from DNAFormer GitHub code.
    
    Structure: Linear -> LayerNorm -> GELU -> Dropout -> Linear -> LayerNorm -> GELU -> Dropout -> Linear
    """
    def __init__(self, input_len, output_len, p_dropout=0):
        super(linear_block, self).__init__()
        self.fc_1 = nn.Linear(input_len, output_len)
        self.norm_1 = nn.LayerNorm(output_len, elementwise_affine=True)
        self.act_1 = nn.GELU()
        self.dout_1 = nn.Dropout(p_dropout)
        self.fc_2 = nn.Linear(output_len, output_len)
        self.norm_2 = nn.LayerNorm(output_len, elementwise_affine=True)
        self.act_2 = nn.GELU()
        self.dout_2 = nn.Dropout(p_dropout)
        self.fc_3 = nn.Linear(output_len, output_len)

    def forward(self, x):
        x = self.fc_1(x)
        x = self.norm_1(x)
        x = self.act_1(x)
        x = self.dout_1(x)
        x = self.act_2(self.norm_2(self.fc_2(x)))
        x = self.dout_2(x)
        x = self.fc_3(x)
        return x

In [6]:
# ==============================================================================
# DNAFormer: Alignment Module (EXACT FROM DNAFormer CODE)
# ==============================================================================

class alignement_module(nn.Module):
    """
    DNAFormer Alignment Module.
    Exact implementation from DNAFormer GitHub code.
    
    - 4 parallel double_conv1D blocks with kernel sizes {1, 3, 5, 7}
    - Each outputs alignment_filters // 4 channels (32 channels each)
    - Concatenate -> Linear block
    - Processes each read INDEPENDENTLY
    """
    def __init__(self, enc_filters, alignment_filters, noisy_copies_length, p_dropout=0):
        super(alignement_module, self).__init__()
        
        out_ch = alignment_filters // 4  # 32 channels per branch
        
        self.conv_block_1 = double_conv1D(enc_filters, out_ch, seq_len=noisy_copies_length, 
                                          kernel_size=1, p_dropout=p_dropout)
        self.conv_block_2 = double_conv1D(enc_filters, out_ch, seq_len=noisy_copies_length, 
                                          kernel_size=3, padding=1, p_dropout=p_dropout)
        self.conv_block_3 = double_conv1D(enc_filters, out_ch, seq_len=noisy_copies_length, 
                                          kernel_size=5, padding=2, p_dropout=p_dropout)
        self.conv_block_4 = double_conv1D(enc_filters, out_ch, seq_len=noisy_copies_length, 
                                          kernel_size=7, padding=3, p_dropout=p_dropout)
        
        self.linear_block = linear_block(input_len=noisy_copies_length, 
                                         output_len=noisy_copies_length, 
                                         p_dropout=p_dropout)
   
    def forward(self, x):      
        batch, cluster, emb, seq = x.shape
        
        # Reshape: (batch, cluster, emb, seq) -> (batch*cluster, emb, seq)
        x = x.view(batch * cluster, emb, seq)
        
        # Apply 4 kernel branches in parallel
        x = torch.cat([self.conv_block_1(x), self.conv_block_2(x), 
                       self.conv_block_3(x), self.conv_block_4(x)], dim=1)
        x = self.linear_block(x)
        
        # Reshape back
        x = x.view(batch, cluster, -1, seq)
        
        return x

In [7]:
# ==============================================================================
# DNAFormer: Embedding Module (EXACT FROM DNAFormer CODE)
# ==============================================================================

class embedding_module(nn.Module):
    """
    DNAFormer Embedding Module (after NCI).
    Exact implementation from DNAFormer GitHub code.
    
    - Same structure as alignment module
    - Input: alignment_filters channels (after NCI sum)
    - Output: d_model channels, with seq_len -> label_length
    """
    def __init__(self, alignment_filters, d_model, noisy_copies_length, label_length, p_dropout=0):
        super(embedding_module, self).__init__()
        
        self.label_length = label_length
        out_ch = d_model // 4  # 256 channels per branch
        
        self.conv_block_1 = double_conv1D(alignment_filters, out_ch, seq_len=noisy_copies_length, 
                                          kernel_size=1, p_dropout=p_dropout)
        self.conv_block_2 = double_conv1D(alignment_filters, out_ch, seq_len=noisy_copies_length, 
                                          kernel_size=3, padding=1, p_dropout=p_dropout)
        self.conv_block_3 = double_conv1D(alignment_filters, out_ch, seq_len=noisy_copies_length, 
                                          kernel_size=5, padding=2, p_dropout=p_dropout)
        self.conv_block_4 = double_conv1D(alignment_filters, out_ch, seq_len=noisy_copies_length, 
                                          kernel_size=7, padding=3, p_dropout=p_dropout)
        
        # Linear projects from noisy_copies_length to label_length
        self.linear_block = linear_block(input_len=noisy_copies_length, 
                                         output_len=label_length, 
                                         p_dropout=p_dropout)
                   
    def forward(self, x):
        # Sum over cluster dimension (Non-Coherent Integration)
        x = torch.sum(x, dim=1)
        
        # Feature extraction with 4 kernel branches
        x = torch.cat([self.conv_block_1(x), self.conv_block_2(x), 
                       self.conv_block_3(x), self.conv_block_4(x)], dim=1)
        x = self.linear_block(x)
        
        return x

In [8]:
# ==============================================================================
# DNAFormer: Output Module (EXACT FROM DNAFormer CODE)
# ==============================================================================

class output_module(nn.Module):
    """
    DNAFormer Output Module.
    Exact implementation from DNAFormer GitHub code.
    
    3 Conv1x1 layers: d_model -> d_model -> d_model -> output_ch (4)
    """
    def __init__(self, d_model, output_ch):
        super(output_module, self).__init__()
        
        self.conv_1 = nn.Conv1d(d_model, d_model, 1)
        self.conv_2 = nn.Conv1d(d_model, d_model, 1)
        self.conv_3 = nn.Conv1d(d_model, output_ch, 1)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        return x

In [9]:
# ==============================================================================
# DNAFormer: Fusion Module (EXACT FROM DNAFormer CODE)
# ==============================================================================

class fusion_module(nn.Module):
    """
    DNAFormer Fusion Module.
    Exact implementation from DNAFormer GitHub code.
    
    - 2 learnable fusion vectors (left, right) of length label_length
    - 3 Conv1x1 layers for refinement
    """
    def __init__(self, output_ch, label_length):
        super(fusion_module, self).__init__()
        
        self.pred_fusion_left = nn.Parameter(torch.ones(label_length))
        self.pred_fusion_right = nn.Parameter(torch.ones(label_length))
        
        self.conv_1 = nn.Conv1d(output_ch, output_ch, 1)
        self.conv_2 = nn.Conv1d(output_ch, output_ch, 1)
        self.conv_3 = nn.Conv1d(output_ch, output_ch, 1)

    def forward(self, x):
        x_left = x[:x.shape[0]//2, :, :]
        x_right = torch.flip(x[x.shape[0]//2:, :, :], dims=[-1])
        x = (x_left * self.pred_fusion_left + x_right * self.pred_fusion_right) / 2
        
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        
        return x, x_left, x_right

In [10]:
# ==============================================================================
# DNAFormer: Complete Model (EXACT FROM DNAFormer CODE)
# ==============================================================================

class DNAFormer(nn.Module):
    """
    Complete DNAFormer model (Siamese architecture with shared weights).
    Exact implementation from DNAFormer GitHub code.
    
    Architecture:
    1. Alignment Module (per-read, 4 kernel multi-head conv + linear)
    2. NCI (sum over cluster dimension) - no parameters, inside embedding_module
    3. Embedding Module (cluster-level, 4 kernel multi-head conv + linear)
    4. Transformer Encoder (12 layers, 32 heads, d_model=1024, ff=2048)
    5. Output Module (3 Conv1x1)
    6. Fusion Module (learnable vectors + 3 Conv1x1)
    
    DNAFormer Hyperparameters (from paper):
    - n_head = 32
    - num_layers = 12
    - d_model = 1024
    - alignment_filters = 128
    - dim_feedforward = 2048
    - output_ch = 4
    - enc_filters = 4
    """
    def __init__(self, enc_filters, alignment_filters, d_model, n_head, 
                 num_layers, dim_feedforward, output_ch, noisy_copies_length, 
                 label_length, p_dropout=0):
        super(DNAFormer, self).__init__()
        
        # Store config
        self.noisy_copies_length = noisy_copies_length
        self.label_length = label_length
        
        # 1. Alignment Module
        self.alignement = alignement_module(
            enc_filters, alignment_filters, noisy_copies_length, p_dropout
        )
        
        # 2. Embedding Module (includes NCI)
        self.embedding = embedding_module(
            alignment_filters, d_model, noisy_copies_length, label_length, p_dropout
        )
        
        # 3. Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_head,
            dim_feedforward=dim_feedforward,
            dropout=p_dropout,
            activation='gelu',
            batch_first=False
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # 4. Output Module
        self.output_module = output_module(d_model, output_ch)
        
        # 5. Fusion Module
        self.fusion = fusion_module(output_ch, label_length)
    
    def forward(self, x):
        # Alignment module (per-read processing)
        x = self.alignement(x)
        
        # Embedding module (includes NCI sum)
        x = self.embedding(x)
        
        # Transformer: (batch, d_model, seq) -> (seq, batch, d_model)
        x = x.permute(2, 0, 1)
        x = self.encoder(x)
        x = x.permute(1, 2, 0)  # Back to (batch, d_model, seq)
        
        # Output module
        x = self.output_module(x)
        
        # Fusion module
        x, x_left, x_right = self.fusion(x)
        
        return {'pred': x, 'pred_left': x_left, 'pred_right': x_right}

---
# PART 2: Your Compact BiGRU Model (Exact from working-BinnedNanoporeTwoFlowcells.py)

## Compact Model Hyperparameters (from your code):
- `embed_dim = 300`
- `alignment_filters = 128`
- `embedding_filters = 500`
- `gru_hidden = 300`
- `gru_layers = 2`
- Kernel sizes: {1, 3, 5} (only 3 kernels vs DNAFormer's 4)
- **No Siamese architecture** (single branch)
- **No Fusion module**
- **BiGRU instead of Transformer**

In [11]:
# ==============================================================================
# Compact Model: Depthwise Separable Conv1d (FROM YOUR CODE)
# ==============================================================================

class DepthwiseSeparableConv1d(nn.Module):
    """Depthwise separable convolution (more efficient than standard conv)"""
    
    def __init__(self, in_channels, out_channels, kernel_size, padding=0):
        super().__init__()
        self.depthwise = nn.Conv1d(
            in_channels, in_channels, 
            kernel_size=kernel_size,
            padding=padding,
            groups=in_channels
        )
        self.pointwise = nn.Conv1d(in_channels, out_channels, kernel_size=1)
    
    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

In [12]:
# ==============================================================================
# Compact Model: Multi-Kernel Conv Block (FROM YOUR CODE)
# ==============================================================================

class MultiKernelConvBlock(nn.Module):
    """
    Multi-kernel convolution block with proper channel handling.
    Uses 3 kernel sizes {1, 3, 5} instead of DNAFormer's 4 {1, 3, 5, 7}.
    """
    
    def __init__(self, in_channels, out_channels, seq_len, dropout=0.1):
        super().__init__()
        
        # Split channels properly, handling remainder
        c1 = out_channels // 3
        c2 = out_channels // 3
        c3 = out_channels - c1 - c2  # Gets the remainder
        
        self.conv1 = DepthwiseSeparableConv1d(in_channels, c1, kernel_size=1)
        self.conv3 = DepthwiseSeparableConv1d(in_channels, c2, kernel_size=3, padding=1)
        self.conv5 = DepthwiseSeparableConv1d(in_channels, c3, kernel_size=5, padding=2)
        
        self.norm1 = nn.LayerNorm([c1, seq_len])
        self.norm2 = nn.LayerNorm([c2, seq_len])
        self.norm3 = nn.LayerNorm([c3, seq_len])
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # Apply different kernel sizes in parallel
        x1 = F.gelu(self.norm1(self.conv1(x)))
        x2 = F.gelu(self.norm2(self.conv3(x)))
        x3 = F.gelu(self.norm3(self.conv5(x)))
        
        # Concatenate multi-kernel outputs
        out = torch.cat([x1, x2, x3], dim=1)
        out = self.dropout(out)
        return out

In [13]:
# ==============================================================================
# Compact Model: Alignment Module (FROM YOUR CODE)
# ==============================================================================

class AlignmentModule(nn.Module):
    """
    Alignment module inspired by DNAFormer.
    Processes each read individually to learn alignment before NCI.
    Lighter than DNAFormer (uses 2 conv blocks instead of 4 double_conv1D + linear).
    """
    
    def __init__(self, embed_dim, out_channels, seq_len, dropout=0.1):
        super().__init__()
        self.conv_block1 = MultiKernelConvBlock(embed_dim, out_channels, seq_len, dropout)
        self.conv_block2 = MultiKernelConvBlock(out_channels, out_channels, seq_len, dropout)
    
    def forward(self, x):
        # x shape: (batch, cluster_size, embed_dim, seq_len)
        batch, cluster, emb, seq = x.shape
        
        # Process each read independently
        x = x.view(batch * cluster, emb, seq)
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        
        # Reshape back
        x = x.view(batch, cluster, -1, seq)
        return x

In [14]:
# ==============================================================================
# Compact Model: Embedding Module (FROM YOUR CODE)
# ==============================================================================

class EmbeddingModule(nn.Module):
    """
    Embedding module - processes cluster after NCI to extract cluster-level features.
    Lighter than DNAFormer (uses 1 conv block + linear instead of 4 double_conv1D + linear).
    """
    
    def __init__(self, in_channels, out_channels, in_len, out_len, dropout=0.1):
        super().__init__()
        self.conv_block = MultiKernelConvBlock(in_channels, out_channels, in_len, dropout)
        
        # Linear projection to target length
        self.linear = nn.Linear(in_len, out_len)
    
    def forward(self, x):
        # x shape: (batch, channels, seq_len)
        x = self.conv_block(x)  # (B, out_channels, in_len)
        
        # Apply linear transformation to sequence dimension
        batch, channels, seq_len = x.shape
        
        # Reshape: (B, C, L) -> (B*C, L)
        x = x.reshape(batch * channels, seq_len)
        
        # Apply linear: (B*C, in_len) -> (B*C, out_len)
        x = self.linear(x)
        
        # Reshape back: (B*C, out_len) -> (B, C, out_len)
        x = x.reshape(batch, channels, -1)
        
        return x

In [15]:
# ==============================================================================
# Compact Model: Complete Model (FROM YOUR CODE)
# ==============================================================================

class ImprovedDNAReconstructionModel(nn.Module):
    """
    Improved DNA reconstruction model inspired by DNAFormer architecture.
    
    Architecture:
    1. Embedding layer (nn.Embedding)
    2. Alignment module (per-read processing with multi-kernel convs)
    3. NCI (Non-Coherent Integration) - sum over cluster dimension
    4. Embedding module (cluster-level feature extraction)
    5. BiGRU (instead of Transformer for efficiency)
    6. Output projection (single linear layer)
    
    Key differences from DNAFormer:
    - NO Siamese architecture (single branch)
    - 3 kernels {1,3,5} instead of 4 {1,3,5,7}
    - BiGRU instead of 12-layer Transformer
    - NO Fusion module
    - Lighter alignment and embedding modules
    
    Parameters: ~5-8M (vs DNAFormer's ~100M)
    """
    
    def __init__(self, vocab_size, label_seq_len, max_read_len, padding_idx,
                 embed_dim=300, alignment_filters=128, embedding_filters=500,
                 gru_hidden=300, gru_layers=2, dropout=0.1):
        super().__init__()
        
        self.label_seq_len = label_seq_len
        self.max_read_len = max_read_len
        
        # 1. Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        
        # 2. Alignment module (per-read processing)
        self.alignment = AlignmentModule(
            embed_dim=embed_dim,
            out_channels=alignment_filters,
            seq_len=max_read_len,
            dropout=dropout
        )
        
        # 3. NCI (Non-Coherent Integration) is just a sum - no learnable params
        
        # 4. Embedding module (cluster-level processing)
        self.embedding_module = EmbeddingModule(
            in_channels=alignment_filters,
            out_channels=embedding_filters,
            in_len=max_read_len,
            out_len=label_seq_len,
            dropout=dropout
        )
        
        # 5. BiGRU for sequence modeling
        self.gru = nn.GRU(
            embedding_filters,
            gru_hidden,
            num_layers=gru_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if gru_layers > 1 else 0
        )
        
        self.dropout = nn.Dropout(dropout)
        
        # 6. Output projection
        self.fc_out = nn.Linear(gru_hidden * 2, vocab_size)
    
    def forward(self, cluster_batch):
        """
        Args:
            cluster_batch: (batch_size, max_cluster_size, max_read_len)
        
        Returns:
            logits: (batch_size, label_seq_len, vocab_size)
        """
        # 1. Embed all reads
        embedded = self.embedding(cluster_batch)  # (B, cluster, read_len, embed_dim)
        embedded = embedded.permute(0, 1, 3, 2)  # (B, cluster, embed_dim, read_len)
        
        # 2. Alignment module - process each read independently
        aligned = self.alignment(embedded)  # (B, cluster, alignment_filters, read_len)
        
        # 3. NCI (Non-Coherent Integration) - sum over cluster dimension
        nci_output = torch.sum(aligned, dim=1)  # (B, alignment_filters, read_len)
        
        # 4. Embedding module - process cluster as a whole
        cluster_features = self.embedding_module(nci_output)  # (B, embedding_filters, label_seq_len)
        
        # 5. Prepare for GRU: (B, seq_len, features)
        x = cluster_features.permute(0, 2, 1)  # (B, label_seq_len, embedding_filters)
        
        # 6. BiGRU
        gru_out, _ = self.gru(x)  # (B, label_seq_len, gru_hidden*2)
        gru_out = self.dropout(gru_out)
        
        # 7. Output projection
        logits = self.fc_out(gru_out)  # (B, label_seq_len, vocab_size)
        
        return logits

---
# PART 3: Parameter Counting Functions

In [16]:
def count_parameters(model):
    """Count total trainable parameters in a model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def count_parameters_by_module(model):
    """Count parameters for each named module in the model."""
    param_counts = OrderedDict()
    
    for name, module in model.named_children():
        params = sum(p.numel() for p in module.parameters() if p.requires_grad)
        param_counts[name] = params
    
    param_counts['TOTAL'] = count_parameters(model)
    return param_counts


def format_params(n):
    """Format parameter count nicely."""
    if n >= 1e9:
        return f"{n/1e9:.2f}B"
    elif n >= 1e6:
        return f"{n/1e6:.2f}M"
    elif n >= 1e3:
        return f"{n/1e3:.2f}K"
    else:
        return str(n)

---
# PART 4: FLOPs Counting Functions

In [17]:
def count_dnaformer_flops(enc_filters, alignment_filters, d_model, n_head,
                          num_layers, dim_feedforward, output_ch,
                          noisy_copies_length, label_length, 
                          batch_size=1, max_cluster_size=16):
    """
    Estimate FLOPs for DNAFormer (forward pass).
    
    FLOPs = 2 * MACs (multiply-accumulate operations)
    """
    flops = OrderedDict()
    num_reads = batch_size * max_cluster_size
    
    # Helper functions
    def ds_conv_flops(in_ch, out_ch, kernel_size, seq_len):
        """FLOPs for depthwise separable conv."""
        # Depthwise: in_ch * kernel_size * seq_len MACs
        depthwise = in_ch * kernel_size * seq_len
        # Pointwise: in_ch * out_ch * seq_len MACs
        pointwise = in_ch * out_ch * seq_len
        return 2 * (depthwise + pointwise)  # MACs -> FLOPs
    
    def double_conv1d_flops(in_ch, out_ch, kernel_size, seq_len):
        """FLOPs for double_conv1D block."""
        conv1 = ds_conv_flops(in_ch, out_ch, kernel_size, seq_len)
        conv2 = ds_conv_flops(out_ch, out_ch, kernel_size, seq_len)
        # LayerNorm: ~5 ops per element
        norms = 2 * 5 * out_ch * seq_len
        # GELU: ~8 ops per element (exp, mul, add, etc.)
        activations = 2 * 8 * out_ch * seq_len
        return conv1 + conv2 + norms + activations
    
    def linear_block_flops(in_len, out_len, channels):
        """FLOPs for linear_block."""
        fc1 = 2 * channels * in_len * out_len
        fc2 = 2 * channels * out_len * out_len
        fc3 = 2 * channels * out_len * out_len
        norms = 2 * 5 * channels * out_len
        activations = 2 * 8 * channels * out_len
        return fc1 + fc2 + fc3 + norms + activations
    
    # =========================================================================
    # 1. ALIGNMENT MODULE
    # =========================================================================
    align_out_ch = alignment_filters // 4
    
    align_conv1 = double_conv1d_flops(enc_filters, align_out_ch, 1, noisy_copies_length) * num_reads
    align_conv3 = double_conv1d_flops(enc_filters, align_out_ch, 3, noisy_copies_length) * num_reads
    align_conv5 = double_conv1d_flops(enc_filters, align_out_ch, 5, noisy_copies_length) * num_reads
    align_conv7 = double_conv1d_flops(enc_filters, align_out_ch, 7, noisy_copies_length) * num_reads
    align_linear = linear_block_flops(noisy_copies_length, noisy_copies_length, alignment_filters) * num_reads
    
    flops['Alignment_Conv_K1'] = align_conv1
    flops['Alignment_Conv_K3'] = align_conv3
    flops['Alignment_Conv_K5'] = align_conv5
    flops['Alignment_Conv_K7'] = align_conv7
    flops['Alignment_Linear'] = align_linear
    flops['Alignment_Module_Total'] = align_conv1 + align_conv3 + align_conv5 + align_conv7 + align_linear
    
    # =========================================================================
    # 2. EMBEDDING MODULE (after NCI)
    # =========================================================================
    emb_out_ch = d_model // 4
    
    emb_conv1 = double_conv1d_flops(alignment_filters, emb_out_ch, 1, noisy_copies_length) * batch_size
    emb_conv3 = double_conv1d_flops(alignment_filters, emb_out_ch, 3, noisy_copies_length) * batch_size
    emb_conv5 = double_conv1d_flops(alignment_filters, emb_out_ch, 5, noisy_copies_length) * batch_size
    emb_conv7 = double_conv1d_flops(alignment_filters, emb_out_ch, 7, noisy_copies_length) * batch_size
    emb_linear = linear_block_flops(noisy_copies_length, label_length, d_model) * batch_size
    
    flops['Embedding_Conv_K1'] = emb_conv1
    flops['Embedding_Conv_K3'] = emb_conv3
    flops['Embedding_Conv_K5'] = emb_conv5
    flops['Embedding_Conv_K7'] = emb_conv7
    flops['Embedding_Linear'] = emb_linear
    flops['Embedding_Module_Total'] = emb_conv1 + emb_conv3 + emb_conv5 + emb_conv7 + emb_linear
    
    # =========================================================================
    # 3. TRANSFORMER ENCODER
    # =========================================================================
    seq_len = label_length
    
    # Self-attention per layer
    # Q, K, V projections: 3 * 2 * batch * seq * d_model^2
    qkv_proj = 3 * 2 * batch_size * seq_len * d_model * d_model
    # Attention scores: batch * n_head * seq * seq * (d_model/n_head)
    attn_scores = 2 * batch_size * n_head * seq_len * seq_len * (d_model // n_head)
    # Softmax: ~5 ops per element
    softmax = 5 * batch_size * n_head * seq_len * seq_len
    # Attention output: batch * n_head * seq * (d_model/n_head) * seq
    attn_out = 2 * batch_size * n_head * seq_len * (d_model // n_head) * seq_len
    # Output projection: batch * seq * d_model^2
    out_proj = 2 * batch_size * seq_len * d_model * d_model
    
    attn_total = qkv_proj + attn_scores + softmax + attn_out + out_proj
    
    # FFN per layer: d_model -> dim_feedforward -> d_model
    ffn1 = 2 * batch_size * seq_len * d_model * dim_feedforward
    ffn2 = 2 * batch_size * seq_len * dim_feedforward * d_model
    ffn_total = ffn1 + ffn2
    
    # Layer norms
    layer_norms = 2 * 5 * batch_size * seq_len * d_model
    
    transformer_per_layer = attn_total + ffn_total + layer_norms
    
    flops['Transformer_Attention_per_layer'] = attn_total
    flops['Transformer_FFN_per_layer'] = ffn_total
    flops['Transformer_LayerNorm_per_layer'] = layer_norms
    flops['Transformer_per_layer_Total'] = transformer_per_layer
    flops['Transformer_Total'] = transformer_per_layer * num_layers
    
    # =========================================================================
    # 4. OUTPUT MODULE
    # =========================================================================
    out_conv1 = 2 * batch_size * d_model * d_model * label_length
    out_conv2 = 2 * batch_size * d_model * d_model * label_length
    out_conv3 = 2 * batch_size * d_model * output_ch * label_length
    
    flops['Output_Conv1'] = out_conv1
    flops['Output_Conv2'] = out_conv2
    flops['Output_Conv3'] = out_conv3
    flops['Output_Module_Total'] = out_conv1 + out_conv2 + out_conv3
    
    # =========================================================================
    # 5. FUSION MODULE
    # =========================================================================
    fusion_conv1 = 2 * batch_size * output_ch * output_ch * label_length
    fusion_conv2 = 2 * batch_size * output_ch * output_ch * label_length
    fusion_conv3 = 2 * batch_size * output_ch * output_ch * label_length
    fusion_elemwise = batch_size * output_ch * label_length * 4  # mul, mul, add, div
    
    flops['Fusion_Conv1'] = fusion_conv1
    flops['Fusion_Conv2'] = fusion_conv2
    flops['Fusion_Conv3'] = fusion_conv3
    flops['Fusion_Elementwise'] = fusion_elemwise
    flops['Fusion_Module_Total'] = fusion_conv1 + fusion_conv2 + fusion_conv3 + fusion_elemwise
    
    # =========================================================================
    # GRAND TOTAL
    # =========================================================================
    flops['GRAND_TOTAL'] = (flops['Alignment_Module_Total'] + 
                            flops['Embedding_Module_Total'] + 
                            flops['Transformer_Total'] + 
                            flops['Output_Module_Total'] + 
                            flops['Fusion_Module_Total'])
    
    return flops

In [18]:
def count_compact_flops(vocab_size, embed_dim, alignment_filters, 
                        embedding_filters, gru_hidden, gru_layers,
                        max_read_len, label_seq_len,
                        batch_size=1, max_cluster_size=16):
    """
    Estimate FLOPs for Compact BiGRU Model (forward pass).
    """
    flops = OrderedDict()
    num_reads = batch_size * max_cluster_size
    
    # Helper functions
    def ds_conv_flops(in_ch, out_ch, kernel_size, seq_len):
        depthwise = in_ch * kernel_size * seq_len
        pointwise = in_ch * out_ch * seq_len
        return 2 * (depthwise + pointwise)
    
    def multi_kernel_block_flops(in_ch, out_ch, seq_len):
        c1 = out_ch // 3
        c2 = out_ch // 3
        c3 = out_ch - c1 - c2
        
        conv1 = ds_conv_flops(in_ch, c1, 1, seq_len)
        conv3 = ds_conv_flops(in_ch, c2, 3, seq_len)
        conv5 = ds_conv_flops(in_ch, c3, 5, seq_len)
        
        # LayerNorms: 5 ops per element
        norms = 5 * (c1 * seq_len + c2 * seq_len + c3 * seq_len)
        # GELU: 8 ops per element
        activations = 8 * out_ch * seq_len
        
        return conv1 + conv3 + conv5 + norms + activations
    
    # =========================================================================
    # 1. EMBEDDING LAYER
    # =========================================================================
    # Lookup: essentially 0 FLOPs (just memory access)
    flops['Embedding_Layer'] = 0
    
    # =========================================================================
    # 2. ALIGNMENT MODULE
    # =========================================================================
    align_block1 = multi_kernel_block_flops(embed_dim, alignment_filters, max_read_len) * num_reads
    align_block2 = multi_kernel_block_flops(alignment_filters, alignment_filters, max_read_len) * num_reads
    
    flops['Alignment_Block1'] = align_block1
    flops['Alignment_Block2'] = align_block2
    flops['Alignment_Module_Total'] = align_block1 + align_block2
    
    # =========================================================================
    # 3. NCI (Sum)
    # =========================================================================
    flops['NCI'] = batch_size * max_cluster_size * alignment_filters * max_read_len
    
    # =========================================================================
    # 4. EMBEDDING MODULE
    # =========================================================================
    emb_conv = multi_kernel_block_flops(alignment_filters, embedding_filters, max_read_len) * batch_size
    emb_linear = 2 * batch_size * embedding_filters * max_read_len * label_seq_len
    
    flops['Embedding_ConvBlock'] = emb_conv
    flops['Embedding_Linear'] = emb_linear
    flops['Embedding_Module_Total'] = emb_conv + emb_linear
    
    # =========================================================================
    # 5. BiGRU
    # =========================================================================
    # GRU FLOPs per timestep per layer per direction:
    # 3 gates, each: 2 * (input_size * hidden_size + hidden_size * hidden_size)
    # Plus activations: ~10 ops per hidden unit for sigmoid/tanh
    
    gru_flops = 0
    for layer in range(gru_layers):
        if layer == 0:
            input_size = embedding_filters
        else:
            input_size = 2 * gru_hidden  # bidirectional output
        
        # Per timestep per direction
        per_timestep = 3 * 2 * (input_size * gru_hidden + gru_hidden * gru_hidden)
        # Activations
        per_timestep += 3 * 10 * gru_hidden  # sigmoid/tanh for 3 gates
        # Both directions
        per_timestep *= 2
        # All timesteps, all batches
        gru_flops += batch_size * label_seq_len * per_timestep
    
    flops['BiGRU_Total'] = gru_flops
    
    # =========================================================================
    # 6. OUTPUT PROJECTION
    # =========================================================================
    flops['Output_Linear'] = 2 * batch_size * label_seq_len * (2 * gru_hidden) * vocab_size
    
    # =========================================================================
    # GRAND TOTAL
    # =========================================================================
    flops['GRAND_TOTAL'] = (flops['Embedding_Layer'] + 
                            flops['Alignment_Module_Total'] + 
                            flops['NCI'] +
                            flops['Embedding_Module_Total'] + 
                            flops['BiGRU_Total'] + 
                            flops['Output_Linear'])
    
    return flops

---
# PART 5: Run Analysis for All Datasets

In [19]:
# DNAFormer fixed hyperparameters (from paper)
DNAFORMER_CONFIG = {
    'enc_filters': 4,          # One-hot encoding of DNA (A, C, G, T)
    'alignment_filters': 128,
    'd_model': 1024,
    'n_head': 32,
    'num_layers': 12,
    'dim_feedforward': 2048,
    'output_ch': 4,
    'p_dropout': 0,
}

# Your compact model hyperparameters (from your code)
COMPACT_CONFIG = {
    'vocab_size': 5,           # N, A, C, G, T
    'embed_dim': 300,
    'alignment_filters': 128,
    'embedding_filters': 500,
    'gru_hidden': 300,
    'gru_layers': 2,
    'dropout': 0.1,
    'padding_idx': 0,
}

MAX_CLUSTER_SIZE = 16  # From DNAFormer paper

print("DNAFormer Config:")
for k, v in DNAFORMER_CONFIG.items():
    print(f"  {k}: {v}")

print("\nCompact Model Config:")
for k, v in COMPACT_CONFIG.items():
    print(f"  {k}: {v}")

DNAFormer Config:
  enc_filters: 4
  alignment_filters: 128
  d_model: 1024
  n_head: 32
  num_layers: 12
  dim_feedforward: 2048
  output_ch: 4
  p_dropout: 0

Compact Model Config:
  vocab_size: 5
  embed_dim: 300
  alignment_filters: 128
  embedding_filters: 500
  gru_hidden: 300
  gru_layers: 2
  dropout: 0.1
  padding_idx: 0


In [20]:
# Storage for results
results_summary = []
dnaformer_detailed_results = {}
compact_detailed_results = {}

print("=" * 100)
print("PARAMETER AND FLOPS ANALYSIS")
print("=" * 100)

for dataset_name, config in DATASET_CONFIGS.items():
    print(f"\n{'='*80}")
    print(f"DATASET: {dataset_name}")
    print(f"{'='*80}")
    print(f"Description: {config['description']}")
    
    # Calculate derived parameters
    label_length = config['label_len']
    if config.get('index_length', 0) > 0:
        label_length -= config['index_length']
    
    noisy_copies_length = config['label_len'] + config['max_deviation']
    if config.get('index_length', 0) > 0:
        noisy_copies_length -= config['index_length']
    
    max_read_len = config['label_len'] + config['max_deviation'] + 8  # Extra buffer
    
    print(f"\nDerived parameters:")
    print(f"  label_length (output): {label_length}")
    print(f"  noisy_copies_length (input): {noisy_copies_length}")
    print(f"  max_read_len: {max_read_len}")
    
    # =========================================================================
    # DNAFormer
    # =========================================================================
    print(f"\n--- DNAFormer ---")
    
    # Build model
    dnaformer_model = DNAFormer(
        enc_filters=DNAFORMER_CONFIG['enc_filters'],
        alignment_filters=DNAFORMER_CONFIG['alignment_filters'],
        d_model=DNAFORMER_CONFIG['d_model'],
        n_head=DNAFORMER_CONFIG['n_head'],
        num_layers=DNAFORMER_CONFIG['num_layers'],
        dim_feedforward=DNAFORMER_CONFIG['dim_feedforward'],
        output_ch=DNAFORMER_CONFIG['output_ch'],
        noisy_copies_length=noisy_copies_length,
        label_length=label_length,
        p_dropout=DNAFORMER_CONFIG['p_dropout']
    )
    
    # Count parameters
    dnaformer_params = count_parameters_by_module(dnaformer_model)
    
    # Count FLOPs
    dnaformer_flops = count_dnaformer_flops(
        enc_filters=DNAFORMER_CONFIG['enc_filters'],
        alignment_filters=DNAFORMER_CONFIG['alignment_filters'],
        d_model=DNAFORMER_CONFIG['d_model'],
        n_head=DNAFORMER_CONFIG['n_head'],
        num_layers=DNAFORMER_CONFIG['num_layers'],
        dim_feedforward=DNAFORMER_CONFIG['dim_feedforward'],
        output_ch=DNAFORMER_CONFIG['output_ch'],
        noisy_copies_length=noisy_copies_length,
        label_length=label_length,
        batch_size=1,
        max_cluster_size=MAX_CLUSTER_SIZE
    )
    
    dnaformer_detailed_results[dataset_name] = {
        'params': dnaformer_params,
        'flops': dnaformer_flops
    }
    
    print(f"  Total Parameters: {dnaformer_params['TOTAL']:,} ({format_params(dnaformer_params['TOTAL'])})")
    print(f"  Total FLOPs:      {dnaformer_flops['GRAND_TOTAL']:,} ({format_params(dnaformer_flops['GRAND_TOTAL'])})")
    
    # =========================================================================
    # Compact BiGRU
    # =========================================================================
    print(f"\n--- Compact BiGRU (Your Model) ---")
    
    # Build model
    compact_model = ImprovedDNAReconstructionModel(
        vocab_size=COMPACT_CONFIG['vocab_size'],
        label_seq_len=label_length,
        max_read_len=max_read_len,
        padding_idx=COMPACT_CONFIG['padding_idx'],
        embed_dim=COMPACT_CONFIG['embed_dim'],
        alignment_filters=COMPACT_CONFIG['alignment_filters'],
        embedding_filters=COMPACT_CONFIG['embedding_filters'],
        gru_hidden=COMPACT_CONFIG['gru_hidden'],
        gru_layers=COMPACT_CONFIG['gru_layers'],
        dropout=COMPACT_CONFIG['dropout']
    )
    
    # Count parameters
    compact_params = count_parameters_by_module(compact_model)
    
    # Count FLOPs
    compact_flops = count_compact_flops(
        vocab_size=COMPACT_CONFIG['vocab_size'],
        embed_dim=COMPACT_CONFIG['embed_dim'],
        alignment_filters=COMPACT_CONFIG['alignment_filters'],
        embedding_filters=COMPACT_CONFIG['embedding_filters'],
        gru_hidden=COMPACT_CONFIG['gru_hidden'],
        gru_layers=COMPACT_CONFIG['gru_layers'],
        max_read_len=max_read_len,
        label_seq_len=label_length,
        batch_size=1,
        max_cluster_size=MAX_CLUSTER_SIZE
    )
    
    compact_detailed_results[dataset_name] = {
        'params': compact_params,
        'flops': compact_flops
    }
    
    print(f"  Total Parameters: {compact_params['TOTAL']:,} ({format_params(compact_params['TOTAL'])})")
    print(f"  Total FLOPs:      {compact_flops['GRAND_TOTAL']:,} ({format_params(compact_flops['GRAND_TOTAL'])})")
    
    # =========================================================================
    # Comparison
    # =========================================================================
    param_reduction = (1 - compact_params['TOTAL'] / dnaformer_params['TOTAL']) * 100
    flop_reduction = (1 - compact_flops['GRAND_TOTAL'] / dnaformer_flops['GRAND_TOTAL']) * 100
    
    print(f"\n--- Comparison ---")
    print(f"  Parameter Reduction: {param_reduction:.1f}%")
    print(f"  FLOP Reduction:      {flop_reduction:.1f}%")
    
    # Store summary
    results_summary.append({
        'Dataset': dataset_name,
        'Label_Length': label_length,
        'Noisy_Copies_Len': noisy_copies_length,
        'Max_Read_Len': max_read_len,
        'DNAFormer_Params': dnaformer_params['TOTAL'],
        'DNAFormer_Params_M': dnaformer_params['TOTAL'] / 1e6,
        'DNAFormer_FLOPs': dnaformer_flops['GRAND_TOTAL'],
        'DNAFormer_FLOPs_G': dnaformer_flops['GRAND_TOTAL'] / 1e9,
        'Compact_Params': compact_params['TOTAL'],
        'Compact_Params_M': compact_params['TOTAL'] / 1e6,
        'Compact_FLOPs': compact_flops['GRAND_TOTAL'],
        'Compact_FLOPs_G': compact_flops['GRAND_TOTAL'] / 1e9,
        'Param_Reduction_%': param_reduction,
        'FLOP_Reduction_%': flop_reduction,
    })

PARAMETER AND FLOPS ANALYSIS

DATASET: Srinivasavaradhan
Description: Srinivasavaradhan et al. - Twist, MinION (high error)

Derived parameters:
  label_length (output): 110
  noisy_copies_length (input): 120
  max_read_len: 128

--- DNAFormer ---
  Total Parameters: 103,396,622 (103.40M)
  Total FLOPs:      23,627,898,400 (23.63B)

--- Compact BiGRU (Your Model) ---




  Total Parameters: 3,405,643 (3.41M)
  Total FLOPs:      956,661,856 (956.66M)

--- Comparison ---
  Parameter Reduction: 96.7%
  FLOP Reduction:      96.0%

DATASET: Grass
Description: Grass et al. (2015) - CustomArray, Illumina miSeq

Derived parameters:
  label_length (output): 117
  noisy_copies_length (input): 128
  max_read_len: 136

--- DNAFormer ---
  Total Parameters: 103,407,903 (103.41M)
  Total FLOPs:      25,192,109,744 (25.19B)

--- Compact BiGRU (Your Model) ---
  Total Parameters: 3,419,578 (3.42M)
  Total FLOPs:      1,018,175,472 (1.02B)

--- Comparison ---
  Parameter Reduction: 96.7%
  FLOP Reduction:      96.0%

DATASET: Erlich
Description: Erlich et al. (2017) - DNA Fountain, Illumina miSeq

Derived parameters:
  label_length (output): 152
  noisy_copies_length (input): 162
  max_read_len: 170

--- DNAFormer ---
  Total Parameters: 103,467,602 (103.47M)
  Total FLOPs:      33,088,512,640 (33.09B)

--- Compact BiGRU (Your Model) ---
  Total Parameters: 3,480,949 (

---
# PART 6: Summary Tables

In [21]:
# Create summary dataframe
df_summary = pd.DataFrame(results_summary)

print("=" * 100)
print("SUMMARY TABLE")
print("=" * 100)
print()

# Display nicely formatted
display_cols = ['Dataset', 'Label_Length', 'DNAFormer_Params_M', 'Compact_Params_M', 
                'Param_Reduction_%', 'DNAFormer_FLOPs_G', 'Compact_FLOPs_G', 'FLOP_Reduction_%']

df_display = df_summary[display_cols].copy()
df_display.columns = ['Dataset', 'Label Len', 'DNAFormer (M)', 'Compact (M)', 
                      'Param Red. %', 'DNAFormer (G)', 'Compact (G)', 'FLOP Red. %']

# Round for display
df_display['DNAFormer (M)'] = df_display['DNAFormer (M)'].round(2)
df_display['Compact (M)'] = df_display['Compact (M)'].round(2)
df_display['Param Red. %'] = df_display['Param Red. %'].round(1)
df_display['DNAFormer (G)'] = df_display['DNAFormer (G)'].round(2)
df_display['Compact (G)'] = df_display['Compact (G)'].round(3)
df_display['FLOP Red. %'] = df_display['FLOP Red. %'].round(1)

print(df_display.to_string(index=False))

SUMMARY TABLE

                   Dataset  Label Len  DNAFormer (M)  Compact (M)  Param Red. %  DNAFormer (G)  Compact (G)  FLOP Red. %
         Srinivasavaradhan        110         103.40         3.41          96.7          23.63        0.957         96.0
                     Grass        117         103.41         3.42          96.7          25.19        1.018         96.0
                    Erlich        152         103.47         3.48          96.6          33.09        1.314         96.0
BinnedNanoporeTwoFlowcells        116         103.40         3.43          96.7          24.94        1.021         95.9
        BinnedTestIllumina        116         103.40         3.43          96.7          24.94        1.021         95.9


In [22]:
# DNAFormer module-by-module parameters
print("\n" + "=" * 100)
print("DNAFormer MODULE-BY-MODULE PARAMETERS")
print("=" * 100)

dnaformer_module_data = {}
for dataset_name in DATASET_CONFIGS.keys():
    params = dnaformer_detailed_results[dataset_name]['params']
    dnaformer_module_data[dataset_name] = {
        'Alignment': params.get('alignement', 0),
        'Embedding': params.get('embedding', 0),
        'Transformer': params.get('encoder', 0),
        'Output': params.get('output_module', 0),
        'Fusion': params.get('fusion', 0),
        'TOTAL': params.get('TOTAL', 0)
    }

df_dnaformer_modules = pd.DataFrame(dnaformer_module_data).T
df_dnaformer_modules = df_dnaformer_modules.applymap(lambda x: f"{x/1e6:.2f}M" if x > 0 else "0")
print(df_dnaformer_modules.to_string())


DNAFormer MODULE-BY-MODULE PARAMETERS
                           Alignment Embedding Transformer Output Fusion    TOTAL
Srinivasavaradhan              0.05M     0.44M     100.80M  2.10M  0.00M  103.40M
Grass                          0.06M     0.45M     100.80M  2.10M  0.00M  103.41M
Erlich                         0.09M     0.48M     100.80M  2.10M  0.00M  103.47M
BinnedNanoporeTwoFlowcells     0.05M     0.45M     100.80M  2.10M  0.00M  103.40M
BinnedTestIllumina             0.05M     0.45M     100.80M  2.10M  0.00M  103.40M


In [23]:
# Compact model module-by-module parameters
print("\n" + "=" * 100)
print("COMPACT BiGRU MODULE-BY-MODULE PARAMETERS")
print("=" * 100)

compact_module_data = {}
for dataset_name in DATASET_CONFIGS.keys():
    params = compact_detailed_results[dataset_name]['params']
    compact_module_data[dataset_name] = {
        'Embedding_Layer': params.get('embedding', 0),
        'Alignment': params.get('alignment', 0),
        'Embedding_Module': params.get('embedding_module', 0),
        'BiGRU': params.get('gru', 0),
        'Output': params.get('fc_out', 0),
        'TOTAL': params.get('TOTAL', 0)
    }

df_compact_modules = pd.DataFrame(compact_module_data).T
df_compact_modules = df_compact_modules.applymap(lambda x: f"{x/1e6:.2f}M" if x > 1e6 else f"{x/1e3:.1f}K" if x > 0 else "0")
print(df_compact_modules.to_string())


COMPACT BiGRU MODULE-BY-MODULE PARAMETERS
                           Embedding_Layer Alignment Embedding_Module  BiGRU Output  TOTAL
Srinivasavaradhan                     1.5K    125.7K           208.2K  3.07M   3.0K  3.41M
Grass                                 1.5K    129.8K           218.1K  3.07M   3.0K  3.42M
Erlich                                1.5K    147.2K           262.0K  3.07M   3.0K  3.48M
BinnedNanoporeTwoFlowcells            1.5K    131.9K           222.4K  3.07M   3.0K  3.43M
BinnedTestIllumina                    1.5K    131.9K           222.4K  3.07M   3.0K  3.43M


In [24]:
# DNAFormer module-by-module FLOPs
print("\n" + "=" * 100)
print("DNAFormer MODULE-BY-MODULE FLOPs")
print("=" * 100)

dnaformer_flop_data = {}
for dataset_name in DATASET_CONFIGS.keys():
    flops = dnaformer_detailed_results[dataset_name]['flops']
    dnaformer_flop_data[dataset_name] = {
        'Alignment': flops.get('Alignment_Module_Total', 0),
        'Embedding': flops.get('Embedding_Module_Total', 0),
        'Transformer': flops.get('Transformer_Total', 0),
        'Output': flops.get('Output_Module_Total', 0),
        'Fusion': flops.get('Fusion_Module_Total', 0),
        'TOTAL': flops.get('GRAND_TOTAL', 0)
    }

df_dnaformer_flops = pd.DataFrame(dnaformer_flop_data).T
df_dnaformer_flops = df_dnaformer_flops.applymap(lambda x: f"{x/1e9:.2f}G")
print(df_dnaformer_flops.to_string())


DNAFormer MODULE-BY-MODULE FLOPs
                           Alignment Embedding Transformer Output Fusion   TOTAL
Srinivasavaradhan              0.21G     0.18G      22.78G  0.46G  0.00G  23.63G
Grass                          0.24G     0.20G      24.27G  0.49G  0.00G  25.19G
Erlich                         0.37G     0.28G      31.80G  0.64G  0.00G  33.09G
BinnedNanoporeTwoFlowcells     0.21G     0.19G      24.06G  0.49G  0.00G  24.94G
BinnedTestIllumina             0.21G     0.19G      24.06G  0.49G  0.00G  24.94G


In [25]:
# Compact model module-by-module FLOPs
print("\n" + "=" * 100)
print("COMPACT BiGRU MODULE-BY-MODULE FLOPs")
print("=" * 100)

compact_flop_data = {}
for dataset_name in DATASET_CONFIGS.keys():
    flops = compact_detailed_results[dataset_name]['flops']
    compact_flop_data[dataset_name] = {
        'Embedding': flops.get('Embedding_Layer', 0),
        'Alignment': flops.get('Alignment_Module_Total', 0),
        'NCI': flops.get('NCI', 0),
        'Embedding_Mod': flops.get('Embedding_Module_Total', 0),
        'BiGRU': flops.get('BiGRU_Total', 0),
        'Output': flops.get('Output_Linear', 0),
        'TOTAL': flops.get('GRAND_TOTAL', 0)
    }

df_compact_flops = pd.DataFrame(compact_flop_data).T
df_compact_flops = df_compact_flops.applymap(lambda x: f"{x/1e9:.4f}G" if x > 0 else "0")
print(df_compact_flops.to_string())


COMPACT BiGRU MODULE-BY-MODULE FLOPs
                           Embedding Alignment      NCI Embedding_Mod    BiGRU   Output    TOTAL
Srinivasavaradhan                  0   0.2470G  0.0003G       0.0316G  0.6772G  0.0007G  0.9567G
Grass                              0   0.2624G  0.0003G       0.0345G  0.7203G  0.0007G  1.0182G
Erlich                             0   0.3280G  0.0003G       0.0491G  0.9357G  0.0009G  1.3141G
BinnedNanoporeTwoFlowcells         0   0.2701G  0.0003G       0.0354G  0.7141G  0.0007G  1.0206G
BinnedTestIllumina                 0   0.2701G  0.0003G       0.0354G  0.7141G  0.0007G  1.0206G


---
# PART 7: Architectural Differences Summary

In [26]:
print("\n" + "=" * 100)
print("ARCHITECTURAL DIFFERENCES SUMMARY")
print("=" * 100)

print("""
╔══════════════════════════════════╦════════════════════════════╦════════════════════════════╗
║           Feature                ║       DNAFormer            ║     Compact BiGRU          ║
╠══════════════════════════════════╬════════════════════════════╬════════════════════════════╣
║ Sequence Modeling                ║ 12-layer Transformer       ║ 2-layer BiGRU              ║
║ Attention Heads                  ║ 32                         ║ N/A                        ║
║ Model Dimension (d_model)        ║ 1024                       ║ 500 (embedding_filters)    ║
║ Feedforward Dimension            ║ 2048                       ║ N/A                        ║
║ Alignment Kernels                ║ 4 kernels {1,3,5,7}        ║ 3 kernels {1,3,5}          ║
║ Alignment Structure              ║ 4 double_conv1D + linear   ║ 2 multi-kernel blocks      ║
║ Embedding Structure              ║ 4 double_conv1D + linear   ║ 1 multi-kernel + linear    ║
║ Siamese Architecture             ║ YES (2 branches)           ║ NO (single branch)         ║
║ Fusion Module                    ║ YES (vectors + 3 conv)     ║ NO                         ║
║ Output Module                    ║ 3 Conv1x1 layers           ║ 1 Linear layer             ║
║ Input Encoding                   ║ One-hot (4 channels)       ║ nn.Embedding (300 dim)     ║
║ Attention Complexity             ║ O(L²)                      ║ O(L) linear                ║
║ ~Parameters                      ║ ~100M                      ║ ~5-8M                      ║
╚══════════════════════════════════╩════════════════════════════╩════════════════════════════╝

KEY ADVANTAGES OF COMPACT MODEL:
═══════════════════════════════════
1. NO Transformer     → Removes ~85M params (largest component in DNAFormer)
2. NO Siamese         → Single forward pass 
3. NO Fusion Module   → Simpler output pipeline
4. Fewer Kernels      → 3 instead of 4 kernel sizes
5. Simpler Modules    → Lighter alignment and embedding modules
6. Linear Complexity  → O(L) vs O(L²) for self-attention
7. nn.Embedding       → Learnable embeddings instead of fixed one-hot

MODULES REMOVED/SIMPLIFIED:
═══════════════════════════════════
- Transformer Encoder (12 layers, 32 heads)         → BiGRU (2 layers)
- Fusion Module (learnable vectors + 3 Conv1x1)     → REMOVED
- Kernel size 7 branch                              → REMOVED
- Second double_conv1D in each kernel branch        → Single conv
- Linear block in alignment (3 linear layers)       → REMOVED
- Output module (3 Conv1x1)                         → Single Linear
""")


ARCHITECTURAL DIFFERENCES SUMMARY

╔══════════════════════════════════╦════════════════════════════╦════════════════════════════╗
║           Feature                ║       DNAFormer            ║     Compact BiGRU          ║
╠══════════════════════════════════╬════════════════════════════╬════════════════════════════╣
║ Sequence Modeling                ║ 12-layer Transformer       ║ 2-layer BiGRU              ║
║ Attention Heads                  ║ 32                         ║ N/A                        ║
║ Model Dimension (d_model)        ║ 1024                       ║ 500 (embedding_filters)    ║
║ Feedforward Dimension            ║ 2048                       ║ N/A                        ║
║ Alignment Kernels                ║ 4 kernels {1,3,5,7}        ║ 3 kernels {1,3,5}          ║
║ Alignment Structure              ║ 4 double_conv1D + linear   ║ 2 multi-kernel blocks      ║
║ Embedding Structure              ║ 4 double_conv1D + linear   ║ 1 multi-kernel + linear    ║
║ Siamese Arch

In [27]:
#### Final statistics
print("\n" + "=" * 100)
print("FINAL STATISTICS")
print("=" * 100)

avg_param_reduction = df_summary['Param_Reduction_%'].mean()
avg_flop_reduction = df_summary['FLOP_Reduction_%'].mean()

print(f"\nAverage Parameter Reduction: {avg_param_reduction:.1f}%")
print(f"Average FLOP Reduction:      {avg_flop_reduction:.1f}%")

print("\n" + "-" * 80)
print("Per-Dataset Summary:")
print("-" * 80)

for r in results_summary:
    print(f"\n{r['Dataset']}:")
    print(f"  DNAFormer: {r['DNAFormer_Params_M']:.2f}M params, {r['DNAFormer_FLOPs_G']:.2f}G FLOPs")
    print(f"  Compact:   {r['Compact_Params_M']:.2f}M params, {r['Compact_FLOPs_G']:.4f}G FLOPs")
    print(f"  Reduction: {r['Param_Reduction_%']:.1f}% params, {r['FLOP_Reduction_%']:.1f}% FLOPs")


FINAL STATISTICS

Average Parameter Reduction: 96.7%
Average FLOP Reduction:      96.0%

--------------------------------------------------------------------------------
Per-Dataset Summary:
--------------------------------------------------------------------------------

Srinivasavaradhan:
  DNAFormer: 103.40M params, 23.63G FLOPs
  Compact:   3.41M params, 0.9567G FLOPs
  Reduction: 96.7% params, 96.0% FLOPs

Grass:
  DNAFormer: 103.41M params, 25.19G FLOPs
  Compact:   3.42M params, 1.0182G FLOPs
  Reduction: 96.7% params, 96.0% FLOPs

Erlich:
  DNAFormer: 103.47M params, 33.09G FLOPs
  Compact:   3.48M params, 1.3141G FLOPs
  Reduction: 96.6% params, 96.0% FLOPs

BinnedNanoporeTwoFlowcells:
  DNAFormer: 103.40M params, 24.94G FLOPs
  Compact:   3.43M params, 1.0206G FLOPs
  Reduction: 96.7% params, 95.9% FLOPs

BinnedTestIllumina:
  DNAFormer: 103.40M params, 24.94G FLOPs
  Compact:   3.43M params, 1.0206G FLOPs
  Reduction: 96.7% params, 95.9% FLOPs


In [28]:
# Export to CSV for paper
df_summary.to_csv('param_flops_comparison.csv', index=False)
print("\nResults saved to 'param_flops_comparison.csv'")


Results saved to 'param_flops_comparison.csv'


---
# PART 8: LaTeX Table Generation for Paper

In [29]:
print("\n" + "=" * 100)
print("LaTeX TABLE FOR PAPER")
print("=" * 100)

latex_table = r"""
\begin{table}[t]
\centering
\caption{Parameter Count and FLOPs Comparison: DNAFormer vs. Compact BiGRU}
\label{tab:param_flops}
\begin{tabular}{lcccccc}
\toprule
\multirow{2}{*}{Dataset} & \multicolumn{2}{c}{Parameters (M)} & \multicolumn{2}{c}{FLOPs (G)} & \multicolumn{2}{c}{Reduction (\%)} \\
\cmidrule(lr){2-3} \cmidrule(lr){4-5} \cmidrule(lr){6-7}
 & DNAFormer & Ours & DNAFormer & Ours & Params & FLOPs \\
\midrule
"""

for r in results_summary:
    latex_table += f"{r['Dataset']} & {r['DNAFormer_Params_M']:.2f} & {r['Compact_Params_M']:.2f} & "
    latex_table += f"{r['DNAFormer_FLOPs_G']:.2f} & {r['Compact_FLOPs_G']:.3f} & "
    latex_table += f"{r['Param_Reduction_%']:.1f} & {r['FLOP_Reduction_%']:.1f} \\\\\n"

latex_table += r"""\midrule
Average & --- & --- & --- & --- & """
latex_table += f"{avg_param_reduction:.1f} & {avg_flop_reduction:.1f} \\\\"
latex_table += r"""
\bottomrule
\end{tabular}
\end{table}
"""

print(latex_table)


LaTeX TABLE FOR PAPER

\begin{table}[t]
\centering
\caption{Parameter Count and FLOPs Comparison: DNAFormer vs. Compact BiGRU}
\label{tab:param_flops}
\begin{tabular}{lcccccc}
\toprule
\multirow{2}{*}{Dataset} & \multicolumn{2}{c}{Parameters (M)} & \multicolumn{2}{c}{FLOPs (G)} & \multicolumn{2}{c}{Reduction (\%)} \\
\cmidrule(lr){2-3} \cmidrule(lr){4-5} \cmidrule(lr){6-7}
 & DNAFormer & Ours & DNAFormer & Ours & Params & FLOPs \\
\midrule
Srinivasavaradhan & 103.40 & 3.41 & 23.63 & 0.957 & 96.7 & 96.0 \\
Grass & 103.41 & 3.42 & 25.19 & 1.018 & 96.7 & 96.0 \\
Erlich & 103.47 & 3.48 & 33.09 & 1.314 & 96.6 & 96.0 \\
BinnedNanoporeTwoFlowcells & 103.40 & 3.43 & 24.94 & 1.021 & 96.7 & 95.9 \\
BinnedTestIllumina & 103.40 & 3.43 & 24.94 & 1.021 & 96.7 & 95.9 \\
\midrule
Average & --- & --- & --- & --- & 96.7 & 96.0 \\
\bottomrule
\end{tabular}
\end{table}



In [30]:
print("\n" + "=" * 100)
print("ANALYSIS COMPLETE!")
print("=" * 100)


ANALYSIS COMPLETE!
