# Block Architecture Ablation Study
## DWBlock vs STMBlock vs FastSTMBlock

Comparing:
- **DWBlock** (JeongWonNet77_Rep256Basis8S24Drop): Full channel processing + PRCM
- **STMBlock** (JeongWonNet_STMShuffle_NoStem): Split-Transform-Merge + AffinePRCM + Channel Shuffle
- **FastSTMBlock** (JeongWonNet_STMNoConcatShuffle_NoStemMS): Slicing + SimplePRCM + 1x1 Mixing

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
from collections import OrderedDict

## 1. Module Definitions

In [2]:
# ============================================
# RepConv (Re-parameterizable Conv)
# ============================================

class RepConv(nn.Module):
    """Re-parameterizable Convolution Block"""
    def __init__(self, in_channels, out_channels, kernel_size=3,
                 stride=1, padding=1, groups=1, use_activation=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.groups = groups
        self.use_identity = (stride == 1) and (in_channels == out_channels)
        
        self.conv_kxk = nn.Conv2d(in_channels, out_channels, kernel_size,
                                  stride, padding, groups=groups, bias=False)
        self.bn_kxk = nn.BatchNorm2d(out_channels)
        
        if kernel_size > 1:
            self.conv_1x1 = nn.Conv2d(in_channels, out_channels, 1,
                                      stride, 0, groups=groups, bias=False)
            self.bn_1x1 = nn.BatchNorm2d(out_channels)
        else:
            self.conv_1x1 = None
        
        if self.use_identity:
            self.bn_identity = nn.BatchNorm2d(out_channels)
        
        self.activation = nn.ReLU(inplace=True) if use_activation else nn.Identity()
           
    def forward(self, x):
        if hasattr(self, 'fused_conv'):
            return self.activation(self.fused_conv(x))
        
        out = self.bn_kxk(self.conv_kxk(x))
        if self.conv_1x1 is not None:
            out += self.bn_1x1(self.conv_1x1(x))
        if self.use_identity:
            out += self.bn_identity(x)
        return self.activation(out)
    
    def switch_to_deploy(self):
        if hasattr(self, 'fused_conv'):
            return
        
        kernel, bias = self._fuse_bn_tensor(self.conv_kxk, self.bn_kxk)
        
        if self.conv_1x1 is not None:
            kernel_1x1, bias_1x1 = self._fuse_bn_tensor(self.conv_1x1, self.bn_1x1)
            kernel += self._pad_1x1_to_kxk(kernel_1x1)
            bias += bias_1x1
        
        if self.use_identity:
            kernel_identity, bias_identity = self._fuse_bn_tensor(None, self.bn_identity)
            kernel += kernel_identity
            bias += bias_identity
        
        self.fused_conv = nn.Conv2d(
            self.in_channels, self.out_channels, self.kernel_size,
            self.stride, self.padding, groups=self.groups, bias=True
        )
        self.fused_conv.weight.data = kernel
        self.fused_conv.bias.data = bias
        
        self.__delattr__('conv_kxk')
        self.__delattr__('bn_kxk')
        if self.conv_1x1 is not None:
            self.__delattr__('conv_1x1')
            self.__delattr__('bn_1x1')
        if hasattr(self, 'bn_identity'):
            self.__delattr__('bn_identity')
   
    def _fuse_bn_tensor(self, conv, bn):
        if conv is None:
            input_dim = self.in_channels // self.groups
            kernel_value = torch.zeros((self.in_channels, input_dim,
                                        self.kernel_size, self.kernel_size),
                                       dtype=bn.weight.dtype, device=bn.weight.device)
            for i in range(self.in_channels):
                kernel_value[i, i % input_dim,
                             self.kernel_size // 2, self.kernel_size // 2] = 1
            kernel = kernel_value
        else:
            kernel = conv.weight
        
        std = torch.sqrt(bn.running_var + bn.eps)
        t = (bn.weight / std).reshape(-1, 1, 1, 1)
        return kernel * t, bn.bias - bn.running_mean * bn.weight / std
   
    def _pad_1x1_to_kxk(self, kernel_1x1):
        if self.kernel_size == 1:
            return kernel_1x1
        pad = self.kernel_size // 2
        return F.pad(kernel_1x1, [pad, pad, pad, pad])

In [3]:
# ============================================
# PRCM (Scale Only)
# ============================================

class PRCM(nn.Module):
    """Pattern Recalibration Module - Scale Only"""
    def __init__(self, channels, num_basis=8, dropout_rate=0.5):
        super().__init__()
        self.basis = nn.Parameter(torch.randn(num_basis, channels))
        self.fuser = nn.Linear(num_basis, channels, bias=False)
        self.coeff_dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
    
    def forward(self, x):
        ctx = x.mean(dim=[2, 3])
        coeff = ctx @ self.basis.t()
        coeff = self.coeff_dropout(coeff)
        w = self.fuser(coeff).sigmoid().unsqueeze(-1).unsqueeze(-1)
        return x * w


# ============================================
# AffinePRCM (Scale + Shift)
# ============================================

class AffinePRCM(nn.Module):
    """Affine Modulation PRCM - Scale + Shift"""
    def __init__(self, channels, num_basis=8, dropout_rate=0.5):
        super().__init__()
        self.basis = nn.Parameter(torch.randn(num_basis, channels))
        self.scale_proj = nn.Linear(num_basis, channels, bias=False)
        self.shift_proj = nn.Linear(num_basis, channels, bias=False)
        self.coeff_dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()

    def forward(self, x):
        ctx = x.mean(dim=[2, 3])
        coeff = ctx @ self.basis.t()
        coeff = self.coeff_dropout(coeff)
        alpha = self.scale_proj(coeff).sigmoid().unsqueeze(-1).unsqueeze(-1)
        beta = self.shift_proj(coeff).unsqueeze(-1).unsqueeze(-1)
        return x * alpha + beta

In [4]:
# ============================================
# DWBlock (from JeongWonNet77_Rep256Basis8S24Drop)
# ============================================

class DWBlock(nn.Module):
    """DWBlock: 1x1 Conv (optional) + RepConv DW + PRCM
    
    Structure:
        [1x1 Conv] -> RepConv 7x7 DW -> PRCM
        
    - Full channel processing (no split)
    - PRCM: scale only (multiplicative)
    """
    def __init__(self, in_ch, out_ch, kernel_size=7, num_basis=8, dropout_rate=0.5):
        super().__init__()
        layers = []
        if in_ch != out_ch:
            layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False))
        layers.append(RepConv(out_ch, out_ch, kernel_size=kernel_size, 
                              padding=kernel_size//2, groups=out_ch))
        self.conv = nn.Sequential(*layers)
        self.prcm = PRCM(out_ch, num_basis=num_basis, dropout_rate=dropout_rate)
    
    def forward(self, x):
        return self.prcm(self.conv(x))
    
    def switch_to_deploy(self):
        for m in self.conv.modules():
            if isinstance(m, RepConv):
                m.switch_to_deploy()


# ============================================
# SplitTransformMergeBlock (from JeongWonNet_STMShuffle_NoStem)
# ============================================

class STMBlock(nn.Module):
    """Split-Transform-Merge Block (ShuffleNet V2 inspired)
    
    Structure:
        Input -> Split into [Passive, Active]
        Passive: 1x1 Conv (channel adjust only)
        Active:  1x1 Conv -> RepConv 7x7 DW -> AffinePRCM
        Output:  Concat -> Channel Shuffle
        
    - Only half channels go through heavy computation
    - AffinePRCM: scale + shift (affine)
    - Channel shuffle for information mixing
    """
    def __init__(self, in_channels, out_channels, kernel_size=7, num_basis=8, dropout_rate=0.5):
        super().__init__()
        
        assert in_channels % 2 == 0, f"in_channels must be even, got {in_channels}"
        assert out_channels % 2 == 0, f"out_channels must be even, got {out_channels}"
        
        self.split_channels = in_channels // 2
        self.out_split_channels = out_channels // 2
        
        # Passive branch: minimal processing
        if self.split_channels != self.out_split_channels:
            self.passive_adjust = nn.Conv2d(
                self.split_channels, self.out_split_channels, kernel_size=1, bias=False
            )
        else:
            self.passive_adjust = nn.Identity()
        
        # Active branch: full processing
        if self.split_channels != self.out_split_channels:
            self.pw_conv = nn.Conv2d(
                self.split_channels, self.out_split_channels, kernel_size=1, bias=False
            )
        else:
            self.pw_conv = nn.Identity()
        
        self.dw_repconv = RepConv(
            self.out_split_channels, self.out_split_channels, 
            kernel_size=kernel_size, padding=kernel_size // 2,
            groups=self.out_split_channels
        )
        
        self.affine_prcm = AffinePRCM(
            self.out_split_channels, num_basis=num_basis, dropout_rate=dropout_rate
        )
    
    def channel_shuffle(self, x, groups):
        B, C, H, W = x.shape
        channels_per_group = C // groups
        x = x.view(B, groups, channels_per_group, H, W)
        x = torch.transpose(x, 1, 2).contiguous()
        x = x.view(B, -1, H, W)
        return x
    
    def forward(self, x):
        # Split
        x_passive, x_active = torch.chunk(x, 2, dim=1)
        
        # Transform
        x_passive = self.passive_adjust(x_passive)
        
        x_active = self.pw_conv(x_active)
        x_active = self.dw_repconv(x_active)
        x_active = self.affine_prcm(x_active)
        
        # Merge
        out = torch.cat([x_passive, x_active], dim=1)
        out = self.channel_shuffle(out, 2)
        
        return out
    
    def switch_to_deploy(self):
        if isinstance(self.dw_repconv, RepConv):
            self.dw_repconv.switch_to_deploy()

In [None]:
# ============================================
# SimplePRCM (Scale Only, Optimized)
# ============================================

class SimplePRCM(nn.Module):
    """SimplePRCM - Optimized version without shift
    
    Same as PRCM but explicitly named for FastSTMBlock
    """
    def __init__(self, channels, num_basis=8, dropout_rate=0.5):
        super().__init__()
        self.basis = nn.Parameter(torch.randn(num_basis, channels))
        self.scale_proj = nn.Linear(num_basis, channels, bias=False)
        self.coeff_dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()

    def forward(self, x):
        ctx = x.mean(dim=[2, 3])
        coeff = ctx @ self.basis.t()
        coeff = self.coeff_dropout(coeff)
        alpha = self.scale_proj(coeff).sigmoid().unsqueeze(-1).unsqueeze(-1)
        return x * alpha


# ============================================
# FastSTMBlock (from JeongWonNet_STMNoConcatShuffle_NoStemMS)
# ============================================

class FastSTMBlock(nn.Module):
    """Fast Split-Transform-Merge Block (FasterNet Style)
    
    Optimizations vs STMBlock:
        1. Slicing instead of torch.chunk (zero-copy view)
        2. SimplePRCM (no shift) instead of AffinePRCM
        3. 1x1 Conv Mixing instead of Channel Shuffle
    
    Structure:
        Input -> [1x1 Align] -> Slicing
        Active (C/2):  RepConv 7x7 DW -> SimplePRCM
        Passive (C/2): Identity
        Concat -> 1x1 Conv Mixing -> BN -> ReLU
    """
    def __init__(self, in_channels, out_channels, kernel_size=7, num_basis=8, dropout_rate=0.5):
        super().__init__()
        
        # Channel Alignment
        if in_channels != out_channels:
            self.align_conv = nn.Conv2d(in_channels, out_channels, 1, bias=False)
            self.align_bn = nn.BatchNorm2d(out_channels)
        else:
            self.align_conv = None
        
        self.out_channels = out_channels
        self.dim_conv = out_channels // 2
        
        # Active Branch (half channels only)
        self.partial_conv = RepConv(
            self.dim_conv, self.dim_conv,
            kernel_size=kernel_size, padding=kernel_size // 2,
            groups=self.dim_conv
        )
        self.partial_prcm = SimplePRCM(
            self.dim_conv, num_basis=num_basis, dropout_rate=dropout_rate
        )
        
        # Mixing (instead of Channel Shuffle)
        self.mix_conv = nn.Conv2d(out_channels, out_channels, 1, bias=False)
        self.mix_bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU(inplace=True)
    
    def forward(self, x):
        # 1. Channel Alignment
        if self.align_conv is not None:
            x = self.act(self.align_bn(self.align_conv(x)))
        
        # 2. Slicing (zero-copy view, faster than chunk)
        x_active = x[:, :self.dim_conv, :, :]
        x_passive = x[:, self.dim_conv:, :, :]
        
        # 3. Transform (Active only)
        x_active = self.partial_conv(x_active)
        x_active = self.partial_prcm(x_active)
        
        # 4. Merge
        x_out = torch.cat([x_active, x_passive], dim=1)
        
        # 5. Mix (1x1 Conv instead of Shuffle)
        x_out = self.act(self.mix_bn(self.mix_conv(x_out)))
        
        return x_out
    
    def switch_to_deploy(self):
        if isinstance(self.partial_conv, RepConv):
            self.partial_conv.switch_to_deploy()

## 2. Latency Measurement Functions

In [5]:
def count_params(model):
    return sum(p.numel() for p in model.parameters())

def measure_latency_gpu(model, input_tensor, warmup=100, repeat=500):
    """GPU latency measurement with CUDA synchronization"""
    if not torch.cuda.is_available():
        print("CUDA not available")
        return None, None
    
    model.eval()
    model.cuda()
    input_tensor = input_tensor.cuda()
    
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(input_tensor)
            torch.cuda.synchronize()
    
    times = []
    with torch.no_grad():
        for _ in range(repeat):
            torch.cuda.synchronize()
            start = time.perf_counter()
            _ = model(input_tensor)
            torch.cuda.synchronize()
            end = time.perf_counter()
            times.append((end - start) * 1000)
    
    return np.mean(times), np.std(times)

def measure_latency_cpu(model, input_tensor, warmup=20, repeat=100):
    """CPU latency measurement"""
    model.eval()
    model.cpu()
    input_tensor = input_tensor.cpu()
    
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(input_tensor)
    
    times = []
    with torch.no_grad():
        for _ in range(repeat):
            start = time.perf_counter()
            _ = model(input_tensor)
            end = time.perf_counter()
            times.append((end - start) * 1000)
    
    return np.mean(times), np.std(times)

## 3. Architecture Comparison

In [None]:
print("=" * 75)
print("Block Architecture Comparison")
print("=" * 75)

print("""
┌─────────────────────────────────────────────────────────────────────────┐
│ DWBlock (JeongWonNet77_Rep256Basis8S24Drop)                             │
├─────────────────────────────────────────────────────────────────────────┤
│   Input (C channels)                                                    │
│      │                                                                  │
│      ▼                                                                  │
│   [1x1 Conv] (if in_ch != out_ch)                                       │
│      │                                                                  │
│      ▼                                                                  │
│   RepConv 7x7 DW (full C channels)                                      │
│      │                                                                  │
│      ▼                                                                  │
│   PRCM (scale only: x * α)                                              │
│      │                                                                  │
│      ▼                                                                  │
│   Output                                                                │
└─────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────┐
│ STMBlock (JeongWonNet_STMShuffle_NoStem)                                │
├─────────────────────────────────────────────────────────────────────────┤
│   Input (C channels)                                                    │
│      │                                                                  │
│      ▼                                                                  │
│   torch.chunk (Split C/2 + C/2)                                         │
│      │                                                                  │
│   ┌──┴──┐                                                               │
│   │     │                                                               │
│   ▼     ▼                                                               │
│ Passive Active                                                          │
│ (C/2)   (C/2)                                                           │
│   │       │                                                             │
│   ▼       ▼                                                             │
│ 1x1     1x1 Conv -> RepConv 7x7 DW -> AffinePRCM (x*α+β)                │
│ Conv      │                                                             │
│   │       │                                                             │
│   └───┬───┘                                                             │
│       │                                                                 │
│       ▼                                                                 │
│   torch.cat -> Channel Shuffle                                          │
│       │                                                                 │
│       ▼                                                                 │
│   Output                                                                │
└─────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────┐
│ FastSTMBlock (JeongWonNet_STMNoConcatShuffle_NoStemMS)                  │
├─────────────────────────────────────────────────────────────────────────┤
│   Input (C channels)                                                    │
│      │                                                                  │
│      ▼                                                                  │
│   [1x1 Conv + BN + ReLU] (channel alignment)                            │
│      │                                                                  │
│      ▼                                                                  │
│   Slicing (zero-copy view, faster than chunk)                           │
│      │                                                                  │
│   ┌──┴──┐                                                               │
│   │     │                                                               │
│   ▼     ▼                                                               │
│ Active Passive                                                          │
│ (C/2)   (C/2)                                                           │
│   │       │                                                             │
│   ▼       │                                                             │
│ RepConv   │ (Identity)                                                  │
│ 7x7 DW    │                                                             │
│   │       │                                                             │
│   ▼       │                                                             │
│ SimplePRCM│ (scale only: x*α, no shift)                                 │
│   │       │                                                             │
│   └───┬───┘                                                             │
│       │                                                                 │
│       ▼                                                                 │
│   torch.cat -> 1x1 Conv + BN + ReLU (Mixing)                            │
│       │                                                                 │
│       ▼                                                                 │
│   Output                                                                │
└─────────────────────────────────────────────────────────────────────────┘

Key Differences:
  ┌────────────────┬──────────────────┬──────────────────┬──────────────────┐
  │                │ DWBlock          │ STMBlock         │ FastSTMBlock     │
  ├────────────────┼──────────────────┼──────────────────┼──────────────────┤
  │ Split          │ None (full C)    │ torch.chunk      │ Slicing (view)   │
  │ Active Ch.     │ C                │ C/2              │ C/2              │
  │ Passive        │ None             │ 1x1 Conv         │ Identity         │
  │ PRCM Type      │ PRCM (scale)     │ AffinePRCM (+β)  │ SimplePRCM       │
  │ Mixing         │ None             │ Channel Shuffle  │ 1x1 Conv + BN    │
  └────────────────┴──────────────────┴──────────────────┴──────────────────┘
""")

## 4. Parameter Comparison

In [None]:
channels = 64
num_basis = 8

print("=" * 75)
print(f"Parameter Comparison (in_ch={channels}, out_ch={channels})")
print("=" * 75)

dw_block = DWBlock(channels, channels, kernel_size=7, num_basis=num_basis)
stm_block = STMBlock(channels, channels, kernel_size=7, num_basis=num_basis)
fast_block = FastSTMBlock(channels, channels, kernel_size=7, num_basis=num_basis)

dw_params = count_params(dw_block)
stm_params = count_params(stm_block)
fast_params = count_params(fast_block)

print(f"\n{'Block':<30} {'Parameters':>15} {'vs DWBlock':>15}")
print("-" * 62)
print(f"{'DWBlock':<30} {dw_params:>15,} {'baseline':>15}")
print(f"{'STMBlock':<30} {stm_params:>15,} {(stm_params/dw_params)*100-100:>+14.1f}%")
print(f"{'FastSTMBlock':<30} {fast_params:>15,} {(fast_params/dw_params)*100-100:>+14.1f}%")

half_ch = channels // 2
print(f"\n[FastSTMBlock Parameter Breakdown]")
print(f"  mix_conv (1x1): {channels}*{channels} = {channels*channels:,}")
print(f"  mix_bn: {channels}*2 = {channels*2:,}")
print(f"  RepConv 7x7 DW ({half_ch}ch): ~{half_ch*49 + half_ch*6:,}")
print(f"  SimplePRCM ({half_ch}ch): {num_basis}*{half_ch}*2 = {num_basis*half_ch*2:,}")

## 5. GPU Latency Comparison

In [None]:
channels = 64
resolution = 64
batch_size = 1

x = torch.randn(batch_size, channels, resolution, resolution)

print("=" * 80)
print(f"GPU Latency Comparison")
print(f"Input: ({batch_size}, {channels}, {resolution}, {resolution})")
print("=" * 80)

blocks = OrderedDict([
    ("DWBlock (train)", DWBlock(channels, channels)),
    ("STMBlock (train)", STMBlock(channels, channels)),
    ("FastSTMBlock (train)", FastSTMBlock(channels, channels)),
])

# Deploy mode
dw_deploy = DWBlock(channels, channels)
dw_deploy.switch_to_deploy()
blocks["DWBlock (deploy)"] = dw_deploy

stm_deploy = STMBlock(channels, channels)
stm_deploy.switch_to_deploy()
blocks["STMBlock (deploy)"] = stm_deploy

fast_deploy = FastSTMBlock(channels, channels)
fast_deploy.switch_to_deploy()
blocks["FastSTMBlock (deploy)"] = fast_deploy

print(f"\n{'Block':<25} {'Mean (ms)':>12} {'Std (ms)':>12} {'vs DWBlock':>15}")
print("-" * 68)

dw_latency = None
for name, block in blocks.items():
    mean, std = measure_latency_gpu(block, x)
    if mean is None:
        continue
    if "DWBlock (train)" in name:
        dw_latency = mean
        diff_str = "baseline"
    else:
        diff = ((mean - dw_latency) / dw_latency) * 100
        diff_str = f"{diff:+.1f}%"
    print(f"{name:<25} {mean:>12.4f} {std:>12.4f} {diff_str:>15}")

## 6. CPU Latency Comparison

In [None]:
channels = 64
resolution = 64
batch_size = 1

x = torch.randn(batch_size, channels, resolution, resolution)

print("=" * 80)
print(f"CPU Latency Comparison")
print(f"Input: ({batch_size}, {channels}, {resolution}, {resolution})")
print("=" * 80)

blocks = OrderedDict([
    ("DWBlock (train)", DWBlock(channels, channels)),
    ("STMBlock (train)", STMBlock(channels, channels)),
    ("FastSTMBlock (train)", FastSTMBlock(channels, channels)),
])

# Deploy mode
dw_deploy = DWBlock(channels, channels)
dw_deploy.switch_to_deploy()
blocks["DWBlock (deploy)"] = dw_deploy

stm_deploy = STMBlock(channels, channels)
stm_deploy.switch_to_deploy()
blocks["STMBlock (deploy)"] = stm_deploy

fast_deploy = FastSTMBlock(channels, channels)
fast_deploy.switch_to_deploy()
blocks["FastSTMBlock (deploy)"] = fast_deploy

print(f"\n{'Block':<25} {'Mean (ms)':>12} {'Std (ms)':>12} {'vs DWBlock':>15}")
print("-" * 68)

dw_latency = None
for name, block in blocks.items():
    mean, std = measure_latency_cpu(block, x)
    if "DWBlock (train)" in name:
        dw_latency = mean
        diff_str = "baseline"
    else:
        diff = ((mean - dw_latency) / dw_latency) * 100
        diff_str = f"{diff:+.1f}%"
    print(f"{name:<25} {mean:>12.4f} {std:>12.4f} {diff_str:>15}")

## 7. Multi-Channel Comparison

In [None]:
channel_list = [24, 48, 64, 96, 128, 192]
resolution = 64

print("=" * 95)
print("GPU Latency vs Channel Size (deploy mode, resolution=64)")
print("=" * 95)

results_gpu = {"DWBlock": [], "STMBlock": [], "FastSTMBlock": []}
params = {"DWBlock": [], "STMBlock": [], "FastSTMBlock": []}

for ch in channel_list:
    x = torch.randn(1, ch, resolution, resolution)
    
    # DWBlock
    block = DWBlock(ch, ch)
    block.switch_to_deploy()
    params["DWBlock"].append(count_params(block))
    mean, _ = measure_latency_gpu(block, x)
    results_gpu["DWBlock"].append(mean)
    
    # STMBlock
    block = STMBlock(ch, ch)
    block.switch_to_deploy()
    params["STMBlock"].append(count_params(block))
    mean, _ = measure_latency_gpu(block, x)
    results_gpu["STMBlock"].append(mean)
    
    # FastSTMBlock
    block = FastSTMBlock(ch, ch)
    block.switch_to_deploy()
    params["FastSTMBlock"].append(count_params(block))
    mean, _ = measure_latency_gpu(block, x)
    results_gpu["FastSTMBlock"].append(mean)

print(f"\n{'Ch':<6} {'DWBlock':<12} {'STMBlock':<12} {'FastSTM':<12} {'DW Params':<10} {'STM Params':<10} {'Fast Params':<10}")
print("-" * 78)

for i, ch in enumerate(channel_list):
    print(f"{ch:<6} {results_gpu['DWBlock'][i]:.4f} ms   {results_gpu['STMBlock'][i]:.4f} ms   {results_gpu['FastSTMBlock'][i]:.4f} ms   {params['DWBlock'][i]:<10,} {params['STMBlock'][i]:<10,} {params['FastSTMBlock'][i]:<10,}")

In [None]:
print("\n" + "=" * 80)
print("CPU Latency vs Channel Size (deploy mode, resolution=64)")
print("=" * 80)

results_cpu = {"DWBlock": [], "STMBlock": [], "FastSTMBlock": []}

for ch in channel_list:
    x = torch.randn(1, ch, resolution, resolution)
    
    block = DWBlock(ch, ch)
    block.switch_to_deploy()
    mean, _ = measure_latency_cpu(block, x)
    results_cpu["DWBlock"].append(mean)
    
    block = STMBlock(ch, ch)
    block.switch_to_deploy()
    mean, _ = measure_latency_cpu(block, x)
    results_cpu["STMBlock"].append(mean)
    
    block = FastSTMBlock(ch, ch)
    block.switch_to_deploy()
    mean, _ = measure_latency_cpu(block, x)
    results_cpu["FastSTMBlock"].append(mean)

print(f"\n{'Channels':<10} {'DWBlock':<15} {'STMBlock':<15} {'FastSTMBlock':<15}")
print("-" * 58)

for i, ch in enumerate(channel_list):
    print(f"{ch:<10} {results_cpu['DWBlock'][i]:.4f} ms{'':<6} {results_cpu['STMBlock'][i]:.4f} ms{'':<6} {results_cpu['FastSTMBlock'][i]:.4f} ms")

## 8. Multi-Resolution Comparison

In [None]:
resolutions = [16, 32, 64, 128, 256]
channels = 64

print("=" * 80)
print("GPU Latency vs Resolution (deploy mode, channels=64)")
print("=" * 80)

results_gpu = {"DWBlock": [], "STMBlock": [], "FastSTMBlock": []}

for res in resolutions:
    x = torch.randn(1, channels, res, res)
    
    block = DWBlock(channels, channels)
    block.switch_to_deploy()
    mean, _ = measure_latency_gpu(block, x)
    results_gpu["DWBlock"].append(mean)
    
    block = STMBlock(channels, channels)
    block.switch_to_deploy()
    mean, _ = measure_latency_gpu(block, x)
    results_gpu["STMBlock"].append(mean)
    
    block = FastSTMBlock(channels, channels)
    block.switch_to_deploy()
    mean, _ = measure_latency_gpu(block, x)
    results_gpu["FastSTMBlock"].append(mean)

print(f"\n{'Resolution':<12} {'DWBlock':<15} {'STMBlock':<15} {'FastSTMBlock':<15}")
print("-" * 60)

for i, res in enumerate(resolutions):
    print(f"{res}x{res:<10} {results_gpu['DWBlock'][i]:.4f} ms{'':<6} {results_gpu['STMBlock'][i]:.4f} ms{'':<6} {results_gpu['FastSTMBlock'][i]:.4f} ms")

In [None]:
print("\n" + "=" * 80)
print("CPU Latency vs Resolution (deploy mode, channels=64)")
print("=" * 80)

results_cpu = {"DWBlock": [], "STMBlock": [], "FastSTMBlock": []}

for res in resolutions:
    x = torch.randn(1, channels, res, res)
    
    block = DWBlock(channels, channels)
    block.switch_to_deploy()
    mean, _ = measure_latency_cpu(block, x)
    results_cpu["DWBlock"].append(mean)
    
    block = STMBlock(channels, channels)
    block.switch_to_deploy()
    mean, _ = measure_latency_cpu(block, x)
    results_cpu["STMBlock"].append(mean)
    
    block = FastSTMBlock(channels, channels)
    block.switch_to_deploy()
    mean, _ = measure_latency_cpu(block, x)
    results_cpu["FastSTMBlock"].append(mean)

print(f"\n{'Resolution':<12} {'DWBlock':<15} {'STMBlock':<15} {'FastSTMBlock':<15}")
print("-" * 60)

for i, res in enumerate(resolutions):
    print(f"{res}x{res:<10} {results_cpu['DWBlock'][i]:.4f} ms{'':<6} {results_cpu['STMBlock'][i]:.4f} ms{'':<6} {results_cpu['FastSTMBlock'][i]:.4f} ms")

## 9. Channel Expansion Comparison (in_ch != out_ch)

In [None]:
channel_pairs = [(24, 48), (48, 64), (64, 96), (96, 128), (128, 192)]
resolution = 64

print("=" * 85)
print("GPU Latency with Channel Expansion (deploy mode)")
print("=" * 85)

results = {"DWBlock": [], "STMBlock": [], "FastSTMBlock": []}

for in_ch, out_ch in channel_pairs:
    x = torch.randn(1, in_ch, resolution, resolution)
    
    block = DWBlock(in_ch, out_ch)
    block.switch_to_deploy()
    mean, _ = measure_latency_gpu(block, x)
    results["DWBlock"].append(mean)
    
    block = STMBlock(in_ch, out_ch)
    block.switch_to_deploy()
    mean, _ = measure_latency_gpu(block, x)
    results["STMBlock"].append(mean)
    
    block = FastSTMBlock(in_ch, out_ch)
    block.switch_to_deploy()
    mean, _ = measure_latency_gpu(block, x)
    results["FastSTMBlock"].append(mean)

print(f"\n{'in->out':<15} {'DWBlock':<15} {'STMBlock':<15} {'FastSTMBlock':<15}")
print("-" * 62)

for i, (in_ch, out_ch) in enumerate(channel_pairs):
    print(f"{in_ch}->{out_ch:<10} {results['DWBlock'][i]:.4f} ms{'':<6} {results['STMBlock'][i]:.4f} ms{'':<6} {results['FastSTMBlock'][i]:.4f} ms")

## 10. Component-wise Latency Breakdown

In [None]:
def measure_component(module, x, warmup=50, repeat=200):
    """Measure single component latency"""
    if torch.cuda.is_available():
        module = module.cuda().eval()
        x = x.cuda()
        
        with torch.no_grad():
            for _ in range(warmup):
                _ = module(x)
                torch.cuda.synchronize()
        
        times = []
        with torch.no_grad():
            for _ in range(repeat):
                torch.cuda.synchronize()
                start = time.perf_counter()
                _ = module(x)
                torch.cuda.synchronize()
                end = time.perf_counter()
                times.append((end - start) * 1000)
        return np.mean(times)
    return None

channels = 64
half_ch = channels // 2
resolution = 64

x_full = torch.randn(1, channels, resolution, resolution)
x_half = torch.randn(1, half_ch, resolution, resolution)

print("=" * 75)
print(f"Component-wise GPU Latency (channels={channels}, resolution={resolution})")
print("=" * 75)

print("\n--- DWBlock Components ---")
repconv_full = RepConv(channels, channels, kernel_size=7, padding=3, groups=channels)
repconv_full.switch_to_deploy()
t = measure_component(repconv_full, x_full)
print(f"  RepConv 7x7 DW ({channels}ch, deploy): {t:.4f} ms")

prcm = PRCM(channels, num_basis=8, dropout_rate=0)
t = measure_component(prcm, x_full)
print(f"  PRCM ({channels}ch):                   {t:.4f} ms")

print("\n--- STMBlock Components ---")

class ChunkOp(nn.Module):
    def forward(self, x):
        return torch.chunk(x, 2, dim=1)
t = measure_component(ChunkOp(), x_full)
print(f"  torch.chunk (split):                 {t:.4f} ms")

repconv_half = RepConv(half_ch, half_ch, kernel_size=7, padding=3, groups=half_ch)
repconv_half.switch_to_deploy()
t = measure_component(repconv_half, x_half)
print(f"  RepConv 7x7 DW ({half_ch}ch, deploy):  {t:.4f} ms")

affine_prcm = AffinePRCM(half_ch, num_basis=8, dropout_rate=0)
t = measure_component(affine_prcm, x_half)
print(f"  AffinePRCM ({half_ch}ch):              {t:.4f} ms")

class ConcatOp(nn.Module):
    def forward(self, x):
        a, b = torch.chunk(x, 2, dim=1)
        return torch.cat([a, b], dim=1)
t = measure_component(ConcatOp(), x_full)
print(f"  torch.cat (concat):                  {t:.4f} ms")

class ShuffleOp(nn.Module):
    def forward(self, x):
        B, C, H, W = x.shape
        x = x.view(B, 2, C//2, H, W)
        x = torch.transpose(x, 1, 2).contiguous()
        return x.view(B, C, H, W)
t = measure_component(ShuffleOp(), x_full)
print(f"  Channel Shuffle:                     {t:.4f} ms")

print("\n--- FastSTMBlock Components ---")

# Slicing (zero-copy view)
class SliceOp(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    def forward(self, x):
        return x[:, :self.dim, :, :], x[:, self.dim:, :, :]
t = measure_component(SliceOp(half_ch), x_full)
print(f"  Slicing (zero-copy view):            {t:.4f} ms")

simple_prcm = SimplePRCM(half_ch, num_basis=8, dropout_rate=0)
t = measure_component(simple_prcm, x_half)
print(f"  SimplePRCM ({half_ch}ch):              {t:.4f} ms")

# 1x1 Conv Mixing
mix_block = nn.Sequential(
    nn.Conv2d(channels, channels, 1, bias=False),
    nn.BatchNorm2d(channels),
    nn.ReLU(inplace=True)
)
t = measure_component(mix_block, x_full)
print(f"  1x1 Conv + BN + ReLU (mixing):       {t:.4f} ms")

## 11. Batch Size Scaling

In [None]:
batch_sizes = [1, 2, 4, 8, 16]
channels = 64
resolution = 64

print("=" * 80)
print("GPU Latency vs Batch Size (deploy mode)")
print("=" * 80)

results = {"DWBlock": [], "STMBlock": [], "FastSTMBlock": []}

for bs in batch_sizes:
    x = torch.randn(bs, channels, resolution, resolution)
    
    block = DWBlock(channels, channels)
    block.switch_to_deploy()
    mean, _ = measure_latency_gpu(block, x)
    results["DWBlock"].append(mean)
    
    block = STMBlock(channels, channels)
    block.switch_to_deploy()
    mean, _ = measure_latency_gpu(block, x)
    results["STMBlock"].append(mean)
    
    block = FastSTMBlock(channels, channels)
    block.switch_to_deploy()
    mean, _ = measure_latency_gpu(block, x)
    results["FastSTMBlock"].append(mean)

print(f"\n{'Batch':<10} {'DWBlock':<15} {'STMBlock':<15} {'FastSTMBlock':<15}")
print("-" * 58)

for i, bs in enumerate(batch_sizes):
    print(f"{bs:<10} {results['DWBlock'][i]:.4f} ms{'':<6} {results['STMBlock'][i]:.4f} ms{'':<6} {results['FastSTMBlock'][i]:.4f} ms")

## 12. Summary

In [None]:
print("=" * 75)
print("ABLATION SUMMARY")
print("=" * 75)
print("""
┌─────────────────────────────────────────────────────────────────────────┐
│ DWBlock (JeongWonNet77_Rep256Basis8S24Drop)                             │
├─────────────────────────────────────────────────────────────────────────┤
│ Structure: [1x1 Conv] -> RepConv 7x7 DW (full C) -> PRCM                │
│ - All channels go through depthwise convolution                         │
│ - PRCM: scale only (x * α)                                              │
│ - Most parameters (larger RepConv, full-channel PRCM)                   │
└─────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────┐
│ STMBlock (JeongWonNet_STMShuffle_NoStem)                                │
├─────────────────────────────────────────────────────────────────────────┤
│ Structure: chunk -> [Passive | Active] -> cat -> Channel Shuffle        │
│ - Only C/2 channels go through heavy computation                        │
│ - AffinePRCM: scale + shift (x * α + β)                                 │
│ - Channel shuffle for information mixing                                │
│ - Fewer parameters than DWBlock                                         │
└─────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────┐
│ FastSTMBlock (JeongWonNet_STMNoConcatShuffle_NoStemMS)                  │
├─────────────────────────────────────────────────────────────────────────┤
│ Structure: slice -> [Active | Passive] -> cat -> 1x1 Conv Mix           │
│ - Slicing instead of chunk (zero-copy view)                             │
│ - SimplePRCM: scale only (no shift overhead)                            │
│ - 1x1 Conv + BN + ReLU for mixing (more learnable than shuffle)         │
│ - More parameters due to 1x1 mixing conv (C*C params)                   │
└─────────────────────────────────────────────────────────────────────────┘

Key Findings:
  1. DWBlock: Highest latency (full channel processing)
  2. STMBlock: Lower latency, fewer params (half channel + shuffle)
  3. FastSTMBlock: More params (1x1 mixing), but potentially faster on GPU
     - 1x1 Conv is highly optimized on GPU (cuBLAS GEMM)
     - Shuffle requires memory reordering

Trade-offs:
  ┌────────────────┬──────────────┬──────────────┬──────────────┐
  │ Metric         │ DWBlock      │ STMBlock     │ FastSTMBlock │
  ├────────────────┼──────────────┼──────────────┼──────────────┤
  │ Parameters     │ Highest      │ Lowest       │ Medium       │
  │ GPU Latency    │ Medium       │ Medium       │ Depends      │
  │ CPU Latency    │ Lowest       │ Medium       │ Highest      │
  │ Expressiveness │ Full channel │ Half + shift │ Half + learn │
  │ GPU-friendly   │ Yes          │ Yes          │ Best (GEMM)  │
  └────────────────┴──────────────┴──────────────┴──────────────┘

Recommendations:
  - GPU deployment: FastSTMBlock (1x1 conv well-optimized)
  - CPU deployment: DWBlock (simpler, no extra mixing layer)
  - Low-param budget: STMBlock (shuffle is parameter-free)
  - Best accuracy: Needs experimental validation
""")