In [6]:
import os
import json
import cv2
import openslide
import numpy as np
from skimage.color import rgb2hsv
from skimage.filters import threshold_otsu
from matplotlib import pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import time
from tqdm.auto import tqdm
import pandas as pd
import json
import xml.etree.ElementTree as ET
import copy
import torch
import torch.nn as nn
import math
import os
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from skimage.measure import points_in_poly
from sklearn.model_selection import train_test_split

In [7]:
class WSIDataset(Dataset):
    """Generate dataset."""
    def __init__(self, filepath):
        self.all_data = pd.read_csv(filepath)
        all_normal = self.all_data[self.all_data['label']=='Normal']
        all_tumor = self.all_data[self.all_data['label']=='Tumor']
        X = self.all_data['image'].tolist()
        y = self.all_data['label'].tolist()
        print('total ')
        print('normal vs tumor = %d vs %d'%(len(all_normal), len(all_tumor)))
        X_train, X_test, y_train, y_test = self.stage_split(X, y)
    def stage_split(self, X, y):
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=1)
        X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=1)  # 0.25 x 0.8 = 0.2
        print('Train detail: ')
        train_normal, train_tumor = self.Obtain_detail(y_train)
        print('normal vs tumor = %d vs %d'%(train_normal, train_tumor))
        print('Valid detail: ')
        train_normal, train_tumor = self.Obtain_detail(y_val)
        print('normal vs tumor = %d vs %d'%(train_normal, train_tumor))
        print('Test detail: ')
        test_normal, test_tumor = self.Obtain_detail(y_test)
        print('normal vs tumor = %d vs %d'%(test_normal, test_tumor))
        return X_train, X_test, y_train, y_test

    def Obtain_detail(self, y):
        normal = tumor = 0
        for i in y:
            if i == 'Normal':
                normal += 1
            else:
                tumor += 1
        return normal, tumor

wsi_data = WSIDataset('/home/congz3414050/HistoGCN/data/5X/all_data.csv')

total 
normal vs tumor = 57942 vs 6886
Train detail: 
normal vs tumor = 46943 vs 5567
Valid detail: 
normal vs tumor = 5246 vs 589
Test detail: 
normal vs tumor = 5753 vs 730


In [10]:
class Formatter(object):
    """
    Format converter e.g. CAMELYON16 to internal json
    """
    def camelyon16xml2json(self, inxml, outjson):
        """
        Convert an annotation of camelyon16 xml format into a json format.
        Arguments:
            inxml: string, path to the input camelyon16 xml format
            outjson: string, path to the output json format
        """
        root = ET.parse(inxml).getroot()
        annotations_tumor = \
            root.findall('./Annotations/Annotation[@PartOfGroup="Tumor"]')
        annotations_0 = \
            root.findall('./Annotations/Annotation[@PartOfGroup="_0"]')
        annotations_1 = \
            root.findall('./Annotations/Annotation[@PartOfGroup="_1"]')
        annotations_2 = \
            root.findall('./Annotations/Annotation[@PartOfGroup="_2"]')

        annotations_positive = \
            annotations_tumor + annotations_0 + annotations_1
        annotations_negative = annotations_2


        annotations_lunghp = \
            root.findall('./Annotations/Annotation[@PartOfGroup="None"]')

        annotations_cam17 = \
            root.findall('./Annotations/Annotation[@PartOfGroup="metastases"]')

        json_dict = {}
        json_dict['positive'] = []
        json_dict['negative'] = []

        for annotation in annotations_positive:
            X = list(map(lambda x: float(x.get('X')),
                     annotation.findall('./Coordinates/Coordinate')))
            Y = list(map(lambda x: float(x.get('Y')),
                     annotation.findall('./Coordinates/Coordinate')))
            vertices = np.round([X, Y]).astype(int).transpose().tolist()
            name = annotation.attrib['Name']
            json_dict['positive'].append({'name': name, 'vertices': vertices})

        for annotation in annotations_negative:
            X = list(map(lambda x: float(x.get('X')),
                     annotation.findall('./Coordinates/Coordinate')))
            Y = list(map(lambda x: float(x.get('Y')),
                     annotation.findall('./Coordinates/Coordinate')))
            vertices = np.round([X, Y]).astype(int).transpose().tolist()
            name = annotation.attrib['Name']
            json_dict['negative'].append({'name': name, 'vertices': vertices})

        with open(outjson, 'w') as f:
            json.dump(json_dict, f, indent=1)


In [11]:
Annotations = '/home/congz3414050/HistoGCN/data/Original/test_annotation'
json_path = '/home/congz3414050/HistoGCN/data/Original/test_annotation_json'

for f in os.listdir(Annotations):
    tumor_number = f.split('.')[0]
    tumor_file = os.path.join(Annotations, f)
    output_json = os.path.join(json_path, tumor_number + '.json')
    _ = Formatter().camelyon16xml2json(tumor_file, output_json)

In [12]:
class Slide_Patch(object):
    def __init__(self, dimension_level, slide_path, json_path, output_path, minRGB=50, patch_size=256, min_tissue=0.3):
        self.dimension_level = dimension_level
        self.slide_path = slide_path
        self.tumor_number = os.path.basename(self.slide_path).split('.')[0]
        self.json_path = json_path
        print('Reading ', self.slide_path)
        self.slide = openslide.OpenSlide(self.slide_path)
        w, h = self.slide.level_dimensions[self.dimension_level]
        self.mask_tumor = np.zeros((h, w))
        self.scale = self.slide.level_downsamples[self.dimension_level]
        print(self.scale)
        self.img_RGB = np.transpose(np.array(self.slide.read_region((0, 0),
                                                          self.dimension_level,
                                                          (w,h)).convert('RGB')), axes=[1, 0, 2])
        self.img_hsv = cv2.cvtColor(self.img_RGB, cv2.COLOR_BGR2HSV)

        self.outpath = output_path
        self.minRGB = minRGB
        self.patch_size = patch_size

        self.out_path = os.path.join(self.outpath, 'test_' + str(self.tumor_number))
        if not os.path.exists(self.out_path):
            os.mkdir(self.out_path)

        self.out_path_normal = os.path.join(self.out_path, 'Normal')
        if not os.path.exists(self.out_path_normal):
            os.mkdir(self.out_path_normal)

        self.out_path_tumor = os.path.join(self.out_path, 'Tumor')
        if not os.path.exists(self.out_path_tumor):
            os.mkdir(self.out_path_tumor)

        self.min_tumor_number = self.patch_size * self.patch_size * min_tissue
        self.thresh_cal()
        self.slide.close()

    def thresh_cal(self):
        print('==> calculate threshold')
        self.color_thresh_R = threshold_otsu(self.img_RGB[:, :, 0])
        self.color_thresh_G = threshold_otsu(self.img_RGB[:, :, 1])
        self.color_thresh_B = threshold_otsu(self.img_RGB[:, :, 2])
        self.color_thresh_H = threshold_otsu(self.img_hsv[:, :, 1])
        print('==> threshold done')

    def _tissue_mask(self, img=False, check=False):
        background_R = self.img_RGB[:, :, 0] > self.color_thresh_R
        background_G = self.img_RGB[:, :, 1] > self.color_thresh_G
        background_B = self.img_RGB[:, :, 2] > self.color_thresh_B
        tissue_RGB = np.logical_not(background_R & background_G & background_B)
        tissue_S = self.img_hsv[:, :, 1] > self.color_thresh_H
        min_R =  self.img_RGB[:, :, 0] > self.minRGB
        min_G =  self.img_RGB[:, :, 1] > self.minRGB
        min_B =  self.img_RGB[:, :, 2] > self.minRGB
        tissue_mask = tissue_S & tissue_RGB & min_R & min_G & min_B  ###############tissue mask
#         tissue_mask = tissue_RGB & min_R & min_G & min_B###############tissue mask

        return tissue_mask  # levl4

    def _tumor_mask(self):
        tumor_json = os.path.basename(self.slide_path).split('.')[0] + '.json'
        tumor_json = os.path.join(self.json_path, tumor_json)

        if not os.path.exists(tumor_json):
            print('not exist')
            tumor_mask = np.array([])
        else:
            with open(tumor_json) as f:
                dicts = json.load(f)
            tumor_polygons = dicts['positive']  # dicts['mask']#

            for tumor_polygon in tumor_polygons:
                # plot a polygon
                name = tumor_polygon["name"]
                vertices = np.array(tumor_polygon["vertices"]) / self.scale
                vertices = vertices.astype(np.int32)
                cv2.fillPoly(self.mask_tumor, [vertices], (255))

            self.mask_tumor = self.mask_tumor[:] > 127
            tumor_mask = np.transpose(self.mask_tumor)

        return tumor_mask  # level4

    def mask(self, plot=True):
        tissue_mask = self._tissue_mask()
        tumor_mask = self._tumor_mask()
        if tumor_mask.shape[0] == 0:
            normal_mask, questionable_mask = tissue_mask, np.zeros((tissue_mask.shape[0], tissue_mask.shape[1]))
        else:
            normal_mask, questionable_mask = tissue_mask & (~ tumor_mask), tissue_mask & (tumor_mask)

        if plot:
            plt.figure(0, figsize=(18, 18))
            plt.subplot(1, 3, 1)
            plt.imshow(normal_mask)
            plt.subplot(1, 3, 2)
            plt.imshow(questionable_mask)
            plt.subplot(1, 3, 3)
            plt.imshow(self.img_RGB)
            plt.show()
        return normal_mask, questionable_mask, self.img_RGB  # level4

    def slide_to_img(self, item_list):
        img, label, coor = item_list

        image = Image.fromarray(img)
        if label == 1:
            img_save_path = os.path.join(self.out_path_normal, 'image')
        else:
            img_save_path = os.path.join(self.out_path_tumor, 'image')
        if not os.path.exists(img_save_path):
            os.mkdir(img_save_path)
        image.save(img_save_path + '/' + str(coor) + '.png')
        # print('==> image saved ',img_save_path)
        return [img_save_path + '/' + str(coor) + '.png', str(label)]

    def obtain_all_patchpts(self):
        '''
        random sampling positive and negative samples
        '''
        normal, questionable_mask, rgb = self.mask(plot=True)  # level4
        # normal
        X_idcs_n, Y_idcs_n = np.where(normal)
        centre_points_normal = np.stack(np.vstack((X_idcs_n.T, Y_idcs_n.T)), axis=1)
        mask_name = [1, 0]
        name = np.full((centre_points_normal.shape[0], 2), mask_name)
        normal_center_points = np.hstack((centre_points_normal, name))
        # tumor
        X_idcs_t, Y_idcs_t = np.where(questionable_mask)
        centre_points_tumor = np.stack(np.vstack((X_idcs_t.T, Y_idcs_t.T)), axis=1)
        mask_name = [0, 1]
        name = np.full((centre_points_tumor.shape[0], 2), mask_name)
        tumor_center_points = np.hstack((centre_points_tumor, name))
        return normal_center_points, tumor_center_points, normal, questionable_mask, rgb  ###########

    def is_tumor(self, x, y, size, tumor_mask):
        if y + size > tumor_mask.shape[0] or x + size > tumor_mask.shape[1]:
            return False
        select_tumor_mask = tumor_mask[y:y + size, x:x + size]
        include_tumor = np.count_nonzero(select_tumor_mask)
        return True if include_tumor / (select_tumor_mask.shape[0] * select_tumor_mask.shape[1]) > 0.001 else False

    def is_normal(self, x, y, size, normal_mask):
        if y + size > normal_mask.shape[0] or x + size > normal_mask.shape[1]:
            return False

        select_normal_mask = normal_mask[y:y + size, x:x + size]
        include_normal = np.count_nonzero(select_normal_mask)
        return True if include_normal / (select_normal_mask.shape[0] * select_normal_mask.shape[1]) > 0.1 else False

    def iter_over_slide(self, x_min, y_min, x_max, y_max, step, level, tm, nm):
        all_sample = []
        rects = []

        img_save_path_normal = os.path.join(self.out_path_normal, 'image')
        img_save_path_tumor = os.path.join(self.out_path_tumor, 'image')

        if not os.path.exists(img_save_path_normal):
            os.mkdir(img_save_path_normal)
        if not os.path.exists(img_save_path_tumor):
            os.mkdir(img_save_path_tumor)

        tumor_count = 0
        normal_count = 0
        # print(x_min, x_max,y_min, y_max, )
        for y in range(y_min, y_max, step):
            for x in range(x_min, x_max, step):

                if self.is_tumor(x, y, step, tm):  # Tumor
                    select_sample = self.img_RGB[y:y + step, x:x + step]
                    image = Image.fromarray(select_sample)
                    image_pth = os.path.join(img_save_path_tumor, str(x) + '_' + str(y) + '.png')
                    image.save(image_pth)
                    mask = tm[y:y + step, x:x + step]
                    mask = Image.fromarray(mask)
                    out_tumor_mask_pth = os.path.join(self.out_path_tumor, 'mask')
                    if not os.path.exists(out_tumor_mask_pth):
                        os.mkdir(out_tumor_mask_pth)
                    mask.save(os.path.join(out_tumor_mask_pth, str(x) + '_' + str(y) + '.png'))
                    rects.append(
                        patches.Rectangle((x, y), self.patch_size, self.patch_size, edgecolor='r', facecolor="none"))
                    tumor_count += 1
                elif self.is_normal(x, y, step, nm):  # Normal
                    select_sample = self.img_RGB[y:y + step, x:x + step]
                    image = Image.fromarray(select_sample)
                    image_pth = os.path.join(img_save_path_normal, str(x) + '_' + str(y) + '.png')
                    image.save(image_pth)
                    rects.append(
                        patches.Rectangle((x, y), self.patch_size, self.patch_size, edgecolor='b', facecolor="none"))
                    normal_count += 1
        print('=> tumor : ', tumor_count)
        print('=> normal : ', normal_count)
        return rects

    def patch_gen(self):
        normal_coord, tumor_coord, normal_mask, tumor_mask, RGB_image = self.obtain_all_patchpts()
        normal_coord = normal_coord[:, 0:2]
        rects = self.iter_over_slide(np.min(normal_coord[:, 1]), np.min(normal_coord[:, 0]), \
                                     np.max(normal_coord[:, 1]), np.max(normal_coord[:, 0]), \
                                     self.patch_size, self.dimension_level, tumor_mask, normal_mask)

        figure, ax = plt.subplots(figsize=(10, 10))

        ax.imshow(RGB_image)
        ax.add_patch(patches.Rectangle((np.min(normal_coord[:, 1]), np.min(normal_coord[:, 0])),
                                       np.max(normal_coord[:, 1]) - np.min(normal_coord[:, 1]),
                                       np.max(normal_coord[:, 0]) - np.min(normal_coord[:, 0]), edgecolor='b',
                                       facecolor="none"))
        for i in rects:
            ax.add_patch(i)
        thumbnail_pth = os.path.join(self.outpath, 'thumbnails')
        if not os.path.exists(thumbnail_pth):
            os.mkdir(thumbnail_pth)
        plot_out = os.path.join(thumbnail_pth, '%s.png'%self.tumor_number)
        plt.savefig(plot_out)
        plt.show()
        print('outpath ',self.outpath)


In [13]:
import json

with open('/home/congz3414050/HistoGCN/data/Original/test_annotation_json/test_097.json') as f:
  data = json.load(f)

# Output: {'name': 'Bob', 'languages': ['English', 'Fench']}
print(data)

{'positive': [{'name': 'Annotation 0', 'vertices': [[16695, 158590], [16604, 158707], [16501, 158817], [16582, 158942], [16714, 159011], [16849, 159070], [16996, 159095], [17146, 159095], [17303, 159095], [17454, 159095], [17593, 159033], [17681, 158909], [17578, 158802], [17364, 158823], [17368, 158663], [17146, 158645], [16996, 158703], [16860, 158645]]}], 'negative': []}


In [8]:
rt = '/home/congz3414050/HistoGCN/data/Original/test_image'
out = '/home/congz3414050/HistoGCN/data/5X/Tumor/Train'
annotation_path = '/home/congz3414050/HistoGCN/data/Original/test_annotation_json'
import time
# finished = ['tumor_102.tif','tumor_001.tif','tumor_088.tif']#,'patient_004','patient_009','patient_015','patient_016','patient_018']
for i in os.listdir(rt):
    tumor_name = i.split('.tif')[0]
#     if tumor_name not in finished:
    if i.endswith('.tif'):
        slide_pth = os.path.join(rt, i)

        patch_generator_tumor = Slide_Patch(3, slide_pth, annotation_path, out, patch_size=256)
        patch_generator_tumor.patch_gen()
        time.sleep(5)

Reading  /home/congz3414050/HistoGCN/data/Original/test_image/test_097.tif
8.0
==> calculate threshold
==> threshold done


KeyError: 'positive'

In [4]:


# root = '/home/congz3414050/HistoGCN/data/5X/Tumor_768/Train'
# out = '/home/congz3414050/HistoGCN/data/5X/Tumor_768/csvs'
# out_all = '/home/congz3414050/HistoGCN/data/5X/Tumor_768/'

# def path2csv(root, out):
#     count = 0
#     for tumors in os.listdir(root):
#         if tumors == 'thumbnails':
#             continue
#         content = {'node':[], 'coord':[], 'label':[], 'id':[], 'slide':[], 'mask':[]}
#         tumor_pth = os.path.join(root, tumors)
#         ids = 0

#         for sub_root in os.listdir(tumor_pth):
#             sub_root_pth = os.path.join(tumor_pth, sub_root, 'image')
#             if sub_root == 'Tumor':
#                 sub_root_mask_pth = os.path.join(tumor_pth, sub_root, 'mask')
#             else:
#                 sub_root_mask_pth = False
#             if len(os.listdir(sub_root_pth)) == 0:
#                 break
#             for node in os.listdir(sub_root_pth):
#                 node_pth = os.path.join(sub_root_pth, node)
#                 mask_pth = os.path.join(sub_root_mask_pth, node) if sub_root_mask_pth else 'Nothing'
#                 node_coor_combine = node.split('.png')[0]
#                 node_label = sub_root
#                 content['node'].append(node_pth)
#                 content['coord'].append(node_coor_combine)
#                 content['label'].append(node_label)
#                 content['id'].append(ids)
#                 content['slide'].append(tumors)
#                 content['mask'].append(mask_pth)
#                 ids += 1
#         count += 1
#             #     print(content)
#             #     break
#             # break
        
#         if len(content['node'])!=0:
#             out_csv_pth = os.path.join(out, '%s.csv'%tumors)
#             df = pd.DataFrame.from_dict(content)
#             df.to_csv(out_csv_pth, index=False)


#     print('total %d csvs'%count)

# def obtain_class(root):
#     out_pth = root
#     root = root + 'csvs'
#     content = {'image':[],'label':[],'slide':[], 'mask':[]}
#     # prefix = 'F:\BaiduNetdiskDownload\CAMELYON16\GCN\data'
#     for tumor_csv in os.listdir(root):
#         csv_path = os.path.join(root, tumor_csv)
#         df = pd.read_csv(csv_path)

#         content['image'] += df['node'].tolist()
#         content['label'] += df['label'].tolist()
#         content['slide'] += df['slide'].tolist()
#         content['mask'] += df['mask'].tolist()

#     out_csv_pth = os.path.join(out_pth, 'all_data.csv')
#     df = pd.DataFrame.from_dict(content)
#     df.to_csv(out_csv_pth, index=False)
#     print('total ',len(df))
#     print(out_csv_pth)
# # 
# path2csv(root, out)
# obtain_class(out_all)
# # df = pd.read_csv(r'F:\BaiduNetdiskDownload\CAMELYON16\GCN\data\5X\csvs\all_data.csv')
# # df_normal = df[df['label']=='Normal']
# # df_tumor = df[df['label']=='Tumor']
# # print(len(df_normal),len(df_tumor))

In [5]:
class Polygon(object):
    """
    Polygon represented as [N, 2] array of vertices
    """
    def __init__(self, name, vertices):
        """
        Initialize the polygon.
        Arguments:
            name: string, name of the polygon
            vertices: [N, 2] 2D numpy array of int
        """
        self._name = name
        self._vertices = vertices

    def __str__(self):
        return self._name

    def inside(self, coord):
        """
        Determine if a given coordinate is inside the polygon or not.
        Arguments:
            coord: 2 element tuple of int, e.g. (x, y)
        Returns:
            bool, if the coord is inside the polygon.
        """
        return points_in_poly([coord], self._vertices)[0]

    def vertices(self):

        return np.array(self._vertices)


class Annotation(object):
    """
    Annotation about the regions within WSI in terms of vertices of polygons.
    """
    def __init__(self, scale=8.0):
        self._json_path = ''
        self._polygons_positive = []
        self._polygons_negative = []
        self.scale = scale

    def __str__(self):
        return self._json_path

    def from_json(self, json_path, ):
        """
        Initialize the annotation from a json file.
        Arguments:
            json_path: string, path to the json annotation.
        """
        self._json_path = json_path
        with open(json_path) as f:
            annotations_json = json.load(f)

        for annotation in annotations_json['positive']:
            name = annotation['name']
            vertices = np.array(annotation["vertices"]) / self.scale
            vertices = vertices.astype(np.int32)
            polygon = Polygon(name, vertices)
            self._polygons_positive.append(polygon)

        for annotation in annotations_json['negative']:
            name = annotation['name']
            vertices = np.array(annotation["vertices"]) / self.scale
            vertices = vertices.astype(np.int32)
            polygon = Polygon(name, vertices)
            self._polygons_negative.append(polygon)

    def inside_polygons(self, coord, is_positive):
        """
        Determine if a given coordinate is inside the positive/negative
        polygons of the annotation.
        Arguments:
            coord: 2 element tuple of int, e.g. (x, y)
            is_positive: bool, inside positive or negative polygons.
        Returns:
            bool, if the coord is inside the positive/negative polygons of the
            annotation.
        """
        if is_positive:
            polygons = copy.deepcopy(self._polygons_positive)
        else:
            polygons = copy.deepcopy(self._polygons_negative)

        for polygon in polygons:
            if polygon.inside(coord):
                return True

        return False

    def polygon_vertices(self, is_positive):
        """
        Return the polygon represented as [N, 2] array of vertices
        Arguments:
            is_positive: bool, return positive or negative polygons.
        Returns:
            [N, 2] 2D array of int
        """
        if is_positive:
            return list(map(lambda x: x.vertices(), self._polygons_positive))
        else:
            return list(map(lambda x: x.vertices(), self._polygons_negative))

In [8]:
class GridImageDataset(Dataset):
    """
    Data producer that generate a square grid, e.g. 3x3, of patches and their
    corresponding labels from pre-sampled images.
    """
    def __init__(self, data_path, json_path, img_size, patch_size,
                 crop_size=224, normalize=True):
        """
        Initialize the data producer.
        Arguments:
            data_path: string, path to pre-sampled images using patch_gen.py
            img_size: int, size of pre-sampled images, e.g. 768
            patch_size: int, size of the patch, e.g. 256
            crop_size: int, size of the final crop that is feed into a CNN,
                e.g. 224 for ResNet
            normalize: bool, if normalize the [0, 255] pixel values to [-1, 1],
                mostly False for debuging purpose
        """
        self._df = data_path
        self._json_path = json_path
        self._img_size = img_size
        self._patch_size = patch_size
        self._crop_size = crop_size
        self._normalize = normalize
        self._color_jitter = transforms.ColorJitter(64.0/255, 0.75, 0.25, 0.04)
        self._preprocess()

    def _preprocess(self):
        if self._img_size % self._patch_size != 0:
            raise Exception('Image size / patch size != 0 : {} / {}'.
                            format(self._img_size, self._patch_size))

        self._patch_per_side = self._img_size // self._patch_size
        self._grid_size = self._patch_per_side * self._patch_per_side

        self._pids = list(map(lambda x: x.strip('.json'),
                              os.listdir(self._json_path)))

#         self._annotations = {}
#         for pid in self._pids:
#             pid_json_path = os.path.join(self._json_path, pid + '.json')
#             anno = Annotation()
#             anno.from_json(pid_json_path)
#             self._annotations[pid] = anno


    def __len__(self):
        return len(self._df)

    def __getitem__(self, idx):
        image_pth = self._df.iloc[idx, 0]
        mask_pth = self._df.iloc[idx, 3]

        slide_name = self._df.iloc[idx, 2]
        image_name = image_pth.split('/')[-1].split('.png')[0].split('_')
        x_top_left, y_top_left = 0,0 #int(image_name[0]), int(image_name[1])
        patch_label = self._df.iloc[idx, 1]

        # the grid of labels for each patch
        label_grid = np.zeros((self._patch_per_side, self._patch_per_side),
                              dtype=np.float32)
        if mask_pth != 'Nothing':
            mask_np = cv2.imread(mask_pth)
            for x_i in range(self._patch_per_side):
                x_t = x_top_left + self._patch_size * x_i
                for y_i in range(self._patch_per_side):
                    y_t = y_top_left + self._patch_size * y_i
                    select_tumor_mask = mask_np[x_t: x_t + self._patch_size, y_t: y_t + self._patch_size]
                    include_tumor = np.count_nonzero(select_tumor_mask)

                    if include_tumor / (mask_np.shape[0] * mask_np.shape[1]) > 0.001:
                        label = 1
                    else:
                        label = 0

                    label_grid[y_i, x_i] = label
     

        img = Image.open(image_pth)

        # color jitter
        img = self._color_jitter(img)

        # use left_right flip
        if np.random.rand() > 0.5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            label_grid = np.fliplr(label_grid)

        # use rotate
        num_rotate = np.random.randint(0, 4)
        img = img.rotate(90 * num_rotate)
        label_grid = np.rot90(label_grid, num_rotate)

        # PIL image:   H x W x C
        # torch image: C X H X W
        img = np.array(img, dtype=np.float32).transpose((2, 0, 1))

        if self._normalize:
            img = (img - 128.0)/128.0

        # flatten the square grid
        img_flat = np.zeros(
            (self._grid_size, 3, self._crop_size, self._crop_size),
            dtype=np.float32)
        label_flat = np.zeros(self._grid_size, dtype=np.float32)

        idx = 0
        for x_idx in range(self._patch_per_side):
            for y_idx in range(self._patch_per_side):
                # center crop each patch
                x_start = int(
                    (x_idx + 0.5) * self._patch_size - self._crop_size / 2)
                x_end = x_start + self._crop_size
                y_start = int(
                    (y_idx + 0.5) * self._patch_size - self._crop_size / 2)
                y_end = y_start + self._crop_size
                img_flat[idx] = img[:, x_start:x_end, y_start:y_end]
                label_flat[idx] = label_grid[x_idx, y_idx]

                idx += 1

        return (img_flat, label_flat)

In [9]:
class WSIDataset(Dataset):
    """Generate dataset."""
    def __init__(self, filepath, annotation_pth):
        self.all_data = pd.read_csv(filepath)
        self.annotation = annotation_pth
        all_normal, all_tumor = self.Obtain_detail(self.all_data['label'].tolist())
        print('total: ', all_normal, all_tumor)
        self.train_data = self.all_data.sample(frac=0.8, replace=False, random_state=200) #random state is a seed value
        self.test_data = self.all_data.drop(self.train_data.index)
        train_normal, train_tumor = self.Obtain_detail(self.train_data['label'].tolist())
        print('train: ', train_normal, train_tumor)
        test_normal, test_tumor = self.Obtain_detail(self.test_data['label'].tolist())
        print('test: ', test_normal, test_tumor)

    def Obtain_detail(self, y):
        normal = tumor = 0
        for i in y:
            if i == 'Normal':
                normal += 1
            else:
                tumor += 1
        return normal, tumor

    def Obtain_dataset(self, stage):
        if stage == 'Train':
            self.dataset = GridImageDataset(self.train_data, self.annotation, 768, 256)
        elif stage == 'Test':
            self.dataset = GridImageDataset(self.test_data, self.annotation, 768, 256)
        return self.dataset

    def Obtain_loader(self, stage, batch_size):
        ds = self.Obtain_dataset(stage)
        self.loader = DataLoader(ds,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=1,
                                 drop_last=True)
        return self.loader



In [10]:
import torch
from torch import nn


class CRF(nn.Module):
    def __init__(self, num_nodes, iteration=10):
        """Initialize the CRF module
        Args:
            num_nodes: int, number of nodes/patches within the fully CRF
            iteration: int, number of mean field iterations, e.g. 10
        """
        super(CRF, self).__init__()
        self.num_nodes = num_nodes
        self.iteration = iteration
        self.W = nn.Parameter(torch.zeros(1, num_nodes, num_nodes))

    def forward(self, feats, logits):
        """Performing the CRF. Algorithm details is explained below:
        Within the paper, I formulate the CRF distribution using negative
        energy and cost, e.g. cosine distance, to derive pairwise potentials
        following the convention in energy based models. But for implementation
        simplicity, I use reward, e.g. cosine similarity to derive pairwise
        potentials. So now, pairwise potentials would encourage high reward for
        assigning (y_i, y_j) with the same label if (x_i, x_j) are similar, as
        measured by cosine similarity, pairwise_sim. For
        pairwise_potential_E = torch.sum(
            probs * pairwise_potential - (1 - probs) * pairwise_potential,
            dim=2, keepdim=True
        )
        This is taking the expectation of pairwise potentials using the current
        marginal distribution of each patch being tumor, i.e. probs. There are
        four cases to consider when taking the expectation between (i, j):
        1. i=T,j=T; 2. i=N,j=T; 3. i=T,j=N; 4. i=N,j=N
        probs is the marginal distribution of each i being tumor, therefore
        logits > 0 means tumor and logits < 0 means normal. Given this, the
        full expectation equation should be:
        [probs * +pairwise_potential] + [(1 - probs) * +pairwise_potential] +
                    case 1                            case 2
        [probs * -pairwise_potential] + [(1 - probs) * -pairwise_potential]
                    case 3                            case 4
        positive sign rewards logits to be more tumor and negative sign rewards
        logits to be more normal. But because of label compatibility, i.e. the
        indicator function within equation 3 in the paper, case 2 and case 3
        are dropped, which ends up being:
        probs * pairwise_potential - (1 - probs) * pairwise_potential
        In high level speaking, if (i, j) embedding are different, then
        pairwise_potential, as computed as cosine similarity, would approach 0,
        which then as no affect anyway. if (i, j) embedding are similar, then
        pairwise_potential would be a positive reward. In this case,
        if probs -> 1, then pairwise_potential promotes tumor probability;
        if probs -> 0, then -pairwise_potential promotes normal probability.
        Args:
            feats: 3D tensor with the shape of
            [batch_size, num_nodes, embedding_size], where num_nodes is the
            number of patches within a grid, e.g. 9 for a 3x3 grid;
            embedding_size is the size of extracted feature representation for
            each patch from ResNet, e.g. 512
            logits: 3D tensor with shape of [batch_size, num_nodes, 1], the
            logit of each patch within the grid being tumor before CRF
        Returns:
            logits: 3D tensor with shape of [batch_size, num_nodes, 1], the
            logit of each patch within the grid being tumor after CRF
        """
        feats_norm = torch.norm(feats, p=2, dim=2, keepdim=True)
        pairwise_norm = torch.bmm(feats_norm,
                                  torch.transpose(feats_norm, 1, 2))
        pairwise_dot = torch.bmm(feats, torch.transpose(feats, 1, 2))
        # cosine similarity between feats
        pairwise_sim = pairwise_dot / pairwise_norm
        # symmetric constraint for CRF weights
        W_sym = (self.W + torch.transpose(self.W, 1, 2)) / 2
        pairwise_potential = pairwise_sim * W_sym
        unary_potential = logits.clone()

        for i in range(self.iteration):
            # current Q after normalizing the logits
            probs = torch.transpose(logits.sigmoid(), 1, 2)
            # taking expectation of pairwise_potential using current Q
            pairwise_potential_E = torch.sum(
                probs * pairwise_potential - (1 - probs) * pairwise_potential,
                dim=2, keepdim=True)
            logits = unary_potential + pairwise_potential_E

        return logits


In [11]:
# class ResNetCRF(nn.Module):

#     def __init__(self, num_classes=1, num_nodes=1, use_crf=True):
#         """Constructs a ResNet model.
#         Args:
#             num_classes: int, since we are doing binary classification
#                 (tumor vs normal), num_classes is set to 1 and sigmoid instead
#                 of softmax is used later
#             num_nodes: int, number of nodes/patches within the fully CRF
#             use_crf: bool, use the CRF component or not
#         """
#         super(ResNetCRF, self).__init__()
#         base_model= torchvision.models.resnet50()
#         num_ftrs = base_model.fc.in_features
        
#         self.feature_et = nn.Sequential(*list(base_model.children())[:-1])
#         self.class_fc = nn.Linear(num_ftrs, num_classes)
#         self.crf = CRF(num_nodes) if use_crf else None


#     def forward(self, x):
#         """
#         Args:
#             x: 5D tensor with shape of
#             [batch_size, grid_size, 3, crop_size, crop_size],
#             where grid_size is the number of patches within a grid (e.g. 9 for
#             a 3x3 grid); crop_size is 224 by default for ResNet input;
#         Returns:
#             logits, 2D tensor with shape of [batch_size, grid_size], the logit
#             of each patch within the grid being tumor
#         """
#         batch_size, grid_size, _, crop_size = x.shape[0:4]
#         # flatten grid_size dimension and combine it into batch dimension
#         x = x.view(-1, 3, crop_size, crop_size)
#         x = self.feature_et(x)
#         feats = x.view(x.size(0), -1)
#         logits = self.class_fc(feats)

#         # restore grid_size dimension for CRF
#         feats = feats.view((batch_size, grid_size, -1))
#         logits = logits.view((batch_size, grid_size, -1))
#         if self.crf:
#             logits = self.crf(feats, logits)

# #         logits = torch.squeeze(logits)

#         return logits

In [18]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1, num_nodes=1,
                 use_crf=True):
        """Constructs a ResNet model.
        Args:
            num_classes: int, since we are doing binary classification
                (tumor vs normal), num_classes is set to 1 and sigmoid instead
                of softmax is used later
            num_nodes: int, number of nodes/patches within the fully CRF
            use_crf: bool, use the CRF component or not
        """
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.crf = CRF(num_nodes) if use_crf else None

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        """
        Args:
            x: 5D tensor with shape of
            [batch_size, grid_size, 3, crop_size, crop_size],
            where grid_size is the number of patches within a grid (e.g. 9 for
            a 3x3 grid); crop_size is 224 by default for ResNet input;
        Returns:
            logits, 2D tensor with shape of [batch_size, grid_size], the logit
            of each patch within the grid being tumor
        """
        batch_size, grid_size, _, crop_size = x.shape[0:4]
        # flatten grid_size dimension and combine it into batch dimension
        x = x.view(-1, 3, crop_size, crop_size)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        # feats means features, i.e. patch embeddings from ResNet
        feats = x.view(x.size(0), -1)
        logits = self.fc(feats)

        # restore grid_size dimension for CRF
        feats = feats.view((batch_size, grid_size, -1))
        logits = logits.view((batch_size, grid_size, -1))

        if self.crf:
            logits = self.crf(feats, logits)

        logits = torch.squeeze(logits)

        return logits

def resnet50(**kwargs):
    """Constructs a ResNet-50 model.
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)

    return model

In [19]:
def metric(output, labels):
    TP = TN = FP = FN = 0
    for i in range(output.size()[0]):
        probs = output[i].sigmoid()
        predicts = (probs >= 0.5).type(torch.cuda.FloatTensor)
        for j in range(len(predicts)):
            patch_pred = predicts[j].cpu().item()
            patch_label = labels[i][j].cpu().item()
            if patch_pred == patch_label == 0:
                TN += 1
            elif patch_pred == patch_label == 1:
                TP += 1
            elif patch_pred == 0 and patch_label == 1:
                FN += 1
            else:
                FP += 1
    return TN, TP, FN, FP

In [None]:
num_epochs = 30
batch_size = 16
data_pth = '/home/congz3414050/HistoGCN/data/5X/Tumor_768/all_data.csv'
model_save_path = '/home/congz3414050/HistoGCN/checkpoint/Scratch_Res50CRF_torch.pt'
annotation_path = '/home/congz3414050/HistoGCN/data/Original/annotation'
wsi_dataset = WSIDataset(data_pth, annotation_path)
trainloader = wsi_dataset.Obtain_loader('Train', batch_size)
testloader = wsi_dataset.Obtain_loader('Test', batch_size)

# check readed data
print('number of training data %d' % len(trainloader))
print('number of testing data %d' % (len(testloader)))
model = resnet50(num_classes=1, num_nodes=9)
# criterion = nn.BCEWithLogitsLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
# for batch_data in trainloader:
#     (img_flat, label_flat) = batch_data
#     print('image size ',img_flat.size())
#     print('label size ',label_flat.size())
#     pred = model(img_flat)
#     pred = pred.squeeze(-1)
#     print(pred.size(), label_flat.size())
#     loss = criterion(pred, label_flat)

#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()
#     break

since = time.time()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)  # , weight_decay=0.01)
criterion = nn.BCEWithLogitsLoss()
best_acc = 0.0
best_loss = float('inf')
best_f1 = 0.0
best_epoch = 0

for epoch in range(num_epochs):
    dataloaders = {'train': trainloader, 'val': testloader}
    batch_TP = 1
    batch_TN = 1
    batch_FP = 1
    batch_FN = 1
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)
    stop_count = 0
    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()  # Set model to evaluate mode
        running_loss = 0.0

        # Iterate over data.
        for data in tqdm(dataloaders[phase]):
            (img_flat, label_flat) = data
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            # track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
                pred = model(img_flat)
                pred = pred.squeeze(-1)

                loss = criterion(pred, label_flat)
                TN, TP, FN, FP = metric(pred, label_flat)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()
            
            # statistics
            running_loss += loss.item() * img_flat.size(0)
            batch_TN += TN
            batch_TP += TP
            batch_FN += FN
            batch_FP += FP
            stop_count += 1
#             if stop_count > 10:
#                 break
#             break
        # epoch_f1 = metric(dataloaders[phase], graphloader, model, batch_size, use_graph=use_graph)
        print(batch_TN, batch_TP, batch_FN, batch_FP)
        epoch_loss = running_loss / len(dataloaders[phase].dataset)
        Specificity = batch_TN / (batch_TN + batch_FP)
        Sensitivity = batch_TP / (batch_FN + batch_TP)
        Precision = batch_TP / (batch_TP + batch_FP)
        F1_Score = 2 * (Precision * Sensitivity) / (Precision + Sensitivity)

        print('Stage %s'%phase)
        print('Specificity: ', Specificity)
        print('Sensitivity: ', Sensitivity)
        print('Precision: ', Precision)
        print('F1-Score: ', F1_Score)
        print('Loss: {:.4f} Lr: {}'.format(epoch_loss, optimizer.param_groups[0]["lr"]))

        # deep copy the modela
        if phase == 'val':
            val_loss = epoch_loss
            if F1_Score > best_f1:
                best_f1 = F1_Score
                best_epoch = epoch
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(best_model_wts, model_save_path)
            print('===========End of Epoch %d============='%epoch)
        # break
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val f1: {:4f} at epoch {:d}'.format(best_f1, best_epoch))

total:  6553 1573
train:  5233 1268
test:  1320 305
number of training data 406
number of testing data 101
Epoch 0/29
----------


  0%|          | 0/406 [00:00<?, ?it/s]

52138 49 5959 322
Stage train
Specificity:  0.9938619900876858
Sensitivity:  0.008155792276964047
Precision:  0.1320754716981132
F1-Score:  0.01536290954695093
Loss: 0.3358 Lr: 0.0002


  0%|          | 0/101 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f8ccd77bf70>
Traceback (most recent call last):
  File "/home/congz3414050/HistoGCN/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/congz3414050/HistoGCN/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f8ccd77bf70>
Traceback (most recent call last):
  File "/home/congz3414050/HistoGCN/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/congz3414050/HistoGCN/venv/lib/python3.8/site-packages/torch/utils/data/dat

65236 49 7405 322
Stage val
Specificity:  0.9950883187406572
Sensitivity:  0.006573651730614435
Precision:  0.1320754716981132
F1-Score:  0.012523961661341851
Loss: 0.3063 Lr: 0.0002
Epoch 1/29
----------


  0%|          | 0/406 [00:00<?, ?it/s]

In [7]:
import numpy as np

a = np.random.rand(64,2)
a = np.swapaxes(a,0,1)
a = np.reshape(a, (2, 8, 8))
a = np.expand_dims(a, axis=0)
print(a.shape)

(1, 2, 8, 8)
