In [7]:
import sys
import torch 
import torch.nn as nn
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from timm.models.vision_transformer import Block

sys.path.append('..')
sys.path.append('Data2Seq')
sys.path.append('../mimic_study')

from Data2Seq.Data2Seq import Data2Seq
from Data2Seq.Text import get_text_embeddings

from torch.utils.data import DataLoader
from mimic_study.data.datasets import MIMIC_DataSet, MIMIC_Multi_Modal_DataSet

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
transform = A.Compose(
    [
        A.Resize(256, 256, always_apply=True),
        A.RandomCrop(224, 224, always_apply=True),
        ToTensorV2(),
    ]
)

dataset = MIMIC_Multi_Modal_DataSet(
    path="/home/le/Projects/mimic_study/data/mimic-cxr-jpg-2.0.0-small", 
    label_file="/home/le/Projects/mimic_study/data/mimic-cxr-jpg-2.0.0-small/mimic-cxr-2.0.0-chexpert.csv",
    task="multilabel",
    transform=transform,
    target_label="No Finding",
    tokenize=False,
)

In [3]:
(image, text), label = dataset[2]
image: torch.Tensor = image.unsqueeze(0).type(torch.float32)

In [13]:
img_tokenier = Data2Seq(modality='x-ray',dim=768)
txt_tokenier = Data2Seq(modality='text',dim=768)

img_features = img_tokenier(image)
txt_features = txt_tokenier(text).unsqueeze(0)

features = torch.concat([img_features.to(device), txt_features.to(device)],dim=1)

In [14]:
txt_features.shape

torch.Size([1, 1, 768])

In [17]:
# For base-scale encoder:
ckpt = torch.load("Meta-Transformer_base_patch16_encoder.pth")
encoder = nn.Sequential(*[
            Block(
                dim=768,
                num_heads=12,
                mlp_ratio=4.,
                qkv_bias=True,
                norm_layer=nn.LayerNorm,
                act_layer=nn.GELU
            )
            for i in range(12)])
encoder.load_state_dict(ckpt,strict=True)
encoder.to(device)

Sequential(
  (0): Block(
    (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (attn): Attention(
      (qkv): Linear(in_features=768, out_features=2304, bias=True)
      (q_norm): Identity()
      (k_norm): Identity()
      (attn_drop): Dropout(p=0.0, inplace=False)
      (proj): Linear(in_features=768, out_features=768, bias=True)
      (proj_drop): Dropout(p=0.0, inplace=False)
    )
    (ls1): Identity()
    (drop_path1): Identity()
    (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): Mlp(
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (act): GELU(approximate='none')
      (drop1): Dropout(p=0.0, inplace=False)
      (norm): Identity()
      (fc2): Linear(in_features=3072, out_features=768, bias=True)
      (drop2): Dropout(p=0.0, inplace=False)
    )
    (ls2): Identity()
    (drop_path2): Identity()
  )
  (1): Block(
    (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (attn): Attention(
      

In [18]:
encoded_features = encoder(features)


In [19]:
encoded_features

tensor([[[ 61.7122, -43.1789,   8.7162,  ...,   7.5555,  -7.6878,  -7.0379],
         [ 57.1193, -63.3660, -15.0408,  ...,  -5.4871, -22.5931, -20.7037],
         [ 54.1568, -63.1641,  10.3465,  ..., -13.3596, -29.4774,  -4.6906],
         ...,
         [228.9890, -58.1208, -22.8079,  ...,  -3.1398, -74.1099,  25.7346],
         [224.1574, -62.9149, -17.1602,  ...,  -2.4421, -78.9703,  25.2712],
         [  0.5240,   2.3075,   3.3165,  ...,   5.0046,   0.6962,  -0.6744]]],
       device='cuda:0', grad_fn=<AddBackward0>)