In [1]:
!python --version

Python 3.9.4


In [2]:
# Only run this if you are in google colab

# %%bash
# git clone https://github.com/fengqingthu/CLIP_Steering.git
# git clone https://github.com/openai/CLIP.git

In [3]:
# # Only run this if you are in google colab
# %%bash

# pip install ninja 2>> install.log
# git clone https://github.com/SIDN-IAP/global-model-repr.git tutorial_code 2>> install.log

In [4]:
# Only run this if you are in google colab
# %%bash
# pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
# pip install \
#   pytorch-pretrained-biggan \
#   ftfy \
#   regex \
#   tqdm \
#   git+https://github.com/openai/CLIP.git \
#   click \
#   requests \
#   pyspng \
#   ninja \
#   imageio-ffmpeg==0.4.3 \
#   scipy

In [4]:
# import google.colab
import sys, torch

sys.path.append('tutorial_code')
if not torch.cuda.is_available():
    print("Change runtime type to include a GPU.")  

## Import all dependencies

Make sure to add CLIP_steering into the path so that we can use the GANAlyze tools.

In [5]:
import logging
import os
import pathlib

import clip
import IPython.display
import numpy as np
import torch.nn.functional as F
import torchvision
import torch.hub
from netdissect import proggan
from PIL import Image

sys.path.insert(0, "./CLIP_Steering")
try:
    import ganalyze_common_utils as common
    import ganalyze_transformations as transformations
except ImportError:
    print("Could not import ganalyze_common_utils or ganalyze_transformations")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Running pytorch', torch.__version__, 'using', device.type)

logging.basicConfig(
    format="%(asctime)s %(levelname)-8s %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)

logger = logging.getLogger(__name__)

Running pytorch 1.9.1+cu111 using cuda


# Import and test the Pro GAN model

The GAN generator is just a function z->x that transforms random z to realistic images x.

To generate images, all we need is a source of random z.  Let's make a micro dataset with a few random z.

In [6]:
import torchvision
import torch.hub
from netdissect import nethook, proggan

n = 'proggan_bedroom-d8a89ff1.pth'
# n = 'proggan_churchoutdoor-7e701dd5.pth'
# n = 'proggan_conferenceroom-21e85882.pth'
# n = 'proggan_diningroom-3aa0ab80.pth'
# n = 'proggan_kitchen-67f1e16c.pth'
# n = 'proggan_livingroom-5ef336dd.pth'
# n = 'proggan_restaurant-b8578299.pth'

url = 'http://gandissect.csail.mit.edu/models/' + n
try:
    sd = torch.hub.load_state_dict_from_url(url) # pytorch 1.1
except:
    sd = torch.hub.model_zoo.load_url(url) # pytorch 1.0
proggan_model = proggan.from_state_dict(sd).to(device)
proggan_model

Downloading: "http://gandissect.csail.mit.edu/models/proggan_bedroom-d8a89ff1.pth" to /home/ubuntu/.cache/torch/hub/checkpoints/proggan_bedroom-d8a89ff1.pth


  0%|          | 0.00/70.0M [00:00<?, ?B/s]

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
from netdissect import zdataset,renormalize

SAMPLE_SIZE = 50 # Increase this for better results (but slower to run)
zds = zdataset.z_dataset_for_model(proggan_model, size=SAMPLE_SIZE, seed=5555)
len(zds), zds[0][0].shape

# Import and test CLIP model

Check that the CLIP model works fine. We import CLIP by installing it through PIP. We cloned CLIP's repo to get the CLIP/CLIP.png test image.

In [None]:
clip_model, preprocess = clip.load("ViT-B/32", device=device)
clip_model.eval()
clip_model.to(device)

image = preprocess(Image.open("CLIP/CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a livingroom", "a bedroom", "a church"]).to(device)

with torch.no_grad():
    image_features = clip_model.encode_image(image)
    text_features = clip_model.encode_text(text)

    logits_per_image, logits_per_text = clip_model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs) 

In [None]:
torch.cuda.empty_cache()
latent_space_dim = zds[0][0][:,0,0].shape[0]
context_length = clip_model.context_length
vocab_size = clip_model.vocab_size

print(
    "Model parameters:",
    f"{np.sum([int(np.prod(p.shape)) for p in clip_model.parameters()]):,}",
)
print("Context length:", context_length)
print("Vocab size:", vocab_size)
print("Latent space dimension:", latent_space_dim)

In [None]:
#@title Helper functions
from typing import List, Optional, Tuple
from PIL import Image


def show_images(
        images: list[Image.Image],
        resize: Optional[Tuple[int, int]] = None
    ):
    """Show a list of images in a row."""
    images = [np.array(img) for img in images]
    images = np.concatenate(images, axis=1)
    images = Image.fromarray(images)

    if resize:
        images.thumbnail(resize)

    IPython.display.display(images)


def show_and_save_images(
    images: list[Image.Image], batch: int, path: str, variant: str = "original"
):
    show_images(images)

    if not os.path.exists(path):
        os.makedirs(path)

    for i, img in enumerate(images):
        img.save(f"{path}/image_{batch}_{i}_{variant}.png")


def show_gan_results(gan_results: List[List[Tuple[Image.Image, np.ndarray]]]):
    for batch_results in gan_results:
        batch_size = len(batch_results[0][0])
      
        for i in range(batch_size):
            steering_images = [res[0][i] for res in batch_results]
            steering_scores = np.stack(
                [res[1][i].detach().cpu().numpy() for res in batch_results]
            ).tolist()
            print(steering_scores)
            show_images(steering_images, resize=(1024, 256))

def get_clip_probs(image_inputs, text_features, model, attribute_index=0):
    image_inputs = torch.stack([preprocess(img.resize((512, 512))) for img in image_inputs]).to(device)
    image_features = model.encode_image(image_inputs).float()

    # normalized features
    image_features = image_features / image_features.norm(dim=1, keepdim=True)
    text_features = text_features / text_features.norm(dim=1, keepdim=True)

    # cosine similarity as logits
    logit_scale = model.logit_scale.exp()
    logits_per_image = logit_scale * image_features @ text_features.t()

    clip_probs = logits_per_image.softmax(dim=-1)

    return clip_probs.narrow(dim=-1, start=attribute_index, length=1).squeeze(dim=-1)

def show_gan_results(gan_results: list):
    for batch_results in gan_results:
        batch_size = len(batch_results[0][0])
      
        for i in range(batch_size):
            steering_images = [res[0][i] for res in batch_results]
            steering_scores = np.stack(
                [res[1][i].detach().cpu().numpy() for res in batch_results]
            ).tolist()
            print(steering_scores)
            show_images(steering_images, resize=(1024, 256))


def make_images_and_probs(
    model, zdataset, clip_model, encoded_text, attribute_index=0
):
    gan_images = []
    for z in zdataset:
      gan_output = model(z[None,...])[0]      
      gan_images.append(renormalize.as_image(gan_output))
      
    clip_probs = get_clip_probs(gan_images, encoded_text, clip_model, attribute_index)

    return gan_images, clip_probs

## Use CLIP to extract target text attributes

Use the CLIP model to extract target text attributes for steering the output of a GAN. We use the CLIP tokenizer and encoder to extract text features and normalize them.

The resulting text features are used later to steer the GAN output towards the desired attribute.

In [7]:
# Extract text features for clip

# Here is how we specify the desired attributes
attributes = ["a luxurious bedroom", "a bedroom"]
attribute_index = 0  # which attribute do we want to maximize
text_descriptions = [f"{label}" for label in attributes]


with torch.no_grad():
    text_tokens = clip.tokenize(text_descriptions).to(device)
    text_features = clip_model.encode_text(text_tokens).float()
    # text_features = F.normalize(text_features, p=2, dim=-1)

text_features.shape

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

# Declare the GAN streering model

This is the model in charge of changing the vectors `z` so that it aligns with the objective declared in `text_features`.

In [None]:
transformation = transformations.OneDirection(latent_space_dim, vocab_size)
transformation = transformation.to(device)

In [None]:
checkpoint_dir = f"checkpoints/results_maximize_classifier_probability"
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_name = f"pytorch_model_progran_steering_{attributes[attribute_index]}_final.pth"

# Traning steps

Now we're ready to train the GAN steering model called `transformation`. The overall algorithm looks like this:

1. Generate the noise and class vectors for a given number of training samples.
2. For each batch, generate the GAN images and calculate the CLIP scores comparing the images features correlation with the target text features.
3. Use the `transformation` model to adjust the original noise. Repeat step 2 for the transformed noise `z_transformed`.
4. Compare the scoring output, and make the model optimize to minimize the difference between the target scores and the resulting one after transforming the original noise.



In [None]:
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader

optimizer = torch.optim.Adam(
    transformation.parameters(), lr=0.0002
)  # as specified in GANalyze
losses = common.AverageMeter(name="Loss")

#  training settings
optim_iter = 0
batch_size = 64  # Do not change
train_alpha_a = -0.5  # Lower limit for step sizes
train_alpha_b = 0.5  # Upper limit for step sizes
#
# Number of samples to train for # Ganalyze uses 400,000 samples.
# Use smaller number for testing.
#
num_samples = 90_000

attribute_index = 0
checkpoint_dir = f"checkpoints/results_maximize_{attributes[attribute_index]}_probability"
pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)

zds = zdataset.z_dataset_for_model(proggan_model, size=num_samples, seed=5555)
zds_dataloader = DataLoader(zds, batch_size=batch_size, shuffle=True)

progress = tqdm("Training", total=len(range(0, num_samples, batch_size)))

# loop over data batches
for batch_start, z_batch in enumerate(zds_dataloader):
    z_batch = z_batch[0].squeeze().to(device)

    step_sizes = (train_alpha_b - train_alpha_a) * np.random.random(
        size=(batch_size)
    ) + train_alpha_a  # sample step_sizes

    step_sizes_broadcast = np.repeat(step_sizes, latent_space_dim).reshape(
        [batch_size, latent_space_dim]
    )
    step_sizes_broadcast = (
        torch.from_numpy(step_sizes_broadcast).type(torch.FloatTensor).to(device)
    )

    #
    # Generate the original images and get their clip scores
    #
    gan_images, out_scores = make_images_and_probs(
        model=proggan_model,
        zdataset = z_batch,
        clip_model=clip_model,
        encoded_text=text_features,
        attribute_index=attribute_index,
    )

    # TODO: ignore z vectors with less confident clip scores
    target_scores = torch.clip(
        out_scores + torch.from_numpy(step_sizes).to(device).float(),
        0.0,
        1.0
    )

    #
    # Transform the z vector and get the clip scores for the transformed images
    #
    zb_transformed = transformation.transform(z_batch, None, step_sizes = step_sizes_broadcast)
    gan_images_transformed, out_scores_transformed = make_images_and_probs(
        model=proggan_model,
        zdataset = zb_transformed,
        clip_model=clip_model,
        encoded_text=text_features,
        attribute_index=attribute_index,
    )

    #
    # Compute loss and backpropagate
    #
    loss = transformation.criterion(out_scores_transformed, target_scores)

    loss.backward()
    optimizer.step()

    #
    # Print and save intermediate results
    #
    losses.update(loss.item(), batch_size)
    if optim_iter % 50 == 0:
        print(
            f"[Maximizing score for {attributes[attribute_index]}] "
            f"Progress: [{batch_start}/{num_samples}] {losses}"
        )

        print(
            f"[Scores] "
            f"Target: {target_scores} Out: {out_scores_transformed}"
        )

    if optim_iter % 200 == 0:
        batch_checkpoint_name = f"pytorch_model_progran_{batch_start}.pth"
        torch.save(
            transformation.state_dict(),
            os.path.join(checkpoint_dir, batch_checkpoint_name)
        )

        # plot and save sample images
        # show_and_save_images(gan_images, batch_start, checkpoint_dir)
        # show_and_save_images(
        #     gan_images_transformed, batch_start, checkpoint_dir, "transformed"
        # )

    optim_iter = optim_iter + 1
    progress.update(1)


In [None]:
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
torch.save(
    transformation.state_dict(),
    checkpoint_path
)

In [None]:
# Only run this if you are in google colab
# !rm -rf Weights
# !mkdir -p Weights

# # mount files from google drive
# # and follow the steps here
# # from google.colab import drive
# # drive.mount('/content/gdrive')

# import shutil

# shutil.copy(checkpoint_path, f"/content/gdrive/MyDrive/Colab Notebooks/Sabrina <> Leo/Weights/{checkpoint_name}")
# shutil.copy(checkpoint_path, f"Weights/{checkpoint_name}")

# Testing steps

Now that the model is trained, we can test it and see how the output changes when incrementing and decrementing the step_sizes on our `z` vectors.

We take the latest saved checkpoint located at `{checkpoint_dir}/pytorch_model_final.pth`.

In [None]:
transformation = transformations.OneDirection(latent_space_dim)
transformation.load_state_dict(
    torch.load(
        os.path.join(checkpoint_dir, checkpoint_name),
    ),
    strict=True,
)
transformation.to(device)
transformation.eval()

In [None]:
#
# Now that the model is trained, we can test it.
#
# Testing the model involves using the transformation model to transform a z vector and then
# using the GAN model to generate an image from the transformed z vector.
# We will change the step size and see how the image changes.
#
batch_size = 6  # Do not change
alpha = 0.2
num_samples = 6

iters = 10

transformation = transformations.OneDirection(latent_space_dim)
transformation.load_state_dict(
    torch.load(
        os.path.join("Weights", checkpoint_name),
    ),
    strict=True,
)
transformation.to(device)
transformation.eval()

gan_results = []

with torch.no_grad():
    zds = zdataset.z_dataset_for_model(proggan_model, size=num_samples, seed=5555)
    zds_dataloader = DataLoader(zds, batch_size=batch_size, shuffle=True)
    
    progress = tqdm("Training", total=len(range(0, num_samples, batch_size)))
    
    # loop over data batches
    for batch_start, z_batch in enumerate(zds_dataloader):
        #
        # Setup the batch z and y vectors. Also sample step sizes.
        #
        z_batch = z_batch[0].squeeze().to(device)
        
        step_sizes = (
            (torch.ones((batch_size, latent_space_dim)) * alpha).float().to(device)
        )

        gan_images, out_scores = make_images_and_probs(
            model=proggan_model,
            zdataset = z_batch,
            clip_model=clip_model,
            encoded_text=text_features,
            attribute_index=attribute_index,
        )

        batch_results = [(gan_images, out_scores)]

        # Generate images by transforming the z vector in the negative direction
        z_negative = z_batch.clone()

        for iter in range(iters):
            z_negative = transformation.transform(z_negative, None, -step_sizes)
            batch_results.insert(
                0,
                make_images_and_probs(
                    model=proggan_model,
                    zdataset = z_negative,
                    clip_model=clip_model,
                    encoded_text=text_features,
                    attribute_index=attribute_index,
                )
            )

        # Generate images by transforming the z vector in the positive direction
        z_positive = z_batch.clone()

        for iter in range(iters):
            z_positive = transformation.transform(z_positive, None, step_sizes)
            batch_results.append(
                make_images_and_probs(
                    model=proggan_model,
                    zdataset = z_positive,
                    clip_model=clip_model,
                    encoded_text=text_features,
                    attribute_index=attribute_index,
                )
            )

        gan_results.append(batch_results)

        progress.update(1)

In [None]:
show_gan_results(gan_results)

## Hooking a model with InstrumentedModel

To analyze what a model is doing inside, we can wrap it with an InstrumentedModel, which makes it easy to hook or modify a particular layer.

InstrumentedModel adds a few useful functions for inspecting a model, including:
   * `model.retain_layer('layername')` - hooks a layer to hold on to its output after computation
   * `model.retained_layer('layername')` - returns the retained data from the last computation
   * `model.edit_layer('layername', rule=...)` - runs the `rule` function after the given layer
   * `model.remove_edits()` - removes editing rules

Let's setup `retain_layer` now.  We'll pick a layer sort of in the early-middle of the generator.  You can pick whatever you like.

In [None]:
from netdissect import nethook

# Don't re-wrap it, if it's already wrapped (e.g., if you press enter twice)
if not isinstance(proggan_model, nethook.InstrumentedModel):
    proggan_model = nethook.InstrumentedModel(proggan_model)
proggan_model.retain_layer('layer4')

In [None]:
# Run the model
img = proggan_model(zds[0][0][None,...].to(device))

# As a side-effect, the proggan_model has retained the output of layer4.
acts = proggan_model.retained_layer('layer4')

# We can look at it.  How much data is it?
acts.shape

In [None]:
# Let's just look at the 0th convolutional channel.
print(acts[0,0])

## Visualizing activation data

It can be informative to visualize activation data instead of just looking at the numbers.

Net dissection comes with an ImageVisualizer object for visualizing grid data as an image in a few different ways.  Here is a heatmap of the array above:

In [None]:
from netdissect import imgviz
iv = imgviz.ImageVisualizer(100)
iv.heatmap(acts[0,1], mode='nearest')

In [None]:
from netdissect import show

show(
    [['unit %d' % u,
      [iv.image(img[0])],
      [iv.masked_image(img[0], acts, (0,u))],
      [iv.heatmap(acts, (0,u), mode='nearest')],
     ] for u in range(1, 6)]  
)

## Collecting quantile statistics for every unit

We want to know per-channel minimum or maximum values, means, medians, quantiles, etc.

We want to treat each pixel as its own sample for all the channels.  For example, here are the activations for one image as an 8x8 tensor over with 512 channels.  We can disregard the geometry and just look at it as a 64x512 sample matrix, that is 64 samples of 512-dimensional vectors.

In [None]:
print(acts.shape)
print(acts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1]).shape)

Net dissection has a tally package that tracks quantiles over large samples.

To use it, just define a function that returns sample matrices like the 64x512 above, and then it will call your function on every batch and tally up the statistics.

In [None]:
from netdissect import tally

# To collect stats, define a function that returns 2d [samples, units]
def compute_samples(zbatch):
    _ = proggan_model(zbatch.to(device))          # run the proggan_model
    acts = proggan_model.retained_layer('layer4') # get the activations, and flatten
    return acts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])

# Then tally_quantile will run your function over the whole dataset to collect quantile stats
rq = tally.tally_quantile(compute_samples, zds)

# Print out the median value for the first 20 channels
rq.quantiles(0.5)[:20]

## Exploring quantiles

The rq object tracks a sketch of all the quantiles of the sampled data.  For example, what is the mean, median, and percentile value for each unit?

In [None]:
# This tells me now, for example, what the means are for channel,
# rq.mean()
# what median is,
# rq.quantiles(0.5)
# Or what the 99th percentile quantile is.
# rq.quantiles(0.99)

(rq.quantiles(0.8) > 0).sum()

The quantiles can be plugged directly into the ImageVisualizer to put heatmaps on an informative per-unit scale.  When you do this:

   * Heatmaps are shown on a scale from black to white from 1% lowest to the 99% highest value.
   * Masked image lassos are shown at a 95% percentile level (by default, can be changed).

In [None]:
iv = imgviz.ImageVisualizer(100, quantiles=rq)
show([
    [  # for every unit, make a block containing
       'unit %d' % u,         # the unit number
       [iv.image(img[0])],    # the unmodified image
       [iv.masked_image(img[0], acts, (0,u))], # the masked image
       [iv.heatmap(acts, (0,u), mode='nearest')], # the heatmap
    ]
    for u in range(1, 6)
])

In [None]:
def compute_image_max(zbatch):
    image_batch = proggan_model(zbatch.to(device))
    return proggan_model.retained_layer('layer4').max(3)[0].max(2)[0]

topk = tally.tally_topk(compute_image_max, zds)
topk.result()[1].shape

In [None]:
# For each unit, this function prints out unit masks from the top-activating images
def unit_viz_row(unitnum, percent_level=0.95):
    out = []
    for imgnum in topk.result()[1][unitnum][:8]:
        img = proggan_model(zds[imgnum][0][None,...].to(device))
        acts = proggan_model.retained_layer('layer4')
        out.append([imgnum.item(),
                    [iv.masked_image(img[0], acts, (0, unitnum), percent_level=percent_level)],
                   ])
    return out

show(unit_viz_row(30))

# Evaluate matches with semantic concepts

Do the filters match any semantic concepts?  To systematically examine this question,
we have pretrained (using lots of labeled data) a semantic segmentation network to recognize
a few hundred classes of objects, parts, and textures.

Run the code in this section to look for matches between filters in our GAN and semantic
segmentation clases.

## Labeling semantics within the generated images

Let's quantify what's inside these images by segmenting them.

First, we create a segmenter network.  (We use the Unified Perceptual Parsing segmenter by Xiao, et al. (https://arxiv.org/abs/1807.10221).

Note that the segmenter we use here requires a GPU.

In [None]:
from netdissect import segmenter, setting

# segmodel = segmenter.UnifiedParsingSegmenter(segsizes=[256])
segmodel, seglabels, _ = setting.load_segmenter('netpq')
# seglabels = [l for l, c in segmodel.get_label_and_category_names()[0]]
print('segmenter has', len(seglabels), 'labels')

Then we create segmentation images for the dataset.  Here tally_cat just concatenates batches of image (or segmentation) data.

  * `segmodel.segment_batch` segments an image
  * `iv.segmentation(seg)` creates a solid-color visualization of a segmentation
  * `iv.segment_key(seg, segmodel)` makes a small legend for the segmentation

In [None]:
from netdissect import upsample
from netdissect import segviz

imgs = tally.tally_cat(run_model_batch, noise_vector)
seg = tally.tally_cat(lambda img: segmodel.segment_batch(img.cuda(), downsample=1), imgs)

from netdissect.segviz import seg_as_image, segment_key
show([
    (iv.image(imgs[i]),
     iv.segmentation(seg[i,0]),
     iv.segment_key(seg[i,0], segmodel)
    )
    for i in range(min(len(seg), 5))
])

In [None]:
torch.cuda.empty_cache()
!nvidia-smi

In [None]:
# We upsample activations to measure them at each segmentation location.
upfn8 = upsample.upsampler((64, 64), (8, 8)) # layer4 is resolution 8x8


def compute_conditional_samples(zbatch):
    zclass = one_hot_from_int(CLIP_CLASS_ID, batch_size=len(zbatch))
    zclass = torch.from_numpy(zclass)
    zclass = zclass.to(zbatch.device)

    image_batch = proggan_model(zbatch, zclass, truncation)
    zclass.detach().cpu()

    seg = segmodel.segment_batch(image_batch, downsample=4)
    upsampled_acts = upfn8(proggan_model.retained_layer('layer4'))

    samples = tally.conditional_samples(upsampled_acts, seg)

    torch.cuda.empty_cache()
    return samples

# Run this function once to sample one image
sample = compute_conditional_samples(noise_vector[:1])

# The result is a list of all the conditional subsamples
[(seglabels[c], d.shape) for c, d in sample]

cq = tally.tally_conditional_quantile(compute_conditional_samples, noise_vector)

Conditional quantile statistics let us compute lots of relationships between units and visual concepts.

For example, IoU is the "intersection over union" ratio, measuring how much overlap there is between the top few percent activations of a unit and the presence of a visual concept.  We can estimate the IoU ratio for all pairs between units and concepts with these stats:

In [None]:
iou_table = tally.iou_from_conditional_quantile(cq, cutoff=0.99)
iou_table.shape

Now let's view a few of the units, labeled with an associated concept, sorted from highest to lowest IoU.

In [None]:
unit_list = sorted(enumerate(zip(*iou_table.max(1))), key=lambda k: -k[1][0])

for unit, (iou, segc) in unit_list[:5]:
    print('unit %d: %s (iou %.2f)' % (unit, seglabels[segc], iou))
    show(unit_viz_row(unit))

We can quantify the overall match between units and segmentation concepts by counting the number of units that match a segmentation concept (omitting low-scoring matches).

In [None]:
print('Number of units total:', len(unit_list))
print('Number of units that match a segmentation concept with IoU > 0.04:',
   len([i for i in range(len(unit_list)) if unit_list[i][1][0] > 0.04]))

## Examining units that select for nose

Now let's filter just units that were labeled as 'nose' units.