In [1]:
import torch
import torch.nn as nn
from functools import partial
from modelling_finetune import get_vit_config, LongViTForClassification
from models.cmil import CMILModel, FeatureExtractor, SliceFusionTransformer

device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

In [None]:
config = get_vit_config(img_size=(384,512,512), patch_size=(8,32,32), embed_dim=384, depth=12, num_heads=16,
                        norm_layer=partial(nn.LayerNorm, eps=1e-6))

v = LongViTForClassification(config, num_classes=14).to(device).half()

ct = torch.randn(4, 1, 384, 512, 512, device=device).half()
preds = v(ct)
print(preds.shape)

Number of patches: 12288
Using Torchscale LongNetEncoder
torch.Size([4, 14])


In [3]:
preds

tensor([[-0.0820, -0.2520,  0.3376, -0.1135, -0.1403, -0.0159, -0.0265,  0.1301,
         -0.2076, -0.1617, -0.1147, -0.1781,  0.2322, -0.3887],
        [-0.0188,  0.1843,  0.0089, -0.0826, -0.3115, -0.0507,  0.2876,  0.3669,
         -0.0900, -0.0640, -0.1317, -0.0127,  0.1244,  0.1436],
        [ 0.0865,  0.5117, -0.0729, -0.4053, -0.0200,  0.2365,  0.2277, -0.2998,
          0.1593, -0.0018, -0.2058, -0.4873,  0.0013,  0.0353],
        [ 0.3340,  0.3708, -0.0269,  0.0812, -0.4121, -0.0415,  0.1624,  0.0815,
         -0.2766,  0.1267,  0.0272,  0.1121,  0.3857,  0.0342]],
       device='cuda:2', dtype=torch.float16, grad_fn=<AddmmBackward0>)

In [2]:
embed_dim = 384  # Must match the embed_dim in FeatureExtractor (DINOv2 output)
num_heads = 16
hidden_dim = 2048
num_layers = 1
patch_size = 1  # Patch size for the SliceFusionTransformer

max_seq_len = 256  # Adjust based on your data
transformer_model = SliceFusionTransformer(
    seq_len=max_seq_len,
    embed_dim=embed_dim,
    num_heads=num_heads,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    patch_size=patch_size
)

model = CMILModel(FeatureExtractor(model_name='dinov2_vits14'), transformer_model).to(device).half()

ct = torch.randn(4, 64, 3, 224, 224, device=device).half()
preds = model(ct)
print(preds.shape)

Using cache found in /home/than/.cache/torch/hub/facebookresearch_dinov2_main


torch.Size([4, 384])
