## Segment Training Images for Classification

Please run create_COVIDx_v2_RICORD.ipynb to create a Covid-Net style dataset before running this notebook.

In [1]:
from fastai import *
from fastai.vision.all import *
from fastai.vision.core import *
from fastai.vision.models import resnet34, resnet18
from fastai.data.transforms import *
from pathlib import Path
import numpy as np
import cv2
import os
import gc
import sys
import shutil
from sklearn.model_selection import train_test_split
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from segmentation_data import *
import torch
from PIL import Image
import pandas as pd

model_file = "/data/output/NLM_Shenzhen_Montgomery-06/segmentation-model.pkl"
dataset_dir = "/data/datasets/extended_v2"
dataset_prefixes = ["train", "test", "evaluate"]
class_map = {
    'COVID-19': 'COVID-19',
    'normal': 'Normal',
    'pneumonia': 'Pneumonia'
}
out_dir = "/data/datasets/extended_v2_Masked"

shutil.rmtree(out_dir, ignore_errors=True)

In [2]:
def save_mask(img, pred, mask, bounds, confidence, saveas):
    fig, ax = plt.subplots(figsize=(4, 4))
    
    img = np.transpose(img, (1, 2, 0))
    img = (img - torch.min(img)) / (torch.max(img) - torch.min(img))
    
    ax.imshow(img)
    rect = patches.Rectangle((bounds[0], bounds[1]), bounds[2] - bounds[0], bounds[3] - bounds[1], linewidth=1, edgecolor='yellow', facecolor='none')
    ax.add_patch(rect)
    ax.imshow(mask, cmap='coolwarm', alpha=0.3)
    fig.suptitle(f"[{confidence:.4f}]")
    plt.savefig(saveas)
    plt.close(fig)
    
def show_mask(img, pred, mask, bounds, confidence):
    fig, ax = plt.subplots(1, 3, figsize=(12, 4))
    
    img = np.transpose(img, (1, 2, 0))
    img = (img - torch.min(img)) / (torch.max(img) - torch.min(img))
    
    ax[0].imshow(img)
    ax[1].imshow(mask, cmap='gray')
    ax[2].imshow(img)
    rect = patches.Rectangle((bounds[0], bounds[1]), bounds[2] - bounds[0], bounds[3] - bounds[1], linewidth=1, edgecolor='yellow', facecolor='none')
    ax[2].add_patch(rect)
    ax[2].imshow(mask, cmap='coolwarm', alpha=0.3)
    fig.suptitle(f"[{confidence:.4f}]")
    plt.show()
    plt.close(fig)

def add_dir(dataset_dir, prefix, learn):
    img_dir = os.path.join(dataset_dir, prefix)
    
    with open(os.path.join(dataset_dir, f"{prefix}.txt"), "r") as dsfile:
        ds = [l.split() for l in dsfile.readlines()]
    
    regions = []
    
    for k, v in class_map.items():
        os.makedirs(os.path.join(out_dir, prefix, v), exist_ok=True)
        os.makedirs(os.path.join(out_dir, "segmentation", prefix, v), exist_ok=True)

    step_size = 256
#     n_files = 64
    n_files = len(ds)
    
    for i in range(0, n_files, step_size):
        i_end = min(i+step_size, n_files)
        ds_slice = ds[i:i_end]
        xray_files = [os.path.join(img_dir, l[1]) for l in ds_slice]
        dl = learn.dls.test_dl(xray_files, bs=64)
        inputs, preds, _, masks = learn.get_preds(dl=dl, with_input=True, with_decoded=True)

        for in_file, img, pred, mask, line in zip(xray_files, inputs, preds, masks, ds_slice):
            out_class = class_map[line[2]]
            
            out_file = os.path.join(out_dir, prefix, out_class, line[1])
            mask_file = os.path.join(out_dir, "segmentation", prefix, out_class, line[1])

            bounds = find_bounds(mask, padding=0.05)
            confidence = compute_confidence_score(pred, mask)
#             show_mask(img, pred, mask, bounds, confidence)
            save_mask(img, pred, mask, bounds, confidence, mask_file)

            os.link(in_file, out_file)
            im = Image.open(out_file)
            width, height = im.size

            real_bounds = (
                int(bounds[0] / mask.shape[1] * width),
                int(bounds[1] / mask.shape[0] * height),
                int((bounds[2] - bounds[0]) / mask.shape[1] * width),
                int((bounds[3] - bounds[1]) / mask.shape[0] * height)
            )
            regions.append((out_file, *real_bounds, confidence, out_class))
        
        print(f"{i_end}/{n_files}")
        gc.collect()
    
    return regions

In [3]:
learn = load_learner(model_file, cpu=False)
learn.freeze()

regions = []
for prefix in dataset_prefixes:
    print(prefix)
    regions += add_dir(dataset_dir, prefix, learn)

# print(regions)

train


256/14552


512/14552


768/14552


1024/14552


1280/14552


1536/14552


1792/14552


2048/14552


2304/14552


2560/14552


2816/14552


3072/14552


3328/14552


3584/14552


3840/14552


4096/14552


4352/14552


4608/14552


4864/14552


5120/14552


5376/14552


5632/14552


5888/14552


6144/14552


6400/14552


6656/14552


6912/14552


7168/14552


7424/14552


7680/14552


7936/14552


8192/14552


8448/14552


8704/14552


8960/14552


9216/14552


9472/14552


9728/14552


9984/14552


10240/14552


10496/14552


10752/14552


11008/14552


11264/14552


11520/14552


11776/14552


12032/14552


12288/14552


12544/14552


12800/14552


13056/14552


13312/14552


13568/14552


13824/14552


14080/14552


14336/14552


14552/14552
test


256/1579


512/1579


768/1579


1024/1579


1280/1579


1536/1579


1579/1579
evaluate


147/147


In [4]:
df = pd.DataFrame(columns=["file", "x", "y", "width", "height", "confidence", "class"], data=regions)
df.to_csv(os.path.join(out_dir, "regions.csv"), index=False)

In [5]:
np.where(df["confidence"] < 0.9, 1, 0).sum()

657