In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
%%bash
# From PyTorch 1.11, the THC namespace is removed which prevents the PreciseRoIPooling library used in segmentation model to be built successfully
# pip install torch==1.10.1+cu102 torchvision==0.11.2+cu102 torchaudio==0.10.1 torchtext>=0.11.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.10.1+cu102 torchvision==0.11.2+cu102 -f https://download.pytorch.org/whl/torch_stable.html

!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit 
pip install ninja 2>> install.log
git clone https://github.com/davidbau/dissect.git dissect 2>> install.log
#pip list -v >> attribution_packages.log

In [None]:
try: # set up path
    import google.colab, sys, torch
    sys.path.append('/content/dissect')
    if not torch.cuda.is_available():
        print("Change runtime type to include a GPU.")  
except:
    pass

In [None]:
import torch, os, pickle, shutil, json, netdissect
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from collections import OrderedDict
from tqdm import tqdm
from PIL import Image
from imageio import imwrite
from IPython.display import display
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms, models
from torchvision.datasets.folder import pil_loader
from sklearn.feature_selection import SelectKBest, VarianceThreshold, mutual_info_classif

from netdissect import nethook, renormalize, parallelfolder, upsample, imgviz, show
from netdissect.easydict import EasyDict
from netdissect.workerpool import WorkerPool
from netdissect.imgsave import SaveImageWorker
from experiment import dissect_experiment as experiment

torch.backends.cudnn.benchmark = True
print('PyTorch version:', torch.__version__)

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print("Cuda available:", torch.cuda.is_available())

# seed = 2021
# torch.manual_seed(seed)
# if device.type == 'cuda':
#     torch.cuda.manual_seed_all(seed)
#     torch.backends.cudnn.benchmark = True

In [None]:
settings_path = '/content/drive/My Drive/Python Projects/POEM Pipeline Results/settings.pkl'
with open(settings_path, 'rb') as f:
    settings = pickle.load(f)

current_setting_path = '/content/drive/My Drive/Python Projects/POEM Pipeline Results/current_setting.txt'
with open(current_setting_path, 'r') as f:
    current_setting_title = f.read().splitlines()[0]
    print('Current setting:', current_setting_title)

model_settings = settings['model_settings']
model_dataset_settings = settings['model_dataset_settings']

title_parts = current_setting_title.split('_')
model_name = title_parts[0]
dataset_name = '_'.join(title_parts[1:]) 
dataset_pure_name = dataset_name.split('_')[0]

current_setting = model_dataset_settings[current_setting_title] 
num_classes = current_setting['num_classes']
use_dissection_models = False
target_layer = model_settings[model_name]['target_layer'] if not use_dissection_models else experiment.instrumented_layername(EasyDict(model=model_name))

seg_model_name = 'netpc'
image_size = 224
batch_size = 32

min_iou = 0.04
activation_high_thresh = 0.95
min_thresh_pixels = 10
n_top_channels_per_concept = 3
overlay_opacity = 0.5

check_seg_overlap = True
overlap_mode = 'overlap_to_activation_ratio'   # 'overlap_pixels_count', 'overlap_to_union_ratio', 'overlap_to_activation_ratio', 'overlap_to_segmentation_ratio'
min_overlap_ratio = 0.5
min_overlap_pixels = 5
category_index_map = {
    'object': 0,
    'material': 1,
    'part': 2,
    'color': 3
}

check_gradients = True
pool_gradients = False

filter_concepts = True
low_variance_thresh = 0.99
max_concepts = 10

binning_features = False
activation_low_thresh = 0.7
high_value = 2 if binning_features else 1
mid_value = 1
low_value = 0

binning_classes = False
certainty_thresh = 0.6

norm_mean = (0.485, 0.456, 0.406)
norm_std = (0.229, 0.224, 0.225)

args = EasyDict(model=model_name, dataset=dataset_name, seg=seg_model_name, layer=target_layer, quantile=activation_high_thresh)

In [None]:
def save_imported_packages(packages_path):

    # Saving imported packages and their versions: 
    import sys
    modules_info = []

    for module in sys.modules:
        if len(module.split('.')) > 1:   # ignoring subpackages
            continue

        try:
            modules_info.append((module, sys.modules[module].__version__))
        except:
            try:
                if type(sys.modules[module].version) is str:
                    modules_info.append((module, sys.modules[module].version))
                else:
                    modules_info.append((module, sys.modules[module].version()))
            except:
                try:
                    modules_info.append((module, sys.modules[module].VERSION))
                except:
                    pass

    modules_info.sort(key=lambda x: x[0])
    with open(packages_path, 'w') as f:
        for m in modules_info:
            f.write('{} {}\n'.format(m[0], m[1]))

In [None]:
def get_target_layer_name (final=False):

    if final and use_dissection_models:
        return experiment.instrumented_layername(args)
    else:
        if model_name is 'resnet18':
            return 'layer4.1.conv2'   # 'layer4'
        elif model_name is 'resnet50':
            return 'layer4.2.conv3'
        elif model_name is 'vgg16':
            return 'features.conv5_3'   # 'features.28'
        elif model_name is 'alexnet':
            return 'conv5'

    return None

In [None]:
def vgg16_model (*args, **kwargs):

    # A version of vgg16 model where layers are given their research names: 
    model = models.vgg16(*args, **kwargs)
    model.features = nn.Sequential(OrderedDict(zip([
        'conv1_1', 'relu1_1',
        'conv1_2', 'relu1_2',
        'pool1',
        'conv2_1', 'relu2_1',
        'conv2_2', 'relu2_2',
        'pool2',
        'conv3_1', 'relu3_1',
        'conv3_2', 'relu3_2',
        'conv3_3', 'relu3_3',
        'pool3',
        'conv4_1', 'relu4_1',
        'conv4_2', 'relu4_2',
        'conv4_3', 'relu4_3',
        'pool4',
        'conv5_1', 'relu5_1',
        'conv5_2', 'relu5_2',
        'conv5_3', 'relu5_3',
        'pool5'],
        model.features)))

    model.classifier = nn.Sequential(OrderedDict(zip([
        'fc6', 'relu6',
        'drop6',
        'fc7', 'relu7',
        'drop7',
        'fc8a'],
        model.classifier)))

    return model

In [None]:
def load_model (model_file=None):

    if use_dissection_models:
        model = experiment.load_model(args)
    else:
        if model_name == 'vgg16':
            model = vgg16_model(num_classes=num_classes)
        else:
            model = models.__dict__[model_name](num_classes=num_classes)
        model.load_state_dict(torch.load(model_file))
        model = nethook.InstrumentedModel(model).cuda().eval()

    return model

In [None]:
class CustomImageFolder (datasets.ImageFolder):

    # Override this method to also return image paths
    def __getitem__(self, index):
        original_tuple = super(CustomImageFolder, self).__getitem__(index)
        path = self.imgs[index][0]
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

In [None]:
def load_data (dataset_dir=None): 
    
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=norm_mean, std=norm_std)
    ])

    dataset = None
    if use_dissection_models:
        dataset = experiment.load_dataset(args)
    else:
        dataset = CustomImageFolder(root=dataset_dir, transform=transform)
        #dataset = parallelfolder.ParallelImageFolders([dataset_dir], classification=True, shuffle=True, transform=transform)

    data_loader = DataLoader(dataset, batch_size=batch_size)
    print('Processing {} data examples in {} batches with class index {}'.format(len(dataset), len(data_loader), dataset.class_to_idx))
    
    return dataset, data_loader

In [None]:
def load_channels_data (tally_path):

    f = open(tally_path)
    tally_data = json.load(f)
    if "units" not in tally_data:
        print('Error: units not present in loaded tally data from path {}'.format(tally_path))
        return None

    f.close()
    channels_data = tally_data["units"]
    channels_map = {}

    for ch_item in channels_data:
        if (('iou' not in ch_item) or ('unit' not in ch_item) or ('label' not in ch_item) or 
            ('cat' not in ch_item) or ('low_thresh' not in ch_item) or ('high_thresh' not in ch_item)): 
            print('Error: incomplete data in channel item:', ch_item)
            continue

        channel = ch_item['unit'] + 1
        channels_map[channel] = {
            'concept': ch_item['label'],
            'category': ch_item['cat'],
            'low_thresh': ch_item['low_thresh'],
            'high_thresh': ch_item['high_thresh'],
            'is_valid': True if ch_item['iou'] > min_iou else False
        }

        # iou = ch_item['iou']
        # if iou > min_iou:
        #     concept = ch_item['label']
        #     category = ch_item['cat']
        #     channel_concept_map[channel] = concept
        #     channel_category_map[channel] = category

    channels = list(set([k for k,v in channels_map.items()]))
    concepts = list(set([v['concept'] for k,v in channels_map.items()]))

    channels.sort()
    concepts.sort()

    # thresholds = np.load(thresholds_path)
    # channel_thresh_map = {i+1:t for i,t in enumerate(thresholds) if i+1 in channels}

    print('Processing {} concepts and {} channels'.format(len(concepts), len(channels)))
    print('channels_map:', channels_map)

    return channels_map, channels, concepts

In [None]:
def get_file_name_from_path (image_path):

    ind = image_path.rfind('/')
    image_fname = image_path[ind+1:]
    return image_fname

In [None]:
def save_image_set (images, filenames):

    pool = WorkerPool(worker=SaveImageWorker)
    for img, fname in zip(images, filenames):
        pool.add(img, fname)
    
    pool.join()

In [None]:
def get_binned_predictions (logits):

    # If the softmax probability for any of the other classes is higher than the certainty_ratio (e.g. 2/3) of the top class probability, 
    # then it will be considered as `maybe` instead of certain. 
    # Number for maybe status of each class is equal to the class number plus total number of classes (e.g. 2 for 0 class in a binary setting).

    certainty_ratio = (1.0 - certainty_thresh) / certainty_thresh

    probs = torch.softmax(logits, dim=1)
    top_probs, preds = torch.max(probs, dim=1)
    top_probs = (top_probs * certainty_ratio).view(-1, 1).expand(-1, probs.shape[1])
    probs_mask = probs > top_probs
    probs_mask_sum = torch.sum(probs_mask, dim=1)
    
    for i in range(preds.shape[0]):
        s = probs_mask_sum[i]
        if s > 1: 
            preds[i] = preds[i] + num_classes

    return preds.cpu().numpy()

In [None]:
def extract_class_titles (ds_name):
	ctitles = {}
	name_parts = ds_name.split('_')
	if len(name_parts) <= 1:
		return ctitles
	
	n_classes = len(name_parts[1:])
	for i,p in enumerate(name_parts[1:]):
		ctitles[i] = p
		if binning_classes:
			ctitles[i + n_classes] = 'maybe ' + p
		
	return ctitles

In [None]:
def plot_activation_histogram (act):

    hist, bin_edges = torch.histogram(act, bins=10)
    hist = hist.tolist()
    bin_edges = bin_edges.tolist()
    print(hist)
    print(bin_edges)
    x = []
    for i in range(len(bin_edges)-1):
        x.append('{:.4f}-{:.4f}'.format(bin_edges[i], bin_edges[i+1]))

    plt.figure(figsize=(15, 15))
    plt.bar(x, hist, align="center")
    plt.xlabel('Activation/Gradient Value')
    plt.ylabel('Frequency')
    plt.show()

In [None]:
def plot_sample_image_activations (dataset, image, channel, seg, act, grad, act_grad, upact_grad, channels_map, seg_concept_index_map, image_threshs):

    act = act[None, :, :, :]
    grad = grad[None, :, :, :]
    act_grad = act_grad[None, :, :, :]
    iv = imgviz.ImageVisualizer(size=(image_size, image_size), image_size=(image_size, image_size), source=dataset)

    img = renormalize.as_image(image, source=dataset)
    ch_index = channel - 1
    ch_info = channels_map[channel]
    concept = ch_info['concept']
    category = ch_info['category']
    act_thresh = ch_info['high_thresh']
    grad_thresh = image_threshs['high_thresh']
    print('Visualizing filter {} mapped to concept {} from category {}, with activation thresh {} and gradient thresh {}'
        .format(channel, concept, category, act_thresh, grad_thresh))

    act_grad_high_mask = (upact_grad[ch_index] > grad_thresh)

    seg_concept_index = seg_concept_index_map[concept]
    cat_index = category_index_map[category]
    target_seg = seg[cat_index]
    target_seg_concept_mask = (target_seg == seg_concept_index)

    overlap_mask = act_grad_high_mask & target_seg_concept_mask

    print('Input image:')
    show([[img]])

    print('Segmentation mask:')
    show([[iv.segmentation(target_seg_concept_mask)]])   # seg[0]

    print('Activation-gradients mask:')
    show([[iv.segmentation(act_grad_high_mask)]])

    print('Segmentation and activation-gradient overlap mask:')
    show([[iv.segmentation(overlap_mask)]])

    print('Activations heatmap:')
    show([[iv.heatmap(act, (0, ch_index), mode='nearest')]])

    print('Gradients heatmap:')
    print(grad[0, ch_index])
    print(grad[0, ch_index].min(), grad[0, ch_index].max())
    show([[iv.heatmap(grad, (0, ch_index), mode='nearest')]])

    print('Activation-gradients heatmap:')
    show([[iv.heatmap(act_grad, (0, ch_index), mode='nearest')]])

    print('Segmentation highlighted:')
    show([[iv.masked_image(image, target_seg_concept_mask.float(), level=0.99)]])

    print('Activations highlighted:')
    show([[iv.masked_image(image, act, (0, ch_index), level=act_thresh)]])

    print('Activation-gradients highlighted:')
    show([[iv.masked_image(image, act_grad, (0, ch_index), level=grad_thresh)]])

    print('Segmentation and activation-gradient overlap highlighted:')
    show([[iv.masked_image(image, overlap_mask.float(), level=0.99)]])

In [None]:
def extract_concepts_from_image (act, upact, seg, path, channels_map, seg_concept_index_map, channels, concepts, image_threshs=None):

    num_channels = act.shape[0]
    image_concepts = {con:low_value for con in concepts}  # used to hold the high/mid/low value of each concept for this image (used for image concepts file and saving activation images process)
    image_channels = {ch:low_value for ch in channels}    # used to hold the high/mid/low value of each channel for this image (used for image channels file and saving activation images process)
    image_concepts_counts = {con:0 for con in concepts}   # used to count how many channels related to each concept were high for this image (just used for stat keeping purposes)
    image_channels_counts = {ch:0 for ch in channels}     # used to keep the number of mid/high thresh pixels for each channel for this image (used in saving activation images process)

    for ch in channels:   # range(num_channels)
        ch_index = ch - 1
        ch_upact = upact[ch_index]
        ch_info = channels_map[ch]
        is_valid = ch_info['is_valid']
        channel_concept = ch_info['concept']
        channel_category = ch_info['category']
        channel_high_thresh = image_threshs['high_thresh'] if check_gradients else ch_info['high_thresh']
        channel_low_thresh = image_threshs['low_thresh'] if check_gradients else ch_info['low_thresh']

        if (ch_upact is None) or (not is_valid) or (channel_concept is None) or (channel_category is None) or (channel_high_thresh is None):
            continue

        # Checking whether the channel concept can be considered as high value for this image: 

        is_high = False
        ch_upact_high_mask = (ch_upact > channel_high_thresh)
        num_high_thresh = np.sum(ch_upact_high_mask.numpy())

        if num_high_thresh >= min_thresh_pixels:
            if check_seg_overlap:
                seg_concept_index = seg_concept_index_map[channel_concept] if channel_concept in seg_concept_index_map else None
                cat_index = category_index_map[channel_category] if channel_category in category_index_map else None
                if (seg_concept_index is None) or (cat_index is None):
                    print('Error: Missing segmentation concept index ({}) or category index ({}) for channel {} mapped to concept {} from category {}'
                        .format(seg_concept_index, cat_index, ch, channel_concept, channel_category))
                else:
                    target_seg = seg[cat_index]
                    target_seg_concept_mask = (target_seg == seg_concept_index)
                    num_concept_seg = np.sum(target_seg_concept_mask.numpy())
                    overlap_mask = ch_upact_high_mask & target_seg_concept_mask
                    num_overlap = np.sum(overlap_mask.numpy())

                    if overlap_mode is 'overlap_to_union_ratio':
                        union_mask = ch_upact_high_mask | target_seg_concept_mask
                        num_union = np.sum(union_mask.numpy())
                        overlap_ratio = num_overlap / num_union
                        if overlap_ratio >= min_overlap_ratio:
                            is_high = True

                    elif overlap_mode is 'overlap_to_activation_ratio':
                        overlap_ratio = num_overlap / num_high_thresh
                        if overlap_ratio >= min_overlap_ratio:
                            is_high = True

                    elif overlap_mode is 'overlap_to_segmentation_ratio':
                        overlap_ratio = num_overlap / num_concept_seg
                        if overlap_ratio >= min_overlap_ratio:
                            is_high = True

                    else:
                        if num_overlap >= min_overlap_pixels:
                            is_high = True

            else:
                is_high = True

        if is_high:
            image_channels[ch] = high_value
            image_concepts[channel_concept] = high_value
            image_concepts_counts[channel_concept] += 1
            image_channels_counts[ch] = num_high_thresh
            continue

        # if (image_concepts[channel_concept] == high_value) or (image_concepts[channel_concept] == mid_value):
        #     continue
        
        if not binning_features:
            continue

        # Checking whether the channel concept can be considered as mid value for this image: 

        is_mid = False
        ch_upact_mid_mask = (ch_upact > channel_low_thresh)
        num_mid_thresh = np.sum(ch_upact_mid_mask.numpy())

        if num_mid_thresh >= min_thresh_pixels:
            if check_seg_overlap:
                seg_concept_index = seg_concept_index_map[channel_concept] if channel_concept in seg_concept_index_map else None
                cat_index = category_index_map[channel_category] if channel_category in category_index_map else None
                if (seg_concept_index is None) or (cat_index is None):
                    print('Error: Missing segmentation concept index ({}) or category index ({}) for channel {} mapped to concept {} from category {}'
                        .format(seg_concept_index, cat_index, ch, channel_concept, channel_category))
                else:
                    target_seg = seg[cat_index]
                    target_seg_concept_mask = (target_seg == seg_concept_index)
                    num_concept_seg = np.sum(target_seg_concept_mask.numpy())
                    overlap_mask = ch_upact_mid_mask & target_seg_concept_mask
                    num_overlap = np.sum(overlap_mask.numpy())

                    if overlap_mode is 'overlap_to_union_ratio':
                        union_mask = ch_upact_mid_mask | target_seg_concept_mask
                        num_union = np.sum(union_mask.numpy())
                        overlap_ratio = num_overlap / num_union
                        if overlap_ratio >= min_overlap_ratio:
                            is_mid = True

                    elif overlap_mode is 'overlap_to_activation_ratio':
                        overlap_ratio = num_overlap / num_mid_thresh
                        if overlap_ratio >= min_overlap_ratio:
                            is_mid = True

                    elif overlap_mode is 'overlap_to_segmentation_ratio':
                        overlap_ratio = num_overlap / num_concept_seg
                        if overlap_ratio >= min_overlap_ratio:
                            is_mid = True

                    else:
                        if num_overlap >= min_overlap_pixels:
                            is_mid = True

            else:
                is_mid = True

        if is_mid:
            image_channels[ch] = mid_value
            image_channels_counts[ch] = num_mid_thresh
            if image_concepts[channel_concept] != high_value:
                image_concepts[channel_concept] = mid_value

    return image_concepts, image_channels, image_concepts_counts, image_channels_counts

In [None]:
def extract_concepts (model, segmodel, upfn, renorm, data_loader, channels_map, seg_concept_index_map, channels, concepts):

    activations = []
    gradients = []

    def activations_hook (module, input, output):
        #print('activations_hook called with output:', output.shape)
        activations.append(output.detach().cpu())
        if check_gradients:
            output.register_hook(gradients_hook)

    def gradients_hook (grad):
        #print('gradients_hook called with grad:', grad.shape)
        gradients.append(grad.detach().cpu())

    layer = model._modules.get(target_layer)
    if layer is None:
        for n,l in model.named_modules():
            if n == 'model.' + target_layer:
                layer = l
                print('Target layer found:', n)
                break
    layer.register_forward_hook(activations_hook)

    model.eval()

    concepts_counts = {con:0 for con in concepts}
    channels_counts = {ch:0 for ch in channels}
    concepts_counts_by_class = {c:{con:0 for con in concepts} for c in classes}

    concepts_rows_list = []
    channels_rows_list = []
    image_channels_counts_list = []
    image_threshs_list = []
    acts_list = []
    num_images = 0
    total_acc = 0
    
    #with torch.no_grad():
    for i, (images, labels, paths) in tqdm(enumerate(data_loader)):
        del activations[:]
        del gradients[:]
        images_gpu = images.cuda()
        labels_gpu = labels.cuda()

        model.zero_grad()
    
        output = model(images_gpu)
        _, preds = torch.max(output, 1)

        acts = activations[0]
        raw_acts = acts
        #acts = model.retained_layer(target_layer)

        grads = None
        if check_gradients:
            one_hot = torch.zeros_like(output).cuda()
            one_hot.scatter_(1, preds[:, None], 1.0)
            output.backward(gradient=one_hot, retain_graph=True)
            grads = gradients[0]

            wgrads = grads
            if pool_gradients:
                wgrads = torch.mean(grads, dim=[2,3], keepdims=True)

            acts = acts * wgrads
            acts = F.relu(acts)
            if i == 0:
                print('output: {}, preds: {}, one_hot: {}, grads: {}, wgrads: {}'
                    .format(output.shape, preds.shape, one_hot.shape, grads.shape, wgrads.shape))

        total_acc += (preds == labels_gpu).float().sum()
        preds = preds.cpu().numpy()
        labels = labels.numpy()
        #images = images.numpy()

        if binning_classes:
            preds = get_binned_predictions(output)

        upacts = upfn(acts)
        segs = None
        if check_seg_overlap:
            segs = segmodel.segment_batch(renorm(images_gpu), downsample=4).cpu()

        if i == 0:
            print('images: {}, labels: {}, preds: {}'.format(images.shape, labels.shape, preds.shape))
            print('acts: {}, upacts: {}, grads: {}, segs: {}'.format(acts.shape, upacts.shape, grads.shape if grads != None else 0, segs.shape if segs != None else 0))

        for j in range(images.shape[0]):
            num_images += 1
            pred = preds[j]
            label = labels[j]
            path = paths[j]
            image = images[j]
            act = acts[j]
            upact = upacts[j]
            raw_act = raw_acts[j]
            seg = segs[j] if segs != None else None
            grad = grads[j] if grads != None else None
            fname = get_file_name_from_path(path)

            # In case of checking gradients, we compute the specific activation/gradient threshold for an image based on the range of the values of all the channels of the image: 
            image_threshs = {'high_thresh': 0, 'low_thresh': 0}
            if check_gradients:
                image_threshs['high_thresh'] = torch.quantile(act, q=activation_high_thresh).item()
                if binning_features:
                    image_threshs['low_thresh'] = torch.quantile(act, q=activation_low_thresh).item()

            image_concepts, image_channels, image_concepts_counts, image_channels_counts = extract_concepts_from_image(act, upact, seg, path, channels_map, seg_concept_index_map, 
                                                                                                                       channels, concepts, image_threshs)

            if (i == 0) and (j == 0):
                print('Image {} with label {}, pred {}, act {}, upact {}, high threshold {}, low threshold {}, and concepts {}'
                    .format(path, label, pred, act.shape, upact.shape, image_threshs['high_thresh'], image_threshs['low_thresh'], image_concepts))
                print('Image {} with min activation {}, max activation {}, and thresholds {}'.format(fname, act.min(), act.max(), image_threshs))
                plot_activation_histogram(act)

            # Test visualizations: 
            # if check_seg_overlap and check_gradients and (fname == '00000133.jpg'):
            #     channel = 170
            #     plot_sample_image_activations(data_loader.dataset, image, channel, seg, raw_act, grad, act, upact, channels_map, seg_concept_index_map, image_threshs)
            #     return

            if check_seg_overlap and check_gradients and ('sea' in image_concepts) and (image_concepts['sea'] == high_value):
                channel = 144
                print('Image {} with pred {} and label {}'.format(fname, pred, label))
                plot_sample_image_activations(data_loader.dataset, image, channel, seg, raw_act, grad, act, upact, channels_map, seg_concept_index_map, image_threshs)
                #return

            acts_list.append(act)
            image_channels_counts_list.append(image_channels_counts)
            image_threshs_list.append(image_threshs)

            image_concepts_row = image_concepts   # {k:1 if v > 0 else 0 for k,v in image_concepts_counts.items()}
            image_concepts_row['pred'] = pred
            image_concepts_row['label'] = label
            image_concepts_row['id'] = num_images
            image_concepts_row['file'] = fname
            image_concepts_row['path'] = path
            concepts_rows_list.append(image_concepts_row)

            image_channels_row = image_channels   # {k:1 if v > 0 else 0 for k,v in image_channels_counts.items()}
            image_channels_row['pred'] = pred
            image_channels_row['label'] = label
            image_channels_row['id'] = num_images
            image_channels_row['file'] = fname
            image_channels_row['path'] = path
            channels_rows_list.append(image_channels_row)
            
            for con in concepts:
                cnt = image_concepts_counts[con]
                val = 1 if cnt > 0 else 0
                concepts_counts[con] += val
                concepts_counts_by_class[pred][con] += val

            for ch in channels:
                cnt = image_channels_counts[ch]
                channels_counts[ch] += 1 if cnt > 0 else 0

    total_acc = total_acc / num_images
    print('\nExtracted concepts from {} images with accuracy {:.3f}.'.format(num_images, total_acc))
    print('\nConcept counts:', concepts_counts)
    for c,counts in concepts_counts_by_class.items():
        print('\nConcept counts of class {}: {}'.format(c, counts))
    print('\nChannel counts:', channels_counts)

    concepts_df = pd.DataFrame(concepts_rows_list)
    channels_df = pd.DataFrame(channels_rows_list)

    return concepts_df, channels_df, acts_list, image_channels_counts_list, image_threshs_list

In [None]:
def filter_extracted_concepts (concepts_df, channels_df, channels_map):

    preds_df = concepts_df['pred']
    meta_cols = ['pred', 'label', 'id', 'file', 'path']
    meta_df = concepts_df[meta_cols]
    concept_cols = list(set(concepts_df.columns) - set(meta_cols))
    concept_cols.sort()
    cons_df = concepts_df[concept_cols]

    initial_concepts = list(cons_df.columns)
    print('Initial concepts ({}): {}'.format(len(initial_concepts), initial_concepts))
    
    var_selector = VarianceThreshold(threshold=(low_variance_thresh * (1 - low_variance_thresh)))
    var_selector.fit(cons_df)
    var_col_indices = var_selector.get_support(indices=True)
    cons_df = cons_df.iloc[:,var_col_indices]
    var_filtered_concepts = list(cons_df.columns)
    var_removed_concepts = set(initial_concepts) - set(var_filtered_concepts)
    print('Concepts removed by variance filtering ({}): {}'.format(len(var_removed_concepts), var_removed_concepts))

    k = max_concepts if len(var_filtered_concepts) > max_concepts else 'all'
    mut_selector = SelectKBest(mutual_info_classif, k=k)
    mut_selector.fit(cons_df, preds_df)
    mut_col_indices = mut_selector.get_support(indices=True)
    cons_df = cons_df.iloc[:,mut_col_indices]
    filtered_concepts = list(cons_df.columns)
    mut_removed_concepts = set(var_filtered_concepts) - set(filtered_concepts)
    print('Concepts removed by mutual info filtering ({}): {}'.format(len(mut_removed_concepts), mut_removed_concepts))
    print('Concepts reduced from {} to {} by concept filtering.'.format(len(initial_concepts), len(filtered_concepts)))
    print('Final concepts after filtering ({}): {}'.format(len(filtered_concepts), filtered_concepts))

    filtered_concepts_df = pd.concat([cons_df, meta_df], axis=1)
    display(filtered_concepts_df.head())

    channel_cols = list(set(channels_df.columns) - set(meta_cols))
    filtered_channels = [ch for ch in channel_cols if (channels_map[ch]['concept'] in filtered_concepts)]
    print('Channels reduced from {} to {} by concept filtering.'.format(len(channel_cols), len(filtered_channels)))

    cols_to_keep = filtered_channels + meta_cols
    filtered_channels_df = channels_df[cols_to_keep]
    display(filtered_channels_df.head())

    return filtered_concepts_df, filtered_channels_df, filtered_concepts, filtered_channels

In [None]:
def save_activation_images_of_image (iv, image_index, image_path, image_fname, acts, img_concepts_row, img_channels_row, 
                                     image_channels_counts, channels_map, concepts, output_dir, image_threshs=None):

    acts = acts[None, :, :, :]   # as required by iv.masked_image
    if image_index == 1:
        print('acts.shape in save_activation_images:', acts.shape)

    image_activated_channels = [k for k,v in image_channels_counts.items() if v > 0]   # Only keep those channels which have been either mid or high for the image
    if len(image_activated_channels) == 0:
        print('Image {} with path {} has no activated channels!'.format(image_index, image_path))
        return 0

    image_concept_channels = {con:[] for con in concepts}
    for ch in image_activated_channels:
        ch_info = channels_map[ch]
        is_valid = ch_info['is_valid']
        channel_concept = ch_info['concept']
        con_value = img_concepts_row[channel_concept]
        ch_value = img_channels_row[ch]
        num_high_thresh = image_channels_counts[ch]   # In case of binning features, it can be the count of either mid or high pixels, depending on whether the channel has been mid or high for the image

        if not is_valid:   # In case the channel IoU with the concept is lower than the min threshold, we don't need an image saved for the channel
            continue

        if con_value != ch_value:   # In case the concept is high for the image and the channel is mid, we don't need an image saved for the channel
            continue
            
        image_concept_channels[channel_concept].append((ch, num_high_thresh))

    image = pil_loader(image_path)   # Image.open(image_path)
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=norm_mean, std=norm_std)
    ])
    image = transform(image)

    images = []
    filenames = []
    for con,lst in image_concept_channels.items():
        if len(lst) == 0:
            continue

        top_channel_nums = sorted(lst, key=lambda x: x[1], reverse=True)[:n_top_channels_per_concept]
        top_channels = [k for k,v in top_channel_nums]
        if image_index == 1:
            print('Top channels for concept {}: {}'.format(con, top_channel_nums))

        for i,ch in enumerate(top_channels):
            ch_index = ch - 1
            ch_info = channels_map[ch]
            ch_value = img_channels_row[ch]

            channel_thresh = image_threshs['high_thresh'] if check_gradients else ch_info['high_thresh']
            feature_value_title = ''

            if binning_features:
                if ch_value == 2:
                    feature_value_title = 'high_'
                else:
                    feature_value_title = 'mid_'
                    channel_thresh = image_threshs['low_thresh'] if check_gradients else ch_info['low_thresh']

            if (channel_thresh is None):
                print('Error: Missing activation or threshold ({}) for channel {}!'.format(channel_thresh, ch))
                continue

            new_image = iv.masked_image(image, acts, (0, ch_index), level=channel_thresh)

            # mask = np.array(Image.fromarray(channel_activation).resize(size=(image.shape[1], image.shape[0]), resample=Image.BILINEAR))   # size=image.shape[:2]
            # mask = mask > channel_thresh
            # new_image = (mask[:, :, np.newaxis] * overlay_opacity + (1 - overlay_opacity)) * image

            ind = image_fname.rfind('.')
            image_fname_raw = image_fname[:ind]
            new_fname = image_fname_raw + '_' + feature_value_title + str(ch) + '_' + con + '.jpg'
            new_path = output_dir + '/' + new_fname

            # final_image = new_image.astype(np.uint8)
            # imwrite(new_path, final_image)

            images.append(new_image)
            filenames.append(new_path)

            new_image.save(new_path, optimize=True, quality=99)

    #save_image_set(images, filenames)
    print('Saved {} activation images for image {} with path {}'.format(len(images), image_index, image_path))
    return len(images)

In [None]:
def save_image_concepts_dataset (concepts_df, channels_df, image_channels_counts_list, image_threshs_list, acts_list, upfn, dataset, channels_map, 
                                 filtered_concepts, filtered_channels, concepts_output_path, channels_output_path, activation_images_path):

    iv = imgviz.ImageVisualizer(size=(image_size, image_size), image_size=(image_size, image_size), source=dataset)   # renormalizer=renorm

    num_images = len(concepts_df.index)

    filtered_concepts_counts = {con:0 for con in filtered_concepts}
    filtered_channels_counts = {ch:0 for ch in filtered_channels}
    filtered_concepts_counts_by_class = {c:{con:0 for con in filtered_concepts} for c in classes}
    
    num_act_images_saved = 0

    for i,con_row in concepts_df.iterrows():
        ch_row = channels_df.iloc[i]
        pred = con_row['pred']
        id = con_row['id']
        path = con_row['path']
        fname = con_row['file']
        act = acts_list[i]
        upact = upfn(torch.unsqueeze(act, dim=0))[0]
        image_channels_counts = image_channels_counts_list[i]
        filtered_image_channels_counts = {ch:image_channels_counts[ch] for ch in image_channels_counts if ch in filtered_channels}
        image_threshs = image_threshs_list[i]

        num_act_images = save_activation_images_of_image(iv, id, path, fname, upact, con_row, ch_row, filtered_image_channels_counts, channels_map, 
                                                         filtered_concepts, activation_images_path, image_threshs)
        num_act_images_saved += num_act_images

        for con in filtered_concepts:
            con_val = con_row[con]
            val = 1 if con_val == high_value else 0
            filtered_concepts_counts[con] += val
            filtered_concepts_counts_by_class[pred][con] += val

        for ch in filtered_channels:
            val = ch_row[ch]
            filtered_channels_counts[ch] += 1 if val == high_value else 0

    print('Saved {} activation images for {} images.'.format(num_act_images_saved, num_images))
    print('\nFiltered concept counts:', filtered_concepts_counts)
    for c,counts in filtered_concepts_counts_by_class.items():
        print('\nFiltered concept counts of class {}: {}'.format(c, counts))
    print('\nFiltered channel counts:', filtered_channels_counts)

    concepts_df.to_csv(concepts_output_path, index=False)
    channels_df.to_csv(channels_output_path, index=False)

In [None]:
model_file = 'model.pth'
dataset_dir = 'dataset'
drive_result_path = '/content/drive/My Drive/Python Projects/POEM Pipeline Results/' + model_name + '_' + dataset_name

if not use_dissection_models:
    dataset_file = dataset_dir + '.zip'
    drive_dataset_dir = drive_result_path + '/' + dataset_file
    !cp "$drive_dataset_dir" '.'
    !unzip -qq -n $dataset_file -d '.'

    drive_model_path = drive_result_path + '/' + model_file
    !cp "$drive_model_path" '.'

result_dir = 'identification_results'
result_file = result_dir + '.zip'
result_path = drive_result_path + "/" + result_file
!cp "$result_path" '.'
!unzip -qq -n $result_file -d '.'

activation_images_path = 'activation_images'
if os.path.exists(activation_images_path):
    shutil.rmtree(activation_images_path)
os.makedirs(activation_images_path)

# Optional: loading the segmenter models to avoid downloading them from netdissect server: 
segmodel_dir = 'segmodel'
segmodel_file = segmodel_dir + '.zip'
drive_segmodel_path = '/content/drive/My Drive/Python Projects/Pretrained Models/' + segmodel_file
target_segmodel_dir = 'datasets'
if not os.path.exists(target_segmodel_dir):
    os.makedirs(target_segmodel_dir)

target_segmodel_file = target_segmodel_dir + '/' + segmodel_file
!cp "$drive_segmodel_path" $target_segmodel_dir
!unzip -qq -n $target_segmodel_file -d $target_segmodel_dir

In [None]:
model = load_model(model_file)
model.retain_layer(target_layer)

dataset, data_loader = load_data(dataset_dir)
class_titles = extract_class_titles(dataset_name)
classes = list(class_titles.keys())

tally_path = './' + result_dir + '/report.json'
thresholds_path = './' + result_dir + '/channel_quantiles.npy'
channels_map, channels, concepts = load_channels_data(tally_path)

upfn = experiment.make_upfn(args, dataset, model, target_layer)
renorm = renormalize.renormalizer(dataset, target='zc')

segmodel = None
seg_concept_index_map = {}
if check_seg_overlap:
    segmodel, seglabels, segcatlabels = experiment.setting.load_segmenter(seg_model_name)
    for i,lbl in enumerate(seglabels):
        seg_concept_index_map[lbl] = i

model.stop_retaining_layers([target_layer])

In [None]:
concepts_df, channels_df, acts_list, image_channels_counts_list, image_threshs_list = extract_concepts(model, segmodel, upfn, renorm, data_loader, channels_map, 
                                                                                                       seg_concept_index_map, channels, concepts)

if filter_concepts:
    concepts_df, channels_df, concepts, channels = filter_extracted_concepts(concepts_df, channels_df, channels_map)

In [None]:
concepts_output_path = 'image_concepts.csv'
channels_output_path = 'image_channels.csv'

save_image_concepts_dataset(concepts_df, channels_df, image_channels_counts_list, image_threshs_list, acts_list, upfn, dataset, channels_map, 
                            concepts, channels, concepts_output_path, channels_output_path, activation_images_path)

In [None]:
!cp $concepts_output_path "$drive_result_path"
!cp $channels_output_path "$drive_result_path"

activation_images_file = activation_images_path + '.zip'
!zip -qq -r $activation_images_file $activation_images_path
!cp $activation_images_file '$drive_result_path'

packages_path = drive_result_path + '/attribution_packages.log'
save_imported_packages(packages_path)