<a href="https://colab.research.google.com/github/rlsn/COPD_Classification/blob/main/nodule_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
!pip install SimpleITK
!pip install transformers[torch]
!pip install accelerate -U

In [2]:
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import PIL
from PIL import Image
import glob, os

datadir="datasets/luna16"

In [None]:
# download dataset
!mkdir -p $datadir
!wget -O $datadir/annotations.csv https://zenodo.org/records/3723295/files/annotations.csv?download=1


!wget -O $datadir/subset0.zip https://zenodo.org/records/3723295/files/subset0.zip?download=1
!unzip $datadir/subset0.zip -d $datadir
# for i in range(7):
#     !wget -O $datadir/subset$i.zip https://zenodo.org/records/3723295/files/subset$i.zip?download=1
#     !unzip $datadir/subset$i.zip -d $datadir
# for i in range(7,10):
#     !wget -O $datadir/subset$i.zip https://zenodo.org/records/4121926/files/subset$i.zip?download=1
#     !unzip $datadir/subset$i.zip -d $datadir

# Dataset

In [3]:
def read_image(image_file):
    # Read the MetaImage file
    image = sitk.ReadImage(image_file, imageIO="MetaImageIO")
    image_array = sitk.GetArrayFromImage(image)

    # print the image's dimensions
    return image_array, np.array(image.GetOrigin()), np.array(image.GetSpacing())

def read_csv(fn):
    with open(fn,"r") as f:
        lines = [l.strip().split(",") for l in f.readlines()]
    return lines

def survey_dataset(datadir="."):
    data_split = dict()
    for i in range(10):
        files = glob.glob(f"{datadir}/subset{i}/*mhd")
        data_split[i]=files
    return data_split

def add_marker(img, bbox):
    low, high = bbox
    center = ((low+high)/2).astype(int)
    mark = np.zeros_like(img)
    new_img = np.copy(img)
    value = img.max() if new_img[center[0],center[1]]<(img.max()-img.min())/2 else img.min()
    new_img[low[0]:high[0],low[1]]=value
    new_img[low[0]:high[0],high[1]]=value
    new_img[low[0],low[1]:high[1]]=value
    new_img[high[0],low[1]:high[1]]=value
    return new_img

def convert_loc(coord, origin, space):
    displacement = np.array(coord[:3]).astype(float)-origin
    loc = np.round(displacement/space)[::-1]
    return loc

def convert_radius(coord, space):
    r = np.round(float(coord[-1])/2/space)[::-1]
    return r

def convert_bounding_box(coord, origin, space):
    center = convert_loc(coord, origin, space)
    rad = convert_radius(coord, space)
    low = np.round(center-rad)
    high = np.round(center+rad)
    return low, high

def mark_bbox(img, bbox):
    low, high = bbox
    low=low.astype(int)
    high=high.astype(int)

    marked_imgs = np.copy(img)
    for z in range(low[0],high[0]+1):
        marked_imgs[z] = add_marker(img[z],(low[1:],high[1:]))
    return marked_imgs

def export_as_gif(filename, image_array, frames_per_second=10, rubber_band=False):
    images = []
    image_array = (image_array-image_array.min())/(image_array.max()-image_array.min())
    for arr in image_array:
        im = Image.fromarray(np.uint8(arr*255))
        images.append(im)
    if rubber_band:
        images += images[2:-1][::-1]
    images[0].save(
        filename,
        save_all=True,
        append_images=images[1:],
        duration=1000 // frames_per_second,
        loop=0,
    )

In [4]:
from torch.utils.data import Dataset
import torch

# compute mean and std
def compute_stats(dataset):
    N = 0
    sum = 0
    for fn in dataset.filenames:
        image,_,_=read_image(fn)
        sum += np.sum(image)
        N+=np.prod(image.shape)
    mean = sum/N
    N = 0
    sum = 0
    for fn in dataset.filenames:
        image,_,_=read_image(fn)
        sum += np.sum((image-mean)**2)
        N+=np.prod(image.shape)
    std = np.sqrt(sum/N)
    return mean, std

def getUID(filename):
    return os.path.basename(filename)[:-4]

def random_crop_3D(img, crop_size):
    size = np.array(img.shape)
    high = size-crop_size
    start = [np.random.randint(0, high=high[0]),
           np.random.randint(0, high=high[1]),
           np.random.randint(0, high=high[2])]
    return img[start[0]:start[0]+crop_size[0],
               start[1]:start[1]+crop_size[1],
               start[2]:start[2]+crop_size[2]]

def random_crop_around_3D(img, bbox, crop_size, margin=[5,20,20]):
    im_size = np.array(img.shape)
    blow, bhigh = bbox
    blow = blow.astype(int)
    bhigh = bhigh.astype(int)
    margin = np.array(margin)
    low = np.maximum(bhigh+margin-crop_size,0)
    high = np.minimum(blow-margin, im_size-crop_size)
    offset = [np.random.randint(low[0], high=high[0]),
           np.random.randint(low[1], high=high[1]),
           np.random.randint(low[2], high=high[2])]
    return img[offset[0]:offset[0]+crop_size[0],
               offset[1]:offset[1]+crop_size[1],
               offset[2]:offset[2]+crop_size[2]], offset

def random_flip(img, axis):
    if np.random.rand()<0.5:
        return np.flip(img, axis=axis)
    else:
        return img

class LUNA16_Dataset(Dataset):
    mean = -718.0491779355748
    std = 889.6629126452339
    """
    https://luna16.grand-challenge.org/
    """
    def __init__(self, split=None, data_dir=".", crop_size=[40,128,128], patch_size=[4,16,16], return_bbox=False):
        annotations_csv = read_csv(f"{data_dir}/annotations.csv")[1:]
        data_subsets = survey_dataset(data_dir)
        # to filenames
        if split is None:
            split = np.arange(10) # all subsets
        self.filenames = []
        for s in split:
            self.filenames+=data_subsets[s]
        # annotation to dict
        self.annotations = dict([(getUID(k),[]) for k in self.filenames])
        for entry in annotations_csv:
            self.annotations.setdefault(entry[0], [])
            self.annotations[entry[0]]+=[entry[1:]]

        self.crop_size = np.array(crop_size)
        self.patch_size = np.array(patch_size)

        self.return_bbox = return_bbox
    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fn = self.filenames[idx]
        uid = getUID(fn)
        image, origin, space = read_image(fn)
        coords = self.annotations[uid]
        patch_size_mm = self.patch_size * space[::-1]

        result = dict()

        if len(coords)>0 and np.random.rand()<1:
            # crop with a nodule
            target_idx = np.random.randint(len(coords))
            coord = coords[target_idx]

            bbox = convert_bounding_box(coord, origin, space)
            cropped_img, offset = random_crop_around_3D(image, bbox, self.crop_size)
            offset_bbox = bbox[0] - offset, bbox[1] - offset
            target = np.concatenate([offset_bbox[0]/self.crop_size, offset_bbox[1]/self.crop_size])

            result["label"] = torch.tensor(1)
            result["bbox"] = torch.tensor(target).to(torch.float32)

            # for debugging
            if self.return_bbox:
                marked_imgs = mark_bbox(cropped_img, offset_bbox)
                result["bbox_imgs"]=marked_imgs
        else:
            # random crop
            cropped_img = random_crop_3D(image, self.crop_size)
            result["label"] = torch.tensor(0)
            result["bbox"] = torch.zeros(6)


        # random flip
        pixel_values = random_flip(cropped_img, 0)
        pixel_values = random_flip(pixel_values, 1)
        pixel_values = random_flip(pixel_values, 2)

        # normalize
        pixel_values = (pixel_values-LUNA16_Dataset.mean)/LUNA16_Dataset.std

        # to tensor
        pixel_values = torch.tensor(pixel_values.copy()).to(torch.float32)
        # add channel dim
        pixel_values = pixel_values.unsqueeze(0)
        result["pixel_values"]=pixel_values
        return result


# dataset = LUNA16_Dataset(data_dir=datadir)
# re = dataset[56]
# image = re["pixel_values"]
# export_as_gif("ct.gif",image[0])
# if "bbox_imgs" in re:
#     export_as_gif("ct_marked.gif",re["bbox_imgs"])

In [5]:
from torch.utils.data import DataLoader

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.stack([example["label"] for example in examples])
    bbox = torch.stack([example["bbox"] for example in examples])


    return {"pixel_values": pixel_values, "labels": labels, "bbox":bbox}


# dataloader = DataLoader(dataset, collate_fn=collate_fn,batch_size=4)
# x=next(iter(dataloader))
# x["pixel_values"].shape,x["labels"].shape,x["bbox"].shape

# Model

In [6]:
from transformers import ViTModel, ViTConfig, PreTrainedModel
from transformers.utils import ModelOutput
from transformers.models.vit.modeling_vit import ViTPooler, ViTEncoder
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import torch.nn as nn

class Vit3DEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        num_patches = int(np.prod(np.array(config.image_size)/np.array(config.patch_size)))
        patch_dim = np.prod(config.patch_size)*config.num_channels
        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
        self.projection = nn.Conv3d(config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.config = config
    def forward(self, pixel_values):
        batch_size, num_channels, depth, height, width = pixel_values.shape
        embeddings = self.projection(pixel_values).flatten(2).transpose(1,2)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)
        embeddings = embeddings + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings

class VitDet(PreTrainedModel):
    def __init__(self, config, add_pooling_layer = True):
        super().__init__(config)
        self.embeddings = Vit3DEmbeddings(config)

        self.encoder = ViTEncoder(config)
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.pooler = ViTPooler(config) if add_pooling_layer else None
        self.classification_head = nn.Linear(config.hidden_size, config.num_labels)
        self.bbox_head = nn.Linear(config.hidden_size, 6)

        self.config = config
    def forward(self, pixel_values, labels=None, bbox=None):
        embeddings = self.embeddings(pixel_values)
        encoder_outputs = self.encoder(embeddings)
        sequence_output = encoder_outputs[0]
        sequence_output = self.layernorm(sequence_output)
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
        logits = self.classification_head(pooled_output)
        bbox_pred = self.bbox_head(pooled_output)


        if labels is not None and bbox is not None:
            loss_bbox_fn = MSELoss(reduction='none')
            if self.config.num_labels == 1:
                loss_cls_fn = BCEWithLogitsLoss()
                loss = loss_cls_fn(logits.view(-1), labels.float())
            else:
                loss_cls_fn = CrossEntropyLoss()
                loss = loss_cls_fn(logits, labels)

            mask = labels.unsqueeze(-1).bool()
            mse_loss = loss_bbox_fn(bbox_pred, bbox)*mask
            loss += mse_loss.mean()
        else:
            loss = None

        return ModelOutput(
            loss=loss,
            logits=logits,
            bbox=bbox_pred,
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
        )

In [7]:
config = ViTConfig(hidden_size=384,
                   num_hidden_layers=4,
                   num_attention_heads=6,
                   patch_size=[4,16,16],
                   image_size=[40,128,128],
                   num_channels=1,
                   num_labels=1)

model = VitDet(config)

In [8]:
class DummyDataset(Dataset):
    def __init__(object):
        super().__init__()
    def __len__(self):
        return 10
    def __getitem__(self,idx):
        result=dict()
        result["pixel_values"]=torch.randn(1,40,128,128)
        result["label"]=torch.randint(0,2,[])
        result["bbox"]=torch.randn(6)

        return result
train_dataset = DummyDataset()
valid_dataset = DummyDataset()

dataloader = DataLoader(train_dataset, collate_fn=collate_fn,batch_size=4)
x=next(iter(dataloader))
# model(**x)

# Train

In [94]:
from transformers import TrainingArguments, Trainer
from sklearn.metrics import f1_score

def iou_3d(bbox_pred,bbox):
    ilow = np.maximum(bbox_pred,bbox)[:,:3]
    ihigh = np.minimum(bbox_pred,bbox)[:,3:]
    i_sides = np.maximum(ihigh-ilow,0)
    i_vol = np.prod(i_sides,-1)
    o_vol = np.prod(bbox_pred[:,3:]-bbox_pred[:,:3],-1)+np.prod(bbox[:,3:]-bbox[:,:3],-1)-i_vol
    return (i_vol/o_vol).mean()

def compute_metrics(eval_pred):
    predictions, groundtruth = eval_pred
    logits = predictions[0]
    labels = groundtruth[0]

    mask = labels.astype(bool)
    bbox_pred = predictions[1][mask]
    bbox = groundtruth[1][mask]

    preds = (logits>0.5).astype(int)
    f1 = f1_score(labels, preds)
    if bbox.shape[0]>0:
        iou = iou_3d(bbox_pred,bbox)
    else:
        iou = 1.0
    return dict(f1=f1, iou=iou)

args = TrainingArguments(
        f"luna-train",
        save_strategy="steps",
        evaluation_strategy="steps",
        learning_rate=2e-5,
        per_device_train_batch_size=10,
        per_device_eval_batch_size=10,
        num_train_epochs=300,
        weight_decay=0.01,
        logging_steps=1,
        save_steps=1,
        save_total_limit=5,
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        logging_dir='logs',
        label_names=["labels","bbox"],
        remove_unused_columns=False,
    )

trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [100]:
# trainer.train()