In [1]:
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Subset, Dataset
from tqdm import tqdm
from rl_benchmarks.metrics import *

from rl_benchmarks.trainers.torch_trainer import TorchTrainer
from rl_benchmarks.models.slide_models.meanpool import MeanPool
from rl_benchmarks.models.slide_models.chowder import Chowder
from rl_benchmarks.models.slide_models.dsmil import DSMIL
from rl_benchmarks.models.slide_models.abmil import ABMIL
from rl_benchmarks.models.slide_models.hiptmil import HIPTMIL
from rl_benchmarks.models.slide_models.transmil import TransMIL

from pathlib import Path
from metrics import report


In [2]:
def get_model(model_name, in_dim, out_dim):
    if model_name == 'abmil':
        model = ABMIL(in_dim, out_dim, d_model_attention=128,temperature= 1.0, mlp_hidden= [128, 64])
    elif model_name == 'chowder':
        model = Chowder(in_dim, out_dim, n_top= 2,n_bottom= 2,tiles_mlp_hidden= [128], mlp_hidden= [128, 64])
    elif model_name == 'dsmil':
        model = DSMIL(in_dim, out_dim, d_tiles_values= 32,d_tiles_queries= 32,passing_values= False,tiles_scores_mlp_hidden=[200,100],
                        tiles_queries_mlp_hidden=[200,100], mlp_hidden=[200,100])
    elif model_name == 'hiptmil':
        model = HIPTMIL(in_dim, out_dim)
    elif model_name == 'transmil':
        model = TransMIL(in_dim, out_features=out_dim)
    elif model_name == 'meanpool':
        model = MeanPool(in_dim, out_dim)
    else:
        raise 'model not found'
    return model


In [10]:
class FakeDataset(Dataset):
    """自定义数据集"""

    # def __init__(self, images_path: list, images_class: list, images_ncl: list, images_epi: list, images_tub: list, images_mit: list, transform=None):
    def __init__(self):
      pass

    def __len__(self):
        return 185

    def __getitem__(self, item):
        n_tiles = torch.randint(20, [1]).item()
        feature = torch.rand(n_tiles, 768)
        label = [torch.randint(3,[1]).item(), torch.randint(3,[1]).item()]
        return feature, label

    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images,  labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        # masks = torch.as_tensor(masks)

        return images, labels

class Feature(FakeDataset):
    def get_label(self, type):
        D = {
            'normal':[0,0],
            'dcis-1':[1,0],
            'dcis-2':[1,1],
            'dcis-3':[1,2],
            'ibc-1':[2,0],
            'ibc-2':[2,1],
            'ibc-3':[2,2],
        }
        label =  D[type]
        return label[0],label[1]
    
    def __init__(self, feature_dir, phase='test'):
        folder = Path(feature_dir)
        feature_paths = list(folder.rglob('**/*.npy'))
        self.labels = [self.get_label(path.parent.name) for path in feature_paths]
        self.features = [np.load(path) for path in feature_paths]
        self.phase=phase
        
    def __len__(self):
        return len(self.features)
    def __getitem__(self, item):
        features = self.features[item]
        if self.phase == 'train':
            indices = np.random.choice(features.shape[0], 8, replace=True)
            features = np.stack([features[i] for i in indices], axis=0)
        return features, self.labels[item]
    
    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images,  labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        # masks = torch.as_tensor(masks)

        return images, labels


In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
merics = {'acc': compute_multiclass_accuracy, 'auc': compute_mean_one_vs_all_auc}
# train_dir = '/mnt/hd0/project/bcacad/model/roi_features/suqh/model'
# test_dir = '/mnt/hd0/project/bcacad/model/roi_features/suqh/test'
# train_dir = '/mnt/hd0/project/bcacad/model/roi_features/suqh_full/model'
# test_dir = '/mnt/hd0/project/bcacad/model/roi_features/suqh_full/test'
train_dir = '/mnt/hd0/project/bcacad/model/roi_features/suqh_all_patch_balance/model'
test_dir = '/mnt/hd0/project/bcacad/model/roi_features/suqh_all_patch/test'
train_set = Feature(train_dir, 'train')
test_set = Feature(test_dir, 'test')


In [5]:
fake = FakeDataset()

In [6]:
# model_names = ['meanpool', 'abmil', 'chowder', 'dsmil', 'hiptmil', 'transmil']
# model_names = ['meanpool', 'abmil', 'chowder', 'transmil']
model_names = [ 'abmil',  'transmil']
# model_names = ['transmil']
input_dim = 768
out_dim = [3,3]

In [12]:
cfg = dict(
    epoch=200,
    bs=16
)
models={}
for model_name in model_names:
    # print(model_name)
    model = get_model(model_name, input_dim, out_dim).to(device)
    trainer = TorchTrainer(model, criterion, merics, device=device, num_epochs=cfg['epoch'], batch_size=cfg['bs'])
    res = trainer.train(train_set, test_set)
    # res = trainer.train(fake, fake)
    models[model_name] = model
    tasks = ['type', 'nonibc', 'ibc']
    class_names = {
        'type': ['Normal', 'nonIBC', 'IBC'],
        'nonibc': ['Low', 'Medium', 'High'],
        'ibc': ['Low', 'Medium', 'High'],
    }

    for pha in ['train', 'test']:
        labels = np.array(res[pha][0])
        probs = np.array(res[pha][1])
        preds = np.array(res[pha][2])

        type_labels = labels[0]
        type_probs = probs[0]
        type_preds = preds[0]

        nonibc_index = np.where(type_labels ==1)
        nonibc_labels = labels[1][nonibc_index]
        nonibc_probs = probs[1][nonibc_index]
        nonibc_preds = preds[1][nonibc_index]

        ibc_index = np.where(type_labels ==2)
        ibc_labels = labels[1][ibc_index]
        ibc_probs = probs[1][ibc_index]
        ibc_preds = preds[1][ibc_index]

        re = {}
        avg_aucs = {}
        re['type'] = report(type_labels, type_preds, type_probs, class_names['type'])
        avg_aucs['type'] = compute_mean_one_vs_all_auc(type_labels, type_probs)

        re['nonibc'] = report(nonibc_labels, nonibc_preds, nonibc_probs, class_names['nonibc'])
        avg_aucs['nonibc'] = compute_mean_one_vs_all_auc(nonibc_labels, nonibc_probs)

        re['ibc'] = report(ibc_labels, ibc_preds, ibc_probs, class_names['ibc'])
        avg_aucs['ibc'] = compute_mean_one_vs_all_auc(ibc_labels, ibc_probs)

        for task in ['type', 'nonibc', 'ibc']:
            r = re[task]
            fs = "{} {} {} acc: {:.4f}, auc: {:.4f} [{:.4f} {:.4f} {:.4f}]".format(model_name, pha, task, r['accuracy'], avg_aucs[task], r['0']['auc'], r['1']['auc'], r['2']['auc'])
            print(fs)
        print()
    print()
    

    


100%|██████████| 200/200 [18:50<00:00,  5.65s/it]


abmil train type acc: 0.9982, auc: 0.9999 [1.0000 0.9998 1.0000]
abmil train nonibc acc: 0.9801, auc: 0.9971 [0.9993 0.9944 0.9975]
abmil train ibc acc: 0.9908, auc: 0.9989 [0.9998 0.9985 0.9984]

abmil test type acc: 0.8940, auc: 0.9851 [0.9940 0.9704 0.9912]
abmil test nonibc acc: 0.5331, auc: 0.7283 [0.7565 0.6091 0.8212]
abmil test ibc acc: 0.6682, auc: 0.7704 [0.8356 0.6768 0.8001]




100%|██████████| 200/200 [1:14:10<00:00, 22.25s/it]


transmil train type acc: 0.9984, auc: 0.9989 [1.0000 0.9980 0.9988]
transmil train nonibc acc: 0.9740, auc: 0.9803 [0.9997 0.9467 0.9941]
transmil train ibc acc: 0.9929, auc: 0.9904 [0.9996 0.9718 0.9997]

transmil test type acc: 0.9396, auc: 0.9736 [0.9953 0.9608 0.9647]
transmil test nonibc acc: 0.5784, auc: 0.7384 [0.7199 0.6538 0.8451]
transmil test ibc acc: 0.6541, auc: 0.7672 [0.8685 0.6562 0.7783]




In [24]:
cfg

{'epoch': 100, 'bs': 16}

In [1]:
save_dir = Path('/mnt/hd0/project/bcacad/model/roi_models/model2')
save_dir.mkdir(exist_ok=True, parents=True)
for model_name, model in models.items():
    torch.save(model.state_dict(), save_dir / f'{model_name}.pth')
    print(f'save {model_name} done')

NameError: name 'Path' is not defined

In [7]:
pha='test'
load_dir = Path('/mnt/hd0/project/bcacad/model/roi_models/model2')
for model_name in model_names:
    model = get_model(model_name, input_dim, out_dim).to(device)
    ckpt = torch.load(load_dir / f'{model_name}.pth')
    model.load_state_dict(ckpt)
    model = model.to(device)


    trainer = TorchTrainer(model, criterion, merics, device=device, num_epochs=50)
    res = trainer.predict(test_set)
    tasks = ['type', 'nonibc', 'ibc']
    class_names = {
        'type': ['Normal', 'nonIBC', 'IBC'],
        'nonibc': ['Low', 'Medium', 'High'],
        'ibc': ['Low', 'Medium', 'High'],
    }


    labels = np.array(res[0])
    probs = np.array(res[1])
    preds = np.array(res[2])

    type_labels = labels[0]
    type_probs = probs[0]
    type_preds = preds[0]

    nonibc_index = np.where(type_labels ==1)
    nonibc_labels = labels[1][nonibc_index]
    nonibc_probs = probs[1][nonibc_index]
    nonibc_preds = preds[1][nonibc_index]

    ibc_index = np.where(type_labels ==2)
    ibc_labels = labels[1][ibc_index]
    ibc_probs = probs[1][ibc_index]
    ibc_preds = preds[1][ibc_index]

    re = {}
    avg_aucs = {}
    re['type'] = report(type_labels, type_preds, type_probs, class_names['type'])
    avg_aucs['type'] = compute_mean_one_vs_all_auc(type_labels, type_probs)

    re['nonibc'] = report(nonibc_labels, nonibc_preds, nonibc_probs, class_names['nonibc'])
    avg_aucs['nonibc'] = compute_mean_one_vs_all_auc(nonibc_labels, nonibc_probs)

    re['ibc'] = report(ibc_labels, ibc_preds, ibc_probs, class_names['ibc'])
    avg_aucs['ibc'] = compute_mean_one_vs_all_auc(ibc_labels, ibc_probs)

    for task in ['type', 'nonibc', 'ibc']:
        r = re[task]
        fs = "{} {} {} acc: {:.4f}, auc: {:.4f} [{:.4f} {:.4f} {:.4f}]".format(model_name, pha, task, r['accuracy'], avg_aucs[task], r['0']['auc'], r['1']['auc'], r['2']['auc'])
        print(fs)





abmil test type acc: 0.9235, auc: 0.9916 [0.9932 0.9914 0.9902]
abmil test nonibc acc: 0.5784, auc: 0.7118 [0.6742 0.6248 0.8400]
abmil test ibc acc: 0.6729, auc: 0.7539 [0.8320 0.6772 0.7539]
transmil test type acc: 0.9477, auc: 0.9813 [0.9922 0.9735 0.9781]
transmil test nonibc acc: 0.5401, auc: 0.7023 [0.6550 0.6448 0.8105]
transmil test ibc acc: 0.7058, auc: 0.7648 [0.8328 0.6676 0.7956]
