## DnCNN Experiments



In [1]:
# import built-in liberies
import sys
import os
import glob

# import bsic liberies
import numpy as np
import matplotlib.pyplot as plt
import cv2

# import torch liberies
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms

# import custom liberies
sys.path.insert(0, "..")
from utils import process, visualize
from models.DnCNN import DnCNN

import ipywidgets as widgets
from IPython.display import display
%matplotlib inline

### Define DnCNN Network

In [2]:
class DnCNN(nn.Module):
    def __init__(self, channels=1, num_layers=17, features=64):
        super(DnCNN, self).__init__()
        layers = [
            nn.Conv2d(channels, features, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        ]

        for _ in range(num_layers - 2):
            layers.append(nn.Conv2d(features, features, kernel_size=3, padding=1))
            layers.append(nn.BatchNorm2d(features))
            layers.append(nn.ReLU(inplace=True))

        layers.append(nn.Conv2d(features, channels, kernel_size=3, padding=1))

        self.dncnn = nn.Sequential(*layers)

    def forward(self, x):
        out = self.dncnn(x)
        return x - out


### Parepare CT Sinogram Dataset and Dataloader

In [3]:
# define the custom dataset
class CTSinogramDataset(Dataset):
    def __init__(self, clean_folder, noisy_folder, transform=None):
        self.clean_folder = clean_folder
        self.noisy_folder = noisy_folder
        self.transform = transform
        self.patient_ids = sorted(os.listdir(clean_folder))

        self.clean_slices = {}
        self.noisy_slices = {}
        for patient_id in self.patient_ids:
            clean_patient_folder = os.path.join(clean_folder, patient_id)
            noisy_patient_folder = os.path.join(noisy_folder, patient_id)

            clean_slice_files = sorted(os.listdir(clean_patient_folder))
            noisy_slice_files = sorted(os.listdir(noisy_patient_folder))

            clean_slice_paths = [os.path.join(clean_patient_folder, f) for f in clean_slice_files]
            noisy_slice_paths = [os.path.join(noisy_patient_folder, f) for f in noisy_slice_files]

            self.clean_slices[patient_id] = clean_slice_paths
            self.noisy_slices[patient_id] = noisy_slice_paths

    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, index):
        patient_id = self.patient_ids[index]

        clean_slice_paths = self.clean_slices[patient_id]
        noisy_slice_paths = self.noisy_slices[patient_id]

        clean_slices = [cv2.imread(path, cv2.IMREAD_GRAYSCALE) for path in clean_slice_paths]
        noisy_slices = [cv2.imread(path, cv2.IMREAD_GRAYSCALE) for path in noisy_slice_paths]
        
        if self.transform is not None:
            clean_slices = [self.transform(clean_slice) for clean_slice in clean_slices]
            noisy_slices = [self.transform(noisy_slice) for noisy_slice in noisy_slices]
        else:
            clean_slices = [torch.from_numpy(clean_slice / 255.0).unsqueeze(0) for clean_slice in clean_slices]
            noisy_slices = [torch.from_numpy(noisy_slice / 255.0).unsqueeze(0) for noisy_slice in noisy_slices]

        return torch.stack(clean_slices), torch.stack(noisy_slices)

In [4]:
# 
clean_folder = "../dataset/Kaggle_CT Low Dose Reconstruction/prepared_sinogram/lam_0"
noisy_folder = "../dataset/Kaggle_CT Low Dose Reconstruction/prepared_sinogram/lam_5"

# define data transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# create dataset
dataset = CTSinogramDataset(clean_folder, noisy_folder, transform=transform)

# calculate dataset length
train_len = int(0.6 * len(dataset))
val_len = int(0.2 * len(dataset))
test_len = len(dataset) - train_len - val_len

# random_split dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_len, val_len, test_len])
print(f"train:{len(train_dataset)}, val:{len(val_dataset)}, test:{len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)


train:6, val:2, test:2


In [5]:
patient_idx = 0
dataiter = iter(train_loader)
clean_batch, noisy_batch = dataiter.next()

clean_batch.shape

torch.Size([1, 128, 1, 256, 256])

In [6]:
# visualize training CT silice

    
    
visualize.plot_slices(clean_batch[0])


interactive(children=(IntSlider(value=0, description='slice_idx', max=127), Output()), _dom_classes=('widget-i…

### Training model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

model = DnCNN().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


num_epochs = 50


for epoch in range(num_epochs):
    model.train()
    for i, (clean_sinograms, noisy_sinograms) in enumerate(train_loader):
        clean_sinograms = clean_sinograms.to(device).float()
        noisy_sinograms = noisy_sinograms.to(device).float()

        outputs = model(noisy_sinograms.unsqueeze(1))
        loss = criterion(outputs, clean_sinograms.unsqueeze(1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")

### Evaluations model

In [None]:
model.eval()
with torch.no_grad():
    for i, (clean_sinograms, noisy_sinograms) in enumerate(test_loader):
        clean_sinograms = clean_sinograms.to(device).float()
        noisy_sinograms = noisy_sinograms.to(device).float()

        outputs = model(noisy_sinograms.unsqueeze(1))
        loss = criterion(outputs, clean_sinograms.unsqueeze(1))
        print(f"Test Image {i + 1}, Loss: {loss.item():.4f}")
