# References - 
### data preprocessing - https://www.kaggle.com/carlossouza/end-to-end-model-ct-scans-tabular

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
import sys
sys.path.append("../input/utils-model/")
import os
import copy
from datetime import timedelta, datetime
import imageio
import matplotlib.pyplot as plt
from matplotlib import cm
import multiprocessing
from pathlib import Path
import pydicom
import pytest
import scipy.ndimage as ndimage
from scipy.ndimage.interpolation import zoom
from skimage import measure, morphology, segmentation
from time import time, sleep
from tqdm import trange
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DistributedSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import warnings
import glob
# You can write up to 5GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from models import SegModel

In [None]:
model = SegModel(1,1)
model = nn.DataParallel(model).cuda()

In [None]:
basepath = "../input/osic-pulmonary-fibrosis-progression"
root_dir = f"{basepath}/osic/"
test_dir = f"{basepath}/test/"
train_dir = f"{basepath}/train/"
# model_file = f'/kaggle/working/diophantus.pt'
resize_dims = (64, 256, 256)
clip_bounds = (-1000, 200)
watershed_iterations = 1
pre_calculated_mean = 0.02865046213070556
latent_features = 10
batch_size = 16
learning_rate = 3e-5
num_epochs = 10
val_size = 0.2
tensorboard_dir = './runs'
tr = pd.read_csv(f"{basepath}/train.csv")

In [None]:
import copy
from datetime import timedelta, datetime
import imageio
import matplotlib.pyplot as plt
from matplotlib import cm
import multiprocessing
import numpy as np
import os
from pathlib import Path
import pydicom
import pytest
import scipy.ndimage as ndimage
from scipy.ndimage.interpolation import zoom
from skimage import measure, morphology, segmentation
from time import time, sleep
from tqdm import trange
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DistributedSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import warnings
import glob

In [None]:
class CTScansDataset(Dataset):
    def __init__(self, root_dir, transform=None,transform2=None):
        self.root_dir = Path(root_dir)
        self.patients = [p for p in glob.glob(root_dir+'*')]
        self.patients.sort()
        self.transform = transform
        self.transform2 = transform2

    def __len__(self):
        return len(self.patients)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image, metadata = self.load_scan(self.patients[idx])
#         z,_,_ = image.shape
#         if z>48:
#             image = image[int(z/5):int(4*z/5)]
#             margin = z//2
#             image = image[margin-24:margin+24]
#             print(image.shape)
        
        sample = {'image': image, 'metadata': metadata}
#         if self.transform:
#             sample_mask = self.transform(sample)
        if self.transform2:
            sample1_image = self.transform2(sample)    
        sample = {'image': sample1_image['image'], 
                  'metadata': sample1_image['metadata']}
        return sample

    def save(self, path):
        t0 = time()
        Path(path).mkdir(exist_ok=True, parents=True)
        print('Saving pre-processed dataset to disk')
        sleep(1)
        cum = 0

        bar = trange(len(self))
        for i in bar:
            sample = self[i]
            image, data = sample['image'], sample['metadata']
            cum += torch.mean(image).item()

            bar.set_description(f'Saving CT scan {data.PatientID}')
            fname = Path(path) / f'{data.PatientID}.pt'
            torch.save(image, fname)

        sleep(1)
        bar.close()
        print(f'Done! Time {timedelta(seconds=time() - t0)}\n'
              f'Mean value: {cum / len(self)}')

    def get_patient(self, patient_id):
        patient_ids = [str(p.stem) for p in self.patients]
        return self.__getitem__(patient_ids.index(patient_id))

    @staticmethod
    def load_scan(path):
        slices = [pydicom.dcmread(path + "/" + file) for file in os.listdir(path)]
        slices.sort(key = lambda x: float(x.ImagePositionPatient[2]))
        image = np.stack([s.pixel_array.astype(float) for s in slices])
        return image, slices[0]

In [None]:
class ConvertToHU:
    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
#         img_type = data.ImageType
#         is_hu = img_type[0] == 'ORIGINAL' and not (img_type[2] == 'LOCALIZER')
        # if not is_hu:
        #     warnings.warn(f'Patient {data.PatientID} CT Scan not cannot be'
        #                   f'converted to Hounsfield Units (HU).')
        intercept = data.RescaleIntercept if 'RescaleIntercept' in data else -1024
        slope = data.RescaleSlope if 'RescaleSlope' in data else 1
#         intercept = data.RescaleIntercept
#         slope = data.RescaleSlope
        image = (image * slope + intercept).astype(np.int16)
        return {'image': image, 'metadata': data}
class Resize:
    def __init__(self, output_size):
        assert isinstance(output_size, tuple)
        self.output_size = output_size

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
        resize_factor = np.array(self.output_size) / np.array(image.shape)
        image = zoom(image, resize_factor, mode='nearest')
        return {'image': image, 'metadata': data}
class Clip:
    def __init__(self, bounds=(-1000, 500)):
        self.min = min(bounds)
        self.max = max(bounds)

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
        image[image < self.min] = self.min
        image[image > self.max] = self.max
        return {'image': image, 'metadata': data}
    
    
class Normalize:
    def __init__(self, bounds=(-1000, 500)):
        self.min = min(bounds)
        self.max = max(bounds)

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
        image = image.astype(np.float)
        image = (image - self.min) / (self.max - self.min)
        return {'image': image, 'metadata': data}
    

class ToTensor:
    def __init__(self, add_channel=True):
        self.add_channel = add_channel

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
        if self.add_channel:
            image = np.expand_dims(image, axis=0)

        return {'image': torch.from_numpy(image), 'metadata': data}
    
    
class ZeroCenter:
    def __init__(self, pre_calculated_mean):
        self.pre_calculated_mean = pre_calculated_mean

    def __call__(self, tensor):
        return tensor - self.pre_calculated_mean
    

In [None]:
test = CTScansDataset(
    root_dir=train_dir,
    transform2=transforms.Compose([
        ConvertToHU(),
        Resize(resize_dims),
        Clip(bounds=clip_bounds),
        Normalize(bounds=clip_bounds)
    ])

)

In [None]:
model_path = glob.glob('../input/*/best.pt')[0]
model.load_state_dict(torch.load(model_path))

In [None]:
list_imgs = test[2]
image = list_imgs['image']
image_tensor = torch.tensor(image, dtype=torch.float32).cuda().unsqueeze(0).unsqueeze(0)
with torch.no_grad():
    z = model(image_tensor)
    z = torch.sigmoid(z)[0,0].cpu().numpy()
plt.figure(figsize=[20,20])
for row in range(23,32):
    plt.subplot(3,3,row-22)
    plt.imshow(image[row],cmap='gray')
    plt.imshow(z[row],alpha=0.5,cmap='hot')
plt.show()

In [None]:
list_imgs = test[80]
image = list_imgs['image']
image_tensor = torch.tensor(image, dtype=torch.float32).cuda().unsqueeze(0).unsqueeze(0)
with torch.no_grad():
    z = model(image_tensor)
    z = torch.sigmoid(z)[0,0].cpu().numpy()
plt.figure(figsize=[20,20])
for row in range(23,32):
    plt.subplot(3,3,row-22)
    plt.imshow(image[row],cmap='gray')
    plt.imshow(z[row],alpha=0.5,cmap='hot')
plt.show()

In [None]:
list_imgs = test[100]
image = list_imgs['image']
image_tensor = torch.tensor(image, dtype=torch.float32).cuda().unsqueeze(0).unsqueeze(0)
with torch.no_grad():
    z = model(image_tensor)
    z = torch.sigmoid(z)[0,0].cpu().numpy()
plt.figure(figsize=[20,20])
for row in range(23,32):
    plt.subplot(3,3,row-22)
    plt.imshow(image[row],cmap='gray')
    plt.imshow(z[row],alpha=0.5,cmap='hot')
plt.show()

In [None]:
!rm -rf *