In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
import os
import shutil
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import OrderedDict
from tqdm import tqdm
from PIL import Image
from imageio import imwrite
from sklearn.feature_selection import SelectKBest, VarianceThreshold, mutual_info_classif

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader, Subset, sampler
from torchvision import datasets, transforms, models
from torchvision.datasets.folder import pil_loader

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Cuda available:", torch.cuda.is_available())

seed = 2021
torch.manual_seed(seed)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(seed)

In [None]:
settings_path = '/content/drive/My Drive/Python Projects/POEM Pipeline Results/settings.pkl'
with open(settings_path, 'rb') as f:
    settings = pickle.load(f)

current_setting_path = '/content/drive/My Drive/Python Projects/POEM Pipeline Results/current_setting.txt'
with open(current_setting_path, 'r') as f:
    current_setting_title = f.read().splitlines()[0]
    print('Current setting:', current_setting_title)

model_settings = settings['model_settings']
model_dataset_settings = settings['model_dataset_settings']

title_parts = current_setting_title.split('_')
model_name = title_parts[0]
dataset_name = '_'.join(title_parts[1:]) 
dataset_pure_name = dataset_name.split('_')[0]

current_setting = model_dataset_settings[current_setting_title] 
num_classes = current_setting['num_classes']
layer_names = [model_settings[model_name]['target_layer']]

image_size = 224
batch_size = 32

min_iou = 0.04
quantile_thresh = 0.99
min_thresh_pixels = 1
n_top_channels_per_concept = 3
overlay_opacity = 0.5

filter_concepts = False
low_variance_thresh = 0.99
max_concepts = 10

norm_mean = (0.485, 0.456, 0.406)
norm_std = (0.229, 0.224, 0.225)

In [None]:
def save_imported_packages(packages_path):

    # Saving imported packages and their versions: 
    import sys
    modules_info = []

    for module in sys.modules:
        if len(module.split('.')) > 1:   # ignoring subpackages
            continue

        try:
            modules_info.append((module, sys.modules[module].__version__))
        except:
            try:
                if type(sys.modules[module].version) is str:
                    modules_info.append((module, sys.modules[module].version))
                else:
                    modules_info.append((module, sys.modules[module].version()))
            except:
                try:
                    modules_info.append((module, sys.modules[module].VERSION))
                except:
                    pass

    modules_info.sort(key=lambda x: x[0])
    with open(packages_path, 'w') as f:
        for m in modules_info:
            f.write('{} {}\n'.format(m[0], m[1]))

In [None]:
def vgg16_model (*args, **kwargs):

    # A version of vgg16 model where layers are given their research names: 
    model = models.vgg16(*args, **kwargs)
    model.features = nn.Sequential(OrderedDict(zip([
        'conv1_1', 'relu1_1',
        'conv1_2', 'relu1_2',
        'pool1',
        'conv2_1', 'relu2_1',
        'conv2_2', 'relu2_2',
        'pool2',
        'conv3_1', 'relu3_1',
        'conv3_2', 'relu3_2',
        'conv3_3', 'relu3_3',
        'pool3',
        'conv4_1', 'relu4_1',
        'conv4_2', 'relu4_2',
        'conv4_3', 'relu4_3',
        'pool4',
        'conv5_1', 'relu5_1',
        'conv5_2', 'relu5_2',
        'conv5_3', 'relu5_3',
        'pool5'],
        model.features)))

    model.classifier = nn.Sequential(OrderedDict(zip([
        'fc6', 'relu6',
        'drop6',
        'fc7', 'relu7',
        'drop7',
        'fc8a'],
        model.classifier)))

    return model

In [None]:
def load_model (model_path):

    if model_name == 'vgg16':
        model = vgg16_model(num_classes=num_classes)
    else:
        model = models.__dict__[model_name](num_classes=num_classes)
    model.load_state_dict(torch.load(model_path))

    return model

In [None]:
class ImageDataset (Dataset):
    
    def __init__(self, images_path, file_names, labels, transform):
        self.images_path = images_path
        self.file_names = file_names
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.file_names)
    
    def __getitem__(self, index):
        fname = self.file_names[index]
        path = os.path.join(self.images_path, fname)
        img = Image.open(path)
        
        if self.transform:
            img = self.transform(img)
        
        if self.labels is not None:
            label = self.labels[index]
            return img, label
        
        return img, fname

In [None]:
class CustomImageFolder (datasets.ImageFolder):

    # Override this method to also return image paths
    def __getitem__(self, index):
        original_tuple = super(CustomImageFolder, self).__getitem__(index)
        path = self.imgs[index][0]
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

In [None]:
def load_data (dataset_dir=None): 
    
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=norm_mean, std=norm_std)
    ])

    # data_set = None

    # if dataset == 'imagenette':
    #     data_set = datasets.ImageFolder(root='imagenette2-320/train', transform=transform)

    # elif dataset == 'places2':
    #     data_set = datasets.ImageFolder(root='places2/train', transform=transform)

    # elif dataset == 'cifar10':
    #     data_set = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)

    # elif dataset == 'indoor':
    #     data_set = CustomImageFolder(root='indoor_subset', transform=transform)

    # elif dataset == 'imagewoof':
    #     data_set = datasets.ImageFolder(root='imagewoof2-320/train', transform=transform)

    #     if (target_classes != None) and (len(target_classes) > 0):
    #         class_indexes_dict = data_set.class_to_idx
    #         data_indexes = data_set.targets

    #         target_class_indexes = [v for k,v in class_indexes_dict.items() if k in target_classes]
    #         target_data_indexes = [i for i,x in enumerate(data_indexes) if x in target_class_indexes]

    #         data_set = Subset(data_set, target_data_indexes)

    # else: 
    #     data = pd.read_csv(labels_path)
    #     file_names = data['File'].tolist()
    #     labels = data['Label'].tolist()
        
    #     data_set = ImageDataset(images_path, file_names, labels, transform=transform)

    dataset = CustomImageFolder(root=dataset_dir, transform=transform)
    data_loader = DataLoader(dataset, batch_size=batch_size)
    print('Processing {} data examples in {} batches with class index {}'.format(len(dataset), len(data_loader), dataset.class_to_idx))
    
    return data_loader

In [None]:
def load_channels_data (tally_path, thresholds_path):

    tally_data = pd.read_csv(tally_path)
    tally_data = tally_data[tally_data['score'] > min_iou]

    channels_list = tally_data['unit'].tolist()
    concepts_list = tally_data['label'].tolist()

    channel_concept_map = {k:v for k,v in zip(channels_list, concepts_list)}
    channels = list(set(channels_list))
    concepts = list(set(concepts_list))

    channels.sort()
    concepts.sort()

    thresholds = np.load(thresholds_path)
    channel_thresh_map = {i+1:t for i,t in enumerate(thresholds) if i+1 in channels}

    print('Processing {} concepts and {} channels'.format(len(concepts), len(channels)))
    print('channel_concept_map:', channel_concept_map)
    print('channel_thresh_map:', channel_thresh_map)

    return channel_concept_map, channel_thresh_map, channels, concepts

In [None]:
def get_file_name_from_path (image_path):

    ind = image_path.rfind('/')
    image_fname = image_path[ind+1:]
    return image_fname

In [None]:
def extract_class_titles (ds_name):
	ctitles = {}
	name_parts = ds_name.split('_')
	if len(name_parts) <= 1:
		return ctitles
	
	n_classes = len(name_parts[1:])
	for i,p in enumerate(name_parts[1:]):
		ctitles[i] = p
		
	return ctitles

In [None]:
def extract_concepts_from_batch (acts, channel_concept_map, channel_thresh_map, channels, concepts):

    batch_concepts_counts = []
    batch_channels_counts = []
    num_images = acts.shape[0]
    
    for i in range(num_images):
        act = acts[i]
        num_channels = act.shape[0]
        image_concepts_counts = {con:0 for con in concepts}
        image_channels_counts = {ch:0 for ch in channels}

        for ch in channels:   # range(num_channels)
            ch_index = ch - 1
            channel_activation = act[ch_index]
            channel_thresh = channel_thresh_map[ch] if ch in channel_thresh_map else None
            channel_concept = channel_concept_map[ch] if ch in channel_concept_map else None

            if (channel_activation is None) or (channel_thresh is None) or (channel_concept is None):
                print('Error: Missing activation, concept ({}), or threshold ({}) for channel {}!'.format(channel_concept, channel_thresh, ch))
                continue

            num_high_thresh = np.sum(channel_activation > channel_thresh)
            if num_high_thresh >= min_thresh_pixels:
                image_concepts_counts[channel_concept] += 1
                image_channels_counts[ch] = num_high_thresh
                if ch not in channels:
                    print('Error: channel {} not among channels list!'.format(ch))

        batch_concepts_counts.append(image_concepts_counts)
        batch_channels_counts.append(image_channels_counts)

    return batch_concepts_counts, batch_channels_counts

In [None]:
def extract_concepts (model, data_loader, channel_concept_map, channel_thresh_map, channels, concepts):

    batch_activations = []

    def get_activations(module, input, output):
        batch_activations.append(output.data.cpu().numpy())

    for name in layer_names:
        #model._modules.get(name).register_forward_hook(get_activations)
        layer = model._modules.get(name)
        if layer is None:
            for n,l in model.named_modules():
                if n == name:
                    layer = l
                    print('Target layer found:', n)
                    break
        layer.register_forward_hook(get_activations)

    model.eval()

    concepts_counts = {con:0 for con in concepts}
    channels_counts = {ch:0 for ch in channels}
    concepts_counts_by_class = {c:{con:0 for con in concepts} for c in classes}

    concepts_rows_list = []
    channels_rows_list = []
    image_channels_counts_list = []
    acts_list = []
    num_images = 0
    total_acc = 0
    
    with torch.no_grad():
        for i, (images, labels, paths) in tqdm(enumerate(data_loader)):
            del batch_activations[:]
            images_gpu = images.to(device)
            labels_gpu = labels.to(device)
        
            output = model(images_gpu)

            _, preds = torch.max(output, 1)
            total_acc += (preds == labels_gpu).float().sum()
            preds = preds.cpu().numpy()
            labels = labels.numpy()
            images = images.numpy()
            acts = batch_activations[0]   # currently assume there is only one target layer

            if i == 0:
                print('images: {}, labels: {}, preds: {}'.format(images.shape, labels.shape, preds.shape))
                print('batch_activations.shape: {} * {}'.format(len(batch_activations), batch_activations[0].shape))
                print('paths:', paths)

            batch_concepts_counts, batch_channels_counts = extract_concepts_from_batch(acts, channel_concept_map, channel_thresh_map, channels, concepts)

            for j in range(len(batch_concepts_counts)):
                num_images += 1
                image_concepts_counts = batch_concepts_counts[j]
                image_channels_counts = batch_channels_counts[j]
                pred = preds[j]
                label = labels[j]
                path = paths[j]
                image = images[j]
                act = acts[j]
                fname = get_file_name_from_path(path)

                if i == 0 and j == 0:
                    print('Image {} with label {}, pred {}, act shape {}, and concepts counts {}'
                        .format(path, label, pred, act.shape, image_concepts_counts))
                    
                acts_list.append(act)
                image_channels_counts_list.append(image_channels_counts)

                image_concepts_row = {k:1 if v > 0 else 0 for k,v in image_concepts_counts.items()}
                image_concepts_row['pred'] = pred
                image_concepts_row['label'] = label
                image_concepts_row['id'] = num_images
                image_concepts_row['file'] = fname
                image_concepts_row['path'] = path
                concepts_rows_list.append(image_concepts_row)

                image_channels_row = {k:1 if v > 0 else 0 for k,v in image_channels_counts.items()}
                image_channels_row['pred'] = pred
                image_channels_row['label'] = label
                image_channels_row['id'] = num_images
                image_channels_row['file'] = fname
                image_channels_row['path'] = path
                channels_rows_list.append(image_channels_row)
                
                for con in concepts:
                    cnt = image_concepts_counts[con]
                    val = 1 if cnt > 0 else 0
                    concepts_counts[con] += val
                    concepts_counts_by_class[pred][con] += val

                for ch in channels:
                    cnt = image_channels_counts[ch]
                    channels_counts[ch] += 1 if cnt > 0 else 0

    total_acc = total_acc / num_images
    print('\nExtracted concepts from {} images with accuracy {:.3f}.'.format(num_images, total_acc))
    print('\nConcept counts:', concepts_counts)
    for c,counts in concepts_counts_by_class.items():
        print('\nConcept counts of class {}: {}'.format(c, counts))
    print('\nChannel counts:', channels_counts)

    concepts_df = pd.DataFrame(concepts_rows_list)
    channels_df = pd.DataFrame(channels_rows_list)

    return concepts_df, channels_df, acts_list, image_channels_counts_list

In [None]:
def filter_extracted_concepts (concepts_df, channels_df, channel_concept_map):

    preds_df = concepts_df['pred']
    meta_cols = ['pred', 'label', 'id', 'file', 'path']
    meta_df = concepts_df[meta_cols]
    concept_cols = list(set(concepts_df.columns) - set(meta_cols))
    concept_cols.sort()
    cons_df = concepts_df[concept_cols]

    initial_concepts = list(cons_df.columns)
    print('Initial concepts ({}): {}'.format(len(initial_concepts), initial_concepts))
    
    var_selector = VarianceThreshold(threshold=(low_variance_thresh * (1 - low_variance_thresh)))
    var_selector.fit(cons_df)
    var_col_indices = var_selector.get_support(indices=True)
    cons_df = cons_df.iloc[:,var_col_indices]
    var_filtered_concepts = list(cons_df.columns)
    var_removed_concepts = set(initial_concepts) - set(var_filtered_concepts)
    print('Concepts removed by variance filtering ({}): {}'.format(len(var_removed_concepts), var_removed_concepts))

    k = max_concepts if len(var_filtered_concepts) > max_concepts else 'all'
    mut_selector = SelectKBest(mutual_info_classif, k=k)
    mut_selector.fit(cons_df, preds_df)
    mut_col_indices = mut_selector.get_support(indices=True)
    cons_df = cons_df.iloc[:,mut_col_indices]
    filtered_concepts = list(cons_df.columns)
    mut_removed_concepts = set(var_filtered_concepts) - set(filtered_concepts)
    print('Concepts removed by mutual info filtering ({}): {}'.format(len(mut_removed_concepts), mut_removed_concepts))
    print('Concepts reduced from {} to {} by concept filtering.'.format(len(initial_concepts), len(filtered_concepts)))
    print('Final concepts after filtering ({}): {}'.format(len(filtered_concepts), filtered_concepts))

    filtered_concepts_df = pd.concat([cons_df, meta_df], axis=1)
    display(filtered_concepts_df.head())

    channel_cols = list(set(channels_df.columns) - set(meta_cols))
    filtered_channels = [ch for ch in channel_cols if (ch in channel_concept_map) and (channel_concept_map[ch] in filtered_concepts)]
    print('Channels reduced from {} to {} by concept filtering.'.format(len(channel_cols), len(filtered_channels)))

    cols_to_keep = filtered_channels + meta_cols
    filtered_channels_df = channels_df[cols_to_keep]
    display(filtered_channels_df.head())

    return filtered_concepts_df, filtered_channels_df, filtered_concepts, filtered_channels

In [None]:
def save_activation_images (image_index, image_path, image_fname, acts, image_channels_counts, channel_concept_map, channel_thresh_map, concepts, output_dir):

    image_activated_channels = [k for k,v in image_channels_counts.items() if v > 0]
    if len(image_activated_channels) == 0:
        print('Image {} with path {} has no activated channels!'.format(image_index, image_path))
        return 0

    # # Un-normalizing the image back to its original form: 
    # torch_image = torch.from_numpy(image)
    # torch_image.mul_(torch.as_tensor(norm_std).view(-1,1,1)).add_(torch.as_tensor(norm_mean).view(-1,1,1))   # normalization actually did the reverse: torch_image.sub_(norm_mean).div_(norm_std)
    # image = torch_image.numpy()

    # Preferred to reopen the image and apply the initial resize transform to it without the normalization step, instead of manual un-normalization:
    image = Image.open(image_path)
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(image_size),
        transforms.ToTensor()
    ])
    image = transform(image).numpy()

    # Changing the shape of image from CxHxW to HxWxC format: 
    image = np.transpose(image, (1, 2, 0))   # image.permute(1, 2, 0) in PyTorch

    # Normalizing the pixel values to (0, 255) range required later for saving the image: 
    image_min, image_max = np.min(image), np.max(image)
    image = (((image - image_min) / (image_max - image_min)) * 255).astype(np.uint8)
    if image_index in [1,2]:
        print('image_min: {}, image_max: {}, new_image.min: {}, new_image.max: {}'.format(image_min, image_max, np.min(image), np.max(image)))

    image_concept_channels = {con:[] for con in concepts}
    for ch in image_activated_channels:
        channel_concept = channel_concept_map[ch]
        num_high_thresh = image_channels_counts[ch]

        if (channel_concept is None) or (num_high_thresh is None):
            #print('Error: Missing concept ({}) or number of high-thresh pixels ({}) for channel {}!'.format(channel_concept, num_high_thresh, ch))
            continue
            
        image_concept_channels[channel_concept].append((ch, num_high_thresh))

    cnt = 0
    for con,lst in image_concept_channels.items():
        if len(lst) == 0:
            continue

        top_channel_nums = sorted(lst, key=lambda x: x[1], reverse=True)[:n_top_channels_per_concept]
        top_channels = [k for k,v in top_channel_nums]
        if image_index in [1,2]:
            print('Top channels for concept {}: {}'.format(con, top_channel_nums))

        for i,ch in enumerate(top_channels):
            ch_index = ch - 1
            channel_activation = acts[ch_index]
            channel_thresh = channel_thresh_map[ch]
            channel_concept = con

            if (channel_activation is None) or (channel_thresh is None):
                print('Error: Missing activation or threshold ({}) for channel {}!'.format(channel_thresh, ch))
                continue

            should_print = (image_index in [1,2]) and (i == 0)

            mask = np.array(Image.fromarray(channel_activation).resize(size=(image.shape[1], image.shape[0]), resample=Image.BILINEAR))   # size=image.shape[:2]
            mask = mask > channel_thresh
            new_image = (mask[:, :, np.newaxis] * overlay_opacity + (1 - overlay_opacity)) * image

            ind = image_fname.rfind('.')
            image_fname_raw = image_fname[:ind]
            new_fname = image_fname_raw + '_' + str(ch) + '_' + channel_concept + '.jpg'
            new_path = output_dir + '/' + new_fname

            final_image = new_image.astype(np.uint8)
            if should_print:
                print('new_image.min: {}, new_image.max: {}, final_image.min: {}, final_image.max: {}'
                    .format(np.min(new_image), np.max(new_image), np.min(final_image), np.max(final_image)))

            # new_image_scaled = (((new_image - new_image.min()) / (new_image.max() - new_image.min())) * 255).astype(np.uint8)
            # if should_print:
            #     new_min, new_max = np.min(new_image), np.max(new_image)
            #     scaled_min, scaled_max = np.min(new_image_scaled), np.max(new_image_scaled)
            #     print('new_min: {}, new_max: {}, scaled_min: {}, scaled_max: {}'.format(new_min, new_max, scaled_min, scaled_max))

            imwrite(new_path, final_image)
            cnt += 1
    
    print('Saved {} activation images for image {} with path {}'.format(cnt, image_index, image_path))
    return cnt

In [None]:
def save_image_concepts_dataset (concepts_df, channels_df, image_channels_counts_list, acts_list, channel_concept_map, channel_thresh_map, filtered_concepts, 
                                 filtered_channels, concepts_output_path, channels_output_path, activation_images_path):

    num_images = len(concepts_df.index)

    filtered_concepts_counts = {con:0 for con in filtered_concepts}
    filtered_channels_counts = {ch:0 for ch in filtered_channels}
    filtered_concepts_counts_by_class = {c:{con:0 for con in filtered_concepts} for c in classes}

    num_act_images_saved = 0

    for i,con_row in concepts_df.iterrows():
        ch_row = channels_df.iloc[i]
        pred = con_row['pred']
        id = con_row['id']
        path = con_row['path']
        fname = con_row['file']
        act = acts_list[i]
        image_channels_counts = image_channels_counts_list[i]
        filtered_image_channels_counts = {ch:image_channels_counts[ch] for ch in image_channels_counts if ch in filtered_channels}

        for con in filtered_concepts:
            val = con_row[con]
            filtered_concepts_counts[con] += val
            filtered_concepts_counts_by_class[pred][con] += val

        for ch in filtered_channels:
            val = ch_row[ch]
            filtered_channels_counts[ch] += val

        num_act_images = save_activation_images(id, path, fname, act, filtered_image_channels_counts, channel_concept_map, channel_thresh_map, 
                                                filtered_concepts, activation_images_path)
        num_act_images_saved += num_act_images

    print('Saved {} activation images for {} images.'.format(num_act_images_saved, num_images))
    print('\nFiltered concept counts:', filtered_concepts_counts)
    for c,counts in filtered_concepts_counts_by_class.items():
        print('\nFiltered concept counts of class {}: {}'.format(c, counts))
    print('\nFiltered channel counts:', filtered_channels_counts)

    concepts_df.to_csv(concepts_output_path, index=False)
    channels_df.to_csv(channels_output_path, index=False)

In [None]:
model_file = 'model.pth'
dataset_dir = 'dataset'
drive_result_path = '/content/drive/My Drive/Python Projects/POEM Pipeline Results/' + model_name + '_' + dataset_name + '_old'

dataset_file = dataset_dir + '.zip'
drive_dataset_dir = drive_result_path + '/' + dataset_file
!cp "$drive_dataset_dir" '.'
!unzip -qq -n $dataset_file -d '.'

model_path = drive_result_path + '/' + model_file
!cp "$model_path" '.'

result_dir = 'identification_results'
result_file = result_dir + '.zip'
result_path = drive_result_path + "/" + result_file
!cp "$result_path" '.'
!unzip -qq -n $result_file -d '.'

activation_images_path = 'activation_images'
if os.path.exists(activation_images_path):
    shutil.rmtree(activation_images_path)
os.makedirs(activation_images_path)

In [None]:
data_loader = load_data(dataset_dir)

class_titles = extract_class_titles(dataset_name)
classes = list(class_titles.keys())

tally_path = './' + result_dir + '/tally.csv'
thresholds_path = './' + result_dir + '/quantile.npy'
channel_concept_map, channel_thresh_map, channels, concepts = load_channels_data(tally_path, thresholds_path)

model = load_model(model_file)
model = model.to(device)

In [None]:
concepts_df, channels_df, acts_list, image_channels_counts_list = extract_concepts(model, data_loader, channel_concept_map, channel_thresh_map, channels, concepts)

if filter_concepts:
    concepts_df, channels_df, concepts, channels = filter_extracted_concepts(concepts_df, channels_df, channel_concept_map)

In [None]:
concepts_output_path = 'image_concepts.csv'
channels_output_path = 'image_channels.csv'

save_image_concepts_dataset(concepts_df, channels_df, image_channels_counts_list, acts_list, channel_concept_map, channel_thresh_map, concepts, channels, 
                            concepts_output_path, channels_output_path, activation_images_path)

In [None]:
!cp $concepts_output_path "$drive_result_path"
!cp $channels_output_path "$drive_result_path"

activation_images_file = activation_images_path + '.zip'
!zip -qq -r $activation_images_file $activation_images_path
!cp $activation_images_file '$drive_result_path'

packages_path = drive_result_path + '/attribution_packages.log'
save_imported_packages(packages_path)