# Encoder Block Comparison: JeongWonNet vs CMUNeXt

인코더 블록 구조, 파라미터, FLOPs, Latency 비교

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import pandas as pd
from thop import profile

from models import JeongWonNet, CMUNeXt
from models.JeongWonNet import PRCM
from models.CMUNeXt import CMUNeXtBlock, conv_block

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {DEVICE}")



Device: cuda


## 1. Block Structure Comparison

In [2]:
print("="*70)
print("JeongWonNet Encoder Block Structure")
print("="*70)
print("""
Input (in_ch)
    │
    ├─► [1x1 Conv] Pointwise (in_ch → out_ch)  ← 채널 조정
    │
    ├─► [3x3 DWConv] Depthwise (out_ch → out_ch, groups=out_ch)
    │
    ├─► [GroupNorm + GELU]
    │
    └─► [PRCM] Channel Mixer (out_ch → out_ch)
    │
Output (out_ch)

Key: Depthwise Separable Conv + Channel Attention
""")

print("\n" + "="*70)
print("CMUNeXt Encoder Block Structure (CMUNeXtBlock)")
print("="*70)
print("""
Input (ch_in)
    │
    ├─► [Residual Block] x depth times:
    │       │
    │       ├─► Residual(
    │       │       [3x3 DWConv] (ch_in → ch_in, groups=ch_in)
    │       │       [GELU + BatchNorm]
    │       │   )
    │       │
    │       ├─► [1x1 Conv] Expansion (ch_in → ch_in * 4)  ← 4배 확장!
    │       ├─► [GELU + BatchNorm]
    │       │
    │       ├─► [1x1 Conv] Reduction (ch_in * 4 → ch_in)  ← 다시 축소
    │       └─► [GELU + BatchNorm]
    │
    └─► [conv_block] Standard 3x3 Conv (ch_in → ch_out)
    │
Output (ch_out)

Key: 4x Channel Expansion + Standard 3x3 Conv
""")

JeongWonNet Encoder Block Structure

Input (in_ch)
    │
    ├─► [1x1 Conv] Pointwise (in_ch → out_ch)  ← 채널 조정
    │
    ├─► [3x3 DWConv] Depthwise (out_ch → out_ch, groups=out_ch)
    │
    ├─► [GroupNorm + GELU]
    │
    └─► [PRCM] Channel Mixer (out_ch → out_ch)
    │
Output (out_ch)

Key: Depthwise Separable Conv + Channel Attention


CMUNeXt Encoder Block Structure (CMUNeXtBlock)

Input (ch_in)
    │
    ├─► [Residual Block] x depth times:
    │       │
    │       ├─► Residual(
    │       │       [3x3 DWConv] (ch_in → ch_in, groups=ch_in)
    │       │       [GELU + BatchNorm]
    │       │   )
    │       │
    │       ├─► [1x1 Conv] Expansion (ch_in → ch_in * 4)  ← 4배 확장!
    │       ├─► [GELU + BatchNorm]
    │       │
    │       ├─► [1x1 Conv] Reduction (ch_in * 4 → ch_in)  ← 다시 축소
    │       └─► [GELU + BatchNorm]
    │
    └─► [conv_block] Standard 3x3 Conv (ch_in → ch_out)
    │
Output (ch_out)

Key: 4x Channel Expansion + Standard 3x3 Conv



## 2. Create Isolated Encoder Blocks

In [3]:
# JeongWonNet style encoder block
def make_jwn_block(in_ch, out_ch):
    """JeongWonNet 스타일 인코더 블록 생성"""
    layers = []
    if in_ch != out_ch:
        layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False))
    layers.append(nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, groups=out_ch, bias=False))
    num_groups = min(4, out_ch)
    while out_ch % num_groups != 0:
        num_groups -= 1
    layers.append(nn.GroupNorm(num_groups, out_ch))
    layers.append(nn.GELU())
    layers.append(PRCM(out_ch, num_basis=2))
    return nn.Sequential(*layers)

# Channel configurations
jwn_channels = [3, 6, 12, 18, 24, 32, 48]  # input → c_list
cmu_channels = [3, 8, 8, 16, 24, 32, 48]   # input → dims (stem: 3→8, enc1: 8→8)
cmu_depths = [1, 1, 1, 3, 1]
cmu_kernels = [3, 3, 7, 7, 7]

print("Channel Configurations:")
print(f"  JeongWonNet: {jwn_channels[0]} → {jwn_channels[1:]}")
print(f"  CMUNeXt:     {cmu_channels[0]} → {cmu_channels[1:]}")
print(f"  CMUNeXt depths: {cmu_depths}")

Channel Configurations:
  JeongWonNet: 3 → [6, 12, 18, 24, 32, 48]
  CMUNeXt:     3 → [8, 8, 16, 24, 32, 48]
  CMUNeXt depths: [1, 1, 1, 3, 1]


## 3. Parameter Count Comparison

In [4]:
def count_params(model):
    return sum(p.numel() for p in model.parameters())

def format_num(n):
    if n >= 1e6: return f"{n/1e6:.2f}M"
    if n >= 1e3: return f"{n/1e3:.2f}K"
    return str(n)

results = []

print("="*70)
print("Parameter Count per Encoder Block")
print("="*70)
print(f"{'Stage':<10} {'JWN (in→out)':<15} {'JWN Params':<12} {'CMU (in→out)':<15} {'CMU Params':<12} {'Ratio'}")
print("-"*70)

# JeongWonNet blocks
jwn_blocks = []
for i in range(6):
    in_ch = jwn_channels[i]
    out_ch = jwn_channels[i+1]
    block = make_jwn_block(in_ch, out_ch)
    jwn_blocks.append((f"{in_ch}→{out_ch}", count_params(block)))

# CMUNeXt blocks (stem + encoders)
cmu_blocks = []
# Stem: conv_block(3, 8)
stem = conv_block(3, 8)
cmu_blocks.append(("3→8 (stem)", count_params(stem)))

# Encoder blocks
for i in range(5):
    in_ch = cmu_channels[i+1]
    out_ch = cmu_channels[i+2]
    block = CMUNeXtBlock(in_ch, out_ch, depth=cmu_depths[i], k=cmu_kernels[i])
    cmu_blocks.append((f"{in_ch}→{out_ch} (d={cmu_depths[i]})", count_params(block)))

# Print comparison
for i in range(6):
    jwn_ch, jwn_p = jwn_blocks[i]
    cmu_ch, cmu_p = cmu_blocks[i]
    ratio = cmu_p / jwn_p if jwn_p > 0 else 0
    print(f"Enc{i+1:<5} {jwn_ch:<15} {format_num(jwn_p):<12} {cmu_ch:<15} {format_num(cmu_p):<12} {ratio:.1f}x")
    results.append({
        'Stage': f'Enc{i+1}',
        'JWN_channels': jwn_ch,
        'JWN_params': jwn_p,
        'CMU_channels': cmu_ch,
        'CMU_params': cmu_p,
        'Param_ratio': ratio
    })

print("-"*70)
jwn_total = sum(b[1] for b in jwn_blocks)
cmu_total = sum(b[1] for b in cmu_blocks)
print(f"{'Total':<10} {'':<15} {format_num(jwn_total):<12} {'':<15} {format_num(cmu_total):<12} {cmu_total/jwn_total:.1f}x")

Parameter Count per Encoder Block
Stage      JWN (in→out)    JWN Params   CMU (in→out)    CMU Params   Ratio
----------------------------------------------------------------------
Enc1     3→6             108          3→8 (stem)      240          2.2x
Enc2     6→12            252          8→8 (d=1)       1.33K        5.3x
Enc3     12→18           486          8→16 (d=1)      1.93K        4.0x
Enc4     18→24           792          16→24 (d=1)     6.65K        8.4x
Enc5     24→32           1.25K        24→32 (d=3)     25.66K       20.6x
Enc6     32→48           2.26K        32→48 (d=1)     24.30K       10.8x
----------------------------------------------------------------------
Total                      5.14K                        60.10K       11.7x


## 4. FLOPs Comparison

In [5]:
def measure_flops(model, input_tensor):
    model.eval()
    with torch.no_grad():
        flops, _ = profile(model, inputs=(input_tensor,), verbose=False)
    return flops

print("="*70)
print("FLOPs per Encoder Block (256x256 input)")
print("="*70)
print(f"{'Stage':<10} {'Resolution':<12} {'JWN FLOPs':<12} {'CMU FLOPs':<12} {'Ratio'}")
print("-"*70)

# Resolution at each stage (after pooling)
resolutions = [256, 128, 64, 32, 16, 8]

for i in range(6):
    res = resolutions[i]
    
    # JeongWonNet block
    in_ch = jwn_channels[i]
    out_ch = jwn_channels[i+1]
    jwn_block = make_jwn_block(in_ch, out_ch).to(DEVICE)
    jwn_input = torch.randn(1, in_ch, res, res).to(DEVICE)
    jwn_flops = measure_flops(jwn_block, jwn_input)
    
    # CMUNeXt block
    if i == 0:
        cmu_block = conv_block(3, 8).to(DEVICE)
        cmu_input = torch.randn(1, 3, res, res).to(DEVICE)
    else:
        in_ch_cmu = cmu_channels[i]
        out_ch_cmu = cmu_channels[i+1]
        cmu_block = CMUNeXtBlock(in_ch_cmu, out_ch_cmu, depth=cmu_depths[i-1], k=cmu_kernels[i-1]).to(DEVICE)
        cmu_input = torch.randn(1, in_ch_cmu, res, res).to(DEVICE)
    cmu_flops = measure_flops(cmu_block, cmu_input)
    
    ratio = cmu_flops / jwn_flops if jwn_flops > 0 else 0
    print(f"Enc{i+1:<5} {res}x{res:<8} {format_num(jwn_flops):<12} {format_num(cmu_flops):<12} {ratio:.1f}x")
    
    results[i]['JWN_flops'] = jwn_flops
    results[i]['CMU_flops'] = cmu_flops
    results[i]['FLOPs_ratio'] = ratio
    
    del jwn_block, cmu_block
    torch.cuda.empty_cache()

FLOPs per Encoder Block (256x256 input)
Stage      Resolution   JWN FLOPs    CMU FLOPs    Ratio
----------------------------------------------------------------------
Enc1     256x256      4.72M        16.25M       3.4x
Enc2     128x128      2.95M        22.68M       7.7x
Enc3     64x64       1.55M        8.16M        5.3x
Enc4     32x32       663.60K      6.93M        10.4x
Enc5     16x16       270.40K      6.69M        24.7x
Enc6     8x8        126.05K      1.57M        12.5x


## 5. Latency Comparison

In [6]:
NUM_WARMUP = 50
NUM_ITER = 100

def measure_latency(model, input_tensor):
    model.eval()
    # Warmup
    with torch.no_grad():
        for _ in range(NUM_WARMUP):
            _ = model(input_tensor)
    if DEVICE.type == 'cuda':
        torch.cuda.synchronize()
    
    # Measure
    times = []
    with torch.no_grad():
        for _ in range(NUM_ITER):
            if DEVICE.type == 'cuda':
                torch.cuda.synchronize()
            start = time.perf_counter()
            _ = model(input_tensor)
            if DEVICE.type == 'cuda':
                torch.cuda.synchronize()
            times.append((time.perf_counter() - start) * 1000)
    return np.mean(times), np.std(times)

print("="*70)
print("Latency per Encoder Block (256x256 input, GPU)")
print("="*70)
print(f"{'Stage':<10} {'Resolution':<12} {'JWN (ms)':<15} {'CMU (ms)':<15} {'Ratio'}")
print("-"*70)

for i in range(6):
    res = resolutions[i]
    
    # JeongWonNet block
    in_ch = jwn_channels[i]
    out_ch = jwn_channels[i+1]
    jwn_block = make_jwn_block(in_ch, out_ch).to(DEVICE)
    jwn_input = torch.randn(1, in_ch, res, res).to(DEVICE)
    jwn_mean, jwn_std = measure_latency(jwn_block, jwn_input)
    
    # CMUNeXt block
    if i == 0:
        cmu_block = conv_block(3, 8).to(DEVICE)
        cmu_input = torch.randn(1, 3, res, res).to(DEVICE)
    else:
        in_ch_cmu = cmu_channels[i]
        out_ch_cmu = cmu_channels[i+1]
        cmu_block = CMUNeXtBlock(in_ch_cmu, out_ch_cmu, depth=cmu_depths[i-1], k=cmu_kernels[i-1]).to(DEVICE)
        cmu_input = torch.randn(1, in_ch_cmu, res, res).to(DEVICE)
    cmu_mean, cmu_std = measure_latency(cmu_block, cmu_input)
    
    ratio = cmu_mean / jwn_mean if jwn_mean > 0 else 0
    print(f"Enc{i+1:<5} {res}x{res:<8} {jwn_mean:.3f}±{jwn_std:.3f}   {cmu_mean:.3f}±{cmu_std:.3f}   {ratio:.1f}x")
    
    results[i]['JWN_latency'] = jwn_mean
    results[i]['CMU_latency'] = cmu_mean
    results[i]['Latency_ratio'] = ratio
    
    del jwn_block, cmu_block
    torch.cuda.empty_cache()

Latency per Encoder Block (256x256 input, GPU)
Stage      Resolution   JWN (ms)        CMU (ms)        Ratio
----------------------------------------------------------------------
Enc1     256x256      0.120±0.028   0.065±0.027   0.5x
Enc2     128x128      0.070±0.003   0.108±0.018   1.6x
Enc3     64x64       0.067±0.024   0.092±0.026   1.4x
Enc4     32x32       0.057±0.019   0.088±0.004   1.5x
Enc5     16x16       0.058±0.024   0.205±0.035   3.6x
Enc6     8x8        0.058±0.019   0.088±0.004   1.5x


## 6. Why the Difference?

In [7]:
print("="*70)
print("Root Cause Analysis: Why CMUNeXt Encoder is Slower")
print("="*70)

# Example: Compare at 64x64, ch=24
print("\n[Example: 64x64 resolution, ~24 channels]\n")

# JeongWonNet: 18 → 24
print("JeongWonNet (18→24):")
print("  1x1 Conv:    18 × 24 × 1 × 1         = 432 params")
print("  3x3 DWConv:  24 × 1 × 3 × 3          = 216 params")
print("  PRCM:        2 × 24 + 2 × 24         = 96 params")
print("  GroupNorm:   24 × 2                  = 48 params")
print("  ─────────────────────────────────────────────")
print("  Total:                               ≈ 792 params")

print("\nCMUNeXt (24→32, depth=1, k=7):")
print("  [Residual DWConv] 24 × 7 × 7         = 1,176 params")
print("  [1x1 Expansion]   24 × 96            = 2,304 params  ← 4x expansion!")
print("  [1x1 Reduction]   96 × 24            = 2,304 params")
print("  [conv_block 3x3]  24 × 32 × 3 × 3    = 6,912 params  ← Standard conv!")
print("  [BatchNorms]      ...                ≈ 400 params")
print("  ─────────────────────────────────────────────")
print("  Total:                               ≈ 13,096 params")

print("\n" + "="*70)
print("Key Differences:")
print("="*70)
print("""
1. [4x Channel Expansion]
   - CMUNeXt: ch → ch*4 → ch (MLP-like bottleneck)
   - JeongWonNet: No expansion, direct channel mapping

2. [Standard 3x3 Conv in conv_block]
   - CMUNeXt conv_block: in_ch × out_ch × 3 × 3 params
   - JeongWonNet DWConv: out_ch × 3 × 3 params (groups=out_ch)
   - Ratio: in_ch times more parameters!

3. [Larger Kernel Size at Deep Layers]
   - CMUNeXt uses 7x7 DWConv at enc3, enc4, enc5
   - JeongWonNet uses 3x3 DWConv everywhere
   - 7x7 = 49 vs 3x3 = 9 (5.4x more ops per pixel)

4. [depth Parameter]
   - CMUNeXt enc4 has depth=3 (block repeated 3 times)
   - JeongWonNet: single pass through each block
""")

Root Cause Analysis: Why CMUNeXt Encoder is Slower

[Example: 64x64 resolution, ~24 channels]

JeongWonNet (18→24):
  1x1 Conv:    18 × 24 × 1 × 1         = 432 params
  3x3 DWConv:  24 × 1 × 3 × 3          = 216 params
  PRCM:        2 × 24 + 2 × 24         = 96 params
  GroupNorm:   24 × 2                  = 48 params
  ─────────────────────────────────────────────
  Total:                               ≈ 792 params

CMUNeXt (24→32, depth=1, k=7):
  [Residual DWConv] 24 × 7 × 7         = 1,176 params
  [1x1 Expansion]   24 × 96            = 2,304 params  ← 4x expansion!
  [1x1 Reduction]   96 × 24            = 2,304 params
  [conv_block 3x3]  24 × 32 × 3 × 3    = 6,912 params  ← Standard conv!
  [BatchNorms]      ...                ≈ 400 params
  ─────────────────────────────────────────────
  Total:                               ≈ 13,096 params

Key Differences:

1. [4x Channel Expansion]
   - CMUNeXt: ch → ch*4 → ch (MLP-like bottleneck)
   - JeongWonNet: No expansion, direct chann

In [8]:
# Summary DataFrame
df = pd.DataFrame(results)
print("\n" + "="*70)
print("Summary Table")
print("="*70)
print(df.to_string(index=False))


Summary Table
Stage JWN_channels  JWN_params CMU_channels  CMU_params  Param_ratio  JWN_flops  CMU_flops  FLOPs_ratio  JWN_latency  CMU_latency  Latency_ratio
 Enc1          3→6         108   3→8 (stem)         240     2.222222  4718604.0 16252928.0     3.444436     0.119967     0.065277       0.544130
 Enc2         6→12         252    8→8 (d=1)        1328     5.269841  2949144.0 22675456.0     7.688826     0.069971     0.108497       1.550609
 Enc3        12→18         486   8→16 (d=1)        1928     3.967078  1548324.0  8159232.0     5.269719     0.067416     0.092177       1.367279
 Enc4        18→24         792  16→24 (d=1)        6648     8.393939   663600.0  6930432.0    10.443689     0.056683     0.087713       1.547437
 Enc5        24→32        1248  24→32 (d=3)       25656    20.557692   270400.0  6686720.0    24.728994     0.057781     0.205492       3.556374
 Enc6        32→48        2256  32→48 (d=1)       24304    10.773050   126048.0  1570816.0    12.462046     0.05842