## DnCNN Experiments

This is a implementation of CT post-reconstruction denosing using DnCNN 

In [1]:
# import built-in liberies
import sys
import os
import glob
import shutil

# import bsic liberies
import numpy as np
import matplotlib.pyplot as plt
import cv2

# import torch liberies
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

# import metrics liberies
from skimage.metrics import peak_signal_noise_ratio as psnr
from pytorch_ssim import ssim
from tqdm import tqdm

# import custom liberies
sys.path.insert(0, "..")
from utils import process, visualize
from models.DnCNN import DnCNN

import ipywidgets as widgets
from IPython.display import display
%matplotlib inline

### Define DnCNN Network

In [2]:
class DnCNN(nn.Module):
    def __init__(self, channels=1, num_layers=17, features=64):
        super(DnCNN, self).__init__()
        layers = [
            nn.Conv2d(channels, features, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        ]

        for _ in range(num_layers - 2):
            layers.append(nn.Conv2d(features, features, kernel_size=3, padding=1))
            layers.append(nn.BatchNorm2d(features))
            layers.append(nn.ReLU(inplace=True))

        layers.append(nn.Conv2d(features, channels, kernel_size=3, padding=1))

        self.dncnn = nn.Sequential(*layers)

    def forward(self, x):
        out = self.dncnn(x)
        return x - out


### Parepare CT Sinogram Dataset and Dataloader

In [3]:
# define the custom dataset
class CTSinogramDataset(Dataset):
    def __init__(self, clean_folder, noisy_folder, transform=None):
        self.clean_folder = clean_folder
        self.noisy_folder = noisy_folder
        self.transform = transform
        self.patient_ids = sorted(os.listdir(clean_folder))

        self.clean_slices = {}
        self.noisy_slices = {}
        for patient_id in self.patient_ids:
            clean_patient_folder = os.path.join(clean_folder, patient_id)
            noisy_patient_folder = os.path.join(noisy_folder, patient_id)

            clean_slice_files = sorted(os.listdir(clean_patient_folder))
            noisy_slice_files = sorted(os.listdir(noisy_patient_folder))

            clean_slice_paths = [os.path.join(clean_patient_folder, f) for f in clean_slice_files]
            noisy_slice_paths = [os.path.join(noisy_patient_folder, f) for f in noisy_slice_files]

            self.clean_slices[patient_id] = clean_slice_paths
            self.noisy_slices[patient_id] = noisy_slice_paths
        
        assert len(self.clean_slices) == len(self.noisy_slices), \
            "Number of clean slices and noisy slices are not equal."

    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, index):
        patient_id = self.patient_ids[index]

        clean_slice_paths = self.clean_slices[patient_id]
        noisy_slice_paths = self.noisy_slices[patient_id]

        clean_slices = [cv2.imread(path, cv2.IMREAD_GRAYSCALE) for path in clean_slice_paths]
        noisy_slices = [cv2.imread(path, cv2.IMREAD_GRAYSCALE) for path in noisy_slice_paths]
        
        if self.transform is not None:
            clean_slices = [self.transform(clean_slice) for clean_slice in clean_slices]
            noisy_slices = [self.transform(noisy_slice) for noisy_slice in noisy_slices]
        else:
            clean_slices = [torch.from_numpy(clean_slice / 255.0).unsqueeze(0) for clean_slice in clean_slices]
            noisy_slices = [torch.from_numpy(noisy_slice / 255.0).unsqueeze(0) for noisy_slice in noisy_slices]

        return torch.stack(clean_slices), torch.stack(noisy_slices)

In [4]:
# 
clean_folder = "../dataset/Kaggle_CT Low Dose Reconstruction/prepared_recon/lam_0"
noisy_folder = "../dataset/Kaggle_CT Low Dose Reconstruction/prepared_recon/lam_5"

# define data transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# create dataset
dataset = CTSinogramDataset(clean_folder, noisy_folder, transform=transform)

# calculate dataset length
train_len = int(0.6 * len(dataset))
val_len = int(0.2 * len(dataset))
test_len = len(dataset) - train_len - val_len

# random_split dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_len, val_len, test_len])
print(f"train:{len(train_dataset)}, val:{len(val_dataset)}, test:{len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)


train:6, val:2, test:2


In [5]:
dataiter = iter(train_loader)
clean_batch, noisy_batch = dataiter.next()

print(clean_batch.shape, noisy_batch.shape)

# visualize training CT silice
visualize.plot_slices(clean_batch[0])

torch.Size([1, 128, 1, 256, 256])

### Training model

In [None]:
# train functions
def save_checkpoint(state, is_best, checkpoint_path, best_model_path):
    torch.save(state, checkpoint_path)
    if is_best:
        shutil.copyfile(checkpoint_path, best_model_path)
        
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    num_samples = 0
    for i, (clean_slices, noisy_slices) in enumerate(dataloader):
        for j in range(clean_slices.size(1)):
            clean_slice, noisy_slice = clean_slices[:, j].to(device), noisy_slices[:, j].to(device)
            optimizer.zero_grad()
            outputs = model(noisy_slice)
            loss = criterion(outputs, clean_slice)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            num_samples += 1
    return running_loss / num_samples

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_psnr = 0.0
    running_ssim = 0.0
    num_samples = 0
    with torch.no_grad():
        for clean_slices, noisy_slices in dataloader:
            for j in range(clean_slices.size(1)):
                clean_slice, noisy_slice = clean_slices[:, j].to(device), noisy_slices[:, j].to(device)
                outputs = model(noisy_slice)
                loss = criterion(outputs, clean_slice)
                running_loss += loss.item()
                psnr = 10 * np.log10(1 / criterion(outputs, clean_slice))
                ssim_val = ssim(outputs.cpu().numpy(), clean_slice.cpu().numpy(), data_range=1)
                running_psnr += psnr
                running_ssim += ssim_val
                num_samples += 1
    return running_loss / num_samples, running_psnr / num_samples, running_ssim / num_samples


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

model = DnCNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

num_epochs = 50
best_val_loss = float("inf")

writer = SummaryWriter("tf_logs/runs")
checkpoint_path = "checkpoint.pt"
best_model_path = "best_model.pt"

for epoch in range(1, num_epochs + 1):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    val_loss, val_psnr, val_ssim = validate(model, val_loader, criterion, device)
    scheduler.step()
    
    writer.add_scalars("Loss", {"train_loss": train_loss, "val_loss": val_loss}, epoch)
    writer.add_scalar("PSNR", val_psnr, epoch)
    writer.add_scalar("SSIM", val_ssim, epoch)
    
    is_best = val_loss < best_val_loss
    best_val_loss = min(val_loss, best_val_loss)
    save_checkpoint(
        {
            "epoch": epoch,
            "state_dict": model.state_dict(),
            "best_val_loss": best_val_loss,
            "optimizer": optimizer.state_dict(),
        },
        is_best,
        checkpoint_path,
        best_model_path,
    )
    print(f"Epoch: {epoch}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val PSNR: {val_psnr:.4f}, Val SSIM:{val_ssim:.4f}")

In [None]:
# send result to wechat 

# import requests
# headers = {"Authorization": "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1aWQiOjM5MjI4LCJ1dWlkIjoiNTAyZTcyM2ItZjY2Mi00YTk4LWJkZmEtMzc1ZjdlOWM5NmFlIiwiaXNfYWRtaW4iOmZhbHNlLCJpc19zdXBlcl9hZG1pbiI6ZmFsc2UsInN1Yl9uYW1lIjoiIiwidGVuYW50IjoiYXV0b2RsIiwidXBrIjoiIn0.0IybMXdA-3z6KDYJDDGCj-_qqw6o4kya5usOFcLUtFL-ewBe35RnN8COQn4lO3umL-rWJ3er2PsIWZBjIl5XJw"}
# resp = requests.post("https://www.autodl.com/api/v1/wechat/message/send",
#                      json={
#                          "title": "训练结束",
#                          "name": "DnCNN-noisy10-MSE-50epoch",
#                          "content": f"training time: {total_second // 3600}h {(total_second % 3600) // 60}m"
#                      }, headers = headers)

# print(resp.content.decode())

### Evaluations model

In [None]:
# Load the best model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

model = DnCNN().to(device)

model.load_state_dict(torch.load("best_model.pt"))
model.eval()

# Create test data loader
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)

# Initialize lists for PSNR and SSIM values
psnr_values = []
ssim_values = []

# Test the model on test dataset
with torch.no_grad():
    for i, (clean, noisy) in enumerate(tqdm(test_loader)):
        if torch.cuda.is_available():
            clean = clean.cuda()
            noisy = noisy.cuda()

        denoised = model(noisy)

        clean = clean.cpu().numpy().squeeze()
        denoised = denoised.cpu().numpy().squeeze()

        clean = (clean * 0.5) + 0.5
        denoised = (denoised * 0.5) + 0.5

        clean = (clean * 255).astype(np.uint8)
        denoised = (denoised * 255).astype(np.uint8)

        # Calculate PSNR and SSIM for current sample
        psnr_val = psnr(clean, denoised, data_range=255)
        ssim_val = ssim(clean, denoised, data_range=255, multichannel=False)

        psnr_values.append(psnr_val)
        ssim_values.append(ssim_val)

# Calculate average PSNR and SSIM
avg_psnr = np.mean(psnr_values)
avg_ssim = np.mean(ssim_values)

print(f"Average PSNR: {avg_psnr:.4f}")
print(f"Average SSIM: {avg_ssim:.4f}")
