# HaMLeT

## Session 9: Zero-shot anomaly segmentation with CLIP
by Jin Er, Lehrstuhl für Bildverarbeitung der RWTH Aachen

### Goal of this Session

In this session, you will step by step, implement a SCLIPAD algorithm yourself. You should already be familiar with CLIP concepts, the "torch" library, and familiar with MVTec datasets. 

### import all the relevant libraries

please run `!pip install` to install the necessary python `libraries`

In [None]:
!pip install -r requirements.txt

In [None]:
%load_ext autoreload
%autoreload 2

# plot
import matplotlib.pyplot as plt

# base
import open_clip
import torch
import einops
from torchvision import transforms
from torch.nn import functional as F

from text_prompt import (
    STATE_LEVEL_NORMAL_PROMPTS,
    STATE_LEVEL_ABNORMAL_PROMPTS,
    TEMPLATE_LEVEL_PROMPTS
)

# debugging 
import pdb

### Tasks

### Task 0: MVTec dataloader

The MVTec AD dataset is a comprehensive real-world dataset for benchmarking anomaly detection algorithms, particularly in the field of industrial inspection. Developed by MVTec Software GmbH, a leading provider of industrial machine vision software, this dataset is specifically designed to aid in the development and evaluation of anomaly detection methods.

In the MVTec AD dataset, the 15 categories are divided between object and texture types. Each category represents a distinct class of industrial product or material. Here are the specific names of these categories:

Object Categories:
1. **Bottle**
2. **Cable**
3. **Capsule**
4. **Carpet**
5. **Grid**
6. **Hazelnut**
7. **Leather**
8. **Metal Nut**
9. **Pill**
10. **Screw**
11. **Tile**
12. **Toothbrush**
13. **Transistor**
14. **Wood**
15. **Zipper**


<div>
<img src="imgs/MVTecAD-0000003433-bf7e8d4c.jpg" width="800"/>
</div>


Each of these categories encompasses various examples of normal conditions as well as a range of anomalies or defects relevant to the specific type of object or material. This categorization helps in addressing diverse challenges in industrial anomaly detection and quality control.

**IMPORTANT:** To be able to solve this notebook you need to request access for MVTec AD dataset yourself from here: https://www.mvtec.com/company/research/datasets/mvtec-ad.

### Task 1: Loader MVTec dataloader, for the simplicity, we are only use category `wood` during this session, and plot one image out from the `test.dataset`

Please read the `dataloader.py` file and try to understand the concept of `lightning dataModule` and `torch.dataset`. 

**DO NOT MODIFY the dataloader file**

In [None]:
# load the mvtec dataloader
from dataloader import MVTecDataModule

# -------------- student solution -----------------
# The solution for each part can be done via typing one or max two lines of code, 
# if not, it is highly likely you make the mistakes or misunderstand the question

"""
dataloader_mvtec = 
test_dataset = 
"""
# ------------ end of student solution ------------

ds = iter(test_dataset)

# only select the image contains anomalies 
while 1:
    data = ds.next()
    if data[1] != 0:
        break

### Task 2: Plot the testing image with its respective mask 

data is a `list` containing RGB images, output (0 for normal image, 1 for abnormal image), binary mask

The primary goal of this task is plotting the original image with its respective mask. 

In [None]:
def plot_input_img(
    data: list = data,
    pred = None, 
):
    invTrans = transforms.Compose(
        [
            transforms.Normalize(
                mean = [ 0., 0., 0. ],
                std = [ 1/0.229, 1/0.224, 1/0.225 ]
            ),
            transforms.Normalize(
                mean = [ -0.485, -0.456, -0.406 ],
                std = [ 1., 1., 1. ]
            ),
        ]
    )
    def convert_to_numpy(input, inverse=False):
        "help function to convert torch tensor to numpy for plotting images"
        if inverse:
            input = invTrans(input)
        input = einops.rearrange(input, "b c w h -> b w h c")[0]
        return input.detach().cpu().numpy()
        
    # -------------- student solution  -----------------

    # convert the image and mask

    imgs: list

    # --------- end of student solution --------------

    if pred is not None:
        imgs.append(convert_to_numpy(pred)[0])

    num_imgs = len(imgs)
    fig, axes = plt.subplots(1, num_imgs, figsize=(10 * num_imgs, 10))
    # Plot each image
    for ax, i in zip(axes, imgs):
        if i.shape[-1] == 1:
            ax.imshow(i, cmap="gray")
        else:
            ax.imshow(i)
        ax.axis('off')

    plt.tight_layout()  # Adjust subplots to fit into the figure area.
    plt.show()

In [None]:
plot_input_img()

### Task 3: Load CLIP visual encoder and text encoder

Please check the `utils.py` 

CLIP framework contains two networks, image encoder and text encoder; in our session.

In [None]:
# import library

from utils import get_clip_encoders

In [None]:
# -------------- student solution  -----------------

model = None
tokenizer = None

# use print(model) to print the general structure of the model
vit = None
# ------------end of student solution-------------------

### Task 4:

Define a new image encoder with modified attention blocks, as shown in the following figure. 

<div>
<img src="imgs/mll_lab.jpg" width="800"/>
</div>

Fig (a) shows the modified `qq` and `kk` attention instead of the original` qk `attention from the attention block. 
Fig (b) shows the general CLIP framework.

We need to rewrite the image encoder for applying the modify `qq` `kk` attention. 

#### Please note that we only apply the modification in the last resblock in ViT 

In [None]:
def image_encoder(
    x, 
    vit=vit,
    layer_num=11,
    normalize=True,
):
    x = vit.conv1(x)
    x = x.reshape(x.shape[0], x.shape[1], -1)
    x = x.permute(0, 2, 1)

    # class embedding and positional embeddings
    x = torch.cat(
        [vit.class_embedding.to(x.dtype) + 
         torch.zeros(
             x.shape[0], 1, x.shape[-1], 
             dtype=x.dtype, 
             device=x.device
         ),
         x], dim=1
        )  # shape = [*, grid ** 2 + 1, width]
    
    x = x + vit.positional_embedding.to(x.dtype)
    x = vit.patch_dropout(x)
    x = vit.ln_pre(x)
    x = x.permute(1, 0, 2)  # NLD -> LND

    def csa_attn(resblock, x):
        attn = resblock.attn
        num_heads = attn.num_heads
        _, B, E = x.size()
        head_dim = E // num_heads
        scale = head_dim ** -0.5

        qkv = F.linear(x, attn.in_proj_weight, attn.in_proj_bias)
        q, k, v = qkv.chunk(3, dim=-1)  # [L, B, E]

        # [B * num_head, L (sequence length), E // num_head)
        q = q.contiguous().view(-1, B * num_heads, head_dim).transpose(0, 1)
        k = k.contiguous().view(-1, B * num_heads, head_dim).transpose(0, 1)
        v = v.contiguous().view(-1, B * num_heads, head_dim).transpose(0, 1)

        #------------    student solution     ---------

        # please fulfil your solution here, and DO NOT CHANGE THE OTHER PART OF CODE!!!

        # part one calculate qq attetion and kk attention weight

        # get attention weight
        
        #------------    end of student solution     ------------

        attn_output = attn_weights @ v  # [B * num_head, L, E // num_head]
        attn_output = attn_output.transpose(0, 1).contiguous().view(-1, B, E)  # [L, B, E]

        attn_output = attn.out_proj(attn_output)
        return attn_output

    resblocks = vit.transformer.resblocks
    for block_idx, block in enumerate(resblocks):
        if block_idx >= layer_num:
            """ Applying Self Attention Module in the last attention block """
            # ----------student solution  ----------
            # please check how atten is applied in ViT, if you do not know, please check the following link
            # https://github.com/mlfoundations/open_clip/blob/1be2c8993b3f2628d495dfa791061e92f1cd4d0e/src/open_clip/transformer.py#L262

            # Hint: the original ResidualAttentionBlock remain the same, please only replace the atten block
            # ------------------------------------ 
            
        else:
            x = block(x)
    x = x.permute(1, 0, 2)

    x = vit.ln_post(x)
    proj = vit.proj
    x = x @ proj

    if normalize:
        x = F.normalize(x, dim=-1, p=2)

    return {
        "cls": x[:, 0, :],
        "tokens": x[:, 1:, :]
    }

In [None]:
#### DO NOT CHANGE THE CODE HERE
# for sanity check 

encoded_patches = image_encoder(x=data[0])["tokens"]
shape = encoded_patches.shape
assert shape[1] == 196, shape[2] == 768

### Task 5: Text Classifier 

Generate average text embeddings for abnormal text and normal text, please check `text_prompt.py` for the details text prompt architecture 

In [None]:
def text_encoder(
    text,
    model = model,
    normalize: bool = True
):
    model = model.eval()
    text_features = model.encode_text(text)
    if normalize:
        text_features = F.normalize(text_features, dim=-1, p=2)
    return text_features

def build_text_classifier(
    model = model,
    tokenizer=tokenizer,
    category = "wood"
):
    def _process_template(state_level_templates):
        text = []
        
        for template in TEMPLATE_LEVEL_PROMPTS:
            for state_template in state_level_templates:
                # ------------student version ------------------
                #  generate the template normal and abnormal text, please check how to use lambada
                #  DO NOT MAKE IT TOO COMPLICATE; MAX 2 Lines of Code is ENOUGH!!!!
                # ----------------------------------------------

                
        device = model.parameters().__next__().device
        texts = tokenizer(text).to(device=device)
        class_embeddings = text_encoder(texts)
        mean_class_embeddings = torch.mean(class_embeddings, dim=0, keepdim=True)
        mean_class_embeddings = F.normalize(mean_class_embeddings, dim=-1)
        return mean_class_embeddings
    normal_text_embedding = _process_template(STATE_LEVEL_NORMAL_PROMPTS)
    abnormal_text_embedding = _process_template(STATE_LEVEL_ABNORMAL_PROMPTS)
    return torch.cat([normal_text_embedding, abnormal_text_embedding], dim=0)

In [None]:
text = build_text_classifier()

### Task 6: Generate Anomaly Heatmap

Use the text, which is generated in the Task 5 and encoded patches in Task 4 to generate the heatmap

Please always normalise text and patch embedding

hint: to calculate the cosine similarity !!!

In [None]:
def heatmap_generation(
    img_patches,
    text_features,
    logit_scale=4.7
):
    """generating patch-wise abnormal score"""
    # ----------- solution -------------------
    logit_patches = (img_patches @ text_features.T)
    logit_patches = (logit_scale * logit_patches).softmax(dim=-1)
    logit_patches = logit_patches.transpose(1, -1)

    # Hint: you can change the shape of the image by using einops operation

    
    #--------------end of solution --------------------------
    return logit_patches[:, 1:, ...]

In [None]:
logit_patches = heatmap_generation(
    img_patches=encoded_patches, 
    text_features=text
)

In [None]:
plot_input_img(data, logit_patches)

### Questions 

What do you think about the prediction mask ? Is there any better way we can improve it ?



**MVTec Dataset Citations**:

[1] Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, Carsten Steger: The MVTec Anomaly Detection Dataset: A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection; in: *International Journal of Computer Vision* 129(4):1038-1059, 2021, DOI: [10.1007/s11263-020-01400-4](https://link.springer.com/article/10.1007/s11263-020-01400-4).

[2] Paul Bergmann, Michael Fauser, David Sattlegger, Carsten Steger: MVTec AD — A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection; in: *IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 9584-9592, 2019, DOI: [10.1109/CVPR.2019.00982](https://ieeexplore.ieee.org/document/8954181).