In [1]:
import torch
import torch.nn as nn
from functools import partial
from modelling_finetune import get_vit_config, LongViTForClassification
from models.cmil import CMILModel, FeatureExtractor, SliceFusionTransformer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
config = get_vit_config(img_size=(64,256,256), patch_size=(4,16,16), embed_dim=384, depth=12, num_heads=16,
                        norm_layer=partial(nn.LayerNorm, eps=1e-6))

v = LongViTForClassification(config, num_classes=14).to(device).half()

ct = torch.randn(4, 1, 64, 256, 256, device=device).half()
preds = v(ct)
print(preds.shape)

Number of patches: 4096
Using Torchscale LongNetEncoder
torch.Size([4, 14])


In [7]:
v.model.patch_embed.num_patches

4096

In [6]:
preds

tensor([[ 0.0527, -0.0679,  0.0416, -0.4260, -0.0988,  0.2468, -0.0457, -0.1605,
          0.0790, -0.0021, -0.3213, -0.0792,  0.1904,  0.1860],
        [ 0.2742, -0.1193,  0.2974, -0.0947,  0.1022,  0.1819, -0.1227, -0.1455,
         -0.0143,  0.2477, -0.2988,  0.4233, -0.0638,  0.2346],
        [-0.1184, -0.0112,  0.1346, -0.1506,  0.0205,  0.2908,  0.0127,  0.3157,
          0.0383, -0.0862,  0.0596, -0.0310,  0.0005, -0.0715],
        [ 0.0455,  0.1278,  0.1597, -0.0884, -0.0680,  0.0740, -0.0955, -0.1681,
         -0.2808,  0.0016, -0.0254, -0.0112, -0.0260, -0.1202]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward0>)

In [2]:
embed_dim = 384  # Must match the embed_dim in FeatureExtractor (DINOv2 output)
num_heads = 16
hidden_dim = 2048
num_layers = 1
patch_size = 1  # Patch size for the SliceFusionTransformer

max_seq_len = 256  # Adjust based on your data
transformer_model = SliceFusionTransformer(
    seq_len=max_seq_len,
    embed_dim=embed_dim,
    num_heads=num_heads,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    patch_size=patch_size
)

model = CMILModel(FeatureExtractor(model_name='dinov2_vits14'), transformer_model).to(device).half()

ct = torch.randn(4, 64, 3, 224, 224, device=device).half()
preds = model(ct)
print(preds.shape)

Using cache found in /home/than/.cache/torch/hub/facebookresearch_dinov2_main


torch.Size([4, 384])
