MedSRGAN 구현 

참고: https://github.com/04RR/MedSRGAN/

In [1]:
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as ttf
import os
import random

from tqdm import tqdm
import torch.nn as nn
import torch
import torchvision.models as models
import torch.optim as optim

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable

import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio, structural_similarity


In [2]:

device = "cuda" if torch.cuda.is_available() else "cpu"
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor


# 난수 고정 함수
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# 난수 고정
set_seed(42)

### Dataset

In [3]:
dataset_path = '/home/ykjeong/MedSRGAN/dataset'
path_list = [os.path.join(dataset_path, f) for f in os.listdir(dataset_path) if f.endswith('.png')]

class GAN_Data(Dataset):
    def __init__(self, path_list, transform_lr=None, transform_hr=None):
        self.path_list = path_list
        self.transform_lr = transform_lr
        self.transform_hr = transform_hr

    def __len__(self):
        return len(self.path_list)

    def __getitem__(self, idx):
        img_path = self.path_list[idx]
        image = Image.open(img_path).convert('L')
        
        if self.transform_lr:
            lr_image = self.transform_lr(image)
        if self.transform_hr:
            hr_image = self.transform_hr(image)
        
        return lr_image, hr_image


transform_lr = ttf.Compose([
    ttf.Resize((128, 128)),  # LR 이미지 크기
    ttf.ToTensor(),
])

transform_hr = ttf.Compose([
    ttf.Resize((128, 128)),  # HR 이미지 크기
    ttf.GaussianBlur(3, sigma=(0.1, 2.0)),  # Gaussian Blur 추가
    ttf.ToTensor(),
])


# 데이터셋 및 DataLoader 생성
dataset = GAN_Data(path_list, transform_lr=transform_lr, transform_hr=transform_hr)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

### Modeling

In [4]:
class RWMAB(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, (3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels, in_channels, (3, 3), stride=1, padding=1),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, (1, 1), stride=1, padding=0),
            nn.Sigmoid(),
        )
    
    def forward(self, x):
        x_ = self.layer1(x)
        x__ = self.layer2(x_)
        x = x__ * x_ + x

        return x

class ShortResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.layers = nn.ModuleList([RWMAB(in_channels) for _ in range(16)])

    def forward(self, x):

        x_ = x.clone()

        for layer in self.layers:
            x_ = layer(x_)

        return x_ + x
    
class Generator(nn.Module):
    def __init__(self, in_channels=1, blocks=8):
        super().__init__()

        self.conv = nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1)

        self.short_blocks = nn.ModuleList(
            [ShortResidualBlock(64) for _ in range(blocks)]
        )

        self.conv2 = nn.Conv2d(64, 64, (1, 1), stride=1, padding=0)
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 64, (3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, (3, 3), stride=1, padding=1),
            nn.ReLU(),
        )

        # 최종 출력 채널 수를 1로 줄이기 위한 Conv2d 레이어
        self.final_conv = nn.Conv2d(64, 1, (3, 3), stride=1, padding=1)

    def forward(self, x):
        x = self.conv(x)
        x_ = x.clone()

        for layer in self.short_blocks:
            x_ = layer(x_)

        x = torch.cat([self.conv2(x_), x], dim=1)
        x = self.conv3(x)
        x = self.final_conv(x)

        return x

In [5]:
class D_Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
        )

    def forward(self, x):

        return self.layer(x)
    
class Discriminator(nn.Module):
    def __init__(self, img_size, in_channels=1):
        super().__init__()
        
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),  
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU()
        )

        self.classifier = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 1, kernel_size=1),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Sigmoid()
        )
        
    def forward(self, x1, x2):
        
        # 특징 추출
        x_1 = self.feature_extractor(x1)
        x_2 = self.feature_extractor(x2)
        
        # 특징 결합
        x = torch.cat([x_1, x_2], dim=1)

        # 분류
        x = self.classifier(x)
        return x


In [6]:
gen = Generator().to(device)
disc = Discriminator(img_size=(128, 128), in_channels=1).to(device)

# 학습률 이 이상으로 올릴 경우 학습되지 않는 현상 발생 
optimizer_G = optim.Adam(gen.parameters(), lr=1e-4, betas=(0.5, 0.999))
optimizer_D = optim.Adam(disc.parameters(), lr=1e-4, betas=(0.5, 0.999)) 

loss_function = torch.nn.L1Loss().to(device)
gan_loss = torch.nn.BCEWithLogitsLoss().to(device)
scaler = torch.cuda.amp.GradScaler()

In [7]:
def fit(
    gen,
    disc,
    dataloader,
    epochs,
    optimizer_G,
    optimizer_D,
    scaler,
    loss_function,
    gan_loss,
):

    t_loss_G, t_loss_D = [], []

    for epoch in tqdm(range(epochs)):
        e_loss_G, e_loss_D = [], []

        for data in dataloader:
            hr_img, lr_img = data
            hr_img = hr_img.float().to(device)
            lr_img = lr_img.float().to(device)

            valid = torch.tensor(np.ones((1, 1)), dtype=torch.float32, device=device).detach()
            fake = torch.tensor(np.zeros((1, 1)), dtype=torch.float32, device=device).detach()

            with torch.cuda.amp.autocast():

                # Train Generator

                pred_hr = gen(lr_img)

                content_loss = loss_function(pred_hr, hr_img)
                feature_loss = 0.0

                pred_real = disc(hr_img.detach(), lr_img)
                pred_fake = disc(pred_hr, lr_img)

                gan_loss_num = gan_loss(
                    pred_fake - pred_real.mean(0, keepdim=True), valid
                )

                loss_G = content_loss * 0.1 + feature_loss * 0.1 + gan_loss_num

                optimizer_G.zero_grad()
                scaler.scale(loss_G).backward()
                scaler.step(optimizer_G)
                scaler.update()
                e_loss_G.append(loss_G)

                # Train Discriminator

                pred_real = disc(hr_img, lr_img)
                pred_fake = disc(pred_hr.detach(), lr_img)

                loss_real = gan_loss(pred_real - pred_fake.mean(0, keepdim=True), valid)
                loss_fake = gan_loss(pred_fake - pred_real.mean(0, keepdim=True), fake)

                loss_real_num = gan_loss(pred_real, valid)
                loss_fake_num = gan_loss(pred_fake, fake)

                loss_D = ((loss_real + loss_fake) / 2) + (
                    (loss_real_num + loss_fake_num) / 2
                )

                optimizer_D.zero_grad()
                scaler.scale(loss_D).backward()
                scaler.step(optimizer_D)
                scaler.update()
                e_loss_D.append(loss_D)

        t_loss_D.append(sum(e_loss_D) / len(e_loss_D))
        t_loss_G.append(sum(e_loss_G) / len(e_loss_G))

        print(
            f"{epoch+1}/{epochs} -- Gen Loss: {sum(t_loss_G) / len(t_loss_G)} -- Disc Loss: {sum(t_loss_D) / len(t_loss_D)}"
        )

        torch.save(gen, "./gen_{epoch}")
        torch.save(disc, "./disc_{epoch}")

    return t_loss_G, t_loss_D

In [8]:
epochs=10
fit(gen, disc, dataloader, epochs, optimizer_G, optimizer_D, scaler, loss_function, gan_loss)

  0%|          | 0/10 [00:00<?, ?it/s]

1/10 -- Gen Loss: 0.7146836519241333 -- Disc Loss: 1.3877990245819092


 10%|█         | 1/10 [02:22<21:23, 142.65s/it]

2/10 -- Gen Loss: 0.704403281211853 -- Disc Loss: 1.387054443359375


 20%|██        | 2/10 [04:30<17:50, 133.78s/it]

3/10 -- Gen Loss: 0.7009477615356445 -- Disc Loss: 1.386803150177002


 30%|███       | 3/10 [06:38<15:18, 131.20s/it]

4/10 -- Gen Loss: 0.6992075443267822 -- Disc Loss: 1.3866766691207886


 40%|████      | 4/10 [08:43<12:52, 128.70s/it]

5/10 -- Gen Loss: 0.6981515884399414 -- Disc Loss: 1.3866006135940552


 50%|█████     | 5/10 [10:47<10:35, 127.07s/it]

6/10 -- Gen Loss: 0.6974424719810486 -- Disc Loss: 1.3865498304367065


 60%|██████    | 6/10 [12:51<08:24, 126.16s/it]

7/10 -- Gen Loss: 0.6996310353279114 -- Disc Loss: 1.3862272500991821


 70%|███████   | 7/10 [14:57<06:18, 126.13s/it]

8/10 -- Gen Loss: 0.698916494846344 -- Disc Loss: 1.3862359523773193


 80%|████████  | 8/10 [17:02<04:11, 125.53s/it]

9/10 -- Gen Loss: 0.6983544826507568 -- Disc Loss: 1.3862427473068237


 90%|█████████ | 9/10 [19:04<02:04, 124.62s/it]

10/10 -- Gen Loss: 0.6979025602340698 -- Disc Loss: 1.3862481117248535


100%|██████████| 10/10 [21:09<00:00, 126.95s/it]


([tensor(0.7147, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(0.6941, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(0.6940, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(0.6940, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(0.6939, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(0.6939, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(0.7128, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(0.6939, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(0.6939, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(0.6938, device='cuda:0', grad_fn=<DivBackward0>)],
 [tensor(1.3878, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(1.3863, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(1.3863, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(1.3863, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(1.3863, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(1.3863, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(1.3843, device='cuda:0', grad_fn=<DivBackward0

In [24]:

def visualize_results(generator, index, device, dataset):
    set_seed(42)
    generator.eval()
    
    # Load specific samples from the dataset
    hr_img, lr_img = dataset[index]
    hr_img = hr_img.unsqueeze(0).to(device)
    lr_img = lr_img.unsqueeze(0).to(device) 
    
    # 입력 이미지(lr_img)를 생성자(generator) 모델을 통해 변환
    with torch.no_grad():
        generated_img = generator(lr_img).detach().cpu()
    
    # 첫 번째 이미지만 시각적으로 비교
    hr_example = hr_img[0].permute(1, 2, 0).cpu().numpy()
    lr_example = lr_img[0].permute(1, 2, 0).cpu().numpy()
    generated_example = generated_img[0].permute(1, 2, 0).numpy()
    
    # PSNR 계산
    psnr_value_1 = peak_signal_noise_ratio(hr_example, generated_example)
    psnr_value_2 = peak_signal_noise_ratio(lr_example, generated_example)
        
    # PSNR 값 반환
    return psnr_value_1, psnr_value_2

index = 2
visualize_results(gen, index, 'cuda', dataset)

(35.67166461356242, 37.865413160832574)

In [25]:

# 모든 데이터에 대해 PSNR 평균 계산 예시
def calculate_average_psnr(generator, device, dataset):
    psnr_1_list = []
    psnr_2_list = []
    num_samples = len(dataset)  # 데이터셋의 샘플 수
    
    for index in range(num_samples):
        psnr_1, psnr_2 = visualize_results(generator, index, device, dataset)
        psnr_1_list.append(psnr_1)
        psnr_2_list.append(psnr_2)
    
    average_psnr_1 = np.mean(psnr_1_list)
    average_psnr_2 = np.mean(psnr_2_list)
    
    return average_psnr_1, average_psnr_2

# 예시 실행
avg_psnr_1, avg_psnr_2 = calculate_average_psnr(gen, device, dataset)

print(f'Average PSNR ground truth - output: {avg_psnr_1:.4f} dB')
print(f'Average PSNR ground truth - input : {avg_psnr_2:.4f} dB')

Average PSNR ground truth - output: 38.4463 dB
Average PSNR ground truth - input : 36.8205 dB
