# PRCM Expansion Ablation Study
## GAP 이후 Expansion 전략 비교

Comparing:

### Baseline (Traditional Conv Block)
0. **InvertedBottleneck** - Direct Transform (MobileNetV2 style)
   - Structure: `7x7 DW Conv -> BN -> ReLU -> 1x1 Expand -> BN -> ReLU -> 1x1 Project -> BN`
   - 특징: 공간 정보를 직접 처리하는 전통적인 convolution 방식

### PRCM Variants (Channel Recalibration)
1. **PRCM (baseline)** - Modulator
   - Structure: `GAP -> basis projection -> fuser -> sigmoid -> x * w`
   - 특징: Low-rank basis를 사용한 채널 재조정

2. **PRCM_SE** - Modulator
   - Structure: `GAP -> expand(C*r) -> ReLU -> shrink(C) -> sigmoid -> x * w`
   - 특징: SE-style expansion, HW=1이라 expansion 늘려도 연산량 적음

3. **PRCM_AdaptivePool** - **NOT a Modulator** (Direct Transform)
   - Structure: `AdaptivePool(NxN) -> 1x1 Conv expand -> ReLU -> 1x1 Conv shrink -> Upsample -> out`
   - 특징: 
     - 공간 정보 유지 (1x1, 2x2, 4x4, 8x8, 16x16, 32x32 등)
     - 1x1 Conv 사용 → **pool_size와 무관하게 파라미터 동일**
     - 직접 feature 변환 (NOT modulator)

4. **PRCM_BasisExpand** - Modulator
   - Structure: `GAP -> basis(C->B) -> expand(B*r) -> ReLU -> shrink(C) -> sigmoid -> x * w`
   - 특징: Basis로 먼저 압축 후 expansion, 파라미터 효율적

5. **PRCM_SE_Affine** - Modulator (scale + shift)
   - Structure: `GAP -> expand -> ReLU -> [scale_proj, shift_proj] -> x * alpha + beta`
   - 특징: SE + Affine transform (scale과 shift 모두 적용)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
from collections import OrderedDict

## 1. PRCM Variants

In [2]:
# ============================================
# PRCM (Baseline)
# ============================================

class PRCM(nn.Module):
    """Pattern Recalibration Module (Baseline)
    
    Structure: GAP -> basis projection -> fuser -> sigmoid
    """
    def __init__(self, channels, num_basis=8, dropout_rate=0.5):
        super().__init__()
        self.basis = nn.Parameter(torch.randn(num_basis, channels))
        self.fuser = nn.Linear(num_basis, channels, bias=False)
        self.coeff_dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
    
    def forward(self, x):
        ctx = x.mean(dim=[2, 3])  # [B, C]
        coeff = ctx @ self.basis.t()  # [B, num_basis]
        coeff = self.coeff_dropout(coeff)
        w = self.fuser(coeff).sigmoid().unsqueeze(-1).unsqueeze(-1)
        return x * w

In [3]:
# ============================================
# PRCM_SE (SE-style Expansion)
# ============================================

class PRCM_SE(nn.Module):
    """PRCM with SE-style expansion
    
    Structure: GAP -> expand(C*r) -> ReLU -> shrink(C) -> sigmoid
    
    GAP 이후 HW=1이므로 expansion ratio를 크게 해도 연산량 적음
    """
    def __init__(self, channels, expansion=4, dropout_rate=0.5):
        super().__init__()
        hidden = channels * expansion
        
        self.expand = nn.Linear(channels, hidden, bias=False)
        self.act = nn.ReLU(inplace=True)
        self.shrink = nn.Linear(hidden, channels, bias=False)
        self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
    
    def forward(self, x):
        ctx = x.mean(dim=[2, 3])  # [B, C]
        ctx = self.expand(ctx)    # [B, C*r]
        ctx = self.act(ctx)
        ctx = self.dropout(ctx)
        w = self.shrink(ctx).sigmoid().unsqueeze(-1).unsqueeze(-1)  # [B, C, 1, 1]
        return x * w

In [4]:
# ============================================
# PRCM_AdaptivePool (Spatial Info Preserved - Direct Transform)
# ============================================

class PRCM_AdaptivePool(nn.Module):
    """PRCM with Adaptive Pooling (preserve spatial info)
    
    Structure: AdaptivePool(pool_size) -> 1x1 Conv expand -> ReLU -> 1x1 Conv shrink -> upsample
    
    GAP 대신 고정 크기 pooling으로 공간 정보 일부 유지
    1x1 Conv 사용 (flatten 대신)
    Modulator가 아님 - 직접 feature 변환
    """
    def __init__(self, channels, pool_size=2, expansion=2, dropout_rate=0.5):
        super().__init__()
        self.pool_size = pool_size
        hidden = channels * expansion
        
        # 1x1 Conv (flatten 대신)
        self.conv1 = nn.Conv2d(channels, hidden, 1, bias=False)
        self.act = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(hidden, channels, 1, bias=False)
        self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
    
    def forward(self, x):
        B, C, H, W = x.shape
        
        # Adaptive pool to fixed size
        if H >= self.pool_size and W >= self.pool_size:
            out = F.adaptive_avg_pool2d(x, (self.pool_size, self.pool_size))  # [B, C, ps, ps]
        else:
            # 입력이 pool_size보다 작으면 직접 사용
            out = x
        
        out = self.conv1(out)   # [B, hidden, ps, ps]
        out = self.act(out)
        out = self.dropout(out)
        out = self.conv2(out)   # [B, C, ps, ps]
        
        # Upsample back to original size
        out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=True)
        
        # NOT a modulator - direct feature transformation
        return out

In [5]:
# ============================================
# PRCM_BasisExpand (Basis + Expansion)
# ============================================

class PRCM_BasisExpand(nn.Module):
    """PRCM with Basis Projection + Expansion
    
    Structure: GAP -> basis(C->B) -> expand(B->B*r) -> ReLU -> shrink(B*r->C) -> sigmoid
    
    Basis로 먼저 압축 후 expansion
    """
    def __init__(self, channels, num_basis=8, expansion=4, dropout_rate=0.5):
        super().__init__()
        self.basis = nn.Parameter(torch.randn(num_basis, channels))
        hidden = num_basis * expansion
        
        self.expand = nn.Linear(num_basis, hidden, bias=False)
        self.act = nn.ReLU(inplace=True)
        self.shrink = nn.Linear(hidden, channels, bias=False)
        self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
    
    def forward(self, x):
        ctx = x.mean(dim=[2, 3])  # [B, C]
        coeff = ctx @ self.basis.t()  # [B, num_basis]
        coeff = self.expand(coeff)  # [B, num_basis * expansion]
        coeff = self.act(coeff)
        coeff = self.dropout(coeff)
        w = self.shrink(coeff).sigmoid().unsqueeze(-1).unsqueeze(-1)
        return x * w

In [None]:
# ============================================
# PRCM_SE_Affine (SE + Shift)
# ============================================

class PRCM_SE_Affine(nn.Module):
    """PRCM SE-style with Affine (scale + shift)
    
    Structure: GAP -> expand -> ReLU -> [scale_proj, shift_proj] -> affine
    """
    def __init__(self, channels, expansion=4, dropout_rate=0.5):
        super().__init__()
        hidden = channels * expansion
        
        self.expand = nn.Linear(channels, hidden, bias=False)
        self.act = nn.ReLU(inplace=True)
        self.scale_proj = nn.Linear(hidden, channels, bias=False)
        self.shift_proj = nn.Linear(hidden, channels, bias=False)
        self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
    
    def forward(self, x):
        ctx = x.mean(dim=[2, 3])
        ctx = self.expand(ctx)
        ctx = self.act(ctx)
        ctx = self.dropout(ctx)
        
        alpha = self.scale_proj(ctx).sigmoid().unsqueeze(-1).unsqueeze(-1)
        beta = self.shift_proj(ctx).unsqueeze(-1).unsqueeze(-1)
        return x * alpha + beta


# ============================================
# InvertedBottleneck (Baseline Block)
# ============================================

class InvertedBottleneck(nn.Module):
    """Inverted Bottleneck Block (MobileNetV2 style baseline)
    
    Structure: 7x7 DW Conv -> 1x1 Expand -> ReLU -> 1x1 Project
    
    PRCM variants와 비교하기 위한 베이스라인 블록
    공간 정보를 직접 처리하는 전통적인 convolution 방식
    """
    def __init__(self, channels, kernel_size=7, expansion=2, dropout_rate=0.5):
        super().__init__()
        hidden = channels * expansion
        
        # 7x7 Depthwise Conv
        self.dw_conv = nn.Conv2d(
            channels, channels, kernel_size, 
            padding=kernel_size // 2, groups=channels, bias=False
        )
        self.bn1 = nn.BatchNorm2d(channels)
        
        # 1x1 Expand
        self.expand = nn.Conv2d(channels, hidden, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(hidden)
        self.act = nn.ReLU(inplace=True)
        
        # 1x1 Project
        self.project = nn.Conv2d(hidden, channels, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(channels)
        
        self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
    
    def forward(self, x):
        # 7x7 DW
        out = self.dw_conv(x)
        out = self.bn1(out)
        out = self.act(out)
        
        # 1x1 Expand
        out = self.expand(out)
        out = self.bn2(out)
        out = self.act(out)
        out = self.dropout(out)
        
        # 1x1 Project
        out = self.project(out)
        out = self.bn3(out)
        
        return out


class InvertedBottleneckRes(nn.Module):
    """Inverted Bottleneck with Residual Connection
    
    Structure: 7x7 DW Conv -> 1x1 Expand -> ReLU -> 1x1 Project + Residual
    """
    def __init__(self, channels, kernel_size=7, expansion=2, dropout_rate=0.5):
        super().__init__()
        hidden = channels * expansion
        
        # 7x7 Depthwise Conv
        self.dw_conv = nn.Conv2d(
            channels, channels, kernel_size, 
            padding=kernel_size // 2, groups=channels, bias=False
        )
        self.bn1 = nn.BatchNorm2d(channels)
        
        # 1x1 Expand
        self.expand = nn.Conv2d(channels, hidden, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(hidden)
        self.act = nn.ReLU(inplace=True)
        
        # 1x1 Project
        self.project = nn.Conv2d(hidden, channels, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(channels)
        
        self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
    
    def forward(self, x):
        identity = x
        
        # 7x7 DW
        out = self.dw_conv(x)
        out = self.bn1(out)
        out = self.act(out)
        
        # 1x1 Expand
        out = self.expand(out)
        out = self.bn2(out)
        out = self.act(out)
        out = self.dropout(out)
        
        # 1x1 Project + Residual
        out = self.project(out)
        out = self.bn3(out)
        out = out + identity
        
        return out

## 2. Measurement Functions

In [7]:
def count_params(model):
    return sum(p.numel() for p in model.parameters())

def measure_latency_gpu(model, input_tensor, warmup=100, repeat=500):
    if not torch.cuda.is_available():
        return None, None
    
    model.eval().cuda()
    input_tensor = input_tensor.cuda()
    
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(input_tensor)
            torch.cuda.synchronize()
    
    times = []
    with torch.no_grad():
        for _ in range(repeat):
            torch.cuda.synchronize()
            start = time.perf_counter()
            _ = model(input_tensor)
            torch.cuda.synchronize()
            times.append((time.perf_counter() - start) * 1000)
    
    return np.mean(times), np.std(times)

def measure_latency_cpu(model, input_tensor, warmup=20, repeat=100):
    model.eval().cpu()
    input_tensor = input_tensor.cpu()
    
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(input_tensor)
    
    times = []
    with torch.no_grad():
        for _ in range(repeat):
            start = time.perf_counter()
            _ = model(input_tensor)
            times.append((time.perf_counter() - start) * 1000)
    
    return np.mean(times), np.std(times)

## 3. Architecture Comparison

In [None]:
print("=" * 80)
print("PRCM Variants Architecture")
print("=" * 80)

print("""
┌─────────────────────────────────────────────────────────────────────────────┐
│ InvertedBottleneck (Baseline) - Direct Transform                            │
├─────────────────────────────────────────────────────────────────────────────┤
│   x [B,C,H,W] -> 7x7 DW Conv -> BN -> ReLU                                  │
│   -> 1x1 Expand(C*r) -> BN -> ReLU                                          │
│   -> 1x1 Project(C) -> BN -> out                                            │
│                                                                             │
│   Params: C*7*7 + C*C*r + C*r*C + BNs = C*49 + 2*C^2*r + BNs                │
│   (공간 정보 직접 처리, 전통적인 convolution 방식)                           │
└─────────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────────┐
│ PRCM (Baseline) - Modulator                                                 │
├─────────────────────────────────────────────────────────────────────────────┤
│   x [B,C,H,W] -> GAP -> ctx [B,C]                                           │
│   ctx @ basis.T -> coeff [B, num_basis]                                     │
│   fuser(coeff) -> w [B, C] -> sigmoid -> x * w                              │
│                                                                             │
│   Params: num_basis * C + num_basis * C = 2 * num_basis * C                 │
└─────────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────────┐
│ PRCM_SE (SE-style Expansion) - Modulator                                    │
├─────────────────────────────────────────────────────────────────────────────┤
│   x [B,C,H,W] -> GAP -> ctx [B,C]                                           │
│   expand(C -> C*r) -> ReLU -> shrink(C*r -> C) -> sigmoid -> x * w          │
│                                                                             │
│   Params: C * C*r + C*r * C = 2 * C^2 * r                                   │
│   (expansion 크게 해도 HW=1이라 연산량 적음)                                  │
└─────────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────────┐
│ PRCM_AdaptivePool (Spatial Info) - **NOT a Modulator**                      │
├─────────────────────────────────────────────────────────────────────────────┤
│   x [B,C,H,W] -> AdaptivePool(ps) -> [B,C,ps,ps]                            │
│   1x1 Conv expand -> ReLU -> 1x1 Conv shrink -> Upsample -> out             │
│                                                                             │
│   **NOT x * w, direct feature transformation**                              │
│                                                                             │
│   Params: C * hidden + hidden * C = 2 * C * C*exp                           │
│   (공간 정보 유지, 1x1 Conv로 효율적)                                        │
└─────────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────────┐
│ PRCM_BasisExpand (Basis + Expansion) - Modulator                            │
├─────────────────────────────────────────────────────────────────────────────┤
│   x [B,C,H,W] -> GAP -> ctx [B,C]                                           │
│   ctx @ basis.T -> coeff [B, num_basis]                                     │
│   expand(B -> B*r) -> ReLU -> shrink(B*r -> C) -> sigmoid -> x * w          │
│                                                                             │
│   Params: num_basis * C + num_basis * B*r + B*r * C                         │
│   (basis로 압축 후 expansion - 중간 크기)                                    │
└─────────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────────┐
│ PRCM_SE_Affine (SE + Shift) - Modulator (scale + shift)                     │
├─────────────────────────────────────────────────────────────────────────────┤
│   Same as PRCM_SE but with scale + shift output                             │
│   -> scale_proj(C), shift_proj(C) -> x * alpha + beta                       │
│                                                                             │
│   Params: C * C*r + C*r * C * 2 (scale + shift)                             │
└─────────────────────────────────────────────────────────────────────────────┘
""")

## 4. Parameter Comparison

In [None]:
channels = 64
num_basis = 8

print("=" * 80)
print(f"Parameter Comparison (channels={channels})")
print("=" * 80)

variants = OrderedDict([
    # Baseline Conv Block
    ("InvertedBottleneck (7x7, exp=2)", InvertedBottleneck(channels, kernel_size=7, expansion=2)),
    ("InvertedBottleneck (7x7, exp=4)", InvertedBottleneck(channels, kernel_size=7, expansion=4)),
    ("InvertedBottleneckRes (7x7, exp=2)", InvertedBottleneckRes(channels, kernel_size=7, expansion=2)),
    # PRCM variants
    ("PRCM (baseline)", PRCM(channels, num_basis=8)),
    ("PRCM_SE (exp=2)", PRCM_SE(channels, expansion=2)),
    ("PRCM_SE (exp=4)", PRCM_SE(channels, expansion=4)),
    ("PRCM_AdaptivePool (2x2, exp=2)", PRCM_AdaptivePool(channels, pool_size=2, expansion=2)),
    ("PRCM_AdaptivePool (4x4, exp=2)", PRCM_AdaptivePool(channels, pool_size=4, expansion=2)),
    ("PRCM_AdaptivePool (8x8, exp=2)", PRCM_AdaptivePool(channels, pool_size=8, expansion=2)),
    ("PRCM_BasisExpand (B=8, exp=4)", PRCM_BasisExpand(channels, num_basis=8, expansion=4)),
    ("PRCM_SE_Affine (exp=4)", PRCM_SE_Affine(channels, expansion=4)),
])

baseline_params = count_params(variants["PRCM (baseline)"])
invbottleneck_params = count_params(variants["InvertedBottleneck (7x7, exp=2)"])

print(f"\n{'Variant':<40} {'Params':>12} {'vs PRCM':>12} {'vs InvBN':>12}")
print("-" * 80)

for name, module in variants.items():
    params = count_params(module)
    ratio_prcm = params / baseline_params
    ratio_inv = params / invbottleneck_params
    print(f"{name:<40} {params:>12,} {ratio_prcm:>11.1f}x {ratio_inv:>11.2f}x")

print("\n[Note]")
print("- InvertedBottleneck: 7x7 DW + 1x1 expand + 1x1 project + BNs")
print("- PRCM variants: GAP 이후 처리로 파라미터 효율적")
print("- PRCM_AdaptivePool: 1x1 Conv 사용으로 pool_size와 무관하게 파라미터 동일")

## 5. GPU Latency Comparison

In [None]:
channels = 64
resolution = 64

x = torch.randn(1, channels, resolution, resolution)

print("=" * 85)
print(f"GPU Latency Comparison (input: 1x{channels}x{resolution}x{resolution})")
print("=" * 85)

variants = OrderedDict([
    # Baseline Conv Block
    ("InvertedBottleneck (7x7, exp=2)", InvertedBottleneck(channels, kernel_size=7, expansion=2, dropout_rate=0)),
    ("InvertedBottleneck (7x7, exp=4)", InvertedBottleneck(channels, kernel_size=7, expansion=4, dropout_rate=0)),
    ("InvertedBottleneckRes (7x7, exp=2)", InvertedBottleneckRes(channels, kernel_size=7, expansion=2, dropout_rate=0)),
    # PRCM variants
    ("PRCM (baseline)", PRCM(channels, num_basis=8, dropout_rate=0)),
    ("PRCM_SE (exp=4)", PRCM_SE(channels, expansion=4, dropout_rate=0)),
    ("PRCM_AdaptivePool (2x2)", PRCM_AdaptivePool(channels, pool_size=2, expansion=2, dropout_rate=0)),
    ("PRCM_AdaptivePool (4x4)", PRCM_AdaptivePool(channels, pool_size=4, expansion=2, dropout_rate=0)),
    ("PRCM_AdaptivePool (8x8)", PRCM_AdaptivePool(channels, pool_size=8, expansion=2, dropout_rate=0)),
    ("PRCM_BasisExpand (exp=4)", PRCM_BasisExpand(channels, num_basis=8, expansion=4, dropout_rate=0)),
    ("PRCM_SE_Affine (exp=4)", PRCM_SE_Affine(channels, expansion=4, dropout_rate=0)),
])

print(f"\n{'Variant':<40} {'Mean (ms)':>12} {'Std (ms)':>12} {'Params':>12}")
print("-" * 80)

for name, module in variants.items():
    mean, std = measure_latency_gpu(module, x)
    params = count_params(module)
    if mean:
        print(f"{name:<40} {mean:>12.4f} {std:>12.4f} {params:>12,}")

## 6. CPU Latency Comparison

In [None]:
print("=" * 85)
print(f"CPU Latency Comparison (input: 1x{channels}x{resolution}x{resolution})")
print("=" * 85)

print(f"\n{'Variant':<40} {'Mean (ms)':>12} {'Std (ms)':>12}")
print("-" * 67)

for name, module in variants.items():
    mean, std = measure_latency_cpu(module, x)
    print(f"{name:<40} {mean:>12.4f} {std:>12.4f}")

## 7. Multi-Channel Comparison

In [None]:
channel_list = [24, 48, 64, 96, 128, 192]
resolution = 64

print("=" * 130)
print("GPU Latency vs Channel Size")
print("=" * 130)

results = {name: [] for name in ["InvBN(2)", "InvBN(4)", "PRCM", "PRCM_SE(4)", "AdaPool(4x4)", "BasisExp"]}

for ch in channel_list:
    x = torch.randn(1, ch, resolution, resolution)
    
    m = InvertedBottleneck(ch, kernel_size=7, expansion=2, dropout_rate=0)
    mean, _ = measure_latency_gpu(m, x)
    results["InvBN(2)"].append(mean)
    
    m = InvertedBottleneck(ch, kernel_size=7, expansion=4, dropout_rate=0)
    mean, _ = measure_latency_gpu(m, x)
    results["InvBN(4)"].append(mean)
    
    m = PRCM(ch, num_basis=8, dropout_rate=0)
    mean, _ = measure_latency_gpu(m, x)
    results["PRCM"].append(mean)
    
    m = PRCM_SE(ch, expansion=4, dropout_rate=0)
    mean, _ = measure_latency_gpu(m, x)
    results["PRCM_SE(4)"].append(mean)
    
    m = PRCM_AdaptivePool(ch, pool_size=4, expansion=2, dropout_rate=0)
    mean, _ = measure_latency_gpu(m, x)
    results["AdaPool(4x4)"].append(mean)
    
    m = PRCM_BasisExpand(ch, num_basis=8, expansion=4, dropout_rate=0)
    mean, _ = measure_latency_gpu(m, x)
    results["BasisExp"].append(mean)

print(f"\n{'Ch':<8}", end="")
for name in results.keys():
    print(f"{name:<20}", end="")
print()
print("-" * 130)

for i, ch in enumerate(channel_list):
    print(f"{ch:<8}", end="")
    for name in results.keys():
        if results[name][i]:
            print(f"{results[name][i]:.4f} ms{'':<11}", end="")
    print()

print("\n[Note] InvBN = InvertedBottleneck (7x7 DW + 1x1 expand + 1x1 project)")

## 8. Expansion Ratio Ablation

In [13]:
channels = 64
resolution = 64
expansion_ratios = [1, 2, 4, 8, 16]

print("=" * 80)
print(f"PRCM_SE Expansion Ratio Ablation (channels={channels})")
print("=" * 80)

x = torch.randn(1, channels, resolution, resolution)

print(f"\n{'Expansion':<12} {'Params':>12} {'GPU (ms)':>12} {'CPU (ms)':>12} {'Hidden Dim':>12}")
print("-" * 65)

for exp in expansion_ratios:
    m = PRCM_SE(channels, expansion=exp, dropout_rate=0)
    params = count_params(m)
    gpu_mean, _ = measure_latency_gpu(m, x)
    cpu_mean, _ = measure_latency_cpu(m, x)
    hidden = channels * exp
    
    print(f"{exp:<12} {params:>12,} {gpu_mean:>12.4f} {cpu_mean:>12.4f} {hidden:>12}")

print("\n[Insight] HW=1 이후라 expansion 늘려도 latency 증가 미미")

PRCM_SE Expansion Ratio Ablation (channels=64)

Expansion          Params     GPU (ms)     CPU (ms)   Hidden Dim
-----------------------------------------------------------------
1                   8,192       0.0288       0.0284           64
2                  16,384       0.0321       0.7587          128
4                  32,768       0.0303       0.5601          256
8                  65,536       0.0294       0.4122          512
16                131,072       0.0318       0.1196         1024

[Insight] HW=1 이후라 expansion 늘려도 latency 증가 미미


## 9. AdaptivePool Size Ablation

In [14]:
channels = 64
resolution = 64
pool_sizes = [1, 2, 4, 8, 16, 32]

print("=" * 80)
print(f"PRCM_AdaptivePool Size Ablation (channels={channels}, expansion=2)")
print("=" * 80)
print("(1x1 Conv 사용으로 파라미터는 pool_size와 무관)")

x = torch.randn(1, channels, resolution, resolution)

print(f"\n{'Pool Size':<12} {'Params':>12} {'GPU (ms)':>12} {'CPU (ms)':>12} {'Spatial Info':>14}")
print("-" * 70)

for ps in pool_sizes:
    m = PRCM_AdaptivePool(channels, pool_size=ps, expansion=2, dropout_rate=0)
    params = count_params(m)
    gpu_mean, _ = measure_latency_gpu(m, x)
    cpu_mean, _ = measure_latency_cpu(m, x)
    spatial_info = f"{ps}x{ps} = {ps*ps}"
    
    print(f"{ps}x{ps:<10} {params:>12,} {gpu_mean:>12.4f} {cpu_mean:>12.4f} {spatial_info:>14}")

print("\n[Key Insight]")
print("- 1x1 Conv 사용으로 params는 pool_size와 무관 (모두 동일)")
print("- pool_size 커지면 공간 정보 더 많이 유지")
print("- Latency는 upsample 크기에 따라 약간 증가 (큰 pool → 적은 upsample)")

PRCM_AdaptivePool Size Ablation (channels=64, expansion=2)
(1x1 Conv 사용으로 파라미터는 pool_size와 무관)

Pool Size          Params     GPU (ms)     CPU (ms)   Spatial Info
----------------------------------------------------------------------
1x1                16,384       0.0430       1.1623        1x1 = 1
2x2                16,384       0.0619       0.1901        2x2 = 4
4x4                16,384       0.0446       0.1715       4x4 = 16
8x8                16,384       0.0443       0.1749       8x8 = 64
16x16               16,384       0.0446       0.2253    16x16 = 256
32x32               16,384       0.0445       0.3310   32x32 = 1024

[Key Insight]
- 1x1 Conv 사용으로 params는 pool_size와 무관 (모두 동일)
- pool_size 커지면 공간 정보 더 많이 유지
- Latency는 upsample 크기에 따라 약간 증가 (큰 pool → 적은 upsample)


## 10. Multi-Resolution (Small Feature Maps)

In [None]:
channels = 128
resolutions = [4, 8, 16, 32, 64]

print("=" * 130)
print(f"GPU Latency at Different Resolutions (channels={channels})")
print("=" * 130)
print("(Deep layer에서는 resolution 작음 - 여기서 각 블록의 효율성 비교)")

results = {name: [] for name in ["InvBN(2)", "InvBN(4)", "PRCM", "PRCM_SE(4)", "AdaPool(4x4)"]}

for res in resolutions:
    x = torch.randn(1, channels, res, res)
    
    m = InvertedBottleneck(channels, kernel_size=7, expansion=2, dropout_rate=0)
    mean, _ = measure_latency_gpu(m, x)
    results["InvBN(2)"].append(mean)
    
    m = InvertedBottleneck(channels, kernel_size=7, expansion=4, dropout_rate=0)
    mean, _ = measure_latency_gpu(m, x)
    results["InvBN(4)"].append(mean)
    
    m = PRCM(channels, num_basis=8, dropout_rate=0)
    mean, _ = measure_latency_gpu(m, x)
    results["PRCM"].append(mean)
    
    m = PRCM_SE(channels, expansion=4, dropout_rate=0)
    mean, _ = measure_latency_gpu(m, x)
    results["PRCM_SE(4)"].append(mean)
    
    m = PRCM_AdaptivePool(channels, pool_size=4, expansion=2, dropout_rate=0)
    mean, _ = measure_latency_gpu(m, x)
    results["AdaPool(4x4)"].append(mean)

print(f"\n{'Res':<10}", end="")
for name in results.keys():
    print(f"{name:<22}", end="")
print()
print("-" * 120)

for i, res in enumerate(resolutions):
    print(f"{res}x{res:<8}", end="")
    for name in results.keys():
        if results[name][i]:
            print(f"{results[name][i]:.4f} ms{'':<13}", end="")
    print()

print("\n[Key Insight]")
print("- InvertedBottleneck: resolution 작아질수록 latency 급감 (7x7 conv 연산량 감소)")
print("- PRCM variants: GAP 기반이라 resolution 변화에 덜 민감")
print("- Deep layer에서 InvBN과 PRCM의 latency 차이 줄어듦")

## 11. Summary

In [None]:
print("=" * 80)
print("SUMMARY")
print("=" * 80)
print("""
┌─────────────────────────────────────────────────────────────────────────────┐
│ Key Findings                                                                │
├─────────────────────────────────────────────────────────────────────────────┤
│ 1. InvertedBottleneck vs PRCM variants                                      │
│    - InvBN: 7x7 DW Conv로 공간 정보 직접 처리, resolution에 민감             │
│    - PRCM: GAP 기반으로 resolution과 무관한 연산량                           │
│    - Deep layer (작은 resolution)에서 둘의 latency 차이 줄어듦               │
│                                                                             │
│ 2. GAP 이후 HW=1이므로 expansion 늘려도 latency 증가 미미                   │
│    - PRCM_SE expansion=16도 baseline 대비 큰 차이 없음                      │
│                                                                             │
│ 3. PRCM_AdaptivePool (1x1 Conv 버전)                                        │
│    - 1x1 Conv 사용 → pool_size와 무관하게 파라미터 동일!                    │
│    - NOT a modulator - 직접 feature 변환 (x * w 아님)                       │
│    - pool_size 크면 공간 정보 더 많이 유지                                  │
└─────────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────────┐
│ Block Type Comparison                                                       │
├─────────────────────────────────────────────────────────────────────────────┤
│ InvertedBottleneck (Baseline Conv):                                         │
│   - 7x7 DW Conv → 1x1 Expand → ReLU → 1x1 Project                           │
│   - 공간 정보 직접 처리 (spatial convolution)                               │
│   - resolution이 클수록 연산량 증가                                         │
│                                                                             │
│ PRCM variants (Channel Recalibration):                                      │
│   - GAP 이후 처리로 resolution과 무관                                       │
│   - Modulator (x * w) 또는 Direct Transform                                 │
│   - 파라미터 효율적                                                         │
└─────────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────────┐
│ PRCM_AdaptivePool Pool Size 선택 가이드                                     │
├─────────────────────────────────────────────────────────────────────────────┤
│ Pool Size   공간 정보   적합한 입력 해상도                                   │
│ ─────────────────────────────────────────────────────────────               │
│   1x1       없음        GAP와 동일 (baseline처럼 동작)                       │
│   2x2       4 pixels    모든 해상도 (최소 공간 정보)                         │
│   4x4       16 pixels   ≥ 8x8 입력 (중간 정도 공간 정보)                     │
│   8x8       64 pixels   ≥ 16x16 입력 (더 많은 공간 정보)                     │
│                                                                             │
│ * 파라미터는 모두 동일 (1x1 Conv 덕분)                                       │
└─────────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────────┐
│ Recommendations                                                             │
├─────────────────────────────────────────────────────────────────────────────┤
│ - Full spatial processing:   InvertedBottleneck (7x7 DW baseline)           │
│ - Shallow layers (large HW): PRCM baseline or AdaptivePool(8x8, 16x16)      │
│ - Deep layers (small HW):    PRCM_SE(exp=8~16) or AdaptivePool(2x2, 4x4)    │
│ - Parameter budget:          PRCM_BasisExpand (basis 압축 후 expansion)     │
│ - Spatial info + transform:  PRCM_AdaptivePool (1x1 Conv, NOT modulator)    │
│ - Affine transform:          PRCM_SE_Affine (scale + shift)                 │
└─────────────────────────────────────────────────────────────────────────────┘
""")