In [2]:
%load_ext autoreload
%load_ext tensorboard
%matplotlib inline

In [4]:
import matplotlib
import numpy as np
import os
import random
import yaml
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc
from matplotlib import cm
import seaborn as sns
from importlib import reload
from pathlib import Path
import sklearn
import joblib
import torch
import pandas as pd
import copy

In [5]:
################################################################################
## Global Variables Defining Experiment Flow
################################################################################

GPU = 1
NUM_WORKERS = 8
rc('text', usetex=False)
plt.style.use('seaborn-whitegrid')
plt.rcParams["font.family"] = "serif"

# Load data

In [6]:
import cem.data.CUB200.cub_loader as cub_data_module
result_dir = "results/cub_interventions/"
with open(os.path.join(result_dir, "experiment_2023_05_13_23_50_config.yaml"), "r") as f:
    experiment_config = yaml.load(f, Loader=yaml.FullLoader)
experiment_config['shared_params']['batch_size'] = 512
experiment_config['shared_params']['num_workers'] = 4
train_dl, val_dl, test_dl, imbalance, (n_concepts, n_tasks, concept_map) = \
    cub_data_module.generate_data(
        config=experiment_config['shared_params'],
        seed=42,
        output_dataset_vars=True,
        root_dir=experiment_config.get('root_dir', None),
    )

Global seed set to 42
Global seed set to 42


# Load Models

In [7]:
split = 0
intcem_model_path = os.path.join(
    result_dir,
    f"IntAwareConceptEmbeddingModelRetry_intervention_weight_5_horizon_rate_1.005_intervention_discount_1_task_discount_1.1_resnet34_fold_{split + 1}.pt",
)
intcem_model_config = joblib.load(
    intcem_model_path.replace(".pt", "_experiment_config.joblib"),
)

cem_model_path = os.path.join(
    result_dir,
    f"ConceptEmbeddingModel_resnet34_fold_{split + 1}.pt"
)
cem_model_config = joblib.load(
    cem_model_path.replace(".pt", "_experiment_config.joblib"),
)

In [11]:
from cem.train.training import load_trained_model
from cem.interventions.random import IndependentRandomMaskIntPolicy
from cem.interventions.global_policies import ConstantMaskPolicy

intcem = load_trained_model(
    config=intcem_model_config,
    n_tasks=n_tasks,
    result_dir=result_dir,
    n_concepts=n_concepts,
    split=split,
    imbalance=imbalance,
    task_class_weights=None,
    train_dl=train_dl,
    sequential=False,
    logger=False,
    independent=False,
    gpu=int(torch.cuda.is_available()),
    output_latent=True,
    output_interventions=True,
    enable_checkpointing=False,
)

cem = load_trained_model(
    config=intcem_model_config,
    n_tasks=n_tasks,
    result_dir=result_dir,
    n_concepts=n_concepts,
    split=split,
    imbalance=imbalance,
    task_class_weights=None,
    train_dl=train_dl,
    sequential=False,
    logger=False,
    independent=False,
    gpu=int(torch.cuda.is_available()),
    intervention_policy=None,
    output_latent=True,
    output_interventions=True,
    enable_checkpointing=False,
)


# Generate Latent Spaces

In [10]:
def concepts_from_competencies(c, competencies):
    correct_interventions = np.random.binomial(
        n=1,
        p=competencies,
        size=c.shape,
    )
    return (
        c * correct_interventions + (1 - c) * (1 - correct_interventions)
    ).type(torch.FloatTensor)


In [12]:
import pytorch_lightning as pl
import time

intcem.intervention_policy = ConstantMaskPolicy(
    cbm=intcem,
    mask=np.ones((len(concept_map),)),
    concept_group_map=concept_map,
    num_groups_intervened=len(concept_map),
    group_based=True,
    include_prior=False,
)
trainer = pl.Trainer(
    gpus=1,
    logger=False,
)
test_batch_results = trainer.predict(
    intcem,
    test_dl,
)

intcem_c_preds = np.concatenate(
    list(map(lambda x: x[0].detach().cpu().numpy(), test_batch_results)),
    axis=0,
)
intcem_c_embs = np.concatenate(
    list(map(lambda x: x[1].detach().cpu().numpy(), test_batch_results)),
    axis=0,
)
intcem_y_preds = np.concatenate(
    list(map(lambda x: x[2].detach().cpu().numpy(), test_batch_results)),
    axis=0,
)

GPU available: True, used: True
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
from sklearn.neighbors import NearestNeighbors
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.cluster.hierarchy import fcluster
from sklearn.manifold import TSNE

class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        return tensor * np.expand_dims(np.expand_dims(self.std, axis=-1), axis=-1) + np.expand_dims(
            np.expand_dims(self.mean, axis=-1),
            axis=-1
        )
    
def show_bird_image(image, ax):
    ax.grid(False)
    ax.axis(False)
    ax.imshow(np.transpose(
        (UnNormalize(mean=[0.5, 0.5, 0.5], std=[2, 2, 2])(image) * 255).astype(np.int32),
        axes=[1, 2, 0]
    ))

def show_closest_activation_examples(
    x_test,
    test_dl,
    test_c_embs,
    concept_semantics=None,
    num_examples=5,
    shown_neighs=5,
    scale=1.5,
    selected_concepts=None,
    seed=None,
):
    np.random.seed(seed)
    selected_concepts = selected_concepts or list(range(test_c_embs.shape[1]))
    for selected_concept in selected_concepts:
        if concept_semantics is not None:
            print(
                "Selected concept at index",
                selected_concept,
                "with semantics",
                concept_semantics[selected_concept],
            )

        selected_inds = []
        for i, (_, _, c_batch) in enumerate(test_dl):
            for idx, c in enumerate(c_batch):
                if c[selected_concept] == 1:
                    selected_inds.append(i*c_batch.shape[0] + idx)
        selected_inds = np.random.choice(selected_inds, size=num_examples, replace=False,)
        fig, axs = plt.subplots(
            num_examples,
            shown_neighs + 2,
            figsize=(scale*shown_neighs, scale*num_examples),
        )
        if concept_semantics is not None:
            fig.suptitle(
                f'Closest Embeddings for Concept {concept_semantics[selected_concept]}',
                fontsize=15,
            )
        for i, example_idx in enumerate(selected_inds):
            show_bird_image(x_test[example_idx, :, :, :], axs[i, 0])
            if i == 0:
                # Then add a title here
                axs[i, 0].set_title("Sample", fontsize=20)
            # Let's add an empty image in between as a separator
            axs[i, 1].grid(False)
            axs[i, 1].axis(False)
            nbrs = NearestNeighbors(n_neighbors=(shown_neighs + 1), algorithm='ball_tree').fit(
                test_c_embs[:, selected_concept, :]
            )
            [distances], [nearest_indices] = nbrs.kneighbors(test_c_embs[example_idx:example_idx+1, selected_concept, :])
            for j, sample_idx in enumerate(nearest_indices[1:], start=2):
                show_bird_image(x_test[sample_idx, :, :, :], axs[i, j])
                if (i == 0) and ((j - 2) == shown_neighs // 2):
                    axs[i, j].set_title("Nearest Neighbors", fontsize=20)
        fig.tight_layout()
        fig.subplots_adjust(
            wspace=0,
            hspace=0,
        )
        plt.show()

def show_concept_clusters(
    x_test,
    test_c_embs,
    test_c_sems,
    max_d=50, #100
    concept_semantics=None,
    show_activated_only=True,
    show_examples=True,
    max_clusters=5,
    shown_samples=5,
    scale=1.5,
    selected_concepts=None,
    model_name="",
):
    selected_concepts = selected_concepts or list(range(test_c_embs.shape[1]))
    for selected_concept in selected_concepts:
        if concept_semantics is not None:
            print(
                "Selected concept at index",
                selected_concept,
                "with semantics",
                concept_semantics[selected_concept],
            )
        if show_activated_only:
            selected_inds = np.arange(0, test_c_embs.shape[0])[
                test_c_sems[:, selected_concept] > 0.5,
            ].astype(np.int32)
        else:
            selected_inds = np.arange(0, test_c_embs.shape[0]).astype(np.int32)

        selected_test_embs = test_c_embs[
            selected_inds,
            :,
            :
        ]

        if show_examples:
            print("Examples of selected test samples:")
            fig = plt.figure(figsize=(14, 6))
            for i, idx in enumerate(selected_inds[:8]):
                fig.add_subplot(1, 8, i + 1)
                show_bird_image(x_test[idx, :, :, :], plt)
            plt.show()

        # selected_test_embs = test_c_embs
        Z = linkage(selected_test_embs[:, selected_concept, :], 'ward')
        clusters = fcluster(Z, max_d, criterion='distance')
        cluster_types = np.unique(clusters)
        print("Found", len(cluster_types), "clusters from", clusters.shape[0], "samples")
        cluster_map = [
            [] for _ in range(len(cluster_types))
        ]
        for i, cluster_type in enumerate(clusters):
            cluster_map[cluster_type - 1].append(i)

        fig, axs = plt.subplots(
            min(len(cluster_map), max_clusters),
            shown_samples,
            figsize=(scale*shown_samples, scale*min(max_clusters, len(cluster_map))),
        )
        if concept_semantics is not None:
            fig.suptitle(
                f'{model_name} Sample Concept Clusters for Concept {concept_semantics[selected_concept]}',
                fontsize=15,
            )
        for row in axs:
            for ax in row:
                ax.grid(False)
                ax.axis(False)

        for cluster_id, samples in enumerate(cluster_map):
            if cluster_id >= max_clusters:
                break
            real_shown_samples = min(shown_samples, len(samples))
            centroid = np.expand_dims(
                np.mean(selected_test_embs[samples, selected_concept, :], axis=0),
                axis=0,
            )
            nbrs = NearestNeighbors(n_neighbors=real_shown_samples, algorithm='ball_tree').fit(
                selected_test_embs[samples, selected_concept, :]
            )
            [distances], [nearest_indices] = nbrs.kneighbors(centroid)
            for i, sample_idx in enumerate(nearest_indices):
                real_idx = selected_inds[samples[sample_idx]]
                show_bird_image(x_test[real_idx, :, :, :], axs[cluster_id, i])
        fig.tight_layout()
        fig.subplots_adjust(
            wspace=0,
            hspace=0.1,
        )
        plt.show()

def show_inter_concept_similarity(
    x_test,
    test_c_embs,
    test_c_sems,
    concept_semantics=None,
    show_activated_only=True,
    selected_concepts=None,
    normalize=True,
    n_closest=5,
    metric='cosine',
    to_console=True,
):
    selected_concepts = selected_concepts or list(range(test_c_embs.shape[1]))

    centroids = np.zeros((len(selected_concepts), test_c_embs.shape[-1]))
    for i, concept_idx in enumerate(selected_concepts):
        if show_activated_only:
            selected_inds = np.arange(0, test_c_embs.shape[0])[
                test_c_sems[:, concept_idx] > 0.5,
            ].astype(np.int32)
        else:
            selected_inds = np.arange(0, test_c_embs.shape[0]).astype(np.int32)

        selected_test_embs = test_c_embs[
            selected_inds,
            :,
            :
        ]
        centroids[i, :] = np.mean(
            selected_test_embs[:, concept_idx, :],
            axis=0,
        )
    if normalize:
        centroids = sklearn.preprocessing.normalize(centroids, axis=1)
    nbrs = NearestNeighbors(
        n_neighbors=n_closest + 1,
        algorithm='auto',
        metric=metric,
    ).fit(centroids)
    
    result = []
    for i, concept_idx in enumerate(selected_concepts):
        [distances], [nearest_concepts] = nbrs.kneighbors(centroids[i:i+1, :])
        concept_name = concept_idx
        nearest_concepts_idx = nearest_concepts
        if concept_semantics is not None:
            concept_name = concept_semantics[concept_idx]
            nearest_concepts = np.array(concept_semantics)[nearest_concepts]
        
        if to_console:
            print(f"Nearest concepts to concept {concept_name}:")
        partial_lst = []
        for j, name, dist in zip(nearest_concepts_idx, nearest_concepts[1:], distances[1:]):
            if to_console:
                print(f"\t{name} (distance {dist})")
            partial_lst.append((j, dist))
        result.append(partial_lst)
        if to_console:
            print()
    return centroids, result

def plot_concept_centroids(
    x_test,
    test_c_embs,
    test_c_sems,
    concept_semantics=None,
    selected_concepts=None,
    perplexity=50,
    n_iter=1000,
    figsize=(8, 6),
    show_activated_only=True,
    model_name="SplitEmb",
    annotation_size=5,
    concept_colors=None,
    dot_size=10,
    half_emb=False,
):
    selected_concepts = selected_concepts or list(range(test_c_embs.shape[1]))
    if half_emb:
        centroids = np.zeros((len(selected_concepts), test_c_embs.shape[-1]//2))
    else:
        centroids = np.zeros((len(selected_concepts), test_c_embs.shape[-1]))
    for i, concept_idx in enumerate(selected_concepts):
        if show_activated_only:
            selected_inds = np.arange(0, test_c_embs.shape[0])[
                test_c_sems[:, concept_idx] > 0.5,
            ].astype(np.int32)
        else:
            selected_inds = np.arange(0, test_c_embs.shape[0]).astype(np.int32)
        
        if half_emb:
            selected_test_embs = test_c_embs[
                selected_inds,
                :,
                :test_c_embs.shape[-1]//2
            ]
        else:
            selected_test_embs = test_c_embs[
                selected_inds,
                :,
                :
            ]
        centroids[i, :] = np.mean(
            selected_test_embs[:, concept_idx, :],
            axis=0,
        )
#     centroids = sklearn.preprocessing.normalize(centroids, axis=1)
    tsne = TSNE(
        n_components=2,
        verbose=1,
        perplexity=perplexity,
        n_iter=n_iter,
        init='pca',
        learning_rate='auto',
    )
    tsne_results = tsne.fit_transform(centroids)
    fig, ax = plt.subplots(
        1,
        1,
        figsize=figsize,
    )
    ax.set_title(
        f"{model_name} Cluster Centroids",
        fontsize=15,
    )
    if concept_colors is None:
        colors = []
        marker = []
        concept_semantics = concept_semantics or [
            f'concept_{idx}' for idx in selected_concepts
        ]
        for i, concept_idx in enumerate(selected_concepts):
            concept_name = concept_semantics[concept_idx]
            if "color" in concept_name:
                marker.append("o")
                color = concept_name[concept_name.find("::") + 2:]
                colors.append(color)
            else:
                colors.append("black")
                marker.append("x")
    else:
        colors = list(np.array(concept_colors)[selected_concepts])
        markers = ["o" for _ in selected_concepts]
    for i, color in enumerate(colors):
        if color == "buff":
            color = "palegoldenrod" 
        elif color == "multi-colored":
            color = "palegreen"
        elif color == "white":
            color = "cyan"
        colors[i] = color
    ax.scatter(
        tsne_results[:, 0],
        tsne_results[:, 1],
        c=colors,
        s=dot_size,
        # TODO!!!!!!!
#         marker=markers,
    )
   
    for i, concept_idx in enumerate(selected_concepts):
        concept_name = concept_semantics[concept_idx]
        ax.annotate(
            concept_semantics[concept_idx],
            (tsne_results[i, 0], tsne_results[i, 1]),
            fontsize=annotation_size,
        )
    ax.grid(False)
    ax.axis(False)
    fig.legend(fontsize=10) #, loc='center right')
    plt.show()
            
def plot_tsne_embeddings(
    test_c_embs,
    c_test,
    color_activations=None,
    color_activation_labels=None,
    attributes=None,
    perplexity=50,
    n_iter=1000,
    figsize=(8, 6),
    selected_concepts=None,
    y_test=None,
    model_name="SplitEmb",
):
    results = []
    selected_concepts = selected_concepts or list(range(test_c_embs.shape[1]))
    for selected_concept in selected_concepts:
        if attributes is not None:
            print(
                "Selected concept at index",
                selected_concept,
                "with semantics",
                attributes[selected_concept],
            )
        tsne = TSNE(
            n_components=2,
            verbose=1,
            perplexity=perplexity,
            n_iter=n_iter,
            init='pca',
            learning_rate='auto',
        )
        tsne_results = tsne.fit_transform(test_c_embs[:, selected_concept, :])
        results.append(tsne_results)
        if y_test is not None:
            fig, ax = plt.subplots(
                1,
                1,
                figsize=figsize,
            )
            if attributes is not None:
                ax.set_title(
                    f"TSNE Embeddings for {attributes[selected_concept]} (by class)",
                    fontsize=15,
                )
            ax.scatter(
                tsne_results[:, 0],
                tsne_results[:, 1],
                c=y_test,
                s=5,
                cmap='ocean',
            )
            ax.grid(False)
            ax.axis(False)
            plt.show()
            
        if color_activations is not None:
            activations = color_activations
        else:
            activations = [c_test[:, selected_concept]]
        for i, activation in enumerate(activations):
            if color_activation_labels is not None:
                activation_label = color_activation_labels[i]
            elif (color_activations is None):
                if (attributes is not None):
                    activation_label = attributes[selected_concept]
                else:
                    activation_label = f"Concept {selected_concept + 1}"
            else:
                activation_label = f"Concept {i + 1}"
            
            mask = activation == 1
            neg_mask = np.logical_not(mask)

            # And let's plot all of these
            fig, ax = plt.subplots(
                1,
                1,
                figsize=figsize,
            )
            if attributes is not None:
                ax.set_title(
                    f"{model_name} TSNE Embeddings for {attributes[selected_concept]}",
                    fontsize=15,
                )
            ax.scatter(
                tsne_results[mask, 0],
                tsne_results[mask, 1],
                color='red',
                label=activation_label + " active",
                s=5,
            )

            ax.scatter(
                tsne_results[neg_mask, 0],
                tsne_results[neg_mask, 1],
                color='blue',
                label=activation_label + " not active",
                s=5,
            )
            ax.grid(False)
            ax.axis(False)
            fig.legend(fontsize=10) #, loc='center right')
            plt.show()
        
    return results


def plot_tsne_latent_space(
    test_c_embs,
    c_test,
    attributes=None,
    perplexity=50,
    n_iter=1000,
    figsize=(8, 6),
    selected_concepts=None,
    y_test=None,
):
    tsne = TSNE(
        n_components=2,
        verbose=1,
        perplexity=perplexity,
        n_iter=n_iter,
        init='pca',
        learning_rate='auto',
    )
    latent_space = test_c_embs.reshape(
        test_c_embs.shape[0],
        -1,
    )
    tsne_results = tsne.fit_transform(latent_space)
    selected_concepts = selected_concepts or list(range(test_c_embs.shape[1]))
    if y_test is not None:
        fig, ax = plt.subplots(
            1,
            1,
            figsize=figsize,
        )
        if attributes is not None:
            ax.set_title(
                f"SplitEmb COMPLETE Latent Space TSNE colored by class",
                fontsize=15,
            )
        ax.scatter(
            tsne_results[:, 0],
            tsne_results[:, 1],
            c=y_test,
            s=5,
            cmap='ocean',
        )

        ax.grid(False)
        ax.axis(False)
        plt.show()
        
    for selected_concept in selected_concepts:
        if attributes is not None:
            print(
                "Selected concept at index",
                selected_concept,
                "with semantics",
                attributes[selected_concept],
            )

        mask = c_test[:, selected_concept] == 1
        neg_mask = np.logical_not(mask)

        # And let's plot all of these
        fig, ax = plt.subplots(
            1,
            1,
            figsize=figsize,
        )
        if attributes is not None:
            ax.set_title(
                f"SplitEmb COMPLETE Latent Space TSNE colored by {attributes[selected_concept]}",
                fontsize=15,
            )
        ax.scatter(
            tsne_results[mask, 0],
            tsne_results[mask, 1],
            color='red',
            label="Concept activated",
            s=5,
        )

        ax.scatter(
            tsne_results[neg_mask, 0],
            tsne_results[neg_mask, 1],
            color='blue',
            label="Concept not present",
            s=5,
        )
        ax.grid(False)
        ax.axis(False)
        fig.legend(fontsize=10) #, loc='center right')
        plt.show()
    return tsne_results