In [1]:
import numpy as np
import pandas as pd
import pydicom
from matplotlib import pyplot as plt
import os
from mask_functions import rle2mask
from google.cloud import storage

In [2]:
bucket = storage.Client().get_bucket("pneumothorax_chest_x-rays")

In [3]:
blobs = list(bucket.list_blobs(prefix='train/'))

In [4]:
len(blobs)

10675

In [5]:
blobs_2 = list(bucket.list_blobs(prefix='test/'))

In [None]:
len(blobs_2)

1372

In [None]:
df = pd.read_csv('./stage_2/stage_2_train.csv')
df.head()

Unnamed: 0,ImageId,EncodedPixels
0,1.2.276.0.7230010.3.1.4.8323329.6904.151787520...,-1
1,1.2.276.0.7230010.3.1.4.8323329.13666.15178752...,557374 2 1015 8 1009 14 1002 20 997 26 990 32 ...
2,1.2.276.0.7230010.3.1.4.8323329.11028.15178752...,-1
3,1.2.276.0.7230010.3.1.4.8323329.10366.15178752...,514175 10 1008 29 994 30 993 32 991 33 990 34 ...
4,1.2.276.0.7230010.3.1.4.8323329.10016.15178752...,592184 33 976 58 956 73 941 88 926 102 917 109...


In [None]:
import torch
import torch.utils.data
import collections
from tqdm import tqdm_notebook
from PIL import Image

class TorchDataset(torch.utils.data.Dataset):
    def __init__(self, blobs, blobs_2, df, transforms=None):
        self.blobs = blobs
        self.blobs_2 = blobs_2
        self.transforms = transforms
        self.df = df
        
        self.image_info = collections.defaultdict(dict)
###        
        i = 0
        for blob in tqdm_notebook(self.blobs):
            file_path = blob.name
            img_id = file_path.split('/')[-1][:-4]
            
            df_temp = self.df.loc[self.df.ImageId == img_id]
            rle_list = list()
            for x in range(len(df_temp)):
                rle = df_temp.iloc[x, 1].strip()
                rle_list.append(rle)
            
            for y in rle_list:
                if y != '-1':
                    dcm_path = './dataset/{}.dcm'.format(img_id)
                    png_path = './dataset_png/{}.png'.format(img_id)

                    # if the dcm is not downloaded, download it.
                    if not os.path.exists(dcm_path):
                        with open(dcm_path, 'wb') as file_obj:
                            blob.download_to_file(file_obj)  
                    # if the png is not created, create it.
                    if not os.path.exists(png_path):
                        with open(png_path, 'wb') as file_obj:
                            dcm_data = pydicom.dcmread(dcm_path)#, force=True)
                            img = dcm_data.pixel_array
                            img_mem = Image.fromarray(img)
                            img_mem.save(file_obj)

                    self.image_info[i]["image_id"] = img_id
                    self.image_info[i]["image_path"] = png_path
                    self.image_info[i]["annotations"] = rle_list
                    i += 1
                    break

        print('first portion is', i)
###           
        j = 0
        for blob in tqdm_notebook(self.blobs_2):
            file_path = blob.name
            img_id = file_path.split('/')[-1][:-4]
            
            df_temp = self.df.loc[self.df.ImageId == img_id]
            rle_list = list()
            for x in range(len(df_temp)):
                rle = df_temp.iloc[x, 1].strip()
                rle_list.append(rle)
            
            for y in rle_list:
                if y != '-1':
                    dcm_path = './testset/{}.dcm'.format(img_id)
                    png_path = './testset_png/{}.png'.format(img_id)

                    # if the dcm is not downloaded, download it.
                    if not os.path.exists(dcm_path):
                        with open(dcm_path, 'wb') as file_obj:
                            blob.download_to_file(file_obj)  
                    # if the png is not created, create it.
                    if not os.path.exists(png_path):
                        with open(png_path, 'wb') as file_obj:
                            dcm_data = pydicom.dcmread(dcm_path)#, force=True)
                            img = dcm_data.pixel_array
                            img_mem = Image.fromarray(img)
                            img_mem.save(file_obj)

                    self.image_info[i]["image_id"] = img_id
                    self.image_info[i]["image_path"] = png_path
                    self.image_info[i]["annotations"] = rle_list
                    i += 1
                    j += 1
                    break

        print('second portion is', j)
        print('total train on', i)
###           


    def __getitem__(self, idx):
        # images
        img_path = self.image_info[idx]["image_path"]
        img = Image.open(img_path).convert("RGB")

        # masks
        rle_list = self.image_info[idx]["annotations"] 
        mask = np.zeros((1024 ,1024))
        for i, rle in enumerate(rle_list):
            mask += rle2mask(rle_list[i], 1024, 1024).T
        mask = (mask > 127).astype(np.uint8)
        
        obj_ids = np.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]
        # split the color-encoded mask into a set
        # of binary masks
        masks = mask == obj_ids[:, None, None]
        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target
    

    def __len__(self):
        return len(self.image_info)

In [None]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

      
def get_instance_segmentation_model(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

In [None]:
from engine import train_one_epoch, evaluate
import utils
import transforms as T


def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    transforms.append(T.ToTensor())
    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [None]:
# use our dataset and defined transformations
dataset = TorchDataset(blobs, blobs_2, df, get_transform(train=True))
# dataset_test = TorchDataset(blobs, blobs_2, df, get_transform(train=False))

HBox(children=(IntProgress(value=0, max=10675), HTML(value='')))


first portion is 2379


HBox(children=(IntProgress(value=0, max=1372), HTML(value='')))


second portion is 290
total train on 2669


In [None]:
len(dataset)

2669

In [None]:
dataset[2668]

(tensor([[[0.0157, 0.0118, 0.0118,  ..., 0.0078, 0.0000, 0.0000],
          [0.0196, 0.0196, 0.0157,  ..., 0.0078, 0.0039, 0.0000],
          [0.0196, 0.0196, 0.0196,  ..., 0.0118, 0.0039, 0.0000],
          ...,
          [0.1922, 0.1922, 0.1961,  ..., 0.3882, 0.1255, 0.0000],
          [0.1569, 0.1608, 0.1608,  ..., 0.2902, 0.1255, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
 
         [[0.0157, 0.0118, 0.0118,  ..., 0.0078, 0.0000, 0.0000],
          [0.0196, 0.0196, 0.0157,  ..., 0.0078, 0.0039, 0.0000],
          [0.0196, 0.0196, 0.0196,  ..., 0.0118, 0.0039, 0.0000],
          ...,
          [0.1922, 0.1922, 0.1961,  ..., 0.3882, 0.1255, 0.0000],
          [0.1569, 0.1608, 0.1608,  ..., 0.2902, 0.1255, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
 
         [[0.0157, 0.0118, 0.0118,  ..., 0.0078, 0.0000, 0.0000],
          [0.0196, 0.0196, 0.0157,  ..., 0.0078, 0.0039, 0.0000],
          [0.0196, 0.0196, 0.0196,  ...,

In [None]:
# # split the dataset in train and test set
# torch.manual_seed(1)
# indices = torch.randperm(len(dataset)).tolist()
# dataset = torch.utils.data.Subset(dataset, indices[:-50])
# dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

In [None]:
# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=4,
    collate_fn=utils.collate_fn)

# data_loader_test = torch.utils.data.DataLoader(
#     dataset_test, batch_size=1, shuffle=False, num_workers=4,
#     collate_fn=utils.collate_fn)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# our dataset has two classes only - background and person
num_classes = 2

# get the model using our helper function
model = get_instance_segmentation_model(num_classes)
# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)

# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

In [None]:
print(device)

cuda


In [None]:
# let's train it for 10 epochs
num_epochs = 10

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
#     evaluate(model, data_loader_test, device=device)

Epoch: [0]  [   0/1335]  eta: 0:56:04  lr: 0.000010  loss: 4.1696 (4.1696)  loss_classifier: 0.8585 (0.8585)  loss_box_reg: 0.0408 (0.0408)  loss_mask: 3.1458 (3.1458)  loss_objectness: 0.1156 (0.1156)  loss_rpn_box_reg: 0.0090 (0.0090)  time: 2.5205  data: 0.7106  max mem: 2247
Epoch: [0]  [  10/1335]  eta: 0:27:48  lr: 0.000060  loss: 2.5744 (2.5019)  loss_classifier: 0.7981 (0.7653)  loss_box_reg: 0.0415 (0.0419)  loss_mask: 1.6522 (1.6283)  loss_objectness: 0.0522 (0.0593)  loss_rpn_box_reg: 0.0061 (0.0071)  time: 1.2593  data: 0.0733  max mem: 2898
Epoch: [0]  [  20/1335]  eta: 0:26:21  lr: 0.000110  loss: 1.7933 (2.0244)  loss_classifier: 0.4700 (0.5317)  loss_box_reg: 0.0338 (0.0432)  loss_mask: 1.0467 (1.3779)  loss_objectness: 0.0530 (0.0640)  loss_rpn_box_reg: 0.0063 (0.0075)  time: 1.1366  data: 0.0101  max mem: 2898
Epoch: [0]  [  30/1335]  eta: 0:25:43  lr: 0.000160  loss: 0.9262 (1.6599)  loss_classifier: 0.1275 (0.3880)  loss_box_reg: 0.0338 (0.0425)  loss_mask: 0.6881 (

In [None]:
with open('./model/model_para_0831.pt', 'wb') as f:
    torch.save(model.state_dict(), f)