<a href="https://colab.research.google.com/github/rlsn/COPD_Classification/blob/main/nodule_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
!pip install SimpleITK

Collecting SimpleITK
  Downloading SimpleITK-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.7/52.7 MB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.3.1


In [8]:
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import PIL
from PIL import Image
import glob, os

datadir="datasets/luna16"

In [3]:
# download dataset
!mkdir -p $datadir
!wget -O $datadir/annotations.csv https://zenodo.org/records/3723295/files/annotations.csv?download=1


!wget -O $datadir/subset0.zip https://zenodo.org/records/3723295/files/subset0.zip?download=1
!unzip $datadir/subset0.zip -d $datadir
# for i in range(7):
#     !wget -O $datadir/subset$i.zip https://zenodo.org/records/3723295/files/subset$i.zip?download=1
#     !unzip $datadir/subset$i.zip -d $datadir
# for i in range(7,10):
#     !wget -O $datadir/subset$i.zip https://zenodo.org/records/4121926/files/subset$i.zip?download=1
#     !unzip $datadir/subset$i.zip -d $datadir

--2024-04-07 10:12:05--  https://zenodo.org/records/3723295/files/annotations.csv?download=1
Resolving zenodo.org (zenodo.org)... 188.184.103.159, 188.184.98.238, 188.185.79.172, ...
Connecting to zenodo.org (zenodo.org)|188.184.103.159|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 136986 (134K) [text/plain]
Saving to: ‘datasets/luna16/annotations.csv’


2024-04-07 10:12:06 (205 KB/s) - ‘datasets/luna16/annotations.csv’ saved [136986/136986]

--2024-04-07 10:12:06--  https://zenodo.org/records/3723295/files/subset0.zip?download=1
Resolving zenodo.org (zenodo.org)... 188.184.103.159, 188.184.98.238, 188.185.79.172, ...
Connecting to zenodo.org (zenodo.org)|188.184.103.159|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6811924508 (6.3G) [application/octet-stream]
Saving to: ‘datasets/luna16/subset0.zip’


2024-04-07 10:23:02 (9.92 MB/s) - ‘datasets/luna16/subset0.zip’ saved [6811924508/6811924508]

Archive:  datasets/luna16/subset0.z

# Dataset

In [51]:
def read_image(image_file):
    # Read the MetaImage file
    image = sitk.ReadImage(image_file, imageIO="MetaImageIO")
    image_array = sitk.GetArrayFromImage(image)

    # print the image's dimensions
    return image_array, np.array(image.GetOrigin()), np.array(image.GetSpacing())

def read_csv(fn):
    with open(fn,"r") as f:
        lines = [l.strip().split(",") for l in f.readlines()]
    return lines

def survey_dataset(datadir="."):
    data_split = dict()
    for i in range(10):
        files = glob.glob(f"{datadir}/subset{i}/*mhd")
        data_split[i]=files
    return data_split

def add_marker(img, bbox):
    low, high = bbox
    center = ((low+high)/2).astype(int)
    mark = np.zeros_like(img)
    new_img = np.copy(img)
    value = img.max() if new_img[center[0],center[1]]<(img.max()-img.min())/2 else img.min()
    new_img[low[0]:high[0],low[1]]=value
    new_img[low[0]:high[0],high[1]]=value
    new_img[low[0],low[1]:high[1]]=value
    new_img[high[0],low[1]:high[1]]=value
    return new_img

def convert_loc(coord, origin, space):
    displacement = np.array(coord[:3]).astype(float)-origin
    loc = np.round(displacement/space)[::-1]
    return loc

def convert_radius(coord, space):
    r = np.round(float(coord[-1])/2/space)[::-1]
    return r

def convert_bounding_box(coord, origin, space):
    center = convert_loc(coord, origin, space)
    rad = convert_radius(coord, space)
    low = np.round(center-rad).astype(int)
    high = np.round(center+rad).astype(int)
    return low, high

def mark_bbox(img, bbox):
    low, high = bbox
    marked_imgs = np.copy(img)
    for z in range(low[0],high[0]+1):
        marked_imgs[z] = add_marker(img[z],(low[1:],high[1:]))
    return marked_imgs

def export_as_gif(filename, image_array, frames_per_second=20, rubber_band=False):
    images = []
    image_array = (image_array-image_array.min())/(image_array.max()-image_array.min())
    for arr in image_array:
        im = Image.fromarray(np.uint8(arr*255))
        images.append(im)
    if rubber_band:
        images += images[2:-1][::-1]
    images[0].save(
        filename,
        save_all=True,
        append_images=images[1:],
        duration=1000 // frames_per_second,
        loop=0,
    )

In [53]:
from torch.utils.data import Dataset

def getUID(filename):
    return os.path.basename(filename)[:-4]

def random_crop_3D(img, crop_size):
    size = np.array(img.shape)
    high = size-crop_size
    start = [np.random.randint(0, high=high[0]),
           np.random.randint(0, high=high[1]),
           np.random.randint(0, high=high[2])]
    return img[start[0]:start[0]+crop_size[0],
               start[1]:start[1]+crop_size[1],
               start[2]:start[2]+crop_size[2]]

def random_crop_around_3D(img, bbox, crop_size, margin=[5,20,20]):
    blow, bhigh = bbox
    margin = np.array(margin)
    low = bhigh+margin-crop_size
    high = blow-margin
    offset = [np.random.randint(low[0], high=high[0]),
           np.random.randint(low[1], high=high[1]),
           np.random.randint(low[2], high=high[2])]
    return img[offset[0]:offset[0]+crop_size[0],
               offset[1]:offset[1]+crop_size[1],
               offset[2]:offset[2]+crop_size[2]], offset

def random_flip(img, axis):
    if np.random.rand()<0.5:
        return np.flip(img, axis=axis)
    else:
        return img

class LUNA16_Dataset(Dataset):
    """
    https://luna16.grand-challenge.org/
    """
    def __init__(self, split=None, data_dir=".", crop_size=[40,128,128], patch_size=[4,16,16]):
        annotations_csv = read_csv(f"{data_dir}/annotations.csv")[1:]
        data_subsets = survey_dataset(data_dir)
        # to filenames
        if split is None:
            split = np.arange(10) # all subsets
        self.filenames = []
        for s in split:
            self.filenames+=data_subsets[s]
        # annotation to dict
        self.annotations = dict([(getUID(k),[]) for k in self.filenames])
        for entry in annotations_csv:
            self.annotations.setdefault(entry[0], [])
            self.annotations[entry[0]]+=[entry[1:]]

        self.crop_size = np.array(crop_size)
        self.patch_size = np.array(patch_size)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fn = self.filenames[idx]
        uid = getUID(fn)
        image, origin, space = read_image(fn)
        coords = self.annotations[uid]
        patch_size_mm = self.patch_size * space[::-1]

        result = dict()
        if len(coords)>0 and np.random.rand()<1:
            # crop with a nodule
            target_idx = np.random.randint(len(coords))
            coord = coords[target_idx]

            bbox = convert_bounding_box(coord, origin, space)
            cropped_img, offset = random_crop_around_3D(image, bbox, self.crop_size)
            offset_bbox = bbox[0] - offset, bbox[1] - offset

            # for debugging
            marked_imgs = mark_bbox(cropped_img, offset_bbox)
            result["marked_imgs"]=marked_imgs

        else:
            # random crop
            cropped_img = random_crop_3D(image, self.crop_size)

        # random flip
        cropped_img = random_flip(cropped_img, 0)
        cropped_img = random_flip(cropped_img, 1)
        cropped_img = random_flip(cropped_img, 2)

        result["pixel_values"]=cropped_img
        return result


dataset = LUNA16_Dataset(data_dir=datadir)
re = dataset[23]
image = re["pixel_values"]
export_as_gif("ct.gif",image)
if "marked_imgs" in re:
    export_as_gif("ct_marked.gif",re["marked_imgs"])

TypeError: type numpy.ndarray doesn't define __round__ method

# Model

In [6]:
from transformers import ViTModel, ViTConfig

