<a href="https://colab.research.google.com/github/westjiuuu/SRCNN/blob/main/SRCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn

class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=9, padding=0),  # ✅ padding 제거
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, kernel_size=5, padding=0)   # ✅ padding 제거
        )
        self._initialize_weights()

    def forward(self, x):
        return self.model(x)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, mean=0.0, std=0.001)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

In [None]:
# [2] H5 Dataset 로딩 (정규화 포함)

import h5py
from torch.utils.data import Dataset

class H5Dataset(Dataset):
    def __init__(self, file_path):
        with h5py.File(file_path, 'r') as f:
            self.inputs = f['lr'][:] / 255.0  # (N, 33, 33)
            self.labels = f['hr'][:] / 255.0  # (N, 21, 21)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        x = torch.from_numpy(self.inputs[idx]).unsqueeze(0).float()
        y = torch.from_numpy(self.labels[idx]).unsqueeze(0).float()
        return x, y


In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.optim as optim
import matplotlib.pyplot as plt
import torch
import numpy as np
import os

# 설정
train_file = '/content/drive/MyDrive/SRCNN/91-image_x3_blur.h5'
batch_size = 16
num_epochs = 400

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 데이터셋 로딩
train_dataset = H5Dataset(train_file)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 모델 정의
model = SRCNN().to(device)
criterion = torch.nn.MSELoss()

optimizer = torch.optim.Adam([
    {'params': model.model[0].parameters(), 'lr': 1e-4},
    {'params': model.model[2].parameters(), 'lr': 1e-4},
    {'params': model.model[4].parameters(), 'lr': 1e-5},
])

# 학습 루프 (epoch 기준)
loss_history = []

for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0.0

    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    loss_history.append(avg_loss)
    print(f"[Epoch {epoch}/{num_epochs}] Loss: {avg_loss:.6f}")

# ✅ 모델 저장
torch.save(model.state_dict(), 'srcnn_x3_epoch400.pth')

# ✅ Loss 그래프 시각화
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), loss_history, marker='o', linewidth=1.5)
plt.title("Training Loss over Epochs")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim

# [1] 평가용 Dataset 클래스
class H5EvalDataset(Dataset):
    def __init__(self, file_path):
        self.input_images = []
        self.label_images = []

        with h5py.File(file_path, 'r') as f:
            lr_group = f['lr']
            hr_group = f['hr']

            for key in lr_group.keys():
                lr_img = lr_group[key][()] / 255.0
                hr_img = hr_group[key][()] / 255.0

                self.input_images.append(torch.tensor(lr_img).unsqueeze(0).float())
                self.label_images.append(torch.tensor(hr_img).unsqueeze(0).float())

    def __len__(self):
        return len(self.input_images)

    def __getitem__(self, idx):
        return self.input_images[idx], self.label_images[idx]


# ✅ 중심 crop 함수
def center_crop(img, target_h, target_w):
    h, w = img.shape
    top = (h - target_h) // 2
    left = (w - target_w) // 2
    return img[top:top + target_h, left:left + target_w]


# [2] 평가 함수 (PSNR, SSIM)
def evaluate_model(model, dataloader):
    model.eval()
    total_psnr_bic, total_psnr_srcnn = 0.0, 0.0
    total_ssim_bic, total_ssim_srcnn = 0.0, 0.0
    n = len(dataloader.dataset)

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            output = model(x).clamp(0.0, 1.0)

            # numpy 변환
            input_np  = x.squeeze().cpu().numpy() * 255.0
            output_np = output.squeeze().cpu().numpy() * 255.0
            label_np  = y.squeeze().cpu().numpy() * 255.0

            # 출력 크기에 맞춰 crop
            h_out, w_out = output_np.shape
            input_np  = center_crop(input_np,  h_out, w_out)
            label_np  = center_crop(label_np,  h_out, w_out)

            # PSNR & SSIM
            psnr_bic  = compare_psnr(label_np, input_np, data_range=255)
            psnr_src  = compare_psnr(label_np, output_np, data_range=255)
            ssim_bic  = compare_ssim(label_np, input_np, data_range=255)
            ssim_src  = compare_ssim(label_np, output_np, data_range=255)

            total_psnr_bic   += psnr_bic
            total_psnr_srcnn += psnr_src
            total_ssim_bic   += ssim_bic
            total_ssim_srcnn += ssim_src

    print("📊 Set5 평가 결과:")
    print(f"📌 Avg PSNR - Bicubic: {total_psnr_bic / n:.2f} dB / SRCNN: {total_psnr_srcnn / n:.2f} dB")
    print(f"📌 Avg SSIM - Bicubic: {total_ssim_bic / n:.4f} / SRCNN: {total_ssim_srcnn / n:.4f}")


# 모델 로드
model.load_state_dict(torch.load('/content/srcnn_x3_epoch400.pth', map_location=device))
model.to(device)

# 평가 데이터 로더
eval_file = '/content/drive/MyDrive/SRCNN/Set5_x3_blur.h5'
eval_dataset = H5EvalDataset(eval_file)
eval_loader = DataLoader(eval_dataset, batch_size=1, shuffle=False)

# 평가 실행
evaluate_model(model, eval_loader)

In [None]:
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim

# [1] 평가용 Dataset 클래스
class H5EvalDataset(Dataset):
    def __init__(self, file_path):
        self.input_images = []
        self.label_images = []

        with h5py.File(file_path, 'r') as f:
            lr_group = f['lr']
            hr_group = f['hr']

            for key in lr_group.keys():
                lr_img = lr_group[key][()] / 255.0  # ✅ 정규화
                hr_img = hr_group[key][()] / 255.0  # ✅ 정규화

                self.input_images.append(torch.tensor(lr_img).unsqueeze(0).float())
                self.label_images.append(torch.tensor(hr_img).unsqueeze(0).float())

    def __len__(self):
        return len(self.input_images)

    def __getitem__(self, idx):
        return self.input_images[idx], self.label_images[idx]


# 6픽셀 crop 함수
def shave(img, border=6):
    return img[border:-border, border:-border]


# [2] 평가 함수 (PSNR, SSIM)
def evaluate_model(model, dataloader, shave_border=6):
    model.eval()
    total_psnr_bic, total_psnr_srcnn = 0.0, 0.0
    total_ssim_bic, total_ssim_srcnn = 0.0, 0.0
    n = len(dataloader.dataset)

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)

            # Bicubic = 입력 그대로 사용
            bicubic = x

            # SRCNN 출력
            output = model(x).clamp(0.0, 1.0)

            # numpy 변환 + 6픽셀 crop
            bicubic_np = shave(bicubic.squeeze().cpu().numpy() * 255.0, border=shave_border)
            output_np  = shave(output.squeeze().cpu().numpy() * 255.0, border=shave_border)
            label_np   = shave(y.squeeze().cpu().numpy() * 255.0, border=shave_border)

            # PSNR & SSIM 계산
            psnr_bic  = compare_psnr(label_np, bicubic_np, data_range=255)
            psnr_src  = compare_psnr(label_np, output_np,  data_range=255)
            ssim_bic  = compare_ssim(label_np, bicubic_np, data_range=255)
            ssim_src  = compare_ssim(label_np, output_np,  data_range=255)

            total_psnr_bic   += psnr_bic
            total_psnr_srcnn += psnr_src
            total_ssim_bic   += ssim_bic
            total_ssim_srcnn += ssim_src

    print("📊 Set5 H5 평가 결과 (crop 적용):")
    print(f"📌 Avg PSNR - Bicubic: {total_psnr_bic / n:.2f} dB / SRCNN: {total_psnr_srcnn / n:.2f} dB")
    print(f"📌 Avg SSIM - Bicubic: {total_ssim_bic / n:.4f} / SRCNN: {total_ssim_srcnn / n:.4f}")


# --- 모델 로드 ---
model.load_state_dict(torch.load('/content/srcnn_x2_epoch200.pth', map_location=device))
model.to(device)

# --- 평가 데이터 로더 ---
eval_file = '/content/drive/MyDrive/SRCNN/Set5/Set5_x2.h5'
eval_dataset = H5EvalDataset(eval_file)
eval_loader = DataLoader(eval_dataset, batch_size=1, shuffle=False)

# --- 평가 실행 ---
evaluate_model(model, eval_loader, shave_border=6)


In [None]:
import os
import glob
import h5py
import numpy as np
from PIL import Image
from tqdm import tqdm

# 설정
scale = 3
patch_size = 33         # 입력 (LR) 패치 크기
label_size = 21         # 출력 (SRCNN) 크기 → HR crop 크기
stride = 14

t91_dir = '/content/drive/MyDrive/SRCNN/T91_img'
set5_dir = '/content/drive/MyDrive/SRCNN/Set5_img'

t91_output = '/content/drive/MyDrive/SRCNN/91-image_x3_nopad.h5'
set5_output = '/content/drive/MyDrive/SRCNN/Set5_x3_nopad.h5'


def convert_rgb_to_y(img):  # Y 채널 추출
    img = img.astype(np.float32)
    return 16.0 + (65.481 * img[..., 0] + 128.553 * img[..., 1] + 24.966 * img[..., 2]) / 255.0


# ✅ 학습용 전처리 (91 이미지 → patch 기반, padding 없음 구조에 맞춤)
def preprocess_t91():
    lr_patches = []
    hr_patches = []

    print("📦 91개 이미지 전처리 중 (padding 없음)...")
    for img_path in tqdm(sorted(glob.glob(os.path.join(t91_dir, '*')))):
        hr = Image.open(img_path).convert('RGB')

        # mod crop
        hr_width = (hr.width // scale) * scale
        hr_height = (hr.height // scale) * scale
        hr = hr.resize((hr_width, hr_height), resample=Image.BICUBIC)

        # LR 생성
        lr = hr.resize((hr_width // scale, hr_height // scale), resample=Image.BICUBIC)
        lr = lr.resize((hr_width, hr_height), resample=Image.BICUBIC)

        # Y 채널 추출
        hr = convert_rgb_to_y(np.array(hr))
        lr = convert_rgb_to_y(np.array(lr))

        # patch 추출
        offset = (patch_size - label_size) // 2  # 예: 6
        for i in range(0, lr.shape[0] - patch_size + 1, stride):
            for j in range(0, lr.shape[1] - patch_size + 1, stride):
                lr_patch = lr[i:i+patch_size, j:j+patch_size]
                hr_patch = hr[i+offset:i+offset+label_size, j+offset:j+offset+label_size]  # 중앙 crop

                lr_patches.append(lr_patch)
                hr_patches.append(hr_patch)

    # 저장
    with h5py.File(t91_output, 'w') as f:
        f.create_dataset('lr', data=np.array(lr_patches))
        f.create_dataset('hr', data=np.array(hr_patches))
    print(f"✅ 저장 완료: {t91_output}")


# ✅ 평가용 전처리 (Set5 전체 이미지 저장 - 그대로)
def preprocess_set5():
    print("📦 Set5 이미지 전처리 중 (전체 저장)...")
    with h5py.File(set5_output, 'w') as f:
        f.create_group('lr')
        f.create_group('hr')

        for i, img_path in enumerate(sorted(glob.glob(os.path.join(set5_dir, '*')))):
            hr = Image.open(img_path).convert('RGB')

            # mod crop
            hr_width = (hr.width // scale) * scale
            hr_height = (hr.height // scale) * scale
            hr = hr.resize((hr_width, hr_height), resample=Image.BICUBIC)

            # LR 생성
            lr = hr.resize((hr_width // scale, hr_height // scale), resample=Image.BICUBIC)
            lr = lr.resize((hr_width, hr_height), resample=Image.BICUBIC)

            # Y 채널 추출
            hr = convert_rgb_to_y(np.array(hr))
            lr = convert_rgb_to_y(np.array(lr))

            f['hr'].create_dataset(str(i), data=hr)
            f['lr'].create_dataset(str(i), data=lr)

    print(f"✅ 저장 완료: {set5_output}")


# 실행
preprocess_t91()
preprocess_set5()


In [None]:
import os
import glob
import h5py
import numpy as np
from PIL import Image, ImageFilter
from tqdm import tqdm

# --- 설정 ---
scale = 2
patch_size = 33      # 입력 (LR) 패치 크기
label_size = 21      # 출력 (SRCNN) 크기 → HR crop 크기
stride = 14
blur_radius = 0.0   # 가우시안 블러 반지름

# --- 경로 설정 ---
# 아래 경로는 실제 환경에 맞게 수정해주세요.
t91_dir = '/content/drive/MyDrive/SRCNN/T91_img'
set5_dir = '/content/drive/MyDrive/SRCNN/Set5_img'
output_dir = '/content/drive/MyDrive/SRCNN'

# 출력 디렉토리 생성
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

blur_str = str(int(blur_radius * 10))
t91_output = os.path.join(output_dir, f'91-image_x{scale}_blur{blur_str}.h5')
set5_output = os.path.join(output_dir, f'Set5_x{scale}_blur{blur_str}.h5')

def convert_rgb_to_y(img):
    img = img.astype(np.float32)
    return 16.0 + (65.481 * img[..., 0] + 128.553 * img[..., 1] + 24.966 * img[..., 2]) / 255.0


def preprocess_t91():
    lr_patches = []
    hr_patches = []

    image_paths = sorted(glob.glob(os.path.join(t91_dir, '*')))
    if not image_paths:
        print(f"경고: '{t91_dir}' 디렉토리에서 이미지를 찾을 수 없습니다.")
        return

    for img_path in tqdm(image_paths):
        hr = Image.open(img_path).convert('RGB')

        # 1. Mod crop: 이미지를 scale 배수로 나누어떨어지게 크기 조절
        hr_width = (hr.width // scale) * scale
        hr_height = (hr.height // scale) * scale
        hr = hr.resize((hr_width, hr_height), resample=Image.BICUBIC)

        # 2. LR 생성 (논문 방식 적용)
        #    a. 가우시안 블러 적용
        hr_blurred = hr.filter(ImageFilter.GaussianBlur(radius=blur_radius))
        #    b. 다운샘플링 (저해상도 생성)
        lr = hr_blurred.resize((hr_width // scale, hr_height // scale), resample=Image.BICUBIC)
        #    c. 업샘플링 (네트워크 입력 크기로 복원)
        lr = lr.resize((hr_width, hr_height), resample=Image.BICUBIC)

        # 3. Y 채널 추출
        hr = convert_rgb_to_y(np.array(hr))
        lr = convert_rgb_to_y(np.array(lr))

        # 4. 패치 추출 (padding 없는 구조에 맞춤)
        offset = (patch_size - label_size) // 2
        for i in range(0, lr.shape[0] - patch_size + 1, stride):
            for j in range(0, lr.shape[1] - patch_size + 1, stride):
                lr_patch = lr[i:i+patch_size, j:j+patch_size]
                hr_patch = hr[i+offset:i+offset+label_size, j+offset:j+offset+label_size]

                lr_patches.append(lr_patch)
                hr_patches.append(hr_patch)

    # 5. HDF5 파일로 저장
    with h5py.File(t91_output, 'w') as f:
        f.create_dataset('lr', data=np.array(lr_patches))
        f.create_dataset('hr', data=np.array(hr_patches))
    print(f"✅ 저장 완료: {t91_output}")


def preprocess_set5():
    """평가용 Set5 이미지를 전처리하여 h5 파일로 저장합니다."""
    print(f"📦 Set5 이미지 전처리 중 (scale: x{scale}, blur_radius: {blur_radius})...")

    image_paths = sorted(glob.glob(os.path.join(set5_dir, '*')))
    if not image_paths:
        print(f"경고: '{set5_dir}' 디렉토리에서 이미지를 찾을 수 없습니다.")
        return

    with h5py.File(set5_output, 'w') as f:
        f.create_group('lr')
        f.create_group('hr')

        for i, img_path in enumerate(image_paths):
            hr = Image.open(img_path).convert('RGB')

            # 1. Mod crop
            hr_width = (hr.width // scale) * scale
            hr_height = (hr.height // scale) * scale
            hr = hr.resize((hr_width, hr_height), resample=Image.BICUBIC)

            # 2. LR 생성 (논문 방식 적용)
            hr_blurred = hr.filter(ImageFilter.GaussianBlur(radius=blur_radius))
            lr = hr_blurred.resize((hr_width // scale, hr_height // scale), resample=Image.BICUBIC)
            lr = lr.resize((hr_width, hr_height), resample=Image.BICUBIC)

            # 3. Y 채널 추출
            hr_y = convert_rgb_to_y(np.array(hr))
            lr_y = convert_rgb_to_y(np.array(lr))

            # 4. HDF5 파일로 저장
            f['hr'].create_dataset(str(i), data=hr_y)
            f['lr'].create_dataset(str(i), data=lr_y)

    print(f"✅ 저장 완료: {set5_output}")


# --- 실행 ---
# T91 및 Set5 이미지 디렉토리가 준비되었는지 확인 후 실행하세요.
# 예: ./T91_img/image1.bmp, ./Set5_img/baby.png
if __name__ == '__main__':
    preprocess_t91()
    preprocess_set5()
    print("실행하려면 main 블록의 주석을 해제하고, 이미지 경로를 확인하세요.")
    print("예시: T91_img, Set5_img 디렉토리를 현재 폴더에 생성하고 이미지를 넣어주세요.")

