In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
%%bash
# From PyTorch 1.11, the THC namespace is removed which prevents the PreciseRoIPooling library used in segmentation model to be built successfully
pip install torch==1.10.1+cu102 torchvision==0.11.2+cu102 -f https://download.pytorch.org/whl/torch_stable.html

!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit 
pip install ninja 2>> install.log
git clone https://github.com/davidbau/dissect.git dissect 2>> install.log
#pip list -v >> identification_packages.log

In [None]:
try: # set up path
    import google.colab, sys, torch
    sys.path.append('/content/dissect')
    if not torch.cuda.is_available():
        print("Change runtime type to include a GPU.")  
except:
    pass

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib as mpl
import IPython
from importlib import reload

mpl.rcParams['lines.linewidth'] = 0.25
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.linewidth'] = 0.25

In [None]:
import torch, os, pickle, json, numpy, netdissect

from collections import OrderedDict
from torch import nn
from torchvision import datasets, transforms, models
from netdissect import pbar, nethook, renormalize, parallelfolder, upsample, tally, imgviz, imgsave, show
from netdissect.easydict import EasyDict
from experiment import dissect_experiment as experiment

torch.backends.cudnn.benchmark = True

In [None]:
settings_path = '/content/drive/My Drive/Python Projects/POEM Pipeline Results/settings.pkl'
with open(settings_path, 'rb') as f:
    settings = pickle.load(f)

current_setting_path = '/content/drive/My Drive/Python Projects/POEM Pipeline Results/current_setting.txt'
with open(current_setting_path, 'r') as f:
    current_setting_title = f.read().splitlines()[0]
    print('Current setting:', current_setting_title)

model_settings = settings['model_settings']
model_dataset_settings = settings['model_dataset_settings']

title_parts = current_setting_title.split('_')
model_name = title_parts[0]
dataset_name = '_'.join(title_parts[1:]) 
dataset_pure_name = dataset_name.split('_')[0]

current_setting = model_dataset_settings[current_setting_title] 
num_classes = current_setting['num_classes']
use_dissection_models = False
target_layer = model_settings[model_name]['target_layer'] if not use_dissection_models else experiment.instrumented_layername(EasyDict(model=model_name))

exclude_similar_concepts = True   # It is better to exclude a concept which is very similar to the dataset classes; e.g. laptop or computer concepts in laptop vs mobile dataset
excluded_concepts = current_setting['excluded_concepts'] if exclude_similar_concepts and ('excluded_concepts' in current_setting) else []

seg_model_name = 'netpc'   # 'netpqc'
image_size = 224

min_iou = 0.04
activation_high_thresh = 0.99
activation_low_thresh = 0.7

norm_mean = (0.485, 0.456, 0.406)
norm_std = (0.229, 0.224, 0.225)

args = EasyDict(model=model_name, dataset=dataset_name, seg=seg_model_name, layer=target_layer, quantile=activation_high_thresh)

In [None]:
def save_imported_packages(packages_path):

    # Saving imported packages and their versions: 
    import sys
    modules_info = []

    for module in sys.modules:
        if len(module.split('.')) > 1:   # ignoring subpackages
            continue

        try:
            modules_info.append((module, sys.modules[module].__version__))
        except:
            try:
                if type(sys.modules[module].version) is str:
                    modules_info.append((module, sys.modules[module].version))
                else:
                    modules_info.append((module, sys.modules[module].version()))
            except:
                try:
                    modules_info.append((module, sys.modules[module].VERSION))
                except:
                    pass

    modules_info.sort(key=lambda x: x[0])
    with open(packages_path, 'w') as f:
        for m in modules_info:
            f.write('{} {}\n'.format(m[0], m[1]))

In [None]:
# Method used previously to remove excluded concepts from the segmenter's list of concepts, but seems not to be effective in practice, 
# because a mismatch between the list of concepts and the output classes of the pretrained segmenter model occurs. 
def exclude_concepts_from_segmenter (seglabels_path):

    f = open(seglabels_path)
    labels_data = json.load(f)
    f.close()
    if 'object' not in labels_data:
        return

    new_objects = []
    objects = labels_data['object']
    for obj in objects:
        if obj not in excluded_concepts:
            new_objects.append(obj)

    labels_data['object'] = new_objects

    with open(seglabels_path, 'w') as f:
        json.dump(labels_data, f)

In [None]:
def get_target_layer_name (final=False):

    if final and use_dissection_models:
        return experiment.instrumented_layername(args)
    else:
        if model_name is 'resnet18':
            return 'layer4.1.conv2'
        elif model_name is 'resnet50':
            return 'layer4.2.conv3'
        elif model_name is 'vgg16':
            return 'features.conv5_3'
        elif model_name is 'alexnet':
            return 'conv5'

    return None

In [None]:
def vgg16_model (*args, **kwargs):

    # A version of vgg16 model where layers are given their research names: 
    model = models.vgg16(*args, **kwargs)
    model.features = nn.Sequential(OrderedDict(zip([
        'conv1_1', 'relu1_1',
        'conv1_2', 'relu1_2',
        'pool1',
        'conv2_1', 'relu2_1',
        'conv2_2', 'relu2_2',
        'pool2',
        'conv3_1', 'relu3_1',
        'conv3_2', 'relu3_2',
        'conv3_3', 'relu3_3',
        'pool3',
        'conv4_1', 'relu4_1',
        'conv4_2', 'relu4_2',
        'conv4_3', 'relu4_3',
        'pool4',
        'conv5_1', 'relu5_1',
        'conv5_2', 'relu5_2',
        'conv5_3', 'relu5_3',
        'pool5'],
        model.features)))

    model.classifier = nn.Sequential(OrderedDict(zip([
        'fc6', 'relu6',
        'drop6',
        'fc7', 'relu7',
        'drop7',
        'fc8a'],
        model.classifier)))

    return model

In [None]:
def load_model (model_file=None):

    if use_dissection_models:
        model = experiment.load_model(args)
    else:
        if model_name == 'vgg16':
            model = vgg16_model(num_classes=num_classes)
        else:
            model = models.__dict__[model_name](num_classes=num_classes)
        checkpoint = torch.load(model_file)
        statedict = checkpoint
        if 'state_dict' in checkpoint:
            statedict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
        model.load_state_dict(statedict)
        model = nethook.InstrumentedModel(model).cuda().eval()

    return model

In [None]:
def load_dataset (dataset_dir=None):
    
    if use_dissection_models:
        return experiment.load_dataset(args)
    
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=norm_mean, std=norm_std)
    ])

    dataset = parallelfolder.ParallelImageFolders([dataset_dir], classification=True, shuffle=True, transform=transform)
    print('Processing {} data examples from these classes: {}'.format(len(dataset), dataset.classes))

    return dataset

In [None]:
def show_sample_images (model, dataset, sample_batch, batch_indices, classlabels):

    truth = [classlabels[dataset[i][1]] for i in batch_indices]
    preds = model(sample_batch.cuda()).max(1)[1]
    imgs = [renormalize.as_image(t, source=dataset) for t in sample_batch]
    prednames = [classlabels[p.item()] for p in preds]
    show([[img, 'pred: ' + pred, 'true: ' + gt] for img, pred, gt in zip(imgs, prednames, truth)])

In [None]:
def show_sample_segmentations (segmodel, dataset, sample_batch, renorm):

    iv = imgviz.ImageVisualizer(120, source=dataset)
    seg = segmodel.segment_batch(renorm(sample_batch).cuda(), downsample=4)

    torch.set_printoptions(profile="full")
    print('seg.shape:', seg.shape)
    #print(seg[0])

    show([(iv.image(sample_batch[i]), iv.segmentation(seg[i,0]),
                iv.segment_key(seg[i,0], segmodel))
                for i in range(len(seg))])

In [None]:
def show_sample_heatmaps (model, dataset, sample_batch):

    acts = model.retained_layer(target_layer).cpu()
    print('acts.shape:', acts.shape)
    print('acts_reshaped.shape:', acts.view(acts.shape[0], acts.shape[1], -1).shape)

    ivsmall = imgviz.ImageVisualizer((100, 100), source=dataset)
    display(show.blocks(
        [[[ivsmall.masked_image(sample_batch[0], acts, (0, u), percent_level=activation_high_thresh)],
        [ivsmall.heatmap(acts, (0, u), mode='nearest')]] for u in range(min(acts.shape[1], 12))]
    ))

In [None]:
def show_sample_image_activation (model, dataset, rq, topk, classlabels, sample_unit_number, sample_image_index):

    print(topk.result()[1][sample_unit_number][sample_image_index], dataset.images[topk.result()[1][sample_unit_number][sample_image_index]])
    image_number = topk.result()[1][sample_unit_number][sample_image_index].item()

    iv = imgviz.ImageVisualizer((224, 224), source=dataset, quantiles=rq,
            level=rq.quantiles(activation_high_thresh))
    batch = torch.cat([dataset[i][0][None,...] for i in [image_number]])
    truth = [classlabels[dataset[i][1]] for i in [image_number]]
    preds = model(batch.cuda()).max(1)[1]
    imgs = [renormalize.as_image(t, source=dataset) for t in batch]
    prednames = [classlabels[p.item()] for p in preds]
    acts = model.retained_layer(target_layer)
    print('acts.shape:', acts.shape)
    print('acts_reshaped.shape:', acts.view(acts.shape[0], acts.shape[1], -1).shape)
    #print('acts_reshaped.max():', acts.view(acts.shape[0], acts.shape[1], -1).max(2)[0])
    image_acts = acts[0,sample_unit_number].cpu().numpy()
    unit_quant = rq.quantiles(activation_high_thresh)[sample_unit_number].item()
    #print(unit_quant)
    #print(image_acts)
    print('number of activations higher than quantile {}: {}'.format(unit_quant, numpy.sum(image_acts > unit_quant)))

    show([[img, 'pred: ' + pred, 'true: ' + gt] for img, pred, gt in zip(imgs, prednames, truth)])
    show([[iv.masked_image(batch[0], acts, (0, sample_unit_number))]])
    show([[iv.heatmap(acts, (0, sample_unit_number), mode='nearest')]])

In [None]:
def save_top_channel_images (model, dataset, rq, topk):

    pbar.descnext('unit_images')
    iv = imgviz.ImageVisualizer((100, 100), source=dataset, quantiles=rq,
            level=rq.quantiles(activation_high_thresh))
    
    def compute_acts(image_batch, label_batch):
        image_batch = image_batch.cuda()
        _ = model(image_batch)
        acts_batch = model.retained_layer(target_layer)
        return acts_batch

    unit_images = iv.masked_images_for_topk(
            compute_acts, dataset, topk, k=5, num_workers=2, pin_memory=True,  #num_workers=30
            cachefile=resfile('top5images.npz'))

    image_row_width = 5
    pbar.descnext('saving images')
    imgsave.save_image_set(unit_images, resfile('image/unit%d.jpg'),
            sourcefile=resfile('top%dimages.npz' % image_row_width))
    
    return unit_images

In [None]:
def show_sample_channel_images (unit_images, sample_unit_numbers, unit_label_high=None):

    for u in sample_unit_numbers:
        if unit_label_high is None:
            print('unit %d' % u)
        else:
            print('unit %d, label %s, iou %.3f' % (u, unit_label_high[u][1], unit_label_high[u][3]))
        display(unit_images[u])

In [None]:
# Computes and keeps channel activations for all images in a way that any activation quantile for each channel can be computed easily
def compute_tally_quantile (model, dataset, upfn, sample_size):
    
    pbar.descnext('rq')
    def compute_samples(batch, *args):
        image_batch = batch.cuda()
        _ = model(image_batch)
        acts = model.retained_layer(target_layer)
        hacts = upfn(acts)
        return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])

    rq = tally.tally_quantile(compute_samples, dataset,
                            sample_size=sample_size,
                            r=8192,
                            num_workers=2,  #100
                            pin_memory=True,
                            cachefile=resfile('rq.npz'))
    return rq

In [None]:
# Computes and keeps maximum of channel activations for all images, so that the top k images with the highest maximum activation value can be identified for each channel
def compute_tally_topk (model, dataset, sample_size):

    pbar.descnext('topk')
    def compute_image_max(batch, *args):
        image_batch = batch.cuda()
        _ = model(image_batch)
        acts = model.retained_layer(target_layer)
        acts = acts.view(acts.shape[0], acts.shape[1], -1)
        acts = acts.max(2)[0]
        return acts

    topk = tally.tally_topk(compute_image_max, dataset, sample_size=sample_size,
            batch_size=50, num_workers=2, pin_memory=True,  #num_workers=30
            cachefile=resfile('topk.npz'))
    return topk

In [None]:
# Computes the best concepts matching each channel based on IoUs between concept segmentations and channel activations
def compute_top_channel_concepts (model, segmodel, upfn, dataset, rq, seglabels, segcatlabels, sample_size, renorm):

    # "level_high" was formerly named "level_at_99"
    # "condi_high" was formerly named "condi99"
    # "iou_high" was formerly named "iou_99"
    # "unit_label_high" was formerly named "unit_label_99"

    # Getting the target quantile values of channels: 
    level_high = rq.quantiles(activation_high_thresh).cuda()[None,:,None,None]
    level_low = rq.quantiles(activation_low_thresh).cuda()[None,:,None,None]
    
    # Computing the overlap between all the channel activations and all the image segmentations: 
    def compute_conditional_indicator(batch, *args):
        image_batch = batch.cuda()
        seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
        _ = model(image_batch)
        acts = model.retained_layer(target_layer)
        hacts = upfn(acts)
        iacts = (hacts > level_high).float() # indicator
        return tally.conditional_samples(iacts, seg)

    pbar.descnext('condi_high')
    condi_high = tally.tally_conditional_mean(compute_conditional_indicator,
            dataset, sample_size=sample_size,
            num_workers=10, pin_memory=True,  #num_workers=3
            cachefile=resfile('condi_high.npz'))
    
    # Computing the IoU between each channel and all the concepts: 
    iou_high = tally.iou_from_conditional_indicator_mean(condi_high)

    # Identifying the concept with max IoU for each channel: 
    # unit_label_high = [
    #         (concept.item(), seglabels[concept], segcatlabels[concept], bestiou.item())
    #         for (bestiou, concept) in zip(*iou_high.max(0))]

    unit_label_high = []
    for i,row in enumerate(iou_high.t()):
        top_ious, top_concepts = row.topk(k=3)
        top_list = [(con.item(), seglabels[con], segcatlabels[con], iou.item()) for con,iou in zip(top_concepts, top_ious)]
        top_item = top_list[0]
        top_label = top_item[1]

        # Though not ideal, this is the best we can do to exclude concepts which are very similar to the dataset classes: 
        if (len(excluded_concepts) > 0) and (top_label in excluded_concepts):
            print('Channel {} top concepts: {}'.format(i, top_list))
            top_item = (0, '-', ('-','-'), 0.0)
            for j in range(1,len(top_list)):
                item = top_list[j]
                label = item[1]
                iou = item[3]
                if label not in excluded_concepts:
                    top_item = item
                    break
            print('Because top concept {} is among the excluded concepts, concept {} with iou {} is selected for channel {}'
                .format(top_label, top_item[1], top_item[3], i))
        unit_label_high.append(top_item)

    label_list = [labelcat for concept, label, labelcat, iou in unit_label_high if iou > min_iou]

    print(len(unit_label_high))
    print(unit_label_high)
    
    return unit_label_high, label_list, level_high, level_low

In [None]:
def save_final_data (unit_label_high, label_list, level_high, level_low):

    display(IPython.display.SVG(experiment.graph_conceptcatlist(label_list)))
    experiment.save_conceptcat_graph(resfile('concepts_high.svg'), label_list)

    print('level_high.shape:', level_high.shape)
    print('level_low.shape:', level_low.shape)

    high_quantiles = level_high.view(-1).cpu().numpy()
    low_quantiles = level_low.view(-1).cpu().numpy()

    print('high_quantiles.shape:', high_quantiles.shape)
    print('low_quantiles.shape:', low_quantiles.shape)

    experiment.dump_json_file(resfile('report.json'), dict(
            header=dict(
                name='%s %s %s' % (model_name, dataset_name, seg_model_name),
                image='concepts_high.svg'),
            units=[
                dict(image='image/unit%d.jpg' % u,
                    unit=u, iou=iou, label=label, cat=labelcat[1], high_thresh=float(high_quantiles[u]), low_thresh=float(low_quantiles[u]))
                for u, (concept, label, labelcat, iou)
                in enumerate(unit_label_high)])
            )
    
    experiment.copy_static_file('report.html', resfile('report.html'))

    # print('level_high.shape:', level_high.shape)
    # quantiles = level_high.view(-1).cpu().numpy()
    # print('quantiles.shape:', quantiles.shape)
    numpy.save(resfile('channel_quantiles.npy'), high_quantiles)

    print('Channel high quantiles:')
    for i,q in enumerate(list(high_quantiles)):
        print('{}: {}'.format(i,q))

    print('Channel low quantiles:')
    for i,q in enumerate(list(low_quantiles)):
        print('{}: {}'.format(i,q))

In [None]:
model_file = 'model.pth'   # model_name + '_' + dataset_name + '.pth'
dataset_dir = 'dataset'   # dataset_name
result_dir = 'identification_results'
drive_result_path = '/content/drive/My Drive/Python Projects/POEM Pipeline Results/' + model_name + '_' + dataset_name

def resfile(f):
    return os.path.join(result_dir, f)

if not use_dissection_models:
    dataset_file = dataset_dir + '.zip'
    drive_dataset_dir = drive_result_path + '/' + dataset_file   # '/content/drive/My Drive/Python Projects/Other Data/' + dataset_file
    !cp "$drive_dataset_dir" '.'
    !unzip -qq -n $dataset_file -d '.'

    drive_model_path = drive_result_path + '/' + model_file   # "/content/drive/My Drive/Python Projects/Network Dissection/NetDissect-Lite-master/zoo/" + model_file
    !cp "$drive_model_path" '.'

    # if dataset_pure_name in ['imagenette', 'imagewoof', 'places']:
    #     dataset_dir = dataset_name + '/train'

    #     # Removing any train.txt or val.txt file that may interfere with loading the dataset properly: 
    #     for f in os.listdir(dataset_name):
    #         if f.endswith(".txt"):
    #             os.remove(os.path.join(dataset_name, f))

# Optional: loading the segmenter models to avoid downloading them from netdissect server: 
segmodel_dir = 'segmodel'
segmodel_file = segmodel_dir + '.zip'
drive_segmodel_path = '/content/drive/My Drive/Python Projects/Pretrained Models/' + segmodel_file
target_segmodel_dir = 'datasets'
if not os.path.exists(target_segmodel_dir):
    os.makedirs(target_segmodel_dir)

target_segmodel_file = target_segmodel_dir + '/' + segmodel_file
!cp "$drive_segmodel_path" $target_segmodel_dir
!unzip -qq -n $target_segmodel_file -d $target_segmodel_dir

# if exclude_similar_concepts and (len(excluded_concepts) > 0):
#     seglabels_path = target_segmodel_dir + '/' + segmodel_dir + '/upp-resnet50-upernet/labels.json'
#     exclude_concepts_from_segmenter(seglabels_path)

In [None]:
model = load_model(model_file)
model.retain_layer(target_layer)

dataset = load_dataset(dataset_dir)
classlabels = dataset.classes
sample_size = len(dataset)

print('Inspecting layer %s of model %s on dataset %s' % (target_layer, model_name, dataset_name))
print(model)

In [None]:
upfn = experiment.make_upfn(args, dataset, model, target_layer)
renorm = renormalize.renormalizer(dataset, target='zc')
segmodel, seglabels, segcatlabels = experiment.setting.load_segmenter(seg_model_name)

print('Segmentation labels:')
for i,lbl in enumerate(seglabels):
    print('{}: {} from category {}'.format(i, lbl, segcatlabels[i]))

In [None]:

batch_indices = [10, 20, 30, 40, 50, 60, 70, 80]
batch = torch.cat([dataset[i][0][None,...] for i in batch_indices])
show_sample_images(model, dataset, batch, batch_indices, classlabels)


In [None]:

show_sample_segmentations(segmodel, dataset, batch, renorm)


In [None]:

show_sample_heatmaps(model, dataset, batch)


In [None]:

rq = compute_tally_quantile(model, dataset, upfn, sample_size)


In [None]:

topk = compute_tally_topk(model, dataset, sample_size)


In [None]:

show_sample_image_activation(model, dataset, rq, topk, classlabels, sample_unit_number=2, sample_image_index=0)


In [None]:

unit_images = save_top_channel_images(model, dataset, rq, topk)
sample_unit_numbers = [10, 20, 30, 40]
show_sample_channel_images(unit_images, sample_unit_numbers)


In [None]:

unit_label_high, label_list, level_high, level_low = compute_top_channel_concepts(model, segmodel, upfn, dataset, rq, seglabels, segcatlabels, sample_size, renorm)


In [None]:

show_sample_channel_images(unit_images, sample_unit_numbers, unit_label_high)


In [None]:

save_final_data(unit_label_high, label_list, level_high, level_low)


In [None]:
result_file = result_dir + '.zip'
!zip -qq -r $result_file $result_dir
!cp $result_file '$drive_result_path'

packages_path = drive_result_path + '/identification_packages.log'
save_imported_packages(packages_path)