## DnCNN Experiments

This is a implementation of CT post-reconstruction denosing using DnCNN 

In [1]:
# import built-in liberies
import sys
import os
import glob
import shutil
import time

# import bsic liberies
import numpy as np
import matplotlib.pyplot as plt
import cv2

# import torch liberies
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR

# import metrics liberies
from skimage.metrics import peak_signal_noise_ratio as psnr_metric, structural_similarity as ssim_metric

# import custom liberies
sys.path.insert(0, "..")
from utils import process, visualize
# from models.DnCNN import DnCNN

import ipywidgets as widgets
from IPython.display import display
%matplotlib inline

### Define DnCNN Network

In [2]:
class DnCNN(nn.Module):
    def __init__(self, channels=1, num_layers=17, features=64):
        super(DnCNN, self).__init__()
        layers = [
            nn.Conv2d(channels, features, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        ]

        for _ in range(num_layers - 2):
            layers.append(nn.Conv2d(features, features, kernel_size=3, padding=1))
            layers.append(nn.BatchNorm2d(features))
            layers.append(nn.ReLU(inplace=True))

        layers.append(nn.Conv2d(features, channels, kernel_size=3, padding=1))

        self.dncnn = nn.Sequential(*layers)

    def forward(self, x):
        out = self.dncnn(x)
        return x - out


### Parepare CT Sinogram Dataset and Dataloader

In [3]:
# define the custom dataset
class CTSinogramDataset(Dataset):
    def __init__(self, clean_folder, noisy_folder, transform=None):
        self.clean_folder = clean_folder
        self.noisy_folder = noisy_folder
        self.transform = transform
        self.patient_ids = sorted(os.listdir(clean_folder))

        self.clean_slices = {}
        self.noisy_slices = {}
        for patient_id in self.patient_ids:
            clean_patient_folder = os.path.join(clean_folder, patient_id)
            noisy_patient_folder = os.path.join(noisy_folder, patient_id)
            
            clean_slice_files = sorted([f for f in os.listdir(clean_patient_folder) if self.valid_image_ext(f)])
            noisy_slice_files = sorted([f for f in os.listdir(noisy_patient_folder) if self.valid_image_ext(f)])

            clean_slice_paths = [os.path.join(clean_patient_folder, f) for f in clean_slice_files]
            noisy_slice_paths = [os.path.join(noisy_patient_folder, f) for f in noisy_slice_files]

            self.clean_slices[patient_id] = clean_slice_paths
            self.noisy_slices[patient_id] = noisy_slice_paths
        
        assert len(self.clean_slices) == len(self.noisy_slices), \
            "Number of clean slices and noisy slices are not equal."
        
    def valid_image_ext(self, filename):
        valid_exts = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
        ext = os.path.splitext(filename)[-1].lower()
        return ext in valid_exts
    
    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, index):
        patient_id = self.patient_ids[index]

        clean_slice_paths = self.clean_slices[patient_id]
        noisy_slice_paths = self.noisy_slices[patient_id]

        clean_slices = []
        for path in clean_slice_paths:
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                print(f"Warning: Skipping a corrupted or missing image at path {path}.")
                continue
            clean_slices.append(img)

        noisy_slices = []
        for path in noisy_slice_paths:
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                print(f"Warning: Skipping a corrupted or missing image at path {path}.")
                continue
            noisy_slices.append(img)

        
        if self.transform is not None:
            clean_slices = [self.transform(clean_slice) for clean_slice in clean_slices]
            noisy_slices = [self.transform(noisy_slice) for noisy_slice in noisy_slices]
        else:
            clean_slices = [torch.from_numpy(clean_slice / 255.0).unsqueeze(0) for clean_slice in clean_slices]
            noisy_slices = [torch.from_numpy(noisy_slice / 255.0).unsqueeze(0) for noisy_slice in noisy_slices]

        return torch.stack(clean_slices), torch.stack(noisy_slices)

In [4]:
# 
clean_folder = "../dataset/Kaggle_CT Low Dose Reconstruction/prepared_recon/lam_0"
noisy_folder = "../dataset/Kaggle_CT Low Dose Reconstruction/prepared_recon/lam_5"

# define data transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# create dataset
dataset = CTSinogramDataset(clean_folder, noisy_folder, transform=transform)

# calculate dataset length
train_len = int(0.6 * len(dataset))
val_len = int(0.2 * len(dataset))
test_len = len(dataset) - train_len - val_len

# random_split dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_len, val_len, test_len])
print(f"train:{len(train_dataset)}, val:{len(val_dataset)}, test:{len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)

train:6, val:2, test:2


In [5]:
dataiter = iter(train_loader)
clean_batch, noisy_batch = dataiter.next()

print(clean_batch.shape, noisy_batch.shape)
print(f"range: {clean_batch[0].min()} , {clean_batch[0].max()}")

# visualize training CT silice
visualize.plot_slices(noisy_batch[0])

torch.Size([1, 244, 1, 256, 256]) torch.Size([1, 244, 1, 256, 256])
range: -1.0 , 1.0


interactive(children=(IntSlider(value=0, description='slice_idx', max=243), Output()), _dom_classes=('widget-i…

### Training model

In [11]:
# metrics
def calculate_psnr(clean_img, denoised_img):
    clean_img = clean_img.squeeze()
    denoised_img = denoised_img.squeeze()
    #print("PSNR", clean_img.shape, denoised_img.shape)
    mse = torch.mean((clean_img - denoised_img) ** 2).item()
    if mse == 0:
        return 100
    max_pixel = 1.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr

def calculate_ssim(clean_img, denoised_img):
    clean_img_np = clean_img.squeeze().cpu().numpy()
    denoised_img_np = denoised_img.squeeze().cpu().numpy()
    #print("SSIM", clean_img_np.shape, denoised_img_np.shape)
    ssim = ssim_metric(clean_img_np, denoised_img_np, data_range=1.0, win_size=7)
    return ssim

In [12]:
def train(model, dataloader, criterion, optimizer, scheduler, device, epoch):
    model.train()
    running_loss = 0.0

    for batch_idx, (clean_slices, noisy_slices) in enumerate(dataloader):
        clean_slices, noisy_slices = clean_slices.to(device), noisy_slices.to(device)
    
        for i in range(clean_slices.size(1)):  # Iterate through each slice
            optimizer.zero_grad()
            
            clean_slice = clean_slices[:, i, :, :]
            noisy_slice = noisy_slices[:, i, :, :]
            
            outputs = model(noisy_slice)
            loss = criterion(outputs, clean_slice)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            # writer.add_scalar('Training Loss', loss.item(), epoch * len(dataloader) + batch_idx)

    scheduler.step()
    avg_loss = running_loss / len(dataloader)
    print(f"Epoch {epoch + 1}, training Loss: {avg_loss:.4f}")
    return avg_loss

def validate(model, dataloader, criterion, device, epoch):
    model.eval()
    running_loss = 0.0
    running_psnr = 0.0
    running_ssim = 0.0

    with torch.no_grad():
        for batch_idx, (clean_slices, noisy_slices) in enumerate(dataloader):
            clean_slices, noisy_slices = clean_slices.to(device), noisy_slices.to(device)
            loss = 0.0
            psnr = 0.0
            ssim = 0.0

            for i in range(clean_slices.size(1)):  # Iterate through each slice
                clean_slice = clean_slices[:, i, :, :]
                noisy_slice = noisy_slices[:, i, :, :]
                outputs = model(noisy_slice)
                slice_loss = criterion(outputs, clean_slice)
                loss += slice_loss.item()
                psnr += calculate_psnr(clean_slice, outputs)
                ssim += calculate_ssim(clean_slice, outputs)

            running_loss += loss
            running_psnr += psnr / clean_slices.size(1)
            running_ssim += ssim / clean_slices.size(1)

            #writer.add_scalar('Validation Loss', loss, epoch * len(dataloader) + batch_idx)
            #writer.add_scalar('Validation PSNR', psnr / clean_slices.size(1), epoch * len(dataloader) + batch_idx)
            #writer.add_scalar('Validation SSIM', ssim / clean_slices.size(1), epoch * len(dataloader) + batch_idx)

    avg_loss = running_loss / len(dataloader)
    avg_psnr = running_psnr / len(dataloader)
    avg_ssim = running_ssim / len(dataloader)
    print(f"Epoch {epoch + 1}, validation loss: {avg_loss:.4f}, PSNR: {avg_psnr:.4f}, SSIM: {avg_ssim:.4f}")
    return avg_loss, avg_psnr, avg_ssim


In [13]:
def train_and_validate(model, train_loader, val_loader, num_epochs, device, log_dir, checkpoint_dir):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters())
    scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

    best_loss = float('inf')
    writer = SummaryWriter(log_dir)

    for epoch in range(num_epochs):
        train_loss = train(model, train_loader, criterion, optimizer, scheduler, device, epoch)
        val_loss, val_psnr, val_ssim = validate(model, val_loader, criterion, device, epoch)

        # Save a checkpoint at the end of each epoch
        """
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': val_loss,
            'psnr': val_psnr,
            'ssim': val_ssim
        }, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt'))
        """
        # 
        #writer.add_scalar('Training Loss', train_loss, epoch)
        #writer.add_scalar('Validation Loss', val_loss, epoch)
        writer.add_scalars("Losses", {"Train": train_loss, "Val": val_loss}, epoch)
        writer.add_scalar('Validation PSNR', val_psnr, epoch)
        writer.add_scalar('Validation SSIM', val_ssim, epoch)
        
        
        # Save the best model
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), os.path.join(checkpoint_dir, f'best_model_epoch{epoch}.pt'))
            print(f"Best model saved at epoch {epoch + 1} with val loss: {best_loss:.4f}")

    writer.close()

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

num_epochs = 50

task_name = f"DnCNN_lam5_MSE_epoch{num_epochs}"

log_dir = f"/root/tf-logs/run/{task_name}"
checkpoint_dir = f"checkpoints/{task_name}"

os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

model = DnCNN().to(device)

# start trainging
start_time = time.time()
train_and_validate(model, train_loader, val_loader, num_epochs, device, log_dir, checkpoint_dir)
end_time = time.time()
total_second = int(end_time - start_time)

device: cuda
Epoch 1, training Loss: 1.4971
Epoch 1, validation loss: 3.4993, PSNR: 18.2610, SSIM: 0.4774
Best model saved at epoch 1 with Loss: 3.4993
Epoch 2, training Loss: 1.3976
Epoch 2, validation loss: 3.0117, PSNR: 19.9534, SSIM: 0.5074
Best model saved at epoch 2 with Loss: 3.0117
Epoch 3, training Loss: 1.4508
Epoch 3, validation loss: 2.3153, PSNR: 20.8007, SSIM: 0.5283
Best model saved at epoch 3 with Loss: 2.3153
Epoch 4, training Loss: 1.4253
Epoch 4, validation loss: 2.5763, PSNR: 20.2858, SSIM: 0.5142
Epoch 5, training Loss: 1.4106
Epoch 5, validation loss: 1.8013, PSNR: 21.7724, SSIM: 0.5279
Best model saved at epoch 5 with Loss: 1.8013
Epoch 6, training Loss: 1.2508
Epoch 6, validation loss: 1.5058, PSNR: 22.1715, SSIM: 0.5664
Best model saved at epoch 6 with Loss: 1.5058
Epoch 7, training Loss: 0.5622
Epoch 7, validation loss: 1.2622, PSNR: 24.3475, SSIM: 0.7180
Best model saved at epoch 7 with Loss: 1.2622
Epoch 8, training Loss: 0.5676
Epoch 8, validation loss: 0.9

KeyboardInterrupt: 

In [None]:
# send result to wechat 

import requests
headers = {"Authorization": "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1aWQiOjM5MjI4LCJ1dWlkIjoiNTAyZTcyM2ItZjY2Mi00YTk4LWJkZmEtMzc1ZjdlOWM5NmFlIiwiaXNfYWRtaW4iOmZhbHNlLCJpc19zdXBlcl9hZG1pbiI6ZmFsc2UsInN1Yl9uYW1lIjoiIiwidGVuYW50IjoiYXV0b2RsIiwidXBrIjoiIn0.0IybMXdA-3z6KDYJDDGCj-_qqw6o4kya5usOFcLUtFL-ewBe35RnN8COQn4lO3umL-rWJ3er2PsIWZBjIl5XJw"}
resp = requests.post("https://www.autodl.com/api/v1/wechat/message/send",
                     json={
                         "title": "训练结束",
                         "name": task_name,
                         "content": f"training time: {total_second // 3600}h {(total_second % 3600) // 60}m"
                     }, headers = headers)

print(resp.content.decode())

### Evaluations model