# Label Extractor

With this label extractor, we could label (part of) blood cell and impurity with morphological features. Filter parameters should be tuned to make the label set relatively clean. These samples are simple, but they do contain visional features that could be extracted by deep CNN network.  
Then we could easily label the rest samples by hand, by pointing out the cell, impurity and mixture.  

Then, use the labels to train deep CNN classifer network.  
The classifier could be used as an encoder, to map image into embedding space.  

Then we train a GMM (with the embeddings), and aim for the following samples:  
1. samples with low probability  
2. samples with relatively high probability for multiple clusters  

In both situations, the focused samples are hard. We could label them and boost the classifier by finetuning.  

However, we may find that some classes are in the same cluster. That means these classes are indistinct. It is unlikely to happen, since we could differentiate (at least part of) them even with morphological features.


In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import sys
import os
import pickle
import argparse
import itertools
from datetime import datetime
import gc

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.nn.init as init
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

import cv2
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from tqdm import tqdm
import multiprocessing as mp

from datasets.simple import *
from resnet import *
from transforms import *
from plot import *

In [2]:
BINARY_THRESHOLD = 225

NOISE_THRESHOLD = 200
OPENING_KERNEL_SIZE = 5
HUGE_THRESHOLD = 30000

CIRCULARITY_THRESHOLD = 0.65

CELL_THRESHOLD = 7500

DBSCAN_EPS = 200
DBSCAN_SAMPLES = 5

In [3]:
SAMPLE_COUNT = -1

data_root = '../../data/chromosome'

sample_path = os.path.join(data_root, 'raw/sample')
# chunk_path = os.path.join(data_root, 'neg_chunk')

chunk_path = '/media/ssd-ext4/neg-chunk'

if not os.path.exists(chunk_path):
    os.mkdir(chunk_path)
    
image_list = os.listdir(sample_path)[:SAMPLE_COUNT]


In [4]:
class LabelExtractor:
    def __init__(
        self,
        binary_threshold=225,
        noise_threshold=200,
        opening_kernel_size=5,
        huge_threshold=30000,
        circularity_threshold=0.65,
        cell_threshold=7500,
        dbscan_eps=200,
        dbscan_samples=5
    ):
        self.binary_threshold = binary_threshold
        self.noise_threshold = noise_threshold
        self.opening_kernel_size = opening_kernel_size
        self.huge_threshold = huge_threshold
        self.circularity_threshold = circularity_threshold
        self.cell_threshold = cell_threshold
        self.dbscan_eps = dbscan_eps
        self.dbscan_samples = dbscan_samples
        
    def __call__(self, img):
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        b_img = binaryzation(gray, self.binary_threshold)
        o_img = opening(b_img, self.opening_kernel_size)
        clear_indexes, areas, contours = denoise(o_img, self.noise_threshold)
        indexes = huge_filter(clear_indexes, areas, huge_threshold=self.huge_threshold)
        indexes = cell_filter(indexes, areas, contours, self.circularity_threshold, self.cell_threshold)
        indexes = dbscan_filter(indexes, contours, self.dbscan_eps, self.dbscan_samples, [gray.shape[1], gray.shape[0]])
        
        label_indexes = []
        
        for clear_index in clear_indexes:
            if clear_index not in indexes:
                label_indexes.append(clear_index)       
        
        return gray, label_indexes, contours
    

In [5]:
extractor = LabelExtractor()

def extract_neg_chunks(raw_filename):
    img = cv2.imread(os.path.join(sample_path, raw_filename))
    gray, indexes, contours = extractor(img)
    
    arr = raw_filename.rsplit('.', 1)
    file_id = arr[0]
    
    neg_list = []
    
    for i in indexes:
        contour = contours[i]

        # create mask
        mask = np.zeros_like(gray)
        cv2.drawContours(mask, contours, i, 255, -1)

        # apply mask
        new_img = np.full_like(gray, 255)
        np.copyto(new_img, gray, where=(mask>127))

        # create bbox & crop roi
        x, y, w, h = cv2.boundingRect(contour)
        roi = new_img[y:y+h, x:x+w]

        # write to file
        filename = '{}_{}.jpg'.format(os.path.join(chunk_path, file_id), i)
        cv2.imwrite(filename, roi)
        
        neg_list.append(filename)
        
    return neg_list

In [6]:
pool = mp.Pool()

neg_lists = list(tqdm(
    pool.imap_unordered(extract_neg_chunks, image_list),
    total=len(image_list),
    file=sys.stdout
))

neg_record = {}
for image_filename, neg_list in zip(image_list, neg_lists):
    neg_record[image_filename] = neg_list
    
torch.save(neg_record, 'neg_record.pth')

100%|██████████| 83268/83268 [09:43<00:00, 142.75it/s]


In [22]:
# serial process, keep for debugging

neg_record = {}

for image_filename in image_list:
    neg_list = extract_neg_chunks(extractor, sample_path, chunk_path, image_filename)
    neg_record[image_filename] = neg_list
    
torch.save(neg_record, 'neg_record.pth')