In [1]:
#Fundamentals
import numpy as np
import h5py
import pandas as pd
import matplotlib.pyplot as plt
import random
import math
import pydicom

#System
import os
import glob
import time
import multiprocess as mp
import torch.multiprocessing as tmp
import gc
import copy

#Pytorch
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader

#Sklearn
from sklearn.model_selection import train_test_split

## Settings

In [2]:
folder="/mnt/idms/PROJECTS/Lung/LungCT/Super-resolution"

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
tmp.set_start_method('spawn',force='True')

In [5]:
number_of_epochs = 80
learning_rate = 0.05
batch_size=16

## Helper functions

In [34]:
def plot_slices(ct_):
    ct, path=ct_
    print(path)
    for sl in ct:
        plt.figure()
        plt.imshow(sl.pixel_array, cmap="gray")
        plt.show()

## Reading of Ulyssys input data

In [None]:
def create_triplets_from_case(case_path, idx):
    """
    Ideas for improvement:
    -Why do we have to double the number of slices? Couldn't we predict not 1, but more in-between?
    -Why to take just the neighbouring ones?
    -Now the triplets are (1,2,3),(4,5,6),(7,8,9)..., but could get more triplets by (1,2,3), (2,3,4), (3,4,5)...
    """
    npz=np.load(case_path)
    ct=npz[npz.files[0]]
    ct=np.moveaxis(ct,-1,0)
    ct=ct[:(len(ct)-(len(ct)%3))]
    case_triplets=np.asarray(np.split(ct,len(ct)//3))
    case_triplets=(case_triplets-np.min(case_triplets))/(np.max(case_triplets)-np.min(case_triplets))
    print(f"Case {idx+1} is done")
    return case_triplets
    
def create_all_triplets(parallel_jobs=mp.cpu_count(), max_cases=None):#, save=False):
    ct_path="/mnt/idms/PROJECTS/Lung/Tudo-Ulyssys-Unzipped"
    case_paths=list(glob.glob(f'{ct_path}/*/*.npz'))
    if max_cases:
        case_paths=case_paths[:max_cases]
    with mp.Pool(parallel_jobs) as pool: #get_context("spawn")
        triplets = pool.starmap(create_triplets_from_case, zip(case_paths,list(range(len(case_paths)))))
    pool.join() #Maybe helps

    """if save:
        with h5py.File(f"{folder}/data/input_data.h5", 'w') as h5f:
            h5f.create_dataset('triplets_per_patients', data=np.asarray(triplets))#, dtype=object))
    """
    return triplets

In [None]:
%%time
triplets=create_all_triplets(max_cases=1, parallel_jobs=5) #VEEEERY SLOW WITH MORE MAX_CASES, 5000 DOESN'T EVEN RUN
#I COULDN'T FIND THE EXACT CAUSE, BUT IT CAN DEADLOCK WITH TOO MANY PARALLEL_JOBS AND MAX_CASES

## Dataloaders

In [None]:
class Triplets(Dataset):
    
    def __init__(self, data, triplets_per_case=20, transform=None):
        self.data = []
        self.len=len(data)*triplets_per_case
        for case in data:
            self.data+=list(case[np.random.choice(case.shape[0], triplets_per_case, replace=False),:,:,:])
        self.data=np.asarray(self.data)
        self.transform = transform
            
        
    def __getitem__(self, idx):
        xs=(self.data[idx,0,:,:],self.data[idx,2,:,:])
        y=self.data[idx,1,:,:]
        if self.transform:
            xs=tuple(self.transform(x) for x in xs)
            y=self.transform(y)
        return xs, y
    
    def __len__ (self):
        return self.len

In [None]:
train_transform = transforms.Compose([ #Now composition is unnecessary, but maybe more transforms will be added
    transforms.ToTensor()
])

test_transform = transforms.Compose([
    transforms.ToTensor()
])
    
train_dataset_raw, test_dataset_raw = train_test_split(triplets, test_size=0.2, shuffle=True)
    
train_dataset = Triplets(train_dataset_raw, transform=train_transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    
test_dataset = Triplets(test_dataset_raw, transform=test_transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

## Neural network

In [9]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, 9, padding = 4)
        self.conv2 = nn.Conv2d(64, 32, 1, padding = 0)
        self.conv3 = nn.Conv2d(32, 32, 5, padding = 2)
        self.conv4 = nn.Conv2d(32, 1, 1, padding = 0)
        
    def forward(self, x1, x3):
        x1 = F.selu(self.conv1(x1))
        x1 = F.selu(self.conv2(x1))
        x1 = F.selu(self.conv3(x1))
        x3 = F.selu(self.conv1(x3))
        x3 = F.selu(self.conv2(x3))
        x3 = F.selu(self.conv3(x3))
        out = torch.cat((x1, x3), dim=1)
        out = F.selu(self.conv2(out))
        out = F.selu(self.conv3(out))
        out = F.selu(self.conv4(out))
        return out

## Training

In [None]:
def train_step(model,device,optimizer,criterion,train_loader, log_freq=5):
    
    for batch_idx, (xs, ys) in enumerate(train_loader):
        xs=[x.to(device, dtype=torch.float) for x in xs]
        ys=ys.to(device, dtype=torch.float)
        outputs = model(xs[0],xs[1])
        loss = criterion(outputs, ys)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch_idx % log_freq == 0:
            print(f"Batch {batch_idx+1} / {len(train_loader)} | Loss = {loss.item()}")

In [None]:
def evaluate(model,device=device,criterion=nn.SmoothL1Loss(),test_loader=test_loader, plot_result=False, plot_error=True):
    #plot_error only applies if plot_result is True
    total_loss=0
    with torch.no_grad():
        for batch_idx, (xs, ys) in enumerate(test_loader):
            xs=[x.to(device, dtype=torch.float) for x in xs]
            ys=ys.to(device, dtype=torch.float)
            pred_tensor = model(xs[0],xs[1])
            loss = criterion(pred_tensor, ys)
            total_loss+=loss.item()
            
            if plot_result:
                
                pred_np=pred_tensor[0,0,:,:].cpu().detach().numpy()
                x1_np=xs[0][0,0,:,:].cpu().detach().numpy()
                x3_np=xs[1][0,0,:,:].cpu().detach().numpy()
                y_np=ys[0,0,:,:].cpu().detach().numpy()
                
                components=5 if plot_error else 4
                fig, axs = plt.subplots(1, components, figsize=(8*components,8))
                axs[0].set_title("Previous")
                axs[1].set_title("Next")
                axs[2].set_title("Correct middle")
                axs[3].set_title("Predicted middle")
                
                axs[0].imshow(x1_np,cmap=plt.cm.Greys_r)
                axs[1].imshow(x3_np,cmap=plt.cm.Greys_r)
                axs[2].imshow(y_np,cmap=plt.cm.Greys_r)
                axs[3].imshow(pred_np,cmap=plt.cm.Greys_r)
                
                if plot_error:
                    axs[4].set_title("Error")
                    axs[4].imshow(np.abs(pred_np-y_np),cmap=plt.cm.Greys_r)
                
                plt.show()
                
    avg_loss=total_loss/len(test_loader)
    print(f"The average test loss is {avg_loss}.")
    return avg_loss

In [None]:
def train_model(model = ConvNet(), criterion = nn.SmoothL1Loss(), optimizer = optim.SGD,
                  learning_rate=learning_rate, number_of_epochs=number_of_epochs,
                  train_loader=train_loader, test_loader=test_loader,
                  device=device, plot_loss=True):
    
    model=model.to(device)
    optimizer=optimizer(model.parameters(),lr=learning_rate)
    
    best_sofar=None
    test_losses=[]
    
    print("Training has started.")
    for epoch in range(number_of_epochs):
        print(f"EPOCH {epoch} HAS STARTED.")
        train_step(model,device,optimizer,criterion,train_loader)
        test_loss=evaluate(model,device,criterion,test_loader)
        test_losses.append(test_loss)
        if not best_sofar or test_loss<best_sofar:
            best_sofar=test_loss
            print(f"BEST TEST LOSS SO FAR HAS IMPROVED TO {test_loss}!")
    
    if plot_loss:
        plt.figure()
        plt.title("Test losses")
        plt.plot(test_losses)
        plt.show()
    print("Training has finished.")
    return model

In [None]:
trained_model=train_model()

## Saving the model

In [None]:
model_number=1
base_models_path="/mnt/idms/PROJECTS/Lung/Super-resolution/models"
while os.path.exists(f"{base_models_path}/{model_number}.pth"):
    model_number+=1
torch.save(trained_model, f"{base_models_path}/{model_number}.pth")

## Testing on same Ulyssys data

In [None]:
average_test_loss=evaluate(trained_model,plot_result=True)

In [None]:
for param in trained_model.parameters():
    print(param.data)

## Running on LIDC-IDRI data

In [5]:
def create_slices_from_case(case_path):
    print(f"CT {case_path} has started")
    slice_paths=list(glob.glob(f'{case_path}/*'))
    ct=[pydicom.dcmread(slice_path) for slice_path in slice_paths]
    ct=sorted(ct, key=lambda x:x.SliceLocation)
    print(f"CT {case_path} is done")
    return ct, case_path

In [6]:
def create_all_cases(parallel_jobs=mp.cpu_count(), max_cases=None):
    ct_path="/mnt/idms/PROJECTS/Lung/LIDC-IDRI"
    case_paths=list(glob.glob(f'{ct_path}/*'))
    if max_cases:
        case_paths=case_paths[:max_cases]
    with mp.Pool(parallel_jobs) as pool: #get_context("spawn")
        cases = pool.map(create_slices_from_case, case_paths)
    pool.join() #Maybe helps

    return cases

In [7]:
%%time
original_cts=create_all_cases(max_cases=1, parallel_jobs=1)

CT /mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0447 has started
CT /mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0447 is done
CPU times: user 785 ms, sys: 535 ms, total: 1.32 s
Wall time: 11.3 s


In [16]:
original_cts[0][0][0]

Dataset.file_meta -------------------------------
(0002, 0000) File Meta Information Group Length  UL: 204
(0002, 0001) File Meta Information Version       OB: b'\x00\x01'
(0002, 0002) Media Storage SOP Class UID         UI: CT Image Storage
(0002, 0003) Media Storage SOP Instance UID      UI: 1.3.6.1.4.1.14519.5.2.1.6279.6001.142768823987696751030166377050
(0002, 0010) Transfer Syntax UID                 UI: Implicit VR Little Endian
(0002, 0012) Implementation Class UID            UI: 1.3.6.1.4.1.22213.1.143
(0002, 0013) Implementation Version Name         SH: '0.5'
(0002, 0016) Source Application Entity Title     AE: 'POSDA'
-------------------------------------------------
(0008, 0005) Specific Character Set              CS: 'ISO_IR 100'
(0008, 0008) Image Type                          CS: ['ORIGINAL', 'PRIMARY', 'AXIAL']
(0008, 0012) Instance Creation Date              DA: '20000101'
(0008, 0013) Instance Creation Time              TM: '095239'
(0008, 0016) SOP Class UID          

In [10]:
loaded_model = torch.load("/mnt/idms/PROJECTS/Lung/LungCT/Super-resolution/models/4.pth")
loaded_model.eval()

ConvNet(
  (conv1): Conv2d(1, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (conv2): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (conv3): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv4): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
)

In [11]:
def create_super_resoluted(ct_):
    path=ct_[1]
    print(f"{path} has started.")
    ct=ct_[0]
    super_resoluted=[]
    for i in range(len(ct)-1):
        
        super_resoluted.append(ct[i])
        
        
        x1=transforms.ToTensor()(ct[i].pixel_array.astype(np.int16))
        x1=x1.unsqueeze(0)
        x1=x1.to(device, dtype=torch.float)
        prev_min=torch.min(x1).item()
        prev_max=torch.max(x1).item()
        x1=(x1-prev_min)/(prev_max-prev_min)
        
        x2=transforms.ToTensor()(ct[i+1].pixel_array.astype(np.int16))
        x2=x2.unsqueeze(0)
        x2=x2.to(device, dtype=torch.float)
        x2=(x2-torch.min(x2))/(torch.max(x2)-torch.min(x2))
        
        
        res=loaded_model(x1,x2)
        res_np=res.detach().cpu().squeeze().numpy()
        res_np=(prev_max-prev_min)*res_np+prev_min
        
        middle=copy.deepcopy(ct[i])
        middle.SliceLocation=(float(ct[i].SliceLocation)+float(ct[i+1].SliceLocation))/2
        middle.PixelData=res_np.astype(np.int16).tobytes()
        
        super_resoluted.append(middle)
        
        
        del x1
        del x2
        del res
        torch.cuda.empty_cache()
    
    super_resoluted.append(ct[-1])
    for sl in super_resoluted:
        sl.SliceThickness=float(sl.SliceThickness)/2
        
    
    print(f"{path} is done.")
    return super_resoluted, path

In [20]:
def create_all_super_resoluted_cts_parallel(original_cts, parallel_jobs=tmp.cpu_count()):
    #MULTIPROCESSING AND PYTORCH MULTIPROCESSING DON'T WORK IN JUPYTER NOTEBOOK SOMETIMES
    with tmp.Pool(parallel_jobs) as pool:
        super_resoluted_cts = pool.starmap(create_super_resoluted, original_cts)
    pool.join()
    return super_resoluted_cts

def create_all_super_resoluted_cts_sequential(original_cts):
    super_resoluted_cts=[]
    for ct in original_cts:
        curr=create_super_resoluted(ct)
        super_resoluted_cts.append(curr)
    return super_resoluted_cts

In [21]:
#super_resoluted_cts=create_all_super_resoluted_cts_parallel(original_cts, parallel_jobs=10)
super_resoluted_cts=create_all_super_resoluted_cts_sequential(original_cts)

/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0447 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0447 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0831 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0831 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0564 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0564 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0723 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0723 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0632 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0632 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0135 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0135 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0415 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0415 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0403 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0403 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-I

/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0311 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0896 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0896 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0877 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0877 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0206 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0206 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0455 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0455 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0383 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0383 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0884 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0884 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0175 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0175 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0182 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-I

In [27]:
super_resoluted_cts[0][1]

'/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0447'

## Save superresoluted CTs

In [32]:
def save_ct(ct, ct_path):
    print(f"{ct_path} has started.")
    l=ct_path.split("/")
    new_ct_path=f'{"/".join(l[:-1])}-SuperResoluted/{l[-1]}'
    if not os.path.exists(new_ct_path):
        os.mkdir(new_ct_path)
    for idx,dc in enumerate(ct):
        slice_path=f'{new_ct_path}/{idx}.dcm'
        dc.save_as(slice_path)
    
    print(f"{ct_path} is done.")

def save_all_cts(cts, parallel_jobs=mp.cpu_count()):
    
    print("Saving of superresoluted CTs has started.")
    
    with mp.Pool(parallel_jobs) as pool:
        _ = pool.starmap(save_ct, cts)
    pool.join()
    
    print("Saving of superresoluted CTs has finished.")

In [33]:
save_all_cts(super_resoluted_cts, parallel_jobs=40)

Saving superresoluted CTs has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0447 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0447 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0831 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0564 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0831 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0723 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0564 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0632 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0723 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0632 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0135 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0135 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0415 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0415 is done.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0403 has started.
/mnt/idms/PROJECTS/Lung/LIDC-IDRI/LIDC-IDRI-0403 is done.
/

In [None]:
plot_slices(super_resoluted_cts[0])

In [None]:
torch.cuda.empty_cache()

In [51]:
len(glob.glob("/mnt/idms/PROJECTS/Lung/sinogram_Domi/spr_0001L/out/*"))

55091