체크포인트를 불러와서 결과를 시각화

In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from pathlib import Path
from tqdm import tqdm
import random

변수 설정

In [2]:
# ===== VISUALIZATION 설정 =====
VISUALIZE_ALL = False  # True: 전체 테스트셋, False: 일부만
NUM_SAMPLES = 20  # VISUALIZE_ALL=False일 때 시각화할 샘플 수
RANDOM_SAMPLE = True  # True: 랜덤 샘플링, False: 순차적

# ===== 실험 설정 =====
DATASET_ROOT = "./dataset"
test_dataset = 'LEVIR-CD+'
test_model = 'A2Net'
use_base = True

경로 설정

In [3]:
checkpoint_path = Path(f"experiments/{test_dataset}/{test_model}/checkpoints/best_model.pth")
save_dir = Path(f"experiments/{test_dataset}/{test_model}/visualization")
save_dir.mkdir(parents=True, exist_ok=True)


GPU

In [4]:
# 단일 GPU 사용
GPU_ID = 3
os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU_ID)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using GPU: {GPU_ID}")
print(f"Device: {DEVICE}")

# 멀티 GPU 사용 
# # GPU_IDS = [0, 1, 2, 3]  # 사용할 GPU 리스트
# os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, GPU_IDS))
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# USE_MULTI_GPU = len(GPU_IDS) > 1 and torch.cuda.device_count() > 1
# print(f"Using GPUs: {GPU_IDS}")
# print(f"Available GPU count: {torch.cuda.device_count()}")
# if USE_MULTI_GPU:
#     BATCH_SIZE = BATCH_SIZE * len(GPU_IDS)  # 멀티 GPU시 배치 크기 조정
#     print(f"Adjusted batch size for multi-GPU: {BATCH_SIZE}")
# 시드 설정 (재현가능성)

Using GPU: 3
Device: cuda


In [5]:

# %% 모델 로드
import importlib

def get_model_class(model_name, use_base=False):
    model_name_lower = model_name.lower()
    if use_base:
        module_path = f'models.{model_name_lower}_base'
        class_name = f'{model_name}Base'
    else:
        module_path = f'models.{model_name_lower}'
        class_name = model_name
    
    module = importlib.import_module(module_path)
    return getattr(module, class_name)

# 모델 로드
ModelClass = get_model_class(test_model, use_base=use_base)
model = ModelClass(num_classes=1).to(DEVICE)

checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Model: {ModelClass.__name__}")
print(f"Best F1: {checkpoint.get('best_f1', 0):.4f}")

# %% 데이터로더
from utils import create_dataloaders

_, _, test_loader = create_dataloaders(
    root_dir=DATASET_ROOT,
    dataset_name=test_dataset,
    batch_size=1,
    num_workers=2,
    augment=False
)

total_samples = len(test_loader)
print(f"Test samples: {total_samples}")

# %% Visualization 실행
from utils.visualization import save_visualization

# 샘플 선택
if VISUALIZE_ALL:
    sample_indices = list(range(total_samples))
else:
    if RANDOM_SAMPLE:
        sample_indices = random.sample(range(total_samples), min(NUM_SAMPLES, total_samples))
        sample_indices.sort()
    else:
        sample_indices = list(range(min(NUM_SAMPLES, total_samples)))

print(f"Visualizing {len(sample_indices)} samples...")

# 실행
metrics_list = []
pbar = tqdm(enumerate(test_loader), total=total_samples, desc="Processing")
viz_count = 0

for idx, batch in pbar:
    if idx not in sample_indices:
        continue
    
    # 데이터
    img1 = batch['img1'].to(DEVICE)
    img2 = batch['img2'].to(DEVICE)
    label = batch['label'].to(DEVICE)
    filename = batch['filename'][0]
    
    # 추론
    with torch.no_grad():
        output = model(img1, img2)
        pred = torch.sigmoid(output)
    
    # 메트릭
    pred_binary = (pred > 0.5).float()
    tp = ((pred_binary == 1) & (label == 1)).sum().item()
    fp = ((pred_binary == 1) & (label == 0)).sum().item()
    fn = ((pred_binary == 0) & (label == 1)).sum().item()
    
    f1 = 2*tp / (2*tp + fp + fn + 1e-7)
    metrics_list.append(f1)
    
    # 저장
    save_name = f"{viz_count:04d}_{Path(filename).stem}_f1_{f1:.3f}"
    save_visualization(
        img1.squeeze(0),
        img2.squeeze(0),
        label.squeeze(0),
        pred.squeeze(0),
        save_dir,
        save_name
    )
    
    viz_count += 1
    pbar.set_postfix({'saved': viz_count, 'f1': f'{f1:.3f}'})

# %% 결과
print(f"\n✓ Saved {viz_count} visualizations to {save_dir}")
if metrics_list:
    print(f"Average F1: {np.mean(metrics_list):.4f}")
    print(f"Best F1: {max(metrics_list):.4f}")
    print(f"Worst F1: {min(metrics_list):.4f}")



Model: A2NetBase
Best F1: 0.4893


  from .autonotebook import tqdm as notebook_tqdm


Loaded 10192 images from LEVIR-CD+/train
Loaded 1568 images from LEVIR-CD+/val
Loaded 4000 images from LEVIR-CD+/test
Test samples: 4000
Visualizing 20 samples...


Processing: 100%|██████████| 4000/4000 [00:31<00:00, 127.01it/s, saved=20, f1=0.206]


✓ Saved 20 visualizations to experiments/LEVIR-CD+/A2Net/visualization
Average F1: 0.1446
Best F1: 0.7416
Worst F1: 0.0000



