Clone https://github.com/jfhealthcare/Chexpert and move CheXpert dataset inside this directory
Inference for 5 classes with pretrained weights

In [1]:
import torch
from easydict import EasyDict as edict
import json
from torch.nn import DataParallel
from model.classifier import Classifier
import time

# model = torch.load("config/pre_train.pth", map_location=torch.device('cpu'))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cfg_path = "config/example.json"
pre_train = "config/pre_train.pth"

with open(cfg_path) as f:
    cfg = edict(json.load(f))
    # print(json.dumps(cfg, indent=4))

# device_ids = list(map(int, args.device_ids.split(',')))
device_ids = []
num_devices = torch.cuda.device_count()
device = torch.device('cpu')
# device = torch.device('cuda:{}'.format(device_ids[0]))

model = Classifier(cfg)
model = DataParallel(model, device_ids=device_ids).to(device).train()
ckpt = torch.load(pre_train, map_location=device)
model.module.load_state_dict(ckpt)

<All keys matched successfully>

In [3]:
from torch.utils.data import DataLoader
from data.dataset import ImageDataset
import numpy as np

dataloader_test = DataLoader(
    ImageDataset(cfg["dev_csv"], cfg, mode='test'),
    batch_size=1, num_workers=4,
    drop_last=False, shuffle=False)

In [4]:
torch.set_grad_enabled(False)
model.eval()
device = torch.device('cpu')
dataloader= dataloader_test
steps = len(dataloader)
dataiter = iter(dataloader)
num_tasks = len(cfg.num_classes)
txt_file = "plot.txt"

test_header = [
    'Path',
    'Cardiomegaly',
    'Edema',
    'Consolidation',
    'Atelectasis',
    'Pleural Effusion']

In [5]:
def get_pred(output, cfg):
    if cfg.criterion == 'BCE' or cfg.criterion == "FL":
        for num_class in cfg.num_classes:
            assert num_class == 1
        pred = torch.sigmoid(output.view(-1)).cpu().detach().numpy()
    elif cfg.criterion == 'CE':
        for num_class in cfg.num_classes:
            assert num_class >= 2
        prob = F.softmax(output)
        pred = prob[:, 1].cpu().detach().numpy()
    else:
        raise Exception('Unknown criterion : {}'.format(cfg.criterion))

    return pred

In [6]:
print(','.join(test_header) + '\n')
print("steps", steps, "num_tasks", num_tasks)
steps = 5
image_file = open(txt_file, "w")
images = []

for step in range(steps):
    print("step ", step)
    image, path = next(dataiter)
    image = image.to(device)
    output, __ = model(image)
    batch_size = len(path)
    pred = np.zeros((num_tasks, batch_size))

    for i in range(num_tasks):
        pred[i] = get_pred(output[i], cfg)

    for i in range(batch_size):
        batch = ','.join(map(lambda x: '{}'.format(x), pred[:, i]))
        result = path[i] + ',' + batch
        print(result + '\n')
        print('{}, Image : {}, Prob : {}'.format(
            time.strftime("%Y-%m-%d %H:%M:%S"), path[i], batch))
        print('\n\n')
        # image_file.write(path[i] + '\n')
        images.append(path[i])
        

Path,Cardiomegaly,Edema,Consolidation,Atelectasis,Pleural Effusion

steps 234 num_tasks 5
step  0
CheXpert-v1.0-small/valid/patient64541/study1/view1_frontal.jpg,0.6085793375968933,0.5058494210243225,0.2419632077217102,0.45605579018592834,0.12261803448200226

2022-10-01 00:34:06, Image : CheXpert-v1.0-small/valid/patient64541/study1/view1_frontal.jpg, Prob : 0.6085793375968933,0.5058494210243225,0.2419632077217102,0.45605579018592834,0.12261803448200226



step  1
CheXpert-v1.0-small/valid/patient64542/study1/view1_frontal.jpg,0.05178125575184822,0.1388433873653412,0.17289796471595764,0.29364722967147827,0.08102937042713165

2022-10-01 00:34:07, Image : CheXpert-v1.0-small/valid/patient64542/study1/view1_frontal.jpg, Prob : 0.05178125575184822,0.1388433873653412,0.17289796471595764,0.29364722967147827,0.08102937042713165



step  2
CheXpert-v1.0-small/valid/patient64542/study1/view2_lateral.jpg,0.22355185449123383,0.07542623579502106,0.22232073545455933,0.17489545047283173,0.0746905803

In [7]:
from util.heatmaper import Heatmaper 
# Heatmaper = None

disease_classes = [
    'Cardiomegaly',
    'Edema',
    'Consolidation',
    'Atelectasis',
    'Pleural Effusion'
]
plot_path = "plots"
alpha = 0.2
prefix = "none"

# create plot folder
if not os.path.exists(plot_path):
    os.mkdir(plot_path)
# construct heatmap_cfg
heatmaper = Heatmaper(alpha, prefix, cfg, model, device)
assert prefix in ['none', *(disease_classes)]
with open(txt_file) as f:
    # for line in f:
    for line in images:
        time_start = time.time()
        jpg_file = line.strip('\n')
        print(jpg_file)
        prefix, figure_data = heatmaper.gen_heatmap(jpg_file)
        bn = os.path.basename(jpg_file)
        save_file = '{}/{}{}'.format(plot_path, prefix, bn)
        assert cv2.imwrite(save_file, figure_data), "write failed!"
        time_spent = time.time() - time_start
        print(
            '{}, {}, heatmap generated, Run Time : {:.2f} sec'
            .format(time.strftime("%Y-%m-%d %H:%M:%S"),
                    jpg_file, time_spent))

CheXpert-v1.0-small/valid/patient64541/study1/view1_frontal.jpg
[[[  7   7   7]
  [  7   7   7]
  [  7   7   7]
  ...
  [ 11  11  11]
  [ 11  11  11]
  [ 11  11  11]]

 [[  7   7   7]
  [  7   7   7]
  [  7   7   7]
  ...
  [ 11  11  11]
  [ 11  11  11]
  [ 11  11  11]]

 [[  7   7   7]
  [  7   7   7]
  [  7   7   7]
  ...
  [ 11  11  11]
  [ 11  11  11]
  [ 11  11  11]]

 ...

 [[ 66  66  66]
  [ 85  85  85]
  [103 103 103]
  ...
  [ 14  14  14]
  [ 14  14  14]
  [ 14  14  14]]

 [[ 65  65  65]
  [ 96  96  96]
  [106 106 106]
  ...
  [ 14  14  14]
  [ 14  14  14]
  [ 14  14  14]]

 [[ 77  77  77]
  [ 90  90  90]
  [131 131 131]
  ...
  [ 14  14  14]
  [ 14  14  14]
  [ 14  14  14]]]
[tensor([[0.4413]]), tensor([[0.0234]]), tensor([[-1.1419]]), tensor([[-0.1762]]), tensor([[-1.9679]])] []


RuntimeError: stack expects a non-empty TensorList