<a href="https://colab.research.google.com/github/wayiwc/Simple-Stock-Challenge-/blob/master/Classifier_IoU_score_update.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Set up the environment

In [None]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit 
pip install ninja 2>> install.log

Collecting ninja
  Downloading ninja-1.10.2-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (108 kB)
Installing collected packages: ninja
Successfully installed ninja-1.10.2


In [None]:
try: 
    torch
    if not torch.cuda.is_available():
        print("Change runtime type to include a GPU.") 
except:
    pass # if GPU available pass

In [None]:
# Access file from drive 
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# use file directory on google drive
import sys

sys.path.append('/content/drive/MyDrive/Colab Notebooks/xai')

In [None]:
# autoreload imports
%load_ext autoreload
%autoreload 2

In [None]:
import torch, os, matplotlib.pyplot as plt
from netdissect import nethook, imgviz, show, segmenter, renormalize, upsample, tally, pbar, setting

torch.backends.cudnn.benchmark = True
torch.set_grad_enabled(False) # not training anything!

<torch.autograd.grad_mode.set_grad_enabled at 0x7fc9e6ac68d0>

### Load images dataset, classification model, segmentation model

In [None]:
# load the visual data
ds = setting.load_dataset('places', 'val')
iv = imgviz.ImageVisualizer(224, source=ds, percent_level=0.99)

Downloading http://gandissect.csail.mit.edu/datasets/places_val.zip to datasets/places_val.zip


  0%|          | 0/499777515 [00:00<?, ?it/s]

Extracting datasets/places_val.zip to datasets


In [None]:
# load classifyer model
model = setting.load_vgg16()
model = nethook.InstrumentedModel(model)
model.cuda()
renorm = renormalize.renormalizer(source=ds, target='zc')
ivsmall = imgviz.ImageVisualizer((56, 56), source=ds, percent_level=0.99)

Downloading: "http://gandissect.csail.mit.edu/models/vgg16_places365-6e38b568.pth" to /root/.cache/torch/hub/checkpoints/vgg16_places365-6e38b568.pth


  0%|          | 0.00/518M [00:00<?, ?B/s]

In [None]:
# load segmentation model
segmodel, seglabels, segcatlabels = setting.load_segmenter('netpqc')


Downloading https://dissect.csail.mit.edu/models/segmodel/upp-resnet50-upernet/decoder_epoch_40.pth
Downloading https://dissect.csail.mit.edu/models/segmodel/upp-resnet50-upernet/encoder_epoch_40.pth
Downloading https://dissect.csail.mit.edu/models/segmodel/upp-resnet50-upernet/labels.json
Downloading https://dissect.csail.mit.edu/models/segmodel/color-resnet18dilated-ppm_deepsup/decoder_epoch_20.pth
Downloading https://dissect.csail.mit.edu/models/segmodel/color-resnet18dilated-ppm_deepsup/encoder_epoch_20.pth
Downloading https://dissect.csail.mit.edu/models/segmodel/color-resnet18dilated-ppm_deepsup/labels.json
Loading weights for net_encoder
Loading weights for net_decoder


### Compute top unit activations 

Looking at the higher level in the convolutional network since higher/more abstract layers were found to correspond more strongly to segmentation model concepts.

In below, we examine activations of each unit (i.e., convolutional masks) to determine 99 per cent quantile for each unit in a chosen layer of the network across all images in the dataset.

In [None]:
layername = 'features.conv5_3'  
model.retain_layer(layername)

upfn = upsample.upsampler(
    target_shape=(56, 56),
    data_shape=(7, 7),
)

def flatten_activations(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])

rq = tally.tally_quantile(
    flatten_activations,
    dataset=ds,
    sample_size=1000,
    batch_size=100)

level_at_99 = rq.quantiles(0.99).cuda()[None,:,None,None]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


#### **Original approach** to computing segmentation model concepts presence for high unit activations

In [None]:
def compute_selected_segments(batch, *args):
    image_batch = batch.cuda()
    seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    iacts = (hacts > level_at_99).float() # indicator where > 0.99 percentile.
    return tally.conditional_samples(iacts, seg)

In [None]:
condi99 = tally.tally_conditional_mean(
    compute_selected_segments,
    dataset=ds,
    sample_size=1000)

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)


#### **Updated approach** to computing segmentation model concepts presence for high unit activations

In [None]:
def compute_selected_segments_semantic(batch, *args):
    image_batch = batch.cuda()
    seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    iacts = (hacts > level_at_99).float() # indicator where > 0.99 percentile.
    return tally.conditional_samples_semantic(iacts, seg, trust_rate=0.00001, places=True)

In [None]:
condi99_semantic = tally.tally_conditional_mean(
    compute_selected_segments_semantic,
    dataset=ds,
    sample_size=1000)



**The execution flow is structured as following:**

       tally_conditional_mean
       ---> compute_selected_segments
           ---> tally.conditional_samples

In `tally.conditional_samples` we identify all conditions matched for each image. Then we select high (significant) model unit activations in `compute_selected_segments`. Using both the matched coditions and activations `tally.tally_conditional_mean` records running mean activations overlapping that condition. 

For example, looking at the output of `compute_selected_segments` for an image (referred to as `batch` in the code) which has a format of `(condition, (sample, unit)-tensor)` tuple

`cond` - condition from seg model

`sample` is (sample, unit) - tensor of identified concept parts and unit activations  

```
cond  sample 

674   torch.Size([433, 512])
676   torch.Size([173, 512])
681   torch.Size([2, 512])
684   torch.Size([17, 512])

```

In this particular image conditions 674, 676, 681 and 684 were identified as present.
Specifically, condition 674 is composed of 433 samples, and the tensor of size
`torch.Size([433, 512])` records the response of all units to these samples. 
Responses or indications have value of either `0` or `1`. In general, not all units react/indicate to a particular sample, and more often, very few or none of the units indicate (match).

With each new image (processed individually in `tally_conditional_mean`) the algorithm updates running values of mean indicators for each present condition (either create new mean for condition never met before or update existing running record).

### Compute iou score

In general, `iou` score measures the similarity between the set of unit activations in the classification model and the set of segmentation model conditions that were identified in the precense of former network unit activations.

In this original approach the probability of semantic cencept (from segmentation model) presence is based on the segmentation model conditions that the segmentation model was trained to identify.

In [None]:
iou99 = tally.iou_from_conditional_indicator_mean(condi99)
iou99_semantic = tally.iou_from_conditional_indicator_mean(condi99_semantic)

In [None]:
unit_list = sorted(enumerate(zip(*iou99.max(1))), key=lambda k: -k[1][0])
unit_list_semantic = sorted(enumerate(zip(*iou99_semantic.max(1))), key=lambda k: -k[1][0])

In [None]:
iou_unit_label_99 = sorted([(
    unit, concept.item(), seglabels[concept], bestiou.item())
    for unit, (bestiou, concept) in enumerate(zip(*iou99.max(0)))],
    key=lambda x: -x[-1])

In [None]:
iou_unit_label_99_semantic = sorted([(
    unit, concept.item(), seglabels[concept], bestiou.item())
    for unit, (bestiou, concept) in enumerate(zip(*iou99_semantic.max(0)))],
    key=lambda x: -x[-1])

In [None]:
sample_related_concepts = ['person',
                          'person-t',
                          'hair',
                          'skin',
                          'window',
                          'window-b',
                          'window-t',
                          'glass',
                          'floor',
                          'floor-t',
                          'building-b',
                          'building-l',
                          'building-t',
                          'door',
                          'table',
                          'table-t',
                          'chair']

In [None]:
[(unit, concpt, seg_lab, score) for (unit, concpt, seg_lab, score) in iou_unit_label_99]

[(338, 25, 'mountain', 0.19951239228248596),
 (404, 348, 'car-t', 0.19666534662246704),
 (299, 1018, 'car-b', 0.1727447509765625),
 (265, 1695, 'skin', 0.15590347349643707),
 (143, 72, 'sea', 0.15147200226783752),
 (436, 1685, 'hair', 0.1439235955476761),
 (184, 10, 'road', 0.1388206034898758),
 (40, 1016, 'grass-b', 0.13860926032066345),
 (107, 1822, 'red', 0.13473007082939148),
 (311, 32, 'shelf', 0.13351619243621826),
 (191, 346, 'grass-t', 0.13177068531513214),
 (89, 1012, 'ceiling-b', 0.1312350481748581),
 (217, 9, 'window', 0.1266450583934784),
 (148, 370, 'water-t', 0.12060254067182541),
 (166, 35, 'water', 0.1203669011592865),
 (427, 240, 'washer', 0.11652055382728577),
 (87, 240, 'washer', 0.11437157541513443),
 (74, 240, 'washer', 0.11358436942100525),
 (84, 14, 'plant', 0.11278567463159561),
 (16, 37, 'flower', 0.11137192696332932),
 (149, 25, 'mountain', 0.11075198650360107),
 (71, 37, 'flower', 0.11018384248018265),
 (206, 348, 'car-t', 0.10992279648780823),
 (185, 291, 's

In [None]:
[(unit, concpt, seg_lab, score) for (unit, concpt, seg_lab, score) in iou_unit_label_99_semantic]

[(338, 25, 'mountain', 0.19951239228248596),
 (404, 348, 'car-t', 0.19666534662246704),
 (299, 1018, 'car-b', 0.1727447509765625),
 (265, 1695, 'skin', 0.1611902117729187),
 (143, 72, 'sea', 0.15147200226783752),
 (436, 1685, 'hair', 0.14759010076522827),
 (184, 10, 'road', 0.1388206034898758),
 (40, 1016, 'grass-b', 0.13860926032066345),
 (107, 1822, 'red', 0.13473007082939148),
 (217, 9, 'window', 0.13425813615322113),
 (311, 32, 'shelf', 0.13351619243621826),
 (191, 346, 'grass-t', 0.13177068531513214),
 (89, 1012, 'ceiling-b', 0.1312350481748581),
 (148, 370, 'water-t', 0.12060254067182541),
 (166, 35, 'water', 0.1203669011592865),
 (427, 240, 'washer', 0.11652055382728577),
 (87, 240, 'washer', 0.11437157541513443),
 (74, 240, 'washer', 0.11358436942100525),
 (84, 14, 'plant', 0.11278567463159561),
 (16, 37, 'flower', 0.11137192696332932),
 (149, 25, 'mountain', 0.11075198650360107),
 (71, 37, 'flower', 0.11018384248018265),
 (206, 348, 'car-t', 0.10992279648780823),
 (275, 1695, 

### Inject semantic knowledge into concept recognition

Our objective is to support the presence of a particular condition from the
segmentation model if a related condition has also been identified in the same image. Relation between the conditions can come from various sources as long as it can be quantitatively measured. 

We want to update the `iou` score, which represents the probability that high unit activations correspond to a meaningful semantic concept. `iou` is defined as intersection over the union of probability sets for 1) convolutional classification model feature map activations and 2) recognised segmentation model concepts. 

Hence we propose to extend the matching of the segmentation model concept by arguing that a score for an individual concept can be incremented if a related concept also has been recognised by a segmentation model (and decremented if a highly related semantic concept is absent). 

In order to mimic the relationship between concepts, in this experiment, we use a mock dictionary that connects conditions by assigning their relationship a numeric value between 0 and 1 where stronger/closer relationships have a higher relation score. A simplified example of such a dictionary could be:

```
{
    'building': {'wall': 0.7, 'roof': 0.75},
    'wall': {'building': 0.7, 'roof': 0.2},
    'roof': {'building': 0.75, 'wall': 0.2 }
}
```

Note that for simplicitly we assume that relationship scores are bidirectional, e.g. relation score between `building` and `wall` is the same as relation score between `wall` and `building`.



#### Examine in more detail code implementation from Bau ect. 

In order to understand where and how to influence the `iou` score computation using this relationship dictionary we look in more detail at the implementations of each of the following

       tally_conditional_mean
       ---> compute_selected_segments
           ---> tally.conditional_samples


In [None]:
def tally_conditional_mean(compute, dataset,
        sample_size=None, batch_size=1, cachefile=None, **kwargs):
    '''
    Computes conditional mean and variance for a large data sample that
    can be computed from a dataset.  The compute function should return a
    sequence of sample batch tuples (condition, (sample, unit)-tensor),
    one for each condition relevant to the batch.
    '''
    with torch.no_grad():
        args = dict(sample_size=sample_size)
        loader = make_loader(dataset, sample_size, batch_size, **kwargs)
        cv = runningstats.RunningConditionalVariance()
        for i, batch in enumerate(pbar(loader)):
            sample_set = call_compute(compute, batch)
            for cond, sample in sample_set:
                # Move uncommon conditional data to the cpu before collating.
                cv.add(cond, sample)
        # At the end, move all to the CPU
        cv.to_('cpu')
        save_cached_state(cachefile, cv, args)
        return cv

Broadly, `tally_conditional_mean` collects conditional running statistics (unit activations and segmentation concept indications). 

In [None]:
def compute_selected_segments(batch, *args):
    """
    The compute function should return a
    sequence of sample batch tuples (condition, (sample, unit)-tensor),
    one for each condition relevant to the batch.
    
    in this case batch is an image, conditions are relevant concepts from 
    segmentation model that are identified in the image
    """
    image_batch = batch.cuda()
    seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    iacts = (hacts > level_at_99).float() # indicator where > 0.99 percentile.
    return tally.conditional_samples(iacts, seg)

In [None]:
def conditional_samples(activations, segments):
    '''
    Helper function when defining generators for *_conditional tallies.
    Transforms a batch of activations and segmentations into a
    sequence of conditional statistics, i.e., activations that
    are at the same location as the segmentation label.
    Both activations and segments should be 4d tensors with
    the same sample, y, and x dimensions.  Segments can be
    a multilabel segmentation.  The zero segmentation value is
    assumed to be unused.

    Returns a generator for a sequence of (condition, (sample, unit)-tensor)
    listing every condition present in the segments, along with the
    set of activations overlapping that condition.  The activation tensor
    is 2d in (sample, unit) order, where sample is the number of samples
    with for the condition.
    '''
    channels = activations.shape[1]
    activations_by_channel = activations.permute(0, 2, 3, 1).contiguous()
    segcounts = segments.view(-1).bincount()
    conditions = (segcounts[1:].nonzero() + 1)[:, 0]
    def sample_generator():
        # First yield the full set of activations, unconditioned
        yield (0, activations_by_channel.view(-1, channels))  # yield==return, but function returns a generator
        # Then a set of activations for each condition present in the image
        for condition in conditions:
            mask = (segments == condition).max(1)[0][...,None]
            mask = mask.expand(activations_by_channel.shape)
            yield (condition.item(),
                    activations_by_channel[mask].view(-1, channels))
    return sample_generator()


`compute_selected_segments` filters activations from a chosen network layer and calls `tally.conditional_samples` that transforms a batch of activations and segmentations into a sequence of conditional statistics, i.e., activations that
are at the same location as the segmentation label.

In [None]:
# the update to conditional samples original implementation is discussed in the context of GAN disection

def conditional_samples_semantic(activations, segments, trust_rate=0.0001, places=False):
    '''
    Helper function when defining generators for *_conditional tallies.
    Transforms a batch of activations and segmentations into a
    sequence of conditional statistics, i.e., activations that
    are at the same location as the segmentation label.
    Both activations nad segments should be 4d tensors with
    the same sample, y, and x dimensions.  Segments can be
    a multilabel segmentation.  The zero segmentation value is
    assumed to be unused.

    Returns a generator for a sequence of (condition, (sample, unit)-tensor)
    listing every condition present in the segments, along with the
    set of activations overlapping that condition.  The activation tensor
    is 2d in (sample, unit) order, where sample is the number of samples
    with for the condition.
    '''
    channels = activations.shape[1]
    activations_by_channel = activations.permute(0, 2, 3, 1).contiguous()
    segcounts = segments.view(-1).bincount()
    conditions = (segcounts[1:].nonzero() + 1)[:, 0]
    conditions_list = [c.item() for c in conditions]
    if places:
    	related_cond = PLACES_RELATED_CONCEPTS_DICT
    else:
        related_cond = CHURCH_RELATED_CONCEPTS_DICT
    def sample_generator():
        # First yield the full set of activations, unconditioned
        yield (0, activations_by_channel.view(-1, channels))
        # Then a set of activations for each condition present in the image
        for condition in conditions:
            # 1. find all related conditions
            if condition.item() in related_cond.keys():
                  related_cond = related_cond[condition.item()]  # dict
                  related_present_cond = {
                                            c: r for c,r in related_cond.items()
                                            if c in conditions_list
                                         }
            else:
                related_present_cond = {}
            
            # 2. Compute update based on present related conditions and their indications
            update_semantic = 0
            for cond, rel in related_present_cond.items():
                other_mask = (segments == condition).max(1)[0][...,None]
                update_semantic += other_mask.count_nonzero()*rel
            
            mask = (segments == condition).max(1)[0][...,None]
            mask = mask.expand(activations_by_channel.shape)
            
            # 3. Scale the degree of update
            act_all = activations_by_channel[mask].view(-1, channels)
            if update_semantic > 0:
                update_semantic = update_semantic*trust_rate
                act_all = act_all.add(update_semantic)
            yield (condition.item(), act_all)
    return sample_generator()

In [None]:
PLACES_RELATED_CONCEPTS_DICT = {
    6: {
        1685: 0.8,
        1695: 0.7,
    },
    341: {
        1685: 0.8,
        1695: 0.7,
    },
    1685: {
        6: 0.8,
        341: 0.8,
        1695: 0.4,
    },
    1695:{
        6: 0.7,
        341: 0.8,
        1685: 0.4,
    },
    1010: {
        9: 0.9,
        1014: 0.9,
        344: 0.9,
        89: 0.4,
        3: 0.5,
        338: 0.5,
        16: 0.6,
    },
    675: {
        9: 0.9,
        1014: 0.9,
        344: 0.9,
        89: 0.4,
        3: 0.5,
        338: 0.5,
        16: 0.6,
    },
    340: {
        9: 0.9,
        1014: 0.9,
        344: 0.9,
        89: 0.4,
        3: 0.5,
        338: 0.5,
        16: 0.6,
    },
    9: {
        1010: 0.9,
        675: 0.9,
        340: 0.9,
        89: 0.9,
        3: 0.7,
        338: 0.7,
        16: 0.6,
    },
    1014: {
        1010: 0.9,
        675: 0.9,
        340: 0.9,
        89: 0.9,
        3: 0.7,
        338: 0.7,
        16: 0.6,
    },
    344: {
        1010: 0.9,
        675: 0.9,
        340: 0.9,
        89: 0.9,
        3: 0.7,
        338: 0.7,
        16: 0.6,
    },
    89: {
        1010: 0.4,
        675: 0.4,
        340: 0.4,
        9: 0.9,
        1014: 0.9,
        344: 0.9,
        3: 0.4,
        338: 0.4,
    },
    3: {
       1010: 0.5,
       675: 0.5,
       340: 0.5,
       89: 0.4,
       9: 0.7,
       1014: 0.7,
       344: 0.7, 
    },
    338: {
       1010: 0.5,
       675: 0.5,
       340: 0.5,
       89: 0.4,
       9: 0.7,
       1014: 0.7,
       344: 0.7, 
    },
    16: {
        1010: 0.6,
        675: 0.6,
        340: 0.6,
        9: 0.6,
        1014: 0.6,
        344: 0.6,
    },
    8: {
        12: 0.7,
    },
    343: {
        12: 0.7,
    },
    12: {
        8: 0.7,
        343: 0.7,
    },
}