# Block Latency Comparison
## JeongWonNet Block vs ConvNeXt Block

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
from tqdm import tqdm

## 1. Block Definitions

In [2]:
# ============================================
# JeongWonNet Block (RepConv + PRCM)
# ============================================

class RepConv(nn.Module):
    """Re-parameterizable Convolution Block"""
    def __init__(self, in_channels, out_channels, kernel_size=3,
                 stride=1, padding=1, groups=1, use_activation=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.groups = groups
        self.use_identity = (stride == 1) and (in_channels == out_channels)
        
        self.conv_kxk = nn.Conv2d(in_channels, out_channels, kernel_size,
                                  stride, padding, groups=groups, bias=False)
        self.bn_kxk = nn.BatchNorm2d(out_channels)
        
        if kernel_size > 1:
            self.conv_1x1 = nn.Conv2d(in_channels, out_channels, 1,
                                      stride, 0, groups=groups, bias=False)
            self.bn_1x1 = nn.BatchNorm2d(out_channels)
        else:
            self.conv_1x1 = None
        
        if self.use_identity:
            self.bn_identity = nn.BatchNorm2d(out_channels)
        
        self.activation = nn.ReLU(inplace=True) if use_activation else nn.Identity()
           
    def forward(self, x):
        if hasattr(self, 'fused_conv'):
            return self.activation(self.fused_conv(x))
        
        out = self.bn_kxk(self.conv_kxk(x))
        if self.conv_1x1 is not None:
            out += self.bn_1x1(self.conv_1x1(x))
        if self.use_identity:
            out += self.bn_identity(x)
        return self.activation(out)
    
    def switch_to_deploy(self):
        if hasattr(self, 'fused_conv'):
            return
        
        kernel, bias = self._fuse_bn_tensor(self.conv_kxk, self.bn_kxk)
        
        if self.conv_1x1 is not None:
            kernel_1x1, bias_1x1 = self._fuse_bn_tensor(self.conv_1x1, self.bn_1x1)
            kernel += self._pad_1x1_to_kxk(kernel_1x1)
            bias += bias_1x1
        
        if self.use_identity:
            kernel_identity, bias_identity = self._fuse_bn_tensor(None, self.bn_identity)
            kernel += kernel_identity
            bias += bias_identity
        
        self.fused_conv = nn.Conv2d(
            self.in_channels, self.out_channels, self.kernel_size,
            self.stride, self.padding, groups=self.groups, bias=True
        )
        self.fused_conv.weight.data = kernel
        self.fused_conv.bias.data = bias
        
        self.__delattr__('conv_kxk')
        self.__delattr__('bn_kxk')
        if self.conv_1x1 is not None:
            self.__delattr__('conv_1x1')
            self.__delattr__('bn_1x1')
        if hasattr(self, 'bn_identity'):
            self.__delattr__('bn_identity')
   
    def _fuse_bn_tensor(self, conv, bn):
        if conv is None:
            input_dim = self.in_channels // self.groups
            kernel_value = torch.zeros((self.in_channels, input_dim,
                                        self.kernel_size, self.kernel_size),
                                       dtype=bn.weight.dtype, device=bn.weight.device)
            for i in range(self.in_channels):
                kernel_value[i, i % input_dim,
                             self.kernel_size // 2, self.kernel_size // 2] = 1
            kernel = kernel_value
        else:
            kernel = conv.weight
        
        std = torch.sqrt(bn.running_var + bn.eps)
        t = (bn.weight / std).reshape(-1, 1, 1, 1)
        return kernel * t, bn.bias - bn.running_mean * bn.weight / std
   
    def _pad_1x1_to_kxk(self, kernel_1x1):
        if self.kernel_size == 1:
            return kernel_1x1
        pad = self.kernel_size // 2
        return F.pad(kernel_1x1, [pad, pad, pad, pad])


class PRCM(nn.Module):
    """Pattern Recalibration Module"""
    def __init__(self, channels, num_basis=8):
        super().__init__()
        self.basis = nn.Parameter(torch.randn(num_basis, channels))
        self.fuser = nn.Linear(num_basis, channels, bias=False)
    
    def forward(self, x):
        ctx = x.mean(dim=[2, 3])
        coeff = ctx @ self.basis.t()
        w = self.fuser(coeff).sigmoid().unsqueeze(-1).unsqueeze(-1)
        return x * w


class JeongWonBlock(nn.Module):
    """JeongWonNet Block: 1x1 Conv + RepConv DW + PRCM"""
    def __init__(self, in_ch, out_ch, kernel_size=7, num_basis=8):
        super().__init__()
        layers = []
        if in_ch != out_ch:
            layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False))
        layers.append(RepConv(out_ch, out_ch, kernel_size=kernel_size, 
                              padding=kernel_size//2, groups=out_ch))
        self.conv = nn.Sequential(*layers)
        self.prcm = PRCM(out_ch, num_basis=num_basis)
    
    def forward(self, x):
        return self.prcm(self.conv(x))
    
    def switch_to_deploy(self):
        # self.modules()는 자기 자신을 포함하므로 children() 사용
        for m in self.conv.modules():
            if isinstance(m, RepConv):
                m.switch_to_deploy()

In [3]:
# ============================================
# ConvNeXt Block
# ============================================

class LayerNorm2d(nn.Module):
    """LayerNorm for channels-first tensors (B, C, H, W)"""
    def __init__(self, channels, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(channels))
        self.bias = nn.Parameter(torch.zeros(channels))
        self.eps = eps

    def forward(self, x):
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x


class ConvNeXtBlock(nn.Module):
    """ConvNeXt Block: DW Conv 7x7 -> LN -> 1x1 -> GELU -> 1x1"""
    def __init__(self, channels, expansion=4, kernel_size=7):
        super().__init__()
        hidden = channels * expansion
        
        self.dwconv = nn.Conv2d(channels, channels, kernel_size, 
                                padding=kernel_size//2, groups=channels)
        self.norm = LayerNorm2d(channels)
        self.pwconv1 = nn.Conv2d(channels, hidden, 1)
        self.act = nn.GELU()
        self.pwconv2 = nn.Conv2d(hidden, channels, 1)
    
    def forward(self, x):
        residual = x
        x = self.dwconv(x)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        return x + residual


class ConvNeXtBlockV2(nn.Module):
    """ConvNeXt V2 Block with GRN"""
    def __init__(self, channels, expansion=4, kernel_size=7):
        super().__init__()
        hidden = channels * expansion
        
        self.dwconv = nn.Conv2d(channels, channels, kernel_size, 
                                padding=kernel_size//2, groups=channels)
        self.norm = LayerNorm2d(channels)
        self.pwconv1 = nn.Conv2d(channels, hidden, 1)
        self.act = nn.GELU()
        # GRN (Global Response Normalization)
        self.grn_gamma = nn.Parameter(torch.zeros(1, hidden, 1, 1))
        self.grn_beta = nn.Parameter(torch.zeros(1, hidden, 1, 1))
        self.pwconv2 = nn.Conv2d(hidden, channels, 1)
    
    def forward(self, x):
        residual = x
        x = self.dwconv(x)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        # GRN
        gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
        nx = gx / (gx.mean(dim=1, keepdim=True) + 1e-6)
        x = x + self.grn_gamma * (x * nx) + self.grn_beta
        x = self.pwconv2(x)
        return x + residual

## 2. Latency Measurement Functions

In [4]:
def count_params(model):
    return sum(p.numel() for p in model.parameters())

def measure_latency_cpu(model, input_tensor, warmup=10, repeat=100):
    """CPU latency measurement"""
    model.eval()
    model.cpu()
    input_tensor = input_tensor.cpu()
    
    # Warmup
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(input_tensor)
    
    # Measure
    times = []
    with torch.no_grad():
        for _ in range(repeat):
            start = time.perf_counter()
            _ = model(input_tensor)
            end = time.perf_counter()
            times.append((end - start) * 1000)  # ms
    
    return np.mean(times), np.std(times)

def measure_latency_gpu(model, input_tensor, warmup=50, repeat=200):
    """GPU latency measurement with CUDA synchronization"""
    if not torch.cuda.is_available():
        print("CUDA not available, skipping GPU test")
        return None, None
    
    model.eval()
    model.cuda()
    input_tensor = input_tensor.cuda()
    
    # Warmup
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(input_tensor)
            torch.cuda.synchronize()
    
    # Measure
    times = []
    with torch.no_grad():
        for _ in range(repeat):
            torch.cuda.synchronize()
            start = time.perf_counter()
            _ = model(input_tensor)
            torch.cuda.synchronize()
            end = time.perf_counter()
            times.append((end - start) * 1000)  # ms
    
    return np.mean(times), np.std(times)

def measure_layer_latency_gpu(layer, input_tensor, warmup=50, repeat=200):
    """Single layer GPU latency"""
    if not torch.cuda.is_available():
        return None
    
    layer.eval()
    layer.cuda()
    input_tensor = input_tensor.cuda()
    
    with torch.no_grad():
        for _ in range(warmup):
            _ = layer(input_tensor)
            torch.cuda.synchronize()
    
    times = []
    with torch.no_grad():
        for _ in range(repeat):
            torch.cuda.synchronize()
            start = time.perf_counter()
            _ = layer(input_tensor)
            torch.cuda.synchronize()
            end = time.perf_counter()
            times.append((end - start) * 1000)
    
    return np.mean(times)

## 3. Single Block Comparison

In [5]:
# Test configurations
channels = 64
batch_size = 1
resolution = 64  # Feature map size

# Create blocks
jeongwon_block = JeongWonBlock(channels, channels, kernel_size=7, num_basis=8)
convnext_block = ConvNeXtBlock(channels, expansion=4, kernel_size=7)
convnext_v2_block = ConvNeXtBlockV2(channels, expansion=4, kernel_size=7)

# Input tensor
x = torch.randn(batch_size, channels, resolution, resolution)

print("=" * 60)
print("Single Block Comparison")
print(f"Input: ({batch_size}, {channels}, {resolution}, {resolution})")
print("=" * 60)

Single Block Comparison
Input: (1, 64, 64, 64)


In [6]:
# Parameter count
print("\n[Parameters]")
print(f"JeongWonBlock:    {count_params(jeongwon_block):,}")
print(f"ConvNeXtBlock:    {count_params(convnext_block):,}")
print(f"ConvNeXtBlockV2:  {count_params(convnext_v2_block):,}")


[Parameters]
JeongWonBlock:    4,608
ConvNeXtBlock:    36,416
ConvNeXtBlockV2:  36,928


In [7]:
# Layer-wise Latency Breakdown
print("\n[Layer-wise GPU Latency Breakdown]")
print("=" * 60)

channels = 64
resolution = 64
x = torch.randn(1, channels, resolution, resolution).cuda()

# =====================
# JeongWonBlock Layers
# =====================
print("\n--- JeongWonBlock (train mode) ---")
block = JeongWonBlock(channels, channels, kernel_size=7, num_basis=8).cuda().eval()

# RepConv layers (inside self.conv)
repconv = None
for m in block.conv.modules():
    if isinstance(m, RepConv):
        repconv = m
        break

if repconv:
    # conv_kxk + bn
    t = measure_layer_latency_gpu(nn.Sequential(repconv.conv_kxk, repconv.bn_kxk), x)
    print(f"  conv_kxk + bn_kxk (7x7 DW):  {t:.4f} ms")
    
    # conv_1x1 + bn
    if repconv.conv_1x1 is not None:
        t = measure_layer_latency_gpu(nn.Sequential(repconv.conv_1x1, repconv.bn_1x1), x)
        print(f"  conv_1x1 + bn_1x1:           {t:.4f} ms")
    
    # bn_identity
    if repconv.use_identity:
        t = measure_layer_latency_gpu(repconv.bn_identity, x)
        print(f"  bn_identity:                 {t:.4f} ms")
    
    # activation
    t = measure_layer_latency_gpu(repconv.activation, x)
    print(f"  ReLU:                        {t:.4f} ms")

# PRCM
t = measure_layer_latency_gpu(block.prcm, x)
print(f"  PRCM:                        {t:.4f} ms")

# Total
total, _ = measure_latency_gpu(block, x)
print(f"  --------------------------")
print(f"  Total:                       {total:.4f} ms")

# =====================
# JeongWonBlock Deploy
# =====================
print("\n--- JeongWonBlock (deploy mode) ---")
block_deploy = JeongWonBlock(channels, channels, kernel_size=7, num_basis=8)
block_deploy.switch_to_deploy()
block_deploy = block_deploy.cuda().eval()

# Find fused conv
for m in block_deploy.conv.modules():
    if isinstance(m, RepConv) and hasattr(m, 'fused_conv'):
        t = measure_layer_latency_gpu(nn.Sequential(m.fused_conv, m.activation), x)
        print(f"  fused_conv + ReLU (7x7 DW): {t:.4f} ms")
        break

t = measure_layer_latency_gpu(block_deploy.prcm, x)
print(f"  PRCM:                        {t:.4f} ms")

total, _ = measure_latency_gpu(block_deploy, x)
print(f"  --------------------------")
print(f"  Total:                       {total:.4f} ms")

# =====================
# ConvNeXtBlock Layers
# =====================
print("\n--- ConvNeXtBlock ---")
block = ConvNeXtBlock(channels, expansion=4, kernel_size=7).cuda().eval()

t = measure_layer_latency_gpu(block.dwconv, x)
print(f"  dwconv (7x7 DW):             {t:.4f} ms")

t = measure_layer_latency_gpu(block.norm, x)
print(f"  LayerNorm2d:                 {t:.4f} ms")

t = measure_layer_latency_gpu(block.pwconv1, x)
print(f"  pwconv1 (1x1, expand 4x):    {t:.4f} ms")

x_expanded = torch.randn(1, channels * 4, resolution, resolution).cuda()
t = measure_layer_latency_gpu(block.act, x_expanded)
print(f"  GELU:                        {t:.4f} ms")

t = measure_layer_latency_gpu(block.pwconv2, x_expanded)
print(f"  pwconv2 (1x1, project):      {t:.4f} ms")

total, _ = measure_latency_gpu(block, x)
print(f"  --------------------------")
print(f"  Total:                       {total:.4f} ms")

# =====================
# ConvNeXtBlockV2 Layers
# =====================
print("\n--- ConvNeXtBlockV2 ---")
block = ConvNeXtBlockV2(channels, expansion=4, kernel_size=7).cuda().eval()

t = measure_layer_latency_gpu(block.dwconv, x)
print(f"  dwconv (7x7 DW):             {t:.4f} ms")

t = measure_layer_latency_gpu(block.norm, x)
print(f"  LayerNorm2d:                 {t:.4f} ms")

t = measure_layer_latency_gpu(block.pwconv1, x)
print(f"  pwconv1 (1x1, expand 4x):    {t:.4f} ms")

t = measure_layer_latency_gpu(block.act, x_expanded)
print(f"  GELU:                        {t:.4f} ms")

# GRN timing
class GRN(nn.Module):
    def __init__(self, gamma, beta):
        super().__init__()
        self.gamma = gamma
        self.beta = beta
    def forward(self, x):
        gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
        nx = gx / (gx.mean(dim=1, keepdim=True) + 1e-6)
        return x + self.gamma * (x * nx) + self.beta

grn = GRN(block.grn_gamma, block.grn_beta).cuda().eval()
t = measure_layer_latency_gpu(grn, x_expanded)
print(f"  GRN:                         {t:.4f} ms")

t = measure_layer_latency_gpu(block.pwconv2, x_expanded)
print(f"  pwconv2 (1x1, project):      {t:.4f} ms")

total, _ = measure_latency_gpu(block, x)
print(f"  --------------------------")
print(f"  Total:                       {total:.4f} ms")


[Layer-wise GPU Latency Breakdown]

--- JeongWonBlock (train mode) ---
  conv_kxk + bn_kxk (7x7 DW):  0.0406 ms
  conv_1x1 + bn_1x1:           0.0181 ms
  bn_identity:                 0.0124 ms
  ReLU:                        0.0068 ms
  PRCM:                        0.0259 ms
  --------------------------
  Total:                       0.0850 ms

--- JeongWonBlock (deploy mode) ---
  fused_conv + ReLU (7x7 DW): 0.0371 ms
  PRCM:                        0.0248 ms
  --------------------------
  Total:                       0.0597 ms

--- ConvNeXtBlock ---
  dwconv (7x7 DW):             0.0341 ms
  LayerNorm2d:                 0.0389 ms
  pwconv1 (1x1, expand 4x):    0.0271 ms
  GELU:                        0.0145 ms
  pwconv2 (1x1, project):      0.0214 ms
  --------------------------
  Total:                       0.1180 ms

--- ConvNeXtBlockV2 ---
  dwconv (7x7 DW):             0.0341 ms
  LayerNorm2d:                 0.0390 ms
  pwconv1 (1x1, expand 4x):    0.0275 ms
  GELU:            

In [8]:
# CPU Latency
print("\n[CPU Latency]")

mean, std = measure_latency_cpu(jeongwon_block, x)
print(f"JeongWonBlock (train):    {mean:.3f} ± {std:.3f} ms")

# Deploy mode (fused)
jeongwon_block_deploy = JeongWonBlock(channels, channels, kernel_size=7, num_basis=8)
jeongwon_block_deploy.load_state_dict(jeongwon_block.state_dict())
jeongwon_block_deploy.switch_to_deploy()
mean, std = measure_latency_cpu(jeongwon_block_deploy, x)
print(f"JeongWonBlock (deploy):   {mean:.3f} ± {std:.3f} ms")

mean, std = measure_latency_cpu(convnext_block, x)
print(f"ConvNeXtBlock:            {mean:.3f} ± {std:.3f} ms")

mean, std = measure_latency_cpu(convnext_v2_block, x)
print(f"ConvNeXtBlockV2:          {mean:.3f} ± {std:.3f} ms")


[CPU Latency]
JeongWonBlock (train):    0.268 ± 0.011 ms
JeongWonBlock (deploy):   0.147 ± 0.009 ms
ConvNeXtBlock:            0.686 ± 0.026 ms
ConvNeXtBlockV2:          0.816 ± 0.049 ms


In [9]:
# GPU Latency
print("\n[GPU Latency]")

jeongwon_block = JeongWonBlock(channels, channels, kernel_size=7, num_basis=8)
mean, std = measure_latency_gpu(jeongwon_block, x)
if mean: print(f"JeongWonBlock (train):    {mean:.3f} ± {std:.3f} ms")

jeongwon_block_deploy = JeongWonBlock(channels, channels, kernel_size=7, num_basis=8)
jeongwon_block_deploy.switch_to_deploy()
mean, std = measure_latency_gpu(jeongwon_block_deploy, x)
if mean: print(f"JeongWonBlock (deploy):   {mean:.3f} ± {std:.3f} ms")

mean, std = measure_latency_gpu(convnext_block, x)
if mean: print(f"ConvNeXtBlock:            {mean:.3f} ± {std:.3f} ms")

mean, std = measure_latency_gpu(convnext_v2_block, x)
if mean: print(f"ConvNeXtBlockV2:          {mean:.3f} ± {std:.3f} ms")


[GPU Latency]
JeongWonBlock (train):    0.088 ± 0.028 ms
JeongWonBlock (deploy):   0.060 ± 0.001 ms
ConvNeXtBlock:            0.118 ± 0.002 ms
ConvNeXtBlockV2:          0.187 ± 0.002 ms


## 4. Multi-Scale Comparison (Different Resolutions)

In [10]:
resolutions = [32, 64, 128, 256]
channels = 64
batch_size = 1

print("\n" + "=" * 60)
print("Multi-Resolution GPU Latency Comparison")
print("=" * 60)

results = {"JeongWon (deploy)": [], "ConvNeXt": [], "ConvNeXtV2": []}

for res in resolutions:
    x = torch.randn(batch_size, channels, res, res)
    
    # JeongWonBlock (deploy)
    block = JeongWonBlock(channels, channels, kernel_size=7, num_basis=8)
    block.switch_to_deploy()
    mean, _ = measure_latency_gpu(block, x)
    results["JeongWon (deploy)"].append(mean)
    
    # ConvNeXt
    block = ConvNeXtBlock(channels, expansion=4, kernel_size=7)
    mean, _ = measure_latency_gpu(block, x)
    results["ConvNeXt"].append(mean)
    
    # ConvNeXtV2
    block = ConvNeXtBlockV2(channels, expansion=4, kernel_size=7)
    mean, _ = measure_latency_gpu(block, x)
    results["ConvNeXtV2"].append(mean)

# Print results
print(f"\n{'Resolution':<12}", end="")
for name in results.keys():
    print(f"{name:<20}", end="")
print()
print("-" * 72)

for i, res in enumerate(resolutions):
    print(f"{res}x{res:<10}", end="")
    for name in results.keys():
        print(f"{results[name][i]:.3f} ms{'':<12}", end="")
    print()


Multi-Resolution GPU Latency Comparison

Resolution  JeongWon (deploy)   ConvNeXt            ConvNeXtV2          
------------------------------------------------------------------------
32x32        0.036 ms            0.077 ms            0.104 ms            
64x64        0.062 ms            0.122 ms            0.188 ms            
128x128       0.130 ms            0.361 ms            0.587 ms            
256x256       0.451 ms            1.313 ms            2.098 ms            


## 5. Multi-Channel Comparison

In [11]:
channel_list = [24, 48, 64, 96, 128, 192]
resolution = 64
batch_size = 1

print("\n" + "=" * 60)
print("Multi-Channel GPU Latency Comparison (64x64 resolution)")
print("=" * 60)

results = {"JeongWon (deploy)": [], "ConvNeXt": [], "ConvNeXtV2": []}

for ch in channel_list:
    x = torch.randn(batch_size, ch, resolution, resolution)
    
    # JeongWonBlock (deploy)
    block = JeongWonBlock(ch, ch, kernel_size=7, num_basis=8)
    block.switch_to_deploy()
    mean, _ = measure_latency_gpu(block, x)
    results["JeongWon (deploy)"].append(mean)
    
    # ConvNeXt
    block = ConvNeXtBlock(ch, expansion=4, kernel_size=7)
    mean, _ = measure_latency_gpu(block, x)
    results["ConvNeXt"].append(mean)
    
    # ConvNeXtV2
    block = ConvNeXtBlockV2(ch, expansion=4, kernel_size=7)
    mean, _ = measure_latency_gpu(block, x)
    results["ConvNeXtV2"].append(mean)

# Print results
print(f"\n{'Channels':<12}", end="")
for name in results.keys():
    print(f"{name:<20}", end="")
print()
print("-" * 72)

for i, ch in enumerate(channel_list):
    print(f"{ch:<12}", end="")
    for name in results.keys():
        print(f"{results[name][i]:.3f} ms{'':<12}", end="")
    print()


Multi-Channel GPU Latency Comparison (64x64 resolution)

Channels    JeongWon (deploy)   ConvNeXt            ConvNeXtV2          
------------------------------------------------------------------------
24          0.042 ms            0.072 ms            0.104 ms            
48          0.047 ms            0.098 ms            0.154 ms            
64          0.060 ms            0.116 ms            0.185 ms            
96          0.067 ms            0.168 ms            0.269 ms            
128         0.083 ms            0.215 ms            0.339 ms            
192         0.107 ms            0.328 ms            0.501 ms            


## 6. Full Model Comparison (Stacked Blocks)

In [12]:
class StackedJeongWonBlocks(nn.Module):
    def __init__(self, c_list=[24, 48, 64, 96, 128, 192]):
        super().__init__()
        self.blocks = nn.ModuleList()
        in_ch = 3
        for out_ch in c_list:
            self.blocks.append(JeongWonBlock(in_ch, out_ch, kernel_size=7, num_basis=8))
            in_ch = out_ch
    
    def forward(self, x):
        for block in self.blocks:
            x = F.max_pool2d(block(x), 2)
        return x
    
    def switch_to_deploy(self):
        for block in self.blocks:
            block.switch_to_deploy()


class StackedConvNeXtBlocks(nn.Module):
    def __init__(self, c_list=[24, 48, 64, 96, 128, 192]):
        super().__init__()
        self.stem = nn.Conv2d(3, c_list[0], 3, padding=1)
        self.blocks = nn.ModuleList()
        for i, ch in enumerate(c_list):
            self.blocks.append(ConvNeXtBlock(ch, expansion=4, kernel_size=7))
            if i < len(c_list) - 1:
                self.blocks.append(nn.Conv2d(ch, c_list[i+1], 2, stride=2))  # Downsample
    
    def forward(self, x):
        x = self.stem(x)
        for block in self.blocks:
            x = block(x)
        return x


# Compare
print("\n" + "=" * 60)
print("Full Encoder Comparison (6 stages)")
print("=" * 60)

x = torch.randn(1, 3, 256, 256)

# JeongWon Encoder
encoder_jw = StackedJeongWonBlocks()
print(f"\nJeongWon Encoder Params: {count_params(encoder_jw):,}")
mean, std = measure_latency_gpu(encoder_jw, x)
if mean: print(f"  GPU Latency (train):  {mean:.3f} ± {std:.3f} ms")

encoder_jw_deploy = StackedJeongWonBlocks()
encoder_jw_deploy.switch_to_deploy()
mean, std = measure_latency_gpu(encoder_jw_deploy, x)
if mean: print(f"  GPU Latency (deploy): {mean:.3f} ± {std:.3f} ms")

# ConvNeXt Encoder
encoder_cn = StackedConvNeXtBlocks()
print(f"\nConvNeXt Encoder Params: {count_params(encoder_cn):,}")
mean, std = measure_latency_gpu(encoder_cn, x)
if mean: print(f"  GPU Latency: {mean:.3f} ± {std:.3f} ms")


Full Encoder Comparison (6 stages)

JeongWon Encoder Params: 87,048
  GPU Latency (train):  0.889 ± 0.011 ms
  GPU Latency (deploy): 0.498 ± 0.003 ms

ConvNeXt Encoder Params: 777,112
  GPU Latency: 1.267 ± 0.009 ms


## 7. Summary

In [13]:
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print("""
JeongWonBlock:
  - Structure: 1x1 Conv + RepConv 7x7 DW + PRCM
  - Pros: Lightweight, Re-parameterizable (deploy mode faster)
  - PRCM: Global context with minimal overhead

ConvNeXtBlock:
  - Structure: DW 7x7 -> LN -> 1x1 (expand) -> GELU -> 1x1 (project)
  - Pros: Strong representational power
  - Cons: More parameters (4x expansion), LayerNorm overhead

ConvNeXtBlockV2:
  - Adds GRN (Global Response Normalization)
  - Additional overhead from GRN computation

Key Insights:
  1. JeongWon deploy mode fuses RepConv branches -> faster inference
  2. ConvNeXt has 4x channel expansion -> more FLOPs
  3. PRCM is lightweight compared to full expansion+projection
""")


SUMMARY

JeongWonBlock:
  - Structure: 1x1 Conv + RepConv 7x7 DW + PRCM
  - Pros: Lightweight, Re-parameterizable (deploy mode faster)
  - PRCM: Global context with minimal overhead

ConvNeXtBlock:
  - Structure: DW 7x7 -> LN -> 1x1 (expand) -> GELU -> 1x1 (project)
  - Pros: Strong representational power
  - Cons: More parameters (4x expansion), LayerNorm overhead

ConvNeXtBlockV2:
  - Adds GRN (Global Response Normalization)
  - Additional overhead from GRN computation

Key Insights:
  1. JeongWon deploy mode fuses RepConv branches -> faster inference
  2. ConvNeXt has 4x channel expansion -> more FLOPs
  3. PRCM is lightweight compared to full expansion+projection

