In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import logging
from matplotlib.colors import LinearSegmentedColormap
from KDEpy import FFTKDE
import matplotlib.gridspec as grid_spec
from tqdm import tqdm
import itertools
import seaborn as sns
import colorcet as cc
import scipy

In [None]:
DATASET = ...  # e.g. "/workspace/data/output/cvpr22/dataset.h5"

In [None]:
df_meta = pd.read_hdf(DATASET, key="meta")
df_meta["Task"] = df_meta["Task"].apply(lambda s: "Segmentation" if "Segmentation" in s else s)  # v1.0.0 had unreliable labeling of Segmentation models, so we clean it

logging.info("unfolding IDs")
df_meta["filter_ids"] = df_meta["filter_ids"].apply(lambda s: np.arange(int(s.split(":")[0]), 1 + int(s.split(":")[1])))

In [None]:
with h5py.File(DATASET, "r") as f:
    filters = f["filters"][:].reshape(-1, 9)

In [None]:
# We have a preprocessed the scaling and SVD, but you can do this on demand

with h5py.File("/workspace/data/output/cvpr22/transformed_filters.h5", "r") as f:
    filters_transformed = f["svd/maxscaled/transformed_fp32"][:].reshape(-1, 9)

In [None]:
df_meta.loc[df_meta.model == "compnet_weights_sagittal_improvement_09_NITRC_IITmean_b0_256_12", "Visual Category"] = "medical mri"

df_meta.loc[df_meta.model == "compnet_weights_axial_improvement_08_NITRC_IITmean_b0_256_12", "Training-Dataset"] = "nitrc_iitmean_b0/axial"
df_meta.loc[df_meta.model == "compnet_weights_coronal_improvement_08_NITRC_IITmean_b0_256_12", "Training-Dataset"] = "nitrc_iitmean_b0/coronal"
df_meta.loc[df_meta.model == "compnet_weights_sagittal_improvement_09_NITRC_IITmean_b0_256_12", "Training-Dataset"] = "nitrc_iitmean_b0/sagittal"

df_meta.loc[df_meta.model == "torchxrayvision_densenet121_mimic_ch_11", "Training-Dataset"] = "mimic_cxr_ch"
df_meta.loc[df_meta.model == "torchxrayvision_densenet121_mimic_nb_11", "Training-Dataset"] = "mimic_cxr_nb"

df_meta.loc[df_meta.model == "torchxrayvision_resnet101_elastic_ae_padchest_nih_chexpert_mimic_nb_mimic_ch_11", "Training-Dataset"] = "Aggregated"

df_meta.loc[df_meta.model == "torchxrayvision_resnet50_512_all_11", "Training-Dataset"] = "Aggregated"

df_meta.loc[df_meta.model == "torchxrayvision_densenet121_all_11", "Training-Dataset"] = "Aggregated"

df_meta.loc[df_meta.model == "torchxrayvision_densenet121_kaggle_11", "Training-Dataset"] = "Kaggle RSNA"

df_meta.loc[df_meta.model == "unet_carvana_carvana_11", "Training-Dataset"] = "Kaggle Carvana"

df_meta.loc[df_meta.model == "unet_lgg_mri_segmentation_11", "Training-Dataset"] = "Kaggle LGG"


df_medical_meta = df_meta[df_meta["Visual Category"].str.contains("medical")].copy()

In [None]:
df_medical_meta[["model", "Task", "Visual Category", "Training-Dataset"]].groupby("model").max()

### Generate filter visualizations

In [None]:
def hide_border(ax):
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False) 
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())
    ax.imshow(np.zeros((1,1,3)))

ids = df_medical_meta[df_medical_meta.model == "compnet_weights_sagittal_improvement_09_NITRC_IITmean_b0_256_12"].iloc[48].filter_ids[:25]
subfilters = filters[ids]
t = abs(subfilters).max()
fig, axes = plt.subplots(5, 5, figsize=(3, 3), squeeze=False)
for f, ax in zip(subfilters, axes.ravel()):
    hide_border(ax)
    ax.imshow(f.reshape(3, 3), vmin=-t, vmax=t, cmap=LinearSegmentedColormap.from_list("CyanOrange", ["C0", "white", "C1"]))

plt.savefig(f"plots/compnet_binary.pdf", bbox_inches='tight')

 ## KDE plots

In [None]:
def ridge_plot(X, xrange, shape, row_labels=None, col_labels=None, figsize=(40, 10)):
    gs = grid_spec.GridSpec(*shape)
    fig = plt.figure(figsize=figsize)

    ax_objs = []
    for i in tqdm(range(shape[0])):
        for j in range(shape[1]):

            data = X[i][j]
            dx, dy = FFTKDE(kernel="gaussian", bw='silverman').fit(data).evaluate()

            color = f"C{i%10}"

            # creating new axes object
            ax_objs.append(fig.add_subplot(gs[i:i+1, j:j+1]))

            # plotting the distribution
            ax_objs[-1].plot(dx, dy, color="#f0f0f0", lw=1)
            ax_objs[-1].fill_between(dx, dy, alpha=.7, color=color)

            # setting uniform x and y lims
            ax_objs[-1].set_xlim(*xrange)

            # make background transparent
            ax_objs[-1].patch.set_alpha(0)

            # remove borders, axis ticks, and labels
            ax_objs[-1].set_yticklabels([])
            ax_objs[-1].set_yticks([])
            ax_objs[-1].set_ylim([0, None])

            for s in ["top", "right", "left", "bottom"]:
                ax_objs[-1].spines[s].set_visible(False)

            if i == shape[0] - 1:
                ax_objs[-1].tick_params(direction="inout")
                ax_objs[-1].spines["bottom"].set_visible(True)

                if col_labels is not None:
                    ax_objs[-1].set_xlabel(col_labels[j])
            else:
                ax_objs[-1].set_xticks([])
                ax_objs[-1].set_xticklabels([])

            if j == 0 and row_labels is not None:
                label = row_labels[i]
                max_len = 35
                if len(label) > max_len:
                    label = f"{label[:max_len]} ..."
                ax_objs[-1].text(xrange[0] - 0.1, 0, label, ha="right", wrap=True, color=color)

    gs.update(hspace=-0.5)

In [None]:
x_range = (-3, 3)

col_labels = []

for i in range(9):
    ind = divmod(i, 3)
    ind = str(ind[0])+str(ind[1])
    col_labels.append("$c_{"+ind+"}$")

In [None]:
datatype_distributions = df_medical_meta.groupby("model").filter_ids.apply(lambda x: filters_transformed[np.hstack(x)].T)

figsize=(18, 1 + 0.25 * len(datatype_distributions))
ridge_plot(datatype_distributions.values, xrange=x_range, 
           shape=(len(datatype_distributions), 9), 
           row_labels=datatype_distributions.index, 
           col_labels=col_labels, 
           figsize=figsize)
plt.subplots_adjust(hspace=0, wspace=0.1)
plt.savefig(f"plots/kdes_models.pdf", bbox_inches='tight')

## Sparsity of torchxrayvision_resnet50

In [None]:
def sparsity(ids):
    subset = filters[ids]
    t = subset.max()
    return (subset.max(axis=1) >= 0.01 * t).sum()

np.sum([sparsity(f) for f in df_meta[df_meta.model == "torchxrayvision_resnet50_512_all_11"].filter_ids]), df_meta[df_meta.model == "torchxrayvision_resnet50_512_all_11"].filter_ids.apply(len).sum()

## KL

In [None]:
def kl_sym(p, q):
    return scipy.stats.entropy(p, q) + scipy.stats.entropy(q, p)


def nd_kl_sym(p, q, weights=None):
    if type(p) is not list and len(p.shape) == 1:
        return kl_sym(p, q)
    else:
        return np.sum(list(map(lambda i: weights[i] * kl_sym(p[i], q[i]), range(len(p)))))


def get_kl(data, bins, x_range, weights):
    p = get_nd_discrete_probability_distribution(data[0], x_range, bins)
    q = get_nd_discrete_probability_distribution(data[1], x_range, bins)
    return nd_kl_sym(p, q, weights)


def get_discrete_probability_distribution(X, _range, bins):
    v, _ = np.histogram(X, range=_range, bins=bins, density=True)  # density will not sum to 1 but help to not underflow eps during normalization
    v = v.astype(np.double) 
    v[v == 0] = np.finfo(np.float32).eps
    v = v / np.sum(v)
    return v


def get_nd_discrete_probability_distribution(X, _range, bins):
    if type(X) is not list and len(X.shape) == 1:
        return get_discrete_probability_distribution(X, _range, bins)
    else:
        dims = list()
        for x in X:
            v = get_discrete_probability_distribution(x, _range, bins)
            dims.append(v)
        return np.vstack(dims)


def kl_plot(s, figsize=(10, 10), ax=None, sort=True, **kwargs):
    created = False
    if ax is None:
        created = True
        plt.figure(figsize=figsize)
        ax = plt.gca()

    img = get_kl_matrix(s.values, **kwargs)
    labels = s.index
    if sort:
        sort_index = np.argsort(np.mean(img, axis=0))
        img = img[sort_index][:, sort_index]
        labels = labels[sort_index]

    cim = ax.imshow(img, cmap=cc.cm["fire"])
    ax.set_xticks(range(len(s)))
    ax.set_xticklabels(labels, rotation=90)
    ax.set_yticks(range(len(s)))
    ax.set_yticklabels(labels, rotation=0)

    if created:
        plt.colorbar(cim)

    return ax, img


def get_kl_matrix(data, bins, x_range, weights=None):
    d = len(data)
    kl_matrix = np.zeros((d, d))
    lookup = dict()
    for p, q in tqdm(itertools.product(range(d), repeat=2), total=d**2):
        if p not in lookup:
            lookup[p] = get_nd_discrete_probability_distribution(data[p], x_range, bins)
        if q not in lookup:
            lookup[q] = get_nd_discrete_probability_distribution(data[q], x_range, bins)
        kl_matrix[p, q] = nd_kl_sym(lookup[p], lookup[q], weights=weights)
    return kl_matrix

In [None]:
def dist(df):
    distr_series = df.groupby("model").agg({
        "filter_ids": lambda x: [filters_transformed[np.hstack(x)].T],
        "Training-Dataset": "max",
        "Visual Category": "max"
    })
    distr_series.index = [f'{row["Training-Dataset"]}\n({row["Visual Category"]})' for _, row in distr_series.iterrows()]
    distr_series.filter_ids = distr_series.filter_ids.apply(lambda x: x[0])
    ax, kl_m = kl_plot(distr_series.filter_ids, bins=70, x_range=x_range, weights=np.ones(9) / 9, sort=False)
    return kl_m, distr_series.index


combined_plot_labels = list()
combined_plot_kl_mats = list()
bins = 70

for group in [df_meta[df_meta.model.str.contains("unet")], df_meta[(df_meta.model.str.contains("densenet121")) & ~(df_meta.model.str.contains("hso_"))]]:
    mat, labels = dist(group)
    combined_plot_kl_mats.append(mat)
    combined_plot_labels.append(labels)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 6), facecolor="white")

titles = ["UNet", "DenseNet121"]

for i, (ax, img, labels, title) in enumerate(zip(axes.ravel(), combined_plot_kl_mats, combined_plot_labels, titles)): 
    sort_index = np.array(range(len(img)))

    cim = ax.imshow(img, vmin=0, vmax=0.5, cmap=cc.cm["fire"])
    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels(labels, rotation=90)
    ax.set_yticks(range(len(labels)))
    ax.set_yticklabels(labels, rotation=0)
    ax.set_title(title, fontweight="bold")
    cbar = plt.colorbar(cim, pad=0.01, ax=ax)
    cbar.set_label('KL Divergence', rotation=270, labelpad=20, fontweight="bold")

plt.tight_layout()
plt.subplots_adjust(hspace=0, wspace=0)

plt.savefig("plots/kl_combined.pdf", bbox_inches='tight')