# Import

In [1]:
import os
import shutil

import random
import numpy as np
import time
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.transforms import CenterCrop, Resize
from PIL import Image
from tqdm import tqdm

import warnings
warnings.filterwarnings(action='ignore')


# Hyperparameter Setting

In [2]:
CFG = {
    'IMG_SIZE':224,
    'EPOCHS':15,
    'LEARNING_RATE':0.0002086672211449482,
    'BATCH_SIZE':16,
    'SEED':42,
    'HIDDEN_UNITS': 256
}

# Fixed RandomSeed

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

# Data Pre-processing

In [5]:
def Preprocess(base_dir):
    clean_dir = os.path.join(base_dir, 'clean')
    noisy_dir = os.path.join(base_dir, 'noisy')

    os.makedirs(clean_dir, exist_ok=True)
    os.makedirs(noisy_dir, exist_ok=True)


    source_dirs = []
    for root, dirs, files in os.walk(base_dir):
        for dir_name in dirs:
            if 'GT' in dir_name:
                source_dirs.append(os.path.join(root, dir_name))

    if not source_dirs:
        raise ValueError("No directory containing 'GT' found")

    for source_dir in source_dirs:
        for filename in os.listdir(source_dir):
            if filename.endswith('.jpg'):
                shutil.move(os.path.join(source_dir, filename), os.path.join(clean_dir, filename))

    for root, dirs, files in os.walk(base_dir):
        for dir_name in dirs:
            if dir_name not in ['clean', 'noisy'] and 'GT' not in dir_name:
                current_dir = os.path.join(root, dir_name)
                for filename in os.listdir(current_dir):
                    if filename.endswith('.jpg'):
                        shutil.move(os.path.join(current_dir, filename), os.path.join(noisy_dir, filename))
                        
    
    for root, dirs, files in os.walk(base_dir, topdown=False):
        for dir_name in dirs:
            dir_path = os.path.join(root, dir_name)
            if dir_name not in ['clean', 'noisy']:
                shutil.rmtree(dir_path)
                        
    print('preprocessing done')

In [5]:
data_dir = './'
training_base_dir = os.path.join(data_dir, 'Training')
validation_base_dir = os.path.join(data_dir, 'Validation')

Preprocess(training_base_dir)
Preprocess(validation_base_dir)

ValueError: No directory containing 'GT' found

# CustomDataset

In [8]:
class CustomDataset(Dataset):
    def __init__(self, clean_image_paths, noisy_image_paths, transform=None):
        self.clean_image_paths = [os.path.join(clean_image_paths, x) for x in os.listdir(clean_image_paths)]
        self.noisy_image_paths = [os.path.join(noisy_image_paths, x) for x in os.listdir(noisy_image_paths)]
        self.transform = transform
        self.center_crop = CenterCrop(1080)
        self.resize = Resize((CFG['IMG_SIZE'], CFG['IMG_SIZE']))

        # Create a list of (noisy, clean) pairs
        self.noisy_clean_pairs = self._create_noisy_clean_pairs()

    def _create_noisy_clean_pairs(self):
        clean_to_noisy = {}
        for clean_path in self.clean_image_paths:
            clean_id = '_'.join(os.path.basename(clean_path).split('_')[:-1])
            clean_to_noisy[clean_id] = clean_path
        
        noisy_clean_pairs = []
        for noisy_path in self.noisy_image_paths:
            noisy_id = '_'.join(os.path.basename(noisy_path).split('_')[:-1])
            if noisy_id in clean_to_noisy:
                clean_path = clean_to_noisy[noisy_id]
                noisy_clean_pairs.append((noisy_path, clean_path))
            else:
                pass
        
        return noisy_clean_pairs
    

    def __len__(self):
        return len(self.noisy_clean_pairs)

    def __getitem__(self, index):
        noisy_image_path, clean_image_path = self.noisy_clean_pairs[index]

        noisy_image = Image.open(noisy_image_path).convert("RGB")
        clean_image = Image.open(clean_image_path).convert("RGB")
        
        # Central Crop and Resize
        noisy_image = self.center_crop(noisy_image)
        clean_image = self.center_crop(clean_image)
        noisy_image = self.resize(noisy_image)
        clean_image = self.resize(clean_image)
        
        if self.transform:
            noisy_image = self.transform(noisy_image)
            clean_image = self.transform(clean_image)
        
        return noisy_image, clean_image
    

class CutMixDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, beta=1.0, prob=0.5):
        self.dataset = dataset
        self.beta = beta
        self.prob = prob

    def __len__(self):
        return len(self.dataset)

    def rand_bbox(self, size, lam):
        if len(size) == 3:  # (C, H, W) 형식인 경우
            W = size[1]
            H = size[2]
        elif len(size) == 4:  # (N, C, H, W) 형식인 경우
            W = size[2]
            H = size[3]
        else:
            raise ValueError(f"Unexpected size format: {size}")

        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)

        cx = np.random.randint(W)
        cy = np.random.randint(H)

        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)

        return bbx1, bby1, bbx2, bby2


    def __getitem__(self, index):
        img1, label1 = self.dataset[index]

        if np.random.rand() < self.prob:
            lam = np.random.beta(self.beta, self.beta)
            rand_index = np.random.randint(len(self.dataset))
            img2, label2 = self.dataset[rand_index]

            bbx1, bby1, bbx2, bby2 = self.rand_bbox(img1.size(), lam)
            img1[:, bbx1:bbx2, bby1:bby2] = img2[:, bbx1:bbx2, bby1:bby2]

            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (img1.size(-1) * img1.size(-2)))
            label = lam * label1 + (1 - lam) * label2
        else:
            label = label1

        return img1, label



# Model Define

In [9]:
'''
class MDTA(nn.Module):
    def __init__(self, channels, num_heads):
        super(MDTA, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(1, num_heads, 1, 1))

        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False)
        self.qkv_conv = nn.Conv2d(channels * 3, channels * 3, kernel_size=3, padding=1, groups=channels * 3, bias=False)
        self.project_out = nn.Conv2d(channels, channels, kernel_size=1, bias=False)

    def forward(self, x):
        b, c, h, w = x.shape
        q, k, v = self.qkv_conv(self.qkv(x)).chunk(3, dim=1)

        q = q.reshape(b, self.num_heads, -1, h * w)
        k = k.reshape(b, self.num_heads, -1, h * w)
        v = v.reshape(b, self.num_heads, -1, h * w)
        q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1)

        attn = torch.softmax(torch.matmul(q, k.transpose(-2, -1).contiguous()) * self.temperature, dim=-1)
        out = self.project_out(torch.matmul(attn, v).reshape(b, -1, h, w))
        return out


class GDFN(nn.Module):
    def __init__(self, channels, expansion_factor):
        super(GDFN, self).__init__()

        hidden_channels = int(channels * expansion_factor)
        self.project_in = nn.Conv2d(channels, hidden_channels * 2, kernel_size=1, bias=False)
        self.conv = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, kernel_size=3, padding=1,
                              groups=hidden_channels * 2, bias=False)
        self.project_out = nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=False)

    def forward(self, x):
        x1, x2 = self.conv(self.project_in(x)).chunk(2, dim=1)
        x = self.project_out(F.gelu(x1) * x2)
        return x


class TransformerBlock(nn.Module):
    def __init__(self, channels, num_heads, expansion_factor):
        super(TransformerBlock, self).__init__()

        self.norm1 = nn.LayerNorm(channels)
        self.attn = MDTA(channels, num_heads)
        self.norm2 = nn.LayerNorm(channels)
        self.ffn = GDFN(channels, expansion_factor)

    def forward(self, x):
        b, c, h, w = x.shape
        x = x + self.attn(self.norm1(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                          .contiguous().reshape(b, c, h, w))
        x = x + self.ffn(self.norm2(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                         .contiguous().reshape(b, c, h, w))
        return x


class DownSample(nn.Module):
    def __init__(self, channels):
        super(DownSample, self).__init__()
        self.body = nn.Sequential(nn.Conv2d(channels, channels // 2, kernel_size=3, padding=1, bias=False),
                                  nn.PixelUnshuffle(2))

    def forward(self, x):
        return self.body(x)


class UpSample(nn.Module):
    def __init__(self, channels):
        super(UpSample, self).__init__()
        self.body = nn.Sequential(nn.Conv2d(channels, channels * 2, kernel_size=3, padding=1, bias=False),
                                  nn.PixelShuffle(2))

    def forward(self, x):
        return self.body(x)


class Restormer(nn.Module):
    def __init__(self, num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 4, 8], channels=[24, 48, 96, 192], num_refinement=4, expansion_factor=2.66):
        
        super(Restormer, self).__init__()

        self.embed_conv = nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False)

        self.encoders = nn.ModuleList([nn.Sequential(*[TransformerBlock(
            num_ch, num_ah, expansion_factor) for _ in range(num_tb)]) for num_tb, num_ah, num_ch in
                                       zip(num_blocks, num_heads, channels)])
        
        # the number of down sample or up sample == the number of encoder - 1
        self.downs = nn.ModuleList([DownSample(num_ch) for num_ch in channels[:-1]])
        self.ups = nn.ModuleList([UpSample(num_ch) for num_ch in list(reversed(channels))[:-1]])

        # the number of reduce block == the number of decoder - 1
        self.reduces = nn.ModuleList([nn.Conv2d(channels[i], channels[i - 1], kernel_size=1, bias=False)
                                      for i in reversed(range(2, len(channels)))])
        
        # the number of decoder == the number of encoder - 1
        self.decoders = nn.ModuleList([nn.Sequential(*[TransformerBlock(channels[2], num_heads[2], expansion_factor)
                                                       for _ in range(num_blocks[2])])])
        self.decoders.append(nn.Sequential(*[TransformerBlock(channels[1], num_heads[1], expansion_factor)
                                             for _ in range(num_blocks[1])]))
        
        # the channel of last one is not change
        self.decoders.append(nn.Sequential(*[TransformerBlock(channels[1], num_heads[0], expansion_factor)
                                             for _ in range(num_blocks[0])]))

        self.refinement = nn.Sequential(*[TransformerBlock(channels[1], num_heads[0], expansion_factor)
                                          for _ in range(num_refinement)])
        self.output = nn.Conv2d(channels[1], 3, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        fo = self.embed_conv(x)
        out_enc1 = self.encoders[0](fo)
        out_enc2 = self.encoders[1](self.downs[0](out_enc1))
        out_enc3 = self.encoders[2](self.downs[1](out_enc2))
        out_enc4 = self.encoders[3](self.downs[2](out_enc3))

        out_dec3 = self.decoders[0](self.reduces[0](torch.cat([self.ups[0](out_enc4), out_enc3], dim=1)))
        out_dec2 = self.decoders[1](self.reduces[1](torch.cat([self.ups[1](out_dec3), out_enc2], dim=1)))
        fd = self.decoders[2](torch.cat([self.ups[2](out_dec2), out_enc1], dim=1))
        fr = self.refinement(fd)
        out = self.output(fr) + x
        return out


class CGBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CGBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, groups=out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(out_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x_local = F.relu(self.bn1(self.conv1(x)))
        x_local = F.relu(self.bn2(self.conv2(x_local)))
        x_global = self.global_avg_pool(x_local)
        x_global = self.fc(x_global)
        return x_local + x_global

class CGNet(nn.Module):
    def __init__(self, num_classes=3):
        super(CGNet, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.cg_block1 = CGBlock(32, 64)
        self.cg_block2 = CGBlock(64, 128)
        self.classifier = nn.Conv2d(128, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.initial(x)
        x = self.cg_block1(x)
        x = self.cg_block2(x)
        x = self.classifier(x)
        return F.interpolate(x, size=x.shape[-2:], mode='bilinear', align_corners=False)
    '''

import torch
import torch.nn as nn
import torch.nn.functional as F

class NAFBlock(nn.Module):
    def __init__(self, in_channels):
        super(NAFBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels, bias=False)
        self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)

        self.beta = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
        self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))

    def forward(self, x):
        shortcut = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = F.relu(x)  # ReLU activation
        x = self.conv3(x)
        return shortcut + self.beta * x

class NAFNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_blocks=16):
        super(NAFNet, self).__init__()
        self.initial_conv = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        
        self.blocks = nn.Sequential(
            *[NAFBlock(64) for _ in range(num_blocks)]
        )

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        input = x
        x = self.initial_conv(x)
        x = self.blocks(x)
        x = self.final_conv(x)
        return x + input


  

# Train

# Real Train

In [88]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, SubsetRandomSampler
from tqdm import tqdm
import numpy as np

# 시작 시간 기록
start_time = time.time()

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, mode='fan_in', nonlinearity='relu')

def load_img(filepath):
    img = cv2.imread(filepath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# 데이터셋 경로
noisy_image_paths = '/home/work/.default/hyunwoong/Contest/event/Training/noisy'
clean_image_paths = '/home/work/.default/hyunwoong/Contest/event/Training/clean'

# 데이터셋 및 DataLoader 설정
train_transform = Compose([
    ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

if __name__ == "__main__":
    # 커스텀 데이터셋 인스턴스 생성
    train_dataset = CustomDataset(clean_image_paths, noisy_image_paths, transform=train_transform)
    cutmix_dataset = CutMixDataset(train_dataset, beta=1.0, prob=0.5)

    # 데이터 로더 설정
    train_loader = DataLoader(cutmix_dataset, batch_size=CFG['BATCH_SIZE'], num_workers=0, shuffle=True)

    # GPU 장치 설정
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # NAFNet 모델 생성 및 GPU로 이동
    model = NAFNet(in_channels=3, out_channels=3, num_blocks=8).to(device)

    # 손실 함수와 최적화 알고리즘 설정
    criterion = nn.L1Loss()
    optimizer = optim.AdamW(model.parameters(), lr=CFG['LEARNING_RATE'], weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=CFG['EPOCHS'])

    # 학습 루프
    model.train()
    best_loss = float('inf')

    for epoch in range(CFG['EPOCHS']):
        model.train()
        epoch_loss = 0.0
        
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{CFG['EPOCHS']}", unit="batch") as pbar:
            for noisy_images, clean_images in train_loader:
                noisy_images = noisy_images.to(device)
                clean_images = clean_images.to(device)
                
                optimizer.zero_grad()
                outputs = model(noisy_images)
                loss = criterion(outputs, clean_images)
                loss.backward()
                optimizer.step()
                scheduler.step()

                epoch_loss += loss.item() * noisy_images.size(0)
                pbar.set_postfix({"Loss": loss.item()})
                pbar.update(1)

        avg_epoch_loss = epoch_loss / len(train_dataset)
        print(f"Epoch {epoch+1}/{CFG['EPOCHS']}, Average Loss: {avg_epoch_loss:.4f}")

        # 각 epoch마다 모델 저장
        torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pth')
        print(f"Model saved for epoch {epoch+1} with loss {avg_epoch_loss:.4f}")

        # 최적의 모델 저장
        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            torch.save(model.state_dict(), 'best_NAFNet.pth')
            print(f"Best model updated at epoch {epoch+1} with loss {best_loss:.4f}")



Epoch 1/15: 100%|██████████| 871/871 [20:20<00:00,  1.40s/batch, Loss=0.184]


Epoch 1/15, Average Loss: 0.2196
Model saved for epoch 1 with loss 0.2196
Best model updated at epoch 1 with loss 0.2196


Epoch 2/15: 100%|██████████| 871/871 [20:14<00:00,  1.39s/batch, Loss=0.0721]


Epoch 2/15, Average Loss: 0.2168
Model saved for epoch 2 with loss 0.2168
Best model updated at epoch 2 with loss 0.2168


Epoch 3/15: 100%|██████████| 871/871 [20:32<00:00,  1.42s/batch, Loss=0.179]


Epoch 3/15, Average Loss: 0.2167
Model saved for epoch 3 with loss 0.2167
Best model updated at epoch 3 with loss 0.2167


Epoch 4/15: 100%|██████████| 871/871 [20:44<00:00,  1.43s/batch, Loss=0.371]


Epoch 4/15, Average Loss: 0.2159
Model saved for epoch 4 with loss 0.2159
Best model updated at epoch 4 with loss 0.2159


Epoch 5/15: 100%|██████████| 871/871 [20:45<00:00,  1.43s/batch, Loss=0.123]


Epoch 5/15, Average Loss: 0.2153
Model saved for epoch 5 with loss 0.2153
Best model updated at epoch 5 with loss 0.2153


Epoch 6/15: 100%|██████████| 871/871 [20:40<00:00,  1.42s/batch, Loss=0.418]


Epoch 6/15, Average Loss: 0.2144
Model saved for epoch 6 with loss 0.2144
Best model updated at epoch 6 with loss 0.2144


Epoch 7/15: 100%|██████████| 871/871 [20:33<00:00,  1.42s/batch, Loss=0.0333]


Epoch 7/15, Average Loss: 0.2148
Model saved for epoch 7 with loss 0.2148


Epoch 8/15: 100%|██████████| 871/871 [20:42<00:00,  1.43s/batch, Loss=0.19] 


Epoch 8/15, Average Loss: 0.2143
Model saved for epoch 8 with loss 0.2143
Best model updated at epoch 8 with loss 0.2143


Epoch 9/15: 100%|██████████| 871/871 [20:37<00:00,  1.42s/batch, Loss=0.304]


Epoch 9/15, Average Loss: 0.2142
Model saved for epoch 9 with loss 0.2142
Best model updated at epoch 9 with loss 0.2142


Epoch 10/15: 100%|██████████| 871/871 [20:34<00:00,  1.42s/batch, Loss=0.193]


Epoch 10/15, Average Loss: 0.2133
Model saved for epoch 10 with loss 0.2133
Best model updated at epoch 10 with loss 0.2133


Epoch 11/15: 100%|██████████| 871/871 [20:42<00:00,  1.43s/batch, Loss=0.367]


Epoch 11/15, Average Loss: 0.2141
Model saved for epoch 11 with loss 0.2141


Epoch 12/15: 100%|██████████| 871/871 [20:25<00:00,  1.41s/batch, Loss=0.295]


Epoch 12/15, Average Loss: 0.2140
Model saved for epoch 12 with loss 0.2140


Epoch 13/15: 100%|██████████| 871/871 [20:13<00:00,  1.39s/batch, Loss=0.0464]


Epoch 13/15, Average Loss: 0.2119
Model saved for epoch 13 with loss 0.2119
Best model updated at epoch 13 with loss 0.2119


Epoch 14/15: 100%|██████████| 871/871 [20:44<00:00,  1.43s/batch, Loss=0.333]


Epoch 14/15, Average Loss: 0.2133
Model saved for epoch 14 with loss 0.2133


Epoch 15/15: 100%|██████████| 871/871 [20:33<00:00,  1.42s/batch, Loss=0.0656]

Epoch 15/15, Average Loss: 0.2146
Model saved for epoch 15 with loss 0.2146





# Inference

In [13]:
# CustomDatasetTest 정의
class CustomDatasetTest(torch.utils.data.Dataset):
    def __init__(self, noisy_image_paths, transform=None):
        self.noisy_image_paths = [os.path.join(noisy_image_paths, x) for x in os.listdir(noisy_image_paths)]
        self.transform = transform

    def __len__(self):
        return len(self.noisy_image_paths)

    def __getitem__(self, index):
        noisy_image_path = self.noisy_image_paths[index]
        noisy_image = load_img(noisy_image_path)

        if isinstance(noisy_image, np.ndarray):
            noisy_image = Image.fromarray(noisy_image)

        if self.transform:
            noisy_image = self.transform(noisy_image)

        return noisy_image, noisy_image_path

# 이미지 로딩 함수
def load_img(filepath):
    img = cv2.imread(filepath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

# NAFNet 모델 정의 및 가중치 불러오기
model = NAFNet(in_channels=3, out_channels=3, num_blocks=8)

# GPU 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# 테스트 데이터 경로
test_data_path = '/home/work/.default/hyunwoong/Contest/KMS_uni/open_(1)/test/Input'
output_path = './open(1)/test/submission'
test_transform = Compose([
    ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# 테스트 데이터셋 및 데이터로더 설정
test_dataset = CustomDatasetTest(test_data_path, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# 결과 저장 폴더 생성
if not os.path.exists(output_path):
    os.makedirs(output_path)

def load_epoch_model(model, epoch):
    model_path = f'model_epoch_{epoch}.pth'
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print(f"Loaded model from {model_path}")

def denoise_and_save_images(model, test_loader, output_path):
    with torch.no_grad():
        for noisy_image, noisy_image_path in test_loader:
            noisy_image = noisy_image.to(device)
            denoised_image = model(noisy_image)
            
            # 후처리 및 저장
            denoised_image = denoised_image.cpu().squeeze(0)
            denoised_image = (denoised_image * 0.5 + 0.5).clamp(0, 1)
            denoised_image = transforms.ToPILImage()(denoised_image)

            # 파일 경로 및 저장
            output_filename = noisy_image_path[0]
            denoised_filename = output_path + '/' + output_filename.split('/')[-1][:-4] + '.jpg'
            denoised_image.save(denoised_filename) 
# 특정 epoch의 모델 가중치 로드 및 테스트셋에 적용
epoch_to_load = 15  # 불러오고 싶은 epoch 번호 설정
load_epoch_model(model, epoch_to_load)
denoise_and_save_images(model, test_loader, output_path)

Loaded model from model_epoch_15.pth


# Submission

In [15]:
def zip_folder(folder_path, output_zip):
    shutil.make_archive(output_zip, 'zip', folder_path)
    print(f"Created {output_zip}.zip successfully.")

zip_folder(output_path, './submission')

Created ./submission.zip successfully.
