## Adding semantic context to SPAI: cross attention before/after the SCA module

In [1]:
import torch
import random
from functools import partial

# imports and constants
import matplotlib.pyplot as plt
from PIL import Image
import sys
import os

import torch

parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from spai.models import sid
from spai.models import vision_transformer
from spai.models import backbones


IMAGE_PATH = "./../data/images/fake_example.png"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#### OLD CODE ####
# def show_images_in_row(images, titles=None, figsize=(15, 5)):
#     """
#     Displays a list of image arrays in a single row.

#     Args:
#         images (list): List of image data (PIL.Image or NumPy arrays).
#         titles (list, optional): Optional list of titles for each image.
#         figsize (tuple): Size of the entire figure (width, height).
#     """
#     num_images = len(images)
#     fig, axes = plt.subplots(1, num_images, figsize=figsize)

#     if num_images == 1:
#         axes = [axes]

#     for i, (ax, img) in enumerate(zip(axes, images)):
#         ax.imshow(img)
#         ax.axis("off")
#         if titles and i < len(titles):
#             ax.set_title(titles[i], fontsize=10)

#     plt.tight_layout()
#     plt.show()


# img_og = Image.open(IMAGE_PATH)
# # resize to 224x224
# img_resized = img_og.resize((224, 224))

# show_images_in_row(
#     [img_og, img_resized], titles=["Original Image", "224x224"], figsize=(8, 4)
# )

#############################################################
# from torchvision import transforms
# from spai.models.backbones import CLIPBackbone

# clip_encoder = CLIPBackbone()

# # NOTE: we might need to normalize according to clip mean/std
# img_tensor = transforms.ToTensor()(img_resized).unsqueeze(0)
# print(f"Image tensor shape: {img_tensor.shape}") # ([1, 3, 224, 224]

# img_encoding = clip_encoder(img_tensor)
# # print encoding shape
# print(f"Encoding shape: {img_encoding.shape}") # ([1, 12, 196, 768])

### Test: loading the model with the semantic changes and the forward pass

In [3]:
# Test forward with different backbones
backbone_vits = [
    # vision_transformer.VisionTransformer(
    #     img_size=224,
    #     patch_size=16,
    #     in_chans=3,
    #     num_classes=2,
    #     embed_dim=768,
    #     depth=12,
    #     num_heads=12,
    #     mlp_ratio=4,
    #     qkv_bias=True,
    #     drop_rate=0.1,
    #     drop_path_rate=0.1,
    #     norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
    #     init_values=0.1,
    #     use_abs_pos_emb=True,
    #     use_rel_pos_bias=False,
    #     use_shared_rel_pos_bias=False,
    #     use_mean_pooling=False,
    #     use_intermediate_layers=True,
    #     intermediate_layers=tuple(range(12)),
    #     return_features=True,
    # ),
    backbones.CLIPBackbone().cpu(),
    # backbones.DINOv2Backbone().cpu(),
]

vit = backbone_vits[0]

In [4]:
# NOTE: default config values used based on config.yaml/config.py
batch_size = 4
features_num = 12
input_dim = 768
masking_radius = 16


features_processor = sid.FrequencyRestorationEstimator(
    features_num=features_num,
    input_dim=input_dim,
    proj_dim=1024,
    proj_layers=2,
    patch_projection=True,
    patch_projection_per_feature=True,
)
cls_head = sid.ClassificationHead(6 * features_num, 1, mlp_ratio=3)
model = sid.PatchBasedMFViT(
    vit=vit,
    features_processor=features_processor,
    cls_head=cls_head,
    masking_radius=masking_radius,
    img_patch_size=224,
    img_patch_stride=224,
    cls_vector_dim=6 * features_num,
    num_heads=12,
    attn_embed_dim=1536,
    minimum_patches=1,
    use_semantic_cross_attn_sca="after",  # tested : None/before/after
    semantic_embed_dim=512,
)

# load the model weights from weights/PatchBasedMFViT_test.pth
model_weights_path = os.path.join(parent_dir, "weights", "PatchBasedMFViT_test_05-05.pth")
model_weights = torch.load(model_weights_path, map_location="cpu")
model.load_state_dict(model_weights, strict=False)
model.eval()

Using semantic cross-attention after!
512 12 72


  model_weights = torch.load(model_weights_path, map_location="cpu")


PatchBasedMFViT(
  (mfvit): MFViT(
    (vit): CLIPBackbone(
      (clip): CLIP(
        (visual): VisionTransformer(
          (conv1): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
          (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (transformer): Transformer(
            (resblocks): Sequential(
              (0): ResidualAttentionBlock(
                (attn): MultiheadAttention(
                  (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
                )
                (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
                (mlp): Sequential(
                  (c_fc): Linear(in_features=768, out_features=3072, bias=True)
                  (gelu): QuickGELU()
                  (c_proj): Linear(in_features=3072, out_features=768, bias=True)
                )
                (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              )
           

In [5]:
x = torch.randn((batch_size, 3, 224, 224))
model.eval()
with torch.no_grad():
    out = model(x)

print(f"Output shape (fixed resolution): {out.shape}")
assert out.shape == (batch_size, 1), f"Unexpected output shape: {out.shape}"


# # NOTE: Uncommented since this was only used for the untouched version of PatchBasedMFViT
# # save mode to weights/PatchBasedMFViT_test.pth
# model_path = os.path.join("weights", "PatchBasedMFViT_test_05-05.pth")
# os.makedirs(os.path.dirname(model_path), exist_ok=True)
# torch.save(model.state_dict(), model_path)
# print(f"Model saved to {model_path}")

Input shape: torch.Size([4, 3, 224, 224])
Output shape (fixed resolution): torch.Size([4, 1])


In [6]:
# OLD KEYS: untouched version of PatchBasedMFViT; no semantic cross-attention parameters
for key in model_weights.keys():
    print(key)

patch_aggregator
mfvit.frequencies_mask
mfvit.vit.clip.positional_embedding
mfvit.vit.clip.text_projection
mfvit.vit.clip.logit_scale
mfvit.vit.clip.visual.class_embedding
mfvit.vit.clip.visual.positional_embedding
mfvit.vit.clip.visual.proj
mfvit.vit.clip.visual.conv1.weight
mfvit.vit.clip.visual.ln_pre.weight
mfvit.vit.clip.visual.ln_pre.bias
mfvit.vit.clip.visual.transformer.resblocks.0.attn.in_proj_weight
mfvit.vit.clip.visual.transformer.resblocks.0.attn.in_proj_bias
mfvit.vit.clip.visual.transformer.resblocks.0.attn.out_proj.weight
mfvit.vit.clip.visual.transformer.resblocks.0.attn.out_proj.bias
mfvit.vit.clip.visual.transformer.resblocks.0.ln_1.weight
mfvit.vit.clip.visual.transformer.resblocks.0.ln_1.bias
mfvit.vit.clip.visual.transformer.resblocks.0.mlp.c_fc.weight
mfvit.vit.clip.visual.transformer.resblocks.0.mlp.c_fc.bias
mfvit.vit.clip.visual.transformer.resblocks.0.mlp.c_proj.weight
mfvit.vit.clip.visual.transformer.resblocks.0.mlp.c_proj.bias
mfvit.vit.clip.visual.transfo

In [7]:
# NEW KEYS: new version of PatchBasedMFViT with the semantic cross-attention parameters
new_model_weights = model.state_dict()
# print new model weights keys
for key in new_model_weights.keys():
    print(key)
    # if "vit" in key:
    #     print(key)
    #     print(model_weights[key].shape)
    #     print(model.state_dict()[key].shape)
    #     assert model_weights[key].shape == model.state_dict()[key].shape, f"Shape mismatch for {key}"

patch_aggregator
mfvit.frequencies_mask
mfvit.vit.clip.positional_embedding
mfvit.vit.clip.text_projection
mfvit.vit.clip.logit_scale
mfvit.vit.clip.visual.class_embedding
mfvit.vit.clip.visual.positional_embedding
mfvit.vit.clip.visual.proj
mfvit.vit.clip.visual.conv1.weight
mfvit.vit.clip.visual.ln_pre.weight
mfvit.vit.clip.visual.ln_pre.bias
mfvit.vit.clip.visual.transformer.resblocks.0.attn.in_proj_weight
mfvit.vit.clip.visual.transformer.resblocks.0.attn.in_proj_bias
mfvit.vit.clip.visual.transformer.resblocks.0.attn.out_proj.weight
mfvit.vit.clip.visual.transformer.resblocks.0.attn.out_proj.bias
mfvit.vit.clip.visual.transformer.resblocks.0.ln_1.weight
mfvit.vit.clip.visual.transformer.resblocks.0.ln_1.bias
mfvit.vit.clip.visual.transformer.resblocks.0.mlp.c_fc.weight
mfvit.vit.clip.visual.transformer.resblocks.0.mlp.c_fc.bias
mfvit.vit.clip.visual.transformer.resblocks.0.mlp.c_proj.weight
mfvit.vit.clip.visual.transformer.resblocks.0.mlp.c_proj.bias
mfvit.vit.clip.visual.transfo

In [8]:
# check if the numbers from the new model weights are the same as the old ones
for key in model_weights.keys():
    # if "vit" in key:
    #     print(key)
    #     print(model_weights[key].shape)
    #     print(model.state_dict()[key].shape)
    #     assert torch.allclose(
    #         model_weights[key], model.state_dict()[key]
    #     ), f"Values mismatch for {key}"
    assert torch.allclose(
        model_weights[key], model.state_dict()[key]
    ), f"Values mismatch for {key}"

In [9]:
# count the number of parameters added to the new model: new_model_weights - model_weights
added_parameters = sum(new_model_weights[key].numel() for key in new_model_weights.keys() if key not in model_weights)

print(f"Number of parameters added to the new model: {added_parameters}")
print(f"Number of parameters in the new model: {sum(p.numel() for p in model.parameters())}")

# freeze only the original model weights (not the new ones) - the new ones are trainable
# also leave the classification head trainable
for name, param in model.named_parameters():
	if name in model_weights:
		# print(f"Freezing {name}")
		param.requires_grad = False
	if "cls_head" in name:
		param.requires_grad = True
		# print(f"Unfreezing {name}")

Number of parameters added to the new model: 58104
Number of parameters in the new model: 172213010


In [10]:
from torchinfo import summary

# Print summary
summary(model)

Layer (type:depth-idx)                                                                Param #
PatchBasedMFViT                                                                       1,536
├─MFViT: 1-1                                                                          50,176
│    └─CLIPBackbone: 2-1                                                              --
│    │    └─CLIP: 3-1                                                                 (149,620,737)
│    └─FrequencyRestorationEstimator: 2-2                                             --
│    │    └─FeatureSpecificProjector: 3-2                                             (22,087,680)
│    └─Normalize: 2-3                                                                 --
├─Softmax: 1-2                                                                        --
├─Dropout: 1-3                                                                        --
├─Linear: 1-4                                                                