In [2]:
import sys
import numpy as np
import cv2
import matplotlib.pyplot as plt
import joblib
import yaml
import glob
import re

In [3]:
class ClassRatio(object):
    def __init__(self, wsi_list, classes, imgs_dir):
        self.wsi_list = wsi_list
        self.classes = self.get_sub_classes(classes)
        self.imgs_dir = imgs_dir


    def get_files(self, wsi_name: str, classes: list, imgs_dir: str):
        def get_sub_classes(classes):
            # classesからsub-classを取得
            sub_cl_list = []
            for idx in range(len(classes)):
                cl = classes[idx]
                if isinstance(cl, list):
                    for sub_cl in cl:
                        sub_cl_list.append(sub_cl)
                else:
                    sub_cl_list.append(cl)
            return sub_cl_list

        re_pattern = re.compile("|".join([f"/{i}/" for i in get_sub_classes(classes)]))

        files_list = []
        files_list.extend(
            [
                p
                for p in glob.glob(imgs_dir + f"*/{wsi_name}_*/*.png", recursive=True)
                if bool(re_pattern.search(p))
            ]
        )
        return files_list

    def get_class_num(self, files: list, classes: list):
        result_dict = {}
        for cls in classes:
            result_dict[cls] = 0

        for file_path in files:
            for cls in classes:
                if f'/{cls}/' in file_path:
                    result_dict[cls] += 1
        return result_dict

    def get_sub_classes(self, classes):
        # classesからsub-classを取得
        sub_cl_list = []
        for idx in range(len(classes)):
            cl = classes[idx]
            if isinstance(cl, list):
                for sub_cl in cl:
                    sub_cl_list.append(sub_cl)
            else:
                sub_cl_list.append(cl)
        return sub_cl_list

    def draw_pie(self, x, label, colors, title=""):
        figure = plt.figure(figsize=(8, 8))
        plt.rcParams["font.size"] = 12
        plt.title(title)
        plt.pie(
            x,
            labels=label,
            counterclock=False,
            startangle=90,
            autopct="%.1f%%",
            # wedgeprops={'linewidth': 1, 'edgecolor': "white"},
            textprops={'weight': "bold"},
            colors=colors
        )
        return figure

    def num_to_color(self, num):
        if isinstance(num, list):
            num = num[0]

        if num == 0:
            color = (200, 200, 200)
        elif num == 1:
            color = (255, 0, 0)
        # elif num == 2:
        #     color = (255, 255, 0)
        elif num == 2:
            color = (65, 105, 225)
        elif num == 3:
            color = (0, 255, 0)
        elif num == 4:
            color = (0, 255, 255)
        elif num == 5:
            color = (0, 0, 255)
        elif num == 6:
            color = (255, 0, 255)
        elif num == 7:
            color = (128, 0, 0)
        elif num == 8:
            color = (128, 128, 0)
        elif num == 9:
            color = (0, 128, 0)
        elif num == 10:
            color = (0, 0, 128)
        elif num == 11:
            color = (64, 64, 64)
        else:
            sys.exit("invalid number:" + str(num))
        return color

    def draw_class_ratio(self, output_dir, title):
        # 初期化
        counts_dict = {}
        colors = []
        for cls in self.classes:
            counts_dict[str(cls)] = 0
            colors.append(tuple(map(lambda x: x / 255, self.num_to_color(cls))))

        for wsi in self.wsi_list:
            files = self.get_files(wsi_name=wsi, classes=self.classes, imgs_dir=self.imgs_dir)
            counts = self.get_class_num(files, classes=self.classes)

            for cls in self.classes:
                counts_dict[str(cls)] += counts[cls]

        print(f"counts: {counts_dict}")

        # LSIL, HSIL順番入れ替え用
        output_counts_dict = {'Non-Neop.': counts_dict[str(0)], 'LSIL': counts_dict[str(2)], 'HSIL': counts_dict[str(1)]}
        output_colors = [colors[0], colors[2], colors[1]]

        # 2クラスしかパッチがない場合用
        # output_counts_dict = {'Non-Neop.': counts_dict[str(0)], 'HSIL': counts_dict[str(1)]}
        # output_colors = colors

        fig = self.draw_pie(list(output_counts_dict.values()), list(output_counts_dict.keys()), output_colors, title=f"patch class ratio: {title}")
        # fig.savefig(output_dir + title + "_patch_class_ratio_simple.png", dpi=300)
        fig.savefig(output_dir + title + "_patch_class_ratio.png", dpi=300)

        plt.clf()
        plt.close()
    
    def get_patch_nums(self):
        # 初期化
        counts_dict = {}
        for cls in self.classes:
            counts_dict[str(cls)] = 0

        for wsi in self.wsi_list:
            files = self.get_files(wsi_name=wsi, classes=self.classes, imgs_dir=self.imgs_dir)
            counts = self.get_class_num(files, classes=self.classes)

            for cls in self.classes:
                counts_dict[str(cls)] += counts[cls]

        print(f"counts: {counts_dict}")
        # show_counts_dict = {'Non-Neop.': counts_dict[str(0)], 'HSIL': counts_dict[str(1)], 'LSIL': counts_dict[str(2)]}
        # print(f"counts: {show_counts_dict}")
        return counts_dict

def get_files(wsi_name: str, classes: list, imgs_dir: str):
    def get_sub_classes(classes):
        # classesからsub-classを取得
        sub_cl_list = []
        for idx in range(len(classes)):
            cl = classes[idx]
            if isinstance(cl, list):
                for sub_cl in cl:
                    sub_cl_list.append(sub_cl)
            else:
                sub_cl_list.append(cl)
        return sub_cl_list

    re_pattern = re.compile("|".join([f"/{i}/" for i in get_sub_classes(classes)]))

    files_list = []
    files_list.extend(
        [
            p
            for p in glob.glob(imgs_dir + f"*/{wsi_name}_*/*.png", recursive=True)
            if bool(re_pattern.search(p))
        ]
    )
    return files_list


In [3]:
# # WSIのリストを取得 (source)
# src_wsis = joblib.load("/mnt/secssd/AL_SSDA_WSI_strage/dataset/MF0012/cv0_train_MF0012_wsi.jb")
# src_wsis += joblib.load("/mnt/secssd/AL_SSDA_WSI_strage/dataset/MF0012/cv0_valid_MF0012_wsi.jb")
# src_wsis += joblib.load("/mnt/secssd/AL_SSDA_WSI_strage/dataset/MF0012/cv0_test_MF0012_wsi.jb")

# # WSIのリストを取得 (target)
# l_trg_wsis = joblib.load("/mnt/secssd/AL_SSDA_WSI_strage/dataset/MF0003/trg_l_wsi.jb")
# val_trg_wsis = joblib.load("/mnt/secssd/AL_SSDA_WSI_strage/dataset/MF0003/valid_wsi.jb")
# unl_trg_wsis = joblib.load("/mnt/secssd/AL_SSDA_WSI_strage/dataset/MF0003/trg_unl_wsi.jb")
# trg_wsis = l_trg_wsis + unl_trg_wsis + val_trg_wsis

# src_imgs_dir = "/mnt/secssd/SSDA_Annot_WSI_strage/mnt2/MF0012/"
# trg_imgs_dir = "/mnt/secssd/SSDA_Annot_WSI_strage/mnt2/MF0003/"

In [5]:
# WSIのリストを取得 (source)
src_wsis = joblib.load("/mnt/secssd/AL_SSDA_WSI_MICCAI_strage/dataset/MF0012/cv0_train_MF0012_wsi.jb")
src_wsis += joblib.load("/mnt/secssd/AL_SSDA_WSI_MICCAI_strage/dataset/MF0012/cv0_valid_MF0012_wsi.jb")
src_wsis += joblib.load("/mnt/secssd/AL_SSDA_WSI_MICCAI_strage/dataset/MF0012/cv0_test_MF0012_wsi.jb")

# WSIのリストを取得 (target)
l_trg_wsis = joblib.load("/mnt/secssd/AL_SSDA_WSI_MICCAI_strage/dataset/MF0003/trg_l_top_wsi.jb")
l_trg_wsis += joblib.load("/mnt/secssd/AL_SSDA_WSI_MICCAI_strage/dataset/MF0003/trg_l_med_wsi.jb")
l_trg_wsis += joblib.load("/mnt/secssd/AL_SSDA_WSI_MICCAI_strage/dataset/MF0003/trg_l_btm_wsi.jb")

val_trg_wsis = joblib.load("/mnt/secssd/AL_SSDA_WSI_MICCAI_strage/dataset/MF0003/valid_wsi.jb")
unl_trg_wsis = joblib.load("/mnt/secssd/AL_SSDA_WSI_MICCAI_strage/dataset/MF0003/trg_unl_wsi.jb")
trg_wsis = l_trg_wsis + unl_trg_wsis + val_trg_wsis

src_imgs_dir = "/mnt/secssd/SSDA_Annot_WSI_strage/mnt2/MF0012/"
trg_imgs_dir = "/mnt/secssd/SSDA_Annot_WSI_strage/mnt2/MF0003/"

In [4]:
classes = [0, 1, 2]
output_dir = "./"

# obj = ClassRatio(wsi_list=src_wsis, classes=classes, imgs_dir=src_imgs_dir)
# obj.draw_class_ratio(output_dir=output_dir, title="src_MF0012")

# obj = ClassRatio(wsi_list=trg_wsis, classes=classes, imgs_dir=trg_imgs_dir)
# obj.draw_class_ratio(output_dir=output_dir, title="trg_MF0003")

# obj = ClassRatio(wsi_list=unl_trg_wsis, classes=classes, imgs_dir=trg_imgs_dir)
# obj.draw_class_ratio(output_dir=output_dir, title="unl_trg_test_MF0003")

# obj = ClassRatio(wsi_list=val_trg_wsis, classes=classes, imgs_dir=trg_imgs_dir)
# obj.draw_class_ratio(output_dir=output_dir, title="trg_valid_MF0003")

# for wsi in l_trg_wsis:
#     print(wsi)
#     # if wsi == "03_G144":
#     #     continue
#     # classes = [0, 1]
#     obj = ClassRatio(wsi_list=[wsi], classes=classes, imgs_dir=trg_imgs_dir)
#     obj.draw_class_ratio(output_dir=output_dir, title=f"l_trg_train_{wsi}_MF0003")

In [9]:
cv = 5
classes = [0, 1, 2]

# for cv_num in range(cv):
#     print(f"=== cv{cv_num}: test (!!未使用) ===")
#     src_train_wsis = joblib.load(f"/mnt/secssd/AL_SSDA_WSI_strage/dataset/MF0012/cv{cv_num}_test_MF0012_wsi.jb")
#     title=f"src_MF0012_train_cv{cv_num}"
#     obj = ClassRatio(wsi_list=src_train_wsis, classes=classes, imgs_dir=src_imgs_dir)
#     patch_nums = obj.get_patch_nums()
#     print(f"all: {patch_nums['0'] + patch_nums['1'] + patch_nums['2']}")

# print(f"=== l_trg (train, all) ===")
# # wsis = joblib.load(f"/mnt/secssd/AL_SSDA_WSI_MICCAI_strage/dataset/MF0003/valid_wsi.jb")
# wsis = l_trg_wsis
# obj = ClassRatio(wsi_list=wsis, classes=classes, imgs_dir=trg_imgs_dir)
# patch_nums = obj.get_patch_nums()
# print(f"all: {patch_nums['0'] + patch_nums['1'] + patch_nums['2']}\n")


print(f"=== l_trg: btm (train) ===")
wsis = joblib.load(f"/mnt/secssd/AL_SSDA_WSI_MICCAI_strage/dataset/MF0003/trg_l_btm_wsi.jb")
for wsi in wsis:
    print(f"{wsi}: ")
    obj = ClassRatio(wsi_list=[wsi], classes=classes, imgs_dir=trg_imgs_dir)
    patch_nums = obj.get_patch_nums()
    print(f"all: {patch_nums['0'] + patch_nums['1'] + patch_nums['2']}\n")

=== l_trg: top (train) ===
0067_a-1: 
counts: {'0': 343, '1': 52, '2': 0}
all: 395

0056_a-4: 
counts: {'0': 0, '1': 111, '2': 0}
all: 111

0289_a-1: 
counts: {'0': 295, '1': 0, '2': 25}
all: 320

0055_a-1: 
counts: {'0': 896, '1': 131, '2': 0}
all: 1027

0299_a-1: 
counts: {'0': 1239, '1': 0, '2': 11}
all: 1250

