In [1]:
import torch

from hescape.models import CLIPModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = CLIPModel(
    input_genes=280,
    embed_dim=128,
    img_enc_name="uni",
    gene_enc_name="nicheformer",
    loss="CLIP",
    img_finetune=True,
    gene_finetune=False,
    img_proj="mlp",
    gene_proj="linear",
    img_enc_path="/p/project1/hai_spatial_clip/pretrain_weights/image",
    gene_enc_path="/p/project1/hai_spatial_clip/pretrain_weights/gene",
    # drvi_model_dir="drvi_human_breast_panel"
)

Successfully loaded weights for uni
Successfully loaded weights for nicheformer
uni trainable params: 1.49M || all params: 304.84M || trainable%: 0.4887118439246093
nicheformer trainable params: 0.07M || all params: 38.59M || trainable%: 0.1701735239188342


In [3]:
model

CLIPModel(
  (image_encoder): ImageEncoder(
    (trunk): PeftModel(
      (base_model): LoraModel(
        (model): VisionTransformer(
          (patch_embed): PatchEmbed(
            (proj): lora.Conv2d(
              (base_layer): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
              (lora_dropout): ModuleDict(
                (default): Dropout(p=0.1, inplace=False)
              )
              (lora_A): ModuleDict(
                (default): Conv2d(3, 8, kernel_size=(16, 16), stride=(16, 16), bias=False)
              )
              (lora_B): ModuleDict(
                (default): Conv2d(8, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
              )
              (lora_embedding_A): ParameterDict()
              (lora_embedding_B): ParameterDict()
              (lora_magnitude_vector): ModuleDict()
            )
            (norm): Identity()
          )
          (pos_drop): Dropout(p=0.0, inplace=False)
          (patch_drop): Identity()
          (norm_

In [4]:
state_dict = torch.load(
    "/p/project1/hai_spatial_clip/outputs/human_breast_panel/11653143_11_11-uni-nicheformer-batch64-CLIP/checkpoints/last.ckpt",
    map_location=torch.device("cpu"),
)["state_dict"]

In [5]:
state_dict.keys()

odict_keys(['model.logit_scale', 'model.image_encoder.trunk.base_model.model.cls_token', 'model.image_encoder.trunk.base_model.model.pos_embed', 'model.image_encoder.trunk.base_model.model.patch_embed.proj.base_layer.weight', 'model.image_encoder.trunk.base_model.model.patch_embed.proj.base_layer.bias', 'model.image_encoder.trunk.base_model.model.patch_embed.proj.lora_A.default.weight', 'model.image_encoder.trunk.base_model.model.patch_embed.proj.lora_B.default.weight', 'model.image_encoder.trunk.base_model.model.blocks.0.norm1.weight', 'model.image_encoder.trunk.base_model.model.blocks.0.norm1.bias', 'model.image_encoder.trunk.base_model.model.blocks.0.attn.qkv.base_layer.weight', 'model.image_encoder.trunk.base_model.model.blocks.0.attn.qkv.base_layer.bias', 'model.image_encoder.trunk.base_model.model.blocks.0.attn.qkv.lora_A.default.weight', 'model.image_encoder.trunk.base_model.model.blocks.0.attn.qkv.lora_B.default.weight', 'model.image_encoder.trunk.base_model.model.blocks.0.attn

In [6]:
cleaned_state_dict = {k.removeprefix("model."): v for k, v in state_dict.items()}
cleaned_state_dict.keys()

dict_keys(['logit_scale', 'image_encoder.trunk.base_model.model.cls_token', 'image_encoder.trunk.base_model.model.pos_embed', 'image_encoder.trunk.base_model.model.patch_embed.proj.base_layer.weight', 'image_encoder.trunk.base_model.model.patch_embed.proj.base_layer.bias', 'image_encoder.trunk.base_model.model.patch_embed.proj.lora_A.default.weight', 'image_encoder.trunk.base_model.model.patch_embed.proj.lora_B.default.weight', 'image_encoder.trunk.base_model.model.blocks.0.norm1.weight', 'image_encoder.trunk.base_model.model.blocks.0.norm1.bias', 'image_encoder.trunk.base_model.model.blocks.0.attn.qkv.base_layer.weight', 'image_encoder.trunk.base_model.model.blocks.0.attn.qkv.base_layer.bias', 'image_encoder.trunk.base_model.model.blocks.0.attn.qkv.lora_A.default.weight', 'image_encoder.trunk.base_model.model.blocks.0.attn.qkv.lora_B.default.weight', 'image_encoder.trunk.base_model.model.blocks.0.attn.proj.base_layer.weight', 'image_encoder.trunk.base_model.model.blocks.0.attn.proj.ba

In [7]:
missing, unexpected = model.load_state_dict(cleaned_state_dict, strict=False)
print(missing)
print(unexpected)

[]
[]


In [8]:
image_encoder = model.image_encoder
image_encoder

ImageEncoder(
  (trunk): PeftModel(
    (base_model): LoraModel(
      (model): VisionTransformer(
        (patch_embed): PatchEmbed(
          (proj): lora.Conv2d(
            (base_layer): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Conv2d(3, 8, kernel_size=(16, 16), stride=(16, 16), bias=False)
            )
            (lora_B): ModuleDict(
              (default): Conv2d(8, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (norm): Identity()
        )
        (pos_drop): Dropout(p=0.0, inplace=False)
        (patch_drop): Identity()
        (norm_pre): Identity()
        (blocks): Sequential(
          (0): Block(
         

In [11]:
dummy_input = torch.Tensor(8, 3, 224, 224).uniform_()
output = image_encoder(dummy_input)
print(output.shape)  # Output shape: [batch_size, num_features]

torch.Size([8, 128])
