# 🚀 2강 (가벼운 설정): **MobileNet/EfficientNet 전이학습 (timm + Hugging Face Hub)**

- **목표**: 더 가벼운 백본(MobileNetV3-Small, EfficientNet-B0)으로 CPU에서도 쾌적한 전이학습 데모
- **가정**: `timm`이 Hugging Face Hub에서 가중치를 받아옵니다(`hf_hub:` 프리픽스 사용)

**실행 모드**
- 기본: **Feature Extraction** (백본 freeze) → CPU OK
- 옵션: **Fine-tuning** (일부 블록 unfreeze) → GPU 권장


## 0. 환경 설정

In [None]:
# !pip install -q timm
import os, numpy as np, random, matplotlib.pyplot as plt, torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets as tvdatasets, transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import timm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
set_seed(42)

## 1. 데이터 준비 (가벼운 설정 적용)

- **Subset**: 10% (CPU에서도 빠르게)
- **Batch size**: 16
- **Epochs**: 2 (데모용)


In [None]:
# CIFAR-10 로드 (PIL)
root='./data'
train_full = tvdatasets.CIFAR10(root=root, train=True, download=True)
test_set   = tvdatasets.CIFAR10(root=root, train=False, download=True)
class_names = train_full.classes

def to_numpy_list(tv_dataset):
    imgs, labs = [], []
    for img, lab in tv_dataset:
        imgs.append(np.array(img)); labs.append(lab)
    return imgs, labs

images_train, labels_train = to_numpy_list(train_full)
images_test,  labels_test  = to_numpy_list(test_set)

# Subset 10%
SUBSET=0.1
idx = np.random.RandomState(42).permutation(len(images_train))
sel = idx[:int(len(images_train)*SUBSET)]
images_train = [images_train[i] for i in sel]
labels_train = [labels_train[i] for i in sel]

tr_imgs, va_imgs, tr_lbls, va_lbls = train_test_split(
    images_train, labels_train, test_size=0.2, stratify=labels_train, random_state=42
)

print('Train/Val/Test sizes:', len(tr_imgs), len(va_imgs), len(images_test))

## 2. 변환(Transform) — 모델별 입력 규격 맞추기

In [None]:
# timm의 기본 입력 크기/전처리를 사용
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from PIL import Image

# MobileNetV3-Small & EfficientNet-B0 repo IDs on HF Hub via timm
mobilenet_repo = 'timm/mobilenetv3_small_100.lamb_in1k'
efficient_repo = 'timm/efficientnet_b0.ra_in1k'

# 임시 모델로 config 확인 (나중에 실제 모델 생성 시에도 같은 입력)
tmp_model = timm.create_model(f'hf_hub:{mobilenet_repo}', pretrained=True, num_classes=10)
config = resolve_data_config({}, model=tmp_model)  # model의 기본 설정으로부터 transform 생성
transform = create_transform(**config)

IMG_SIZE = config.get('input_size', (3,224,224))[-1]
print('Resolved input size:', IMG_SIZE)

class NumpyCIFARDataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images; self.labels = labels; self.transform = transform
    def __len__(self): return len(self.images)
    def __getitem__(self, idx):
        img = Image.fromarray(self.images[idx])
        if self.transform: img = self.transform(img)
        return img, self.labels[idx]

train_ds = NumpyCIFARDataset(tr_imgs, tr_lbls, transform=transform)
val_ds   = NumpyCIFARDataset(va_imgs, va_lbls, transform=transform)
test_ds  = NumpyCIFARDataset(images_test, labels_test, transform=transform)

BATCH=16
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_ds,  batch_size=BATCH, shuffle=False, num_workers=0)

## 3. 백본 선택 & 전이학습 설정
- 아래 스위치로 **MobileNetV3-Small** ↔ **EfficientNet-B0** 를 간단히 교체
- **Feature Extraction**: 백본 파라미터를 `requires_grad=False`


In [None]:
BACKBONE = 'mobilenet'  # 'mobilenet' or 'efficientnet'

if BACKBONE == 'mobilenet':
    repo_id = mobilenet_repo
else:
    repo_id = efficient_repo

model = timm.create_model(f'hf_hub:{repo_id}', pretrained=True, num_classes=10)
model.to(device)

# Feature Extraction: 백본 동결 (classifier/head 제외)
for name, p in model.named_parameters():
    if 'classifier' in name or 'fc' in name or 'head' in name:  # timm 모델별 head 이름
        p.requires_grad = True
    else:
        p.requires_grad = False

def count_trainable(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)
print('Trainable params:', count_trainable(model))

## 4. 학습 루프 (Feature Extraction, CPU OK)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-4)

def train_epoch(model, loader):
    model.train(); tl=0; tc=0; n=0
    for x,y in loader:
        x,y=x.to(device),torch.tensor(y).to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward(); optimizer.step()
        tl += loss.item()*x.size(0)
        tc += (out.argmax(1)==y).sum().item(); n += x.size(0)
    return tl/n, tc/n

@torch.no_grad()
def eval_epoch(model, loader):
    model.eval(); tl=0; tc=0; n=0
    for x,y in loader:
        x,y=x.to(device),torch.tensor(y).to(device)
        out = model(x)
        loss = criterion(out, y)
        tl += loss.item()*x.size(0)
        tc += (out.argmax(1)==y).sum().item(); n += x.size(0)
    return tl/n, tc/n

EPOCHS=2  # 가벼운 설정
hist={'tr_acc':[],'va_acc':[],'tr_loss':[],'va_loss':[]}
for ep in range(1,EPOCHS+1):
    tr_l,tr_a = train_epoch(model, train_loader)
    va_l,va_a = eval_epoch(model, val_loader)
    hist['tr_loss'].append(tr_l); hist['va_loss'].append(va_l)
    hist['tr_acc'].append(tr_a);  hist['va_acc'].append(va_a)
    print(f'[Ep {ep}/{EPOCHS}] train={tr_a:.3f}/{tr_l:.3f}  val={va_a:.3f}/{va_l:.3f}')

In [None]:
# 학습 곡선
plt.figure(figsize=(6,4)); plt.plot(hist['tr_acc'],label='train_acc'); plt.plot(hist['va_acc'],label='val_acc')
plt.title('Accuracy (Feature Extraction)'); plt.legend(); plt.tight_layout(); plt.show()
plt.figure(figsize=(6,4)); plt.plot(hist['tr_loss'],label='train_loss'); plt.plot(hist['va_loss'],label='val_loss')
plt.title('Loss (Feature Extraction)'); plt.legend(); plt.tight_layout(); plt.show()

## 5. 평가 + 혼동행렬

In [None]:
@torch.no_grad()
def preds_and_labels(model, loader):
    model.eval(); ys=[]; ps=[]
    for x,y in loader:
        x=x.to(device)
        logits=model(x)
        ps.append(logits.argmax(1).cpu().numpy())
        ys.append(np.array(y))
    return np.concatenate(ys), np.concatenate(ps)

y_true, y_pred = preds_and_labels(model, test_loader)
acc = accuracy_score(y_true, y_pred)
print(f'[TEST] Accuracy: {acc:.4f}')

cm = confusion_matrix(y_true,y_pred,labels=list(range(10)))
plt.figure(figsize=(6,5))
plt.imshow(cm, interpolation='nearest'); plt.title(f'Confusion Matrix ({BACKBONE})'); plt.colorbar()
plt.xticks(range(10), class_names, rotation=45); plt.yticks(range(10), class_names)
plt.tight_layout(); plt.xlabel('Pred'); plt.ylabel('True'); plt.show()

print('\n[Classification Report]\n', classification_report(y_true,y_pred,target_names=class_names))

## 6. (선택) Fine-tuning — GPU 권장

- 백본 일부 블록을 `requires_grad=True`로 풀어서 추가 학습
- CPU에서는 매우 느리므로 **GPU에서만 실행**을 권장합니다.


In [None]:
DO_FINETUNE = False  # GPU면 True로 바꿔 실행
if DO_FINETUNE:
    # 예시: EfficientNet/MobileNet의 마지막 stage 파라미터 풀기
    for name, p in model.named_parameters():
        if any(k in name for k in ['blocks.5','blocks.6','stages.3','stages.4']):
            p.requires_grad = True

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)
    FT_EPOCHS=2
    for ep in range(1, FT_EPOCHS+1):
        tr_l,tr_a = train_epoch(model, train_loader)
        va_l,va_a = eval_epoch(model, val_loader)
        print(f'[FT Ep {ep}/{FT_EPOCHS}] train={tr_a:.3f}/{tr_l:.3f}  val={va_a:.3f}/{va_l:.3f}')

## 7. 저장/로딩

In [None]:
torch.save(model.state_dict(), f'{BACKBONE}_fe_cifar10_light.pt')
print('Saved:', f'{BACKBONE}_fe_cifar10_light.pt')

# 로딩 예시
m2 = timm.create_model(f'hf_hub:{repo_id}', pretrained=True, num_classes=10).to(device)
# Feature extraction일 때 head 구조 동일해야 함
m2.load_state_dict(torch.load(f'{BACKBONE}_fe_cifar10_light.pt', map_location=device))
m2.eval(); print('Reload OK')