In [29]:
# Model code exactly as this https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/timm_model.py
# Config from here https://github.com/mlfoundations/open_clip/blob/b2f1403605aade5a004434076246b6bc741aa47d/src/open_clip/model.py#L27


import logging
from collections import OrderedDict
import json

import torch
import torch.nn as nn

import timm
from timm.layers import Mlp, to_2tuple
#from timm.layers import RotAttentionPool2d 
#from timm.layers import AttentionPool2d as AbsAttentionPool2d

In [34]:
with open("custom-vit3.json", "r") as f:
    config = json.load(f)

print(config)

{'embed_dim': 768,
 'custom_text': True,
 'vision_cfg': {'image_size': 224,
  'timm_model_name': 'vit_base_patch16_224',
  'timm_model_pretrained': False,
  'timm_proj': 'mlp',
  'timm_drop_path': 0.1},
 'text_cfg': {'context_length': 512,
  'vocab_size': 49408,
  'width': 768,
  'heads': 8,
  'layers': 12}}

In [2]:
class CLIPTimmModel(nn.Module):
    """ timm model adapter
    """

    def __init__(
            self,
            model_name,
            embed_dim,
            image_size=224,
            pool='avg',
            proj='linear',
            proj_bias=False,
            drop=0.,
            drop_path=None,
            patch_drop=None,
            pretrained=False,
    ):
        super().__init__()
        if timm is None:
            raise RuntimeError("Please `pip install timm` to use timm models.")
        self.image_size = to_2tuple(image_size)

        # setup kwargs that may not be common across all models
        timm_kwargs = {}
        if drop_path is not None:
            timm_kwargs['drop_path_rate'] = drop_path
        if patch_drop is not None:
            timm_kwargs['patch_drop_rate'] = patch_drop

        custom_pool = pool in ('abs_attn', 'rot_attn')
        if proj:
            assert proj in ("linear", "mlp", "none")
        extra_proj = proj in ("linear", "mlp")
        if not extra_proj and not custom_pool:
            # use network classifier head as projection if no proj specified and no custom pooling used
            # if projection is explicitly set to "none" will be pass through from network trunk
            proj_dim = 0 if proj == 'none' else embed_dim
            self.trunk = timm.create_model(
                model_name,
                num_classes=proj_dim,
                global_pool=pool,
                pretrained=pretrained,
                **timm_kwargs,
            )
            prev_chs = embed_dim
        else:
            self.trunk = timm.create_model(
                model_name,
                pretrained=pretrained,
                **timm_kwargs,
            )
            feat_size = self.trunk.default_cfg.get('pool_size', None)
            feature_ndim = 1 if not feat_size else 2
            if custom_pool:
                assert feature_ndim == 2
                # if attn pooling used, remove both classifier and default pool
                self.trunk.reset_classifier(0, global_pool='')
            else:
                # reset global pool if pool config set, otherwise leave as network default
                reset_kwargs = dict(global_pool=pool) if pool else {}
                self.trunk.reset_classifier(0, **reset_kwargs)
            prev_chs = self.trunk.num_features

        head_layers = OrderedDict()

        # # Add custom pooling to head
        # if pool == 'abs_attn':
        #     head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
        #     prev_chs = embed_dim
        # elif pool == 'rot_attn':
        #     head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)

        # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
        if proj == 'linear':
            head_layers['drop'] = nn.Dropout(drop)
            head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
        elif proj == 'mlp':
            head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))

        self.head = nn.Sequential(head_layers)

    def forward(self, x):
        x = self.trunk(x)
        x = self.head(x)
        return x

In [35]:
model = CLIPTimmModel(model_name=config['vision_cfg']['timm_model_name'], 
                      embed_dim=config['embed_dim'], proj=config['vision_cfg']['timm_proj'],
                      drop_path=config['vision_cfg']['timm_drop_path']
                      )

In [36]:
model

TimmModel(
  (trunk): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
  

In [9]:
weights = torch.load("epoch_10.pt", map_location=torch.device("cpu"))
state_dict = weights['state_dict']
state_dict = {k.replace("visual.", "") :v for k, v in state_dict.items() if "visual" in k}
model.load_state_dict(state_dict)

In [28]:
x = torch.randn(10, 3, 224, 224)

model(x).shape

torch.Size([10, 768])

In [94]:
from transformers import AutoTokenizer
from open_clip.tokenizer import SimpleTokenizer

def check_simple_tokenization(word):
    tokenizer = SimpleTokenizer()

    encoding = tokenizer.encode(word)
    decoded_output = [tokenizer.decode([e]) for e in encoding]
    decoded_output = [token.replace(" ", "") for token in decoded_output if token]

    print(f"Word: {word}")
    print(f"Decoded Text: {decoded_output}")

    return decoded_output

def check_hf_tokenization(model_name: str, word: str):
    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Tokenize the word
    tokenized_output = tokenizer.encode(word, add_special_tokens=True)
    
    # Decode the tokenized output
    decoded_output = [tokenizer.decode(token, skip_special_tokens=True) for token in tokenized_output]
    decoded_output = [token for token in decoded_output if token]
    
    # Print results
    print(f"Model: {model_name}")
    print(f"Word: {word}")
    print(f"Decoded Text: {decoded_output}")

    return decoded_output

In [95]:
text = "Pneumomediastinum"

check_simple_tokenization(text)

model_name = "microsoft/BiomedVLP-CXR-BERT-general"
check_hf_tokenization(model_name, text)

model_name = "google-bert/bert-base-uncased"
check_hf_tokenization(model_name, text)

Word: Pneumomediastinum
Decoded Text: ['pneu', 'mom', 'edi', 'ast', 'in', 'um']
Model: microsoft/BiomedVLP-CXR-BERT-general
Word: Pneumomediastinum
Decoded Text: ['pneumomediastinum']
Model: google-bert/bert-base-uncased
Word: Pneumomediastinum
Decoded Text: ['p', '##ne', '##um', '##ome', '##dia', '##sti', '##num']


['p', '##ne', '##um', '##ome', '##dia', '##sti', '##num']

In [98]:
medical_terms = [
    "Cardiomediastinum", "Cardiomegaly", "Lung Lesion",
    "Lung Opacity", "Edema", "Consolidation", "Pneumonia", "Atelectasis",
    "Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture",
    "Support Devices", "Emphysema", "Fibrosis", "Hernia", "Infiltration",
    "Mass", "Nodule", "Pleural Thickening", "Pneumoperitoneum",
    "Pneumomediastinum", "Subcutaneous Emphysema", "Tortuous Aorta",
    "Calcification of the Aorta"
]

latex_rows = []

for term in medical_terms:
    simple_tokens = check_simple_tokenization(term)
    hf_tokens_1 = check_hf_tokenization(hf_model_1, term)
    hf_tokens_2 = check_hf_tokenization(hf_model_2, term)

    # Join tokens with hyphens instead of spaces
    row = f"{term} & {'-'.join(simple_tokens)} & {'-'.join(hf_tokens_1)} & {'-'.join(hf_tokens_2)} \\\\"
    latex_rows.append(row)

# Construct the full LaTeX table
latex_table = r"""
\begin{table}[h!]
\centering
\begin{adjustbox}{width=1.2\textwidth,center=1.1\textwidth}
\begin{tabular}{|l|l|l|l|}
\hline
\textbf{Term} & \textbf{CLIP Tokenizer} & \textbf{CXR-BERT} & \textbf{BERT} \\
\hline
"""
latex_table += "\n".join(latex_rows)
latex_table += r"""
\hline
\end{tabular}
\end{adjustbox}
\caption{Tokenization Comparison for various CXR-related terms}
\label{tab:tokenization_results}
\end{table}
"""

print(latex_table.replace("#", ""))

Word: Cardiomediastinum
Decoded Text: ['cardi', 'ome', 'di', 'ast', 'in', 'um']
Model: microsoft/BiomedVLP-CXR-BERT-general
Word: Cardiomediastinum
Decoded Text: ['cardiomedias', '##tinum']
Model: bert-base-uncased
Word: Cardiomediastinum
Decoded Text: ['card', '##iom', '##ed', '##ias', '##tin', '##um']
Word: Cardiomegaly
Decoded Text: ['cardi', 'ome', 'gal', 'y']
Model: microsoft/BiomedVLP-CXR-BERT-general
Word: Cardiomegaly
Decoded Text: ['cardiomegaly']
Model: bert-base-uncased
Word: Cardiomegaly
Decoded Text: ['card', '##iom', '##ega', '##ly']
Word: Lung Lesion
Decoded Text: ['lung', 'le', 'sion']
Model: microsoft/BiomedVLP-CXR-BERT-general
Word: Lung Lesion
Decoded Text: ['lung', 'lesion']
Model: bert-base-uncased
Word: Lung Lesion
Decoded Text: ['lung', 'les', '##ion']
Word: Lung Opacity
Decoded Text: ['lung', 'op', 'acity']
Model: microsoft/BiomedVLP-CXR-BERT-general
Word: Lung Opacity
Decoded Text: ['lung', 'opacity']
Model: bert-base-uncased
Word: Lung Opacity
Decoded Text: ['