### Todo

1. Only implemented image captioning setup, and now implementing image-text-retrieval part.
3. Also, now only works with single token, we can learn multiple tokens later.
4. Currently does not accept any prompts, will do.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable, Optional, Tuple, List, Dict
from PIL import Image
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import cv2
import pydicom as dicom
import os
from torchvision.models import resnet50
from glob import glob
import random
from transformers import OPTForCausalLM, AutoTokenizer, BioGptTokenizer, BioGptForCausalLM
from pathlib import Path
from skimage import io
import csv
import re
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class BioViL(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = resnet50()
        self._initialize_resnet()
        self.feature_extractor = self._get_feature_extractor()
        
    def _initialize_resnet(self):
        model_state_dict = torch.load("biovil_backbone_2048.pt")
        self.model.load_state_dict(model_state_dict)
    
    def _get_feature_extractor(self):
        self._return_nodes = {'avgpool': 'avgpool'}
        return create_feature_extractor(self.model, return_nodes=self._return_nodes)

    def forward(self, x):
        features = self.feature_extractor(x)["avgpool"]
        features = features.squeeze()
        return features

In [3]:
def remap_to_uint8(array: np.ndarray, percentiles: Optional[Tuple[float, float]] = None) -> np.ndarray:
    """Remap values in input so the output range is :math:`[0, 255]`.

    Percentiles can be used to specify the range of values to remap.
    This is useful to discard outliers in the input data.

    :param array: Input array.
    :param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``.
        Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster).
    :returns: Array with ``0`` and ``255`` as minimum and maximum values.
    """
    array = array.astype(float)
    if percentiles is not None:
        len_percentiles = len(percentiles)
        if len_percentiles != 2:
            message = 'The value for percentiles should be a sequence of length 2,' f' but has length {len_percentiles}'
            raise ValueError(message)
        a, b = percentiles
        if a >= b:
            raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed')
        if a < 0 or b > 100:
            raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed')
        cutoff: np.ndarray = np.percentile(array, percentiles)
        array = np.clip(array, *cutoff)
    array -= array.min()
    array /= array.max()
    array *= 255
    return array.astype(np.uint8)

In [4]:
def load_image(path: Path) -> Image.Image:
    """Load an image from disk.

    The image values are remapped to :math:`[0, 255]` and cast to 8-bit unsigned integers.

    :param path: Path to image.
    :returns: Image as ``Pillow`` ``Image``.
    """
    # Although ITK supports JPEG and PNG, we use Pillow for consistency with older trained models
    if path.suffix in [".jpg", ".jpeg", ".png"]:
        image = io.imread(path)
    elif path.suffix == ".dcm":
        image = dicom.dcmread(path).pixel_array
    else:
        raise ValueError(f"Image type not supported, filename was: {path}")

    image = remap_to_uint8(image)
    return Image.fromarray(image).convert("L")

In [5]:
class ExpandChannels:
    def __call__(self, data: torch.Tensor) -> torch.Tensor:
        if data.shape[0] != 1:
            raise ValueError(f"Expected input of shape [1, H, W], found {data.shape}")
        return torch.repeat_interleave(data, 3, dim=0)

In [6]:
def preprocess_report(text):
    # Remove unnecessary and insensible parts
    text = re.sub(r"EXAMINATION:.*", "", text)  # Remove EXAMINATION line
    text = re.sub(r"WET READ:.*", "", text)  # Remove WET READ line
    text = re.sub(r"FINAL REPORT", "", text)  # Remove FINAL REPORT line
    text = re.sub(r"STUDY:.*", "", text)  # Remove STUDY line
    text = re.sub(r"COMPARISON:.*", "", text)  # Remove COMPARISON section
    text = re.sub(r"TECHNIQUE:.*", "", text)  # Remove TECHNIQUE section
    text = re.sub(r"_+", "_", text)  # Remove multiple underscores

    # Clean up excessive newlines and spaces
    text = re.sub(r"\s\s+", " ", text)
    text = re.sub(r" +", " ", text)
    text = text.strip()
    return text

In [7]:
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, RandomHorizontalFlip, RandomAffine

# values from BioViL repository
RESIZE = 512
CENTER_CROP_SIZE = 480

def create_chest_xray_transform_for_inference(resize: int, center_crop_size: int, train: bool) -> Compose:
    data_aug_rot = 15
    data_aug_trans = 0.10
    data_aug_scale = 0.10
    if not train:
        transforms = [Resize(resize), CenterCrop(center_crop_size), ToTensor(), ExpandChannels()]
    if train:
        transforms = [Resize(resize), 
                        RandomAffine(data_aug_rot, 
                            translate=(data_aug_trans, data_aug_trans), 
                            scale=(1.0-data_aug_scale, 1.0+data_aug_scale)
                        ),
                      CenterCrop(center_crop_size), ToTensor(), ExpandChannels()]
    return Compose(transforms)

In [9]:
class MIMICDataset(torch.utils.data.Dataset):
    IMG_ROOT = '/kuacc/users/oince22/hpc_run/physionet.org/files/mimic-cxr-jpg/2.0.0/files/'
    
    def __init__(self, tsv_fname, transform):
        self.tsv_fname = tsv_fname
        self.img_paths, self.reports = self._read_tsv_file()
        self.transform = transform

    def _read_tsv_file(self):
        reports = []
        img_paths = []
        with open(self.tsv_fname, "r") as f:
            reader = csv.reader(f, delimiter='\t')
            for report, img_path in reader:
                reports.append(report)
                img_paths.append(Path(MIMICDataset.IMG_ROOT + img_path))
                
        return img_paths, reports
        
    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        while True:
            try:
                img = load_image(self.img_paths[idx])
                text = self.reports[idx]
                transform_img = self.transform(img)
                return transform_img, text
            except:
                idx = np.random.randint(0, len(self.img_paths))
        

In [10]:
def setup_tokenizer(model_checkpoint, device):
    tokenizer = AutoTokenizer.from_pretrained(FromageModel.MODEL_CHECKPOINT, device=device)
    tokenizer.add_special_tokens({"cls_token": "<|image|>"})
    tokenizer.add_tokens("[RET]")
    ret_id = tokenizer('[RET]', add_special_tokens=False).input_ids
    assert len(ret_id) == 1, "Failed to add [RET] token to tokenizer"
    ret_token_idx = ret_id[0]
    return tokenizer, ret_token_idx

In [11]:
def contrastive_loss(logits):
    return F.cross_entropy(logits, torch.arange(len(logits), device=logits.device))

In [12]:
transform = create_chest_xray_transform_for_inference(RESIZE, CENTER_CROP_SIZE, train=True)
dataset = MIMICDataset("MIMIC_JPG.tsv", transform=transform)

In [13]:
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=84, shuffle=True, num_workers=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
class FromageModel(nn.Module):
    MODEL_CHECKPOINT = "facebook/opt-350m"
    VISION_EMBED_DIM = 2048
    VISION_EMBED_DROPOUT = 0.1
    SHARED_EMB_DIM = 512
    
    def __init__(self, device, tokenizer, ret_token_idx):
        super().__init__()
        self.ret_token_idx = ret_token_idx
        self.device = device
        self.tokenizer = tokenizer
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.image_token = self.tokenizer.cls_token_id
        
        self.lm = OPTForCausalLM.from_pretrained(FromageModel.MODEL_CHECKPOINT)
        for param in self.lm.parameters():
            param.requires_grad = False
        
        self.vm = BioViL()
        for param in self.vm.parameters():
            param.requires_grad = False

        self.lm.resize_token_embeddings(len(self.tokenizer))
        self.input_embeddings = self.lm.get_input_embeddings()
        
        self.LM_EMBED_DIM = self.input_embeddings.embedding_dim
        
        self.vm.eval()
        self.lm.eval()
        
        # currently only works with one token, we will generalize it to multiple tokens later
        self.caption_mapping = nn.Linear(FromageModel.VISION_EMBED_DIM, self.LM_EMBED_DIM)
        self.mapping_dropout = nn.Dropout(FromageModel.VISION_EMBED_DROPOUT)

        self.ret_i2t_mapping = nn.Linear(FromageModel.VISION_EMBED_DIM, FromageModel.SHARED_EMB_DIM)
        self.ret_t2i_mapping = nn.Linear(self.LM_EMBED_DIM, FromageModel.SHARED_EMB_DIM)
        
    def generate(self, embeddings, max_len, temperature=0.0, top_p=1.0, filter_value=float("-inf")):
        bsz, seq_len, _ = embeddings.shape
        out = None
        past_key_values = None
        output_embeddings = []
        output_logits = []

        with torch.no_grad():
            for i in range(max_len):
                output = self.lm(inputs_embeds=embeddings, use_cache=False, output_hidden_states=True)
                last_hidden_state = output.hidden_states[-1]
                last_hidden_state = last_hidden_state[torch.arange(last_hidden_state.shape[0]), seq_len-1, :]
                last_hidden_state = self.ret_t2i_mapping(last_hidden_state)
                last_embedding = last_hidden_state / last_hidden_state.norm(dim=-1, keepdim=True)
    
                output_embeddings.append(last_embedding)
                logits = output.logits[:,-1,:] # todo, look at here again
                output_logits.append(logits)
    
                if temperature == 0.0:
                    if top_p != 1.0:
                        assert False, "top_p cannot be set in greedy decoding"
                    next_token = torch.argmax(logits, keepdim=True, dim=-1)
                else:
                    logits = logits / temperature

                if top_p < 1.0:
                    assert top_p > 0, f"0 < top_p <= 1"
                    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0

                    for j in range(sorted_indices.shape[0]):
                        indices_to_remove = sorted_indices[j, sorted_indices_to_remove[j, :]]
                        logits[j, indices_to_remove] = filter_value
            
                    token_weights = logits.exp()
                    next_token = torch.multinomial(token_weights, 1)

                next_token = next_token.long().to(embeddings.device)
                if out is not None:
                    out = torch.cat([out, next_token], dim=1)
                else:
                    out = next_token


                next_embedding = self.input_embeddings(next_token)
                embeddings = torch.cat([embeddings, next_embedding], dim=1)

                if (self.tokenizer.eos_token_id and (next_token == self.tokenizer.eos_token_id).all()):
                    break

        return out, output_embeddings, output_logits

    def get_vis_embs(self, pixel_values, mode):
        pixel_values = pixel_values.to(device)
        img_embs = self.vm(pixel_values)

        if mode == "caption":
            img_embs = self.caption_mapping(img_embs)
            img_embs = self.mapping_dropout(img_embs)
        elif mode == "retrieval":
            img_embs = self.ret_i2t_mapping(img_embs)
            img_embs = self.mapping_dropout(img_embs)

        return img_embs
            
    
    def forward(self, pixel_values, text_inputs, mode):
        assert mode in ["caption", "retrieval"], f'Mode must be in ["caption", "retrieval"], got {mode} instead'
        
        if mode == "retrieval":
            new_text_inputs = []
            for i in range(len(text_inputs)):
                new_text_inputs.append(f'{text_inputs[i]}[RET]')
            text_inputs = tuple(new_text_inputs)
        
        text_inputs = self.tokenizer(text_inputs, return_tensors="pt", padding='max_length', truncation=True, max_length=112).to(device)
        text_lens = text_inputs.attention_mask.sum(dim=1)

        if mode == "retrieval":
            for idx in range(len(text_inputs.input_ids)):
                if text_inputs.input_ids[idx][text_lens[idx]-1] != self.ret_token_idx:
                    text_inputs.input_ids[idx][text_lens[idx]-1] = self.ret_token_idx

        t2i_embs, i2t_embs = None, None

        if mode == "caption":
            img_embs = self.get_vis_embs(pixel_values, mode=mode)
            img_embs = img_embs.unsqueeze(1)

            labels = text_inputs.input_ids
            text_embs = self.input_embeddings(labels)
            additional_mask = torch.ones(img_embs.shape[:2], dtype=torch.int64).to(self.device)
            attention_mask = torch.cat([additional_mask, text_inputs.attention_mask], dim=1)
        
            full_labels = torch.full(img_embs.shape[:2], -100).to(self.device)
            full_labels = torch.cat([full_labels, labels], dim=1)
            
            input_embs = torch.cat([img_embs, text_embs], dim=1)
        
            output = self.lm(inputs_embeds=input_embs, attention_mask=attention_mask, labels=full_labels, output_hidden_states=True)
        
        elif mode == "retrieval":
            i2t_embs = self.get_vis_embs(pixel_values, mode=mode)

            labels = text_inputs.input_ids
            text_embs = self.input_embeddings(labels)
            input_embs = text_embs

            output = self.lm(inputs_embeds=input_embs, attention_mask=text_inputs.attention_mask, labels=labels, output_hidden_states=True)

            t2i_embs = output.hidden_states[-1]
            t2i_embs = t2i_embs[torch.arange(t2i_embs.shape[0]), text_lens-1, :]
            t2i_embs = self.ret_t2i_mapping(t2i_embs)

            i2t_embs = i2t_embs / i2t_embs.norm(dim=1, keepdim=True)
            t2i_embs = t2i_embs / t2i_embs.norm(dim=1, keepdim=True)

            i2t_embs = self.logit_scale.exp() * i2t_embs
        
        return output, t2i_embs, i2t_embs

    def train(self, mode=True):
        super(FromageModel, self).train(mode=mode)
        self.vm.eval()
        self.lm.eval()

In [15]:
class Fromage(nn.Module):
    def __init__(self, device, tokenizer, ret_token_idx, resize=512, center_crop_size=480):
        super().__init__()
        self.device = device
        self.ret_token_idx = ret_token_idx
        self.model = FromageModel(device=device, tokenizer=tokenizer, ret_token_idx=ret_token_idx)
        self.img_transform = create_chest_xray_transform_for_inference(resize=resize, center_crop_size=center_crop_size, train=False)

    def __call__(self, images, tgt_tokens=None, generate=False, max_len=96, temperature=0.0, top_p=1.0, mode="caption", inference=False):
        if generate:
            return self.model.generate(embeddings=images, max_len=max_len, temperature=temperature, top_p=top_p)

        return self.model(pixel_values=images, text_inputs=tgt_tokens, mode=mode)

    def generate_for_images_and_texts(self, prompts: List, max_len=32, top_p=1.0, temperature=0.0):
        input_embs = []
        input_ids = []

        add_bos = True

        for i, p in enumerate(prompts):
            if isinstance(p, Path):
                img = load_image(p)
                pixel_values = self.img_transform(img)
                pixel_values = pixel_values[None, ...]
                vis_emb = self.model.get_vis_embs(pixel_values, mode="caption")
                vis_emb = vis_emb.unsqueeze(0).unsqueeze(0)
                input_embs.append(vis_emb)
            elif type(p) == str:
                tokens = self.model.tokenizer(p, add_special_tokens=True, return_tensors="pt")
                text_ids = tokens.input_ids.to(self.device)
                if not add_bos:
                    text_ids = text_ids[:, 1:]
                else:
                    add_bos = False

                text_embs = self.model.input_embeddings(text_ids)
                input_embs.append(text_embs)
                input_ids.append(text_ids)
            else:
                assert False, "Prompt type can only be Path for images or string for text"

        input_embs = torch.cat(input_embs, dim=1)
        input_ids = torch.cat(input_ids, dim=1)

        generated_ids, generated_embeddings, _ = self.model.generate(input_embs, max_len, temperature=temperature, top_p=top_p)

        return_outputs = self.model.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]        
        return return_outputs

### Question:

**Why do we use is `torch.zeros - 100` for labels of input embeddings?**

### Answer:

One way to handle this is to only train on the tag labels for the first subtoken of a split token. We can do this in 🤗 Transformers by setting the labels we wish to ignore to -100. In the example above, if the label for @HuggingFace is 3 (indexing B-corporation), we would set the labels of `['@', 'hugging', '##face']` to `[3, -100, -100]`.

Source: [https://huggingface.co/transformers/v4.4.2/custom_datasets.html](https://huggingface.co/transformers/v4.4.2/custom_datasets.html)

In [16]:
tokenizer, ret_token_idx = setup_tokenizer(FromageModel.MODEL_CHECKPOINT, device)

In [17]:
model = Fromage(device, tokenizer, ret_token_idx).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
EPOCH = 50
MODES = ("caption", "retrieval")

In [18]:
best_checkpoint = torch.load("best.ckpt")
model.load_state_dict(best_checkpoint["model_state_dict"])

<All keys matched successfully>

In [27]:
ex_idx = random.randint(0, len(dataset.img_paths) - 1)
ex_img_path, ex_report = dataset.img_paths[ex_idx], dataset.reports[ex_idx]
print(ex_report)

INDICATION: History: F with epigastric pain, vomiting FINDINGS: Heart size is mildly enlarged with evidence of prior mitral valve replacement. Clips are seen projecting over the right hilum. Mediastinal and hilar contours are unchanged with mild atherosclerotic calcifications noted diffusely. Lungs are hyperinflated but grossly clear without focal consolidation, pleural effusion or pneumothorax. The osseous structures are diffusely demineralized. Bilateral shoulder prostheses are partially imaged. IMPRESSION: No acute cardiopulmonary abnormality.


In [32]:
with torch.inference_mode():
    model.eval()
    prompts = [ex_img_path, "INDICATION: History: F with epigastric pain, vomiting FINDINGS: Heart size is "] # " ".join(ex_report.split()[:5])
    print(model.generate_for_images_and_texts(prompts))

iliac-pulmonary-ventricular (PVC) and ventricular (VVC) size is iliac-pulmonary-ventricular


In [None]:
logfile = open("log_6_cntd.txt", "w+")

model.train()
best_losses = {mode:0x7FFFFFF for mode in MODES}
for epoch in range(1, EPOCH+1):
    losses = {mode:0.0 for mode in MODES}
    batch_count = 0 
    for step, (pixels, text) in enumerate(train_dataloader):
        step_loss = {mode:None for mode in MODES}
        for mode in MODES:
            optimizer.zero_grad()
        
            output, i2t_embs, t2i_embs = model(pixels, text, mode=mode)

            loss = output.loss
            if mode == "retrieval":
                logits_per_image = i2t_embs @ t2i_embs.t()
                logits_per_text = logits_per_image.t()

                caption_loss = contrastive_loss(logits_per_text)
                image_loss = contrastive_loss(logits_per_image)
                
                loss += (caption_loss + image_loss) / 2.0

            loss.backward()

            for param in model.model.input_embeddings.parameters():
                assert param.grad.shape[0] == len(tokenizer), "Embedding and vocabulary sizes should be equal to each other"
                mask = torch.arange(param.grad.shape[0]) != ret_token_idx
                param.grad[mask,:] = 0.0

            nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            
            batch_count += 1
            losses[mode] += loss.item()
    
        if step % 15 == 0:
            is_best_loss = True
            for mode in MODES:
                total_loss_avg = losses[mode] / batch_count
                if total_loss_avg > best_losses[mode]:
                    is_best_loss = False

            logfile.write(f"Epoch {epoch} Step {step} -- ")
            print(f"Epoch {epoch} Step {step} -- ", end="")
            for m, l in losses.items():
                logfile.write(f"{m} Loss: {(l / batch_count):.3f} ")
                print(f"{m} Loss: {(l / batch_count):.3f}", end=" ")
            print()
            logfile.write("\n")
            best_loss_avg = total_loss_avg
            logfile.write("Current best loss!\n")
            print("Current best loss!")
            logfile.flush()

            if step % 105 == 0:
                if is_best_loss:
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': total_loss_avg
                    }, "best.ckpt")

logfile.close()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 1 Step 0 -- caption Loss: 2.076 retrieval Loss: 3.870 
Current best loss!
Epoch 1 Step 15 -- caption Loss: 2.075 retrieval Loss: 3.880 
Current best loss!
Epoch 1 Step 30 -- caption Loss: 2.067 retrieval Loss: 3.873 
Current best loss!
Epoch 1 Step 45 -- caption Loss: 2.064 retrieval Loss: 3.864 
Current best loss!
Epoch 1 Step 60 -- caption Loss: 2.060 retrieval Loss: 3.854 
Current best loss!
Epoch 1 Step 75 -- caption Los