In [None]:
## Below script is obtained from Eric Wu (github:ericwu09) and modified. 
# This was used for the purposes of the following project: https://github.com/ericwu09/medical-ai-evaluation

import os
import argparse
import json
from easydict import EasyDict as edict
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.nn import DataParallel
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import pickle
from tqdm import tqdm
from concept_utils import DensenetCXRBottom, DensenetCXRTop, learn_concept, get_embedding
import pandas as pd

from data.dataset import ImageDataset  # noqa
from model.classifier import Classifier  # noqa

parser = argparse.ArgumentParser(description='Test model')

model_paths = {"nih": "/path/ptx/chexpert_models/nih/nih_sr_1"}

train_dfs = {
            "cxp": "/path/data/cxp/chexpert_full/splits/v0/train_df.csv"
}

model_ds = "nih"
test_ds = "cxp"

MODEL_PATH = model_paths[model_ds]
IN_CSV_PATH = train_dfs[test_ds]



parser = argparse.ArgumentParser()
parser.add_argument('--model-path', default=MODEL_PATH, metavar='MODEL-PATH',
                    type=str, help="Path to the trained models")
parser.add_argument('--in_csv_path', default=IN_CSV_PATH, metavar='IN_CSV_PATH',
                    type=str, help="Path to the input image path in csv")
parser.add_argument('--num-workers', default=2, type=int)
parser.add_argument('--batch-size', default=8, type=int)
parser.add_argument('-f')
parser.add_argument("--seed", default=1, type=int)
parser.add_argument("--n-samples", default=100, type=int)
parser.add_argument("--C", default=0.001, type=float)

if not os.path.exists('test'):
    os.mkdir('test')
args = parser.parse_args()


In [None]:
concept_bank = {}
with open(os.path.join(args.model_path, 'cfg.json')) as f:
    cfg = edict(json.load(f))

device = "cuda"
model = Classifier(cfg)
model = DataParallel(model).to(device).eval()
print("Model is initialized!!")
ckpt_path = os.path.join(args.model_path, 'best1.ckpt')
ckpt = torch.load(ckpt_path, map_location=device)
model.module.load_state_dict(ckpt['state_dict'])

model_bottom, model_top = DensenetCXRBottom(model), DensenetCXRTop(model)
print("Model is loaded!!")
model_bottom = model_bottom.eval()

In [None]:
## CXP
df = pd.read_csv(IN_CSV_PATH)
df.head()
metavars = ["Sex", "AP/PA", "Frontal/Lateral"]
df = pd.get_dummies(df, columns=metavars)
df["View Position_AP"] = df["AP/PA_AP"]
df["Patient Gender_F"] = df["Sex_Female"]
df["Lateral"] = df["Frontal/Lateral_Lateral"]
df["Age<24"] = (df["Age"] < 24).astype(int)
df["Age>60"] = (df["Age"] > 60).astype(int)
df.drop(["AP/PA_AP", "Sex_Female", "Frontal/Lateral_Lateral"], axis=1)
df.head()
metalabels = ["Patient Gender_F", "Age<24", "Age>60", "Atelectasis", "Cardiomegaly", "Lateral", "View Position_AP"]

In [None]:
xr_concepts = {}
with torch.no_grad():
    for var in metalabels:
        np.random.seed(1)
        pos_df = df[df[var] == 1].sample(args.n_samples)
        neg_df = df[df[var] == 0].sample(args.n_samples)
        pos_df.to_csv("./temp_pos.csv")
        neg_df.to_csv("./temp_neg.csv")
        pos_loader = DataLoader(
                    ImageDataset("./temp_pos.csv", cfg, mode='test', sample_n=args.n_samples),
                    batch_size=args.batch_size, num_workers=args.num_workers,
                    drop_last=False, shuffle=True)
        neg_loader = DataLoader(
                    ImageDataset("./temp_neg.csv", cfg, mode='test', sample_n=args.n_samples),
                    batch_size=args.batch_size, num_workers=args.num_workers,
                    drop_last=False, shuffle=True)
        pos_acts = get_embedding(pos_loader, model_bottom, n_samples=args.n_samples, device="cuda")
        neg_acts = get_embedding(neg_loader, model_bottom, n_samples=args.n_samples, device="cuda")
        acts = torch.cat([pos_acts, neg_acts], dim=0).cpu().numpy()
        c_labels = torch.cat([torch.ones(args.n_samples), torch.zeros(args.n_samples)], dim=0).cpu().numpy()
        concept_info = learn_concept(acts, c_labels, args, C=args.C)
        xr_concepts[var] = concept_info
        print(var, concept_info[1], concept_info[2])

In [None]:
for c, c_info in xr_concepts.items():
    assert c not in concept_bank
    concept_bank[c] = c_info

with open(f"./cxr_concept_densenet_{model_ds}_on_{test_ds}.pkl", "wb") as f:
    pickle.dump(concept_bank, f)