In [None]:
import os
import argparse
import json
from easydict import EasyDict as edict
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.nn import DataParallel
import torch.nn.functional as F
from tqdm import tqdm
import pandas as pd
import pickle 
import matplotlib.pyplot as plt
from concept_utils import DensenetCXRBottom, DensenetCXRTop
from concept_utils import ConceptBank, get_concept_scores_mv_valid

from data.dataset import ImageDataset  
from model.classifier import Classifier  

model_paths = {"nih": "/path/projects/ptx/chexpert_models/nih/nih_sr_1"}

test_dfs = {
            "cxp": "/path/cxp/chexpert_full/splits/v0/test_df.csv"
}

broden_path = {
    "nih":"/path/banks/concept_densenet_0.001_nih.pkl",
}
model_ds = "nih"
test_ds = "cxp"

parser = argparse.ArgumentParser(description='Test model')
MODEL_PATH = model_paths[model_ds]
CONCEPT_BANK_PATH = f"./xr_concept_densenet_{model_ds}_on_{test_ds}.pkl"
IN_CSV_PATH = test_dfs[test_ds]
BRODEN_CONCEPTS_PATH = broden_path[model_ds]

parser.add_argument('--model_path', default=MODEL_PATH, metavar='MODEL_PATH',
                    type=str, help="Path to the trained models")
parser.add_argument('--in_csv_path', default=IN_CSV_PATH, metavar='IN_CSV_PATH',
                    type=str, help="Path to the input image path in csv")
parser.add_argument('--out_csv_path', default='test/test.csv',
                    metavar='OUT_CSV_PATH', type=str,
                    help="Path to the ouput predictions in csv")
parser.add_argument('-f', default='0', type=str, help="GPU indices "
                    "comma separated, e.g. '0,1' ")
parser.add_argument('--concept-bank-path', default=CONCEPT_BANK_PATH)
parser.add_argument('--device', default="cuda", type=str)

if not os.path.exists('test'):
    os.mkdir('test')
args = parser.parse_args()
std_pxs = np.array([0.229, 0.224, 0.225])
mean_pxs = np.array([0.485, 0.456, 0.406])

In [None]:
with open(os.path.join(args.model_path, 'cfg.json')) as f:
    cfg = edict(json.load(f))

model = Classifier(cfg)
model = DataParallel(model).to(args.device).eval()
print("Model is initialized!!")
ckpt_path = os.path.join(args.model_path, 'best1.ckpt')
ckpt = torch.load(ckpt_path, map_location=args.device)
model.module.load_state_dict(ckpt['state_dict'])
print("Model is loaded!!")


In [None]:
torch.set_grad_enabled(True)
model.eval()

model_bottom, model_top = DensenetCXRBottom(model), DensenetCXRTop(model)
model_bottom, model_top = model_bottom.eval(), model_top.eval()

In [None]:
ds = pd.read_csv(IN_CSV_PATH)
ds_sampled = ds[ds["Frontal/Lateral"] == "Lateral"]
ds_sampled.to_csv("./temp_df.csv")
dataloader_test = DataLoader(
    ImageDataset("./temp_df.csv", cfg, mode='valid', sample_n=110),
    batch_size=1, num_workers=1,
    drop_last=False, shuffle=False)


In [None]:
with open(args.concept_bank_path, "rb") as f:
    xr_bank = pickle.load(f)
with open(BRODEN_CONCEPTS_PATH, "rb") as f:
    concept_bank = pickle.load(f)
n_broden = len(concept_bank)


for c in xr_bank.keys():
    if xr_bank[c][2] > .6:
        concept_bank[c] = xr_bank[c]

concept_bank = ConceptBank(concept_bank, args.device)
print(len(concept_bank.concept_names))
target_concepts = concept_bank.concept_names[n_broden:]


In [None]:
examples = 0
trues = 0
all_rows = []
for img, path, label in tqdm(dataloader_test):
    examples += 1
    img = img.to(args.device)
    label = label.to(args.device)
    embedding = model_bottom(img)
    out = model_top(embedding)
    prob = torch.sigmoid(out.view(-1)).cpu().detach().numpy()
    pred = (prob > 0.5).astype(np.int)

    if label.cpu().numpy()[0] == pred:
        trues+=1
        print("true")
        continue
    
    opt_result = get_concept_scores_mv_valid(img, label, 
                                         concept_bank, 
                                         model_bottom, model_top,
                                         alpha=1e-1, beta=1e-2, lr=1e-1,
                                         enforce_validity=True, momentum=0.9)
    top5_bottom5 = opt_result.concept_scores_list[:5] + opt_result.concept_scores_list[-5:]
    row = ds[ds.Path == path[0]].copy()
    for c in target_concepts:
        row[f"{c}-Top3"] = (c in opt_result.concept_scores_list[:3])
        row[f"{c}-Bottom3"] = (c in opt_result.concept_scores_list[-3:])
        row[f"{c}-Order"] = opt_result.concept_scores_list.index(c)
    all_rows.append(row)

In [None]:
all_df.groupby(["Lateral-Bottom3", "Frontal/Lateral"]).mean()["Lateral-Order"]

In [None]:
all_df.groupby(["Lateral-Bottom3", "Frontal/Lateral"]).count()["Lateral-Order"]

In [None]:
all_df.groupby(["View Position_AP-Top3", "Frontal/Lateral"]).mean()["View Position_AP-Order"]

### Plotting images

In [None]:
ds = pd.read_csv(IN_CSV_PATH)
ds_sampled = ds[ds["Frontal/Lateral"] != "Lateral"]
ds_sampled.to_csv("./temp_df.csv")
dataloader_test = DataLoader(
    ImageDataset("./temp_df.csv", cfg, mode='valid', sample_n=110),
    batch_size=1, num_workers=1,
    drop_last=False, shuffle=False)


In [None]:

fig, axs = plt.subplots(2,2, figsize=(10, 10))
axs = axs.flatten()
examples = 0
for img, path, label in tqdm(dataloader_test):
    ax = axs[examples]
    img = img.to(args.device)
    label = label.to(args.device)
    ax.imshow(img[0].cpu().permute(1,2,0).numpy()*std_pxs + mean_pxs)
    ax.axis("off")
    examples += 1
    if examples == 4:
        break
fig.tight_layout()
fig.savefig("../paper_figures/cxr/frontal_view.png")
plt.show(fig)
    

In [None]:
ds = pd.read_csv(IN_CSV_PATH)
ds_sampled = ds[ds["Frontal/Lateral"] == "Lateral"]
ds_sampled.to_csv("./temp_df.csv")
dataloader_test = DataLoader(
    ImageDataset("./temp_df.csv", cfg, mode='valid', sample_n=110),
    batch_size=1, num_workers=1,
    drop_last=False, shuffle=False)


In [None]:

fig, axs = plt.subplots(2,2, figsize=(10, 10))
axs = axs.flatten()
examples = 0
for img, path, label in tqdm(dataloader_test):
    ax = axs[examples]
    img = img.to(args.device)
    label = label.to(args.device)
    ax.imshow(img[0].cpu().permute(1,2,0).numpy()*std_pxs + mean_pxs)
    ax.axis("off")
    examples += 1
    if examples == 4:
        break
fig.tight_layout()
fig.savefig("../paper_figures/cxr/lateral_view.png")
plt.show(fig)
    