In [None]:
from torch import nn
import torch.nn.functional as F
import torch
import pdb
import numpy as np

class ImageTextContrastiveLoss(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self,
        input_ids=None,
        pixel_values=None,
        attention_mask=None,
        img_labels=None,
        text_labels=None,
        aug_input_ids=None,
        aug_attention_mask=None,
        **kwargs,
        ):
        '''args:
        labels: the image corresponds to which classes of diagnoses
        text_labels: the text corresponds to which classes of diagnoses
        '''
        if img_labels is None or text_labels is None:
            '''use hard clip loss as the original clip
            '''
            outputs = self.model(
                    input_ids=input_ids,
                    pixel_values=pixel_values,
                    attention_mask=attention_mask,
                    return_loss=True,
                    )
        else:
            '''use soft clip loss
            '''
            outputs = self.model(
                    input_ids=input_ids,
                    pixel_values=pixel_values,
                    attention_mask=attention_mask,
                    return_loss=False,
                    )

            # get logits
            logits = outputs['logits']

            # compute soft-labels, -1: negative, 0: uncertain, 1: positive
            # in the original data: 1: positive, 0: negative, -1: uncertain, NA: not mentioned
            label_sim = torch.matmul(img_labels, text_labels.T)
            label_sim = label_sim.to(logits.device)

            if aug_input_ids is not None:
                aug_text_embeds = self.model.encode_text(aug_input_ids, aug_attention_mask)
                img_embeds = outputs['img_embeds']
                logits_aug = self.model.compute_logits(img_embeds, aug_text_embeds)
                aug_loss_value = self._soft_clip_loss(logits_aug, label_sim)
                loss_value = self._soft_clip_loss(logits, label_sim)
                outputs['loss_value'] = (aug_loss_value + loss_value) / 2
            else:
                outputs['loss_value'] = self._soft_clip_loss(logits, label_sim)

        return_res = {
            'loss_value': outputs['loss_value'],
        }
        return return_res

    def _soft_clip_loss(self, logits_per_img, soft_label):
        '''take labels of images and sentences as a softlabel
        e.g., image_label = [1, 0, 1, -1], sentence_label = [0, 0, 1, -1]
        this pair has similarity as: 1 * 0 + 0 * 0 + 1 * 1 + -1 * -1 = 2.
        We will clamp the similarity into [-1,1], and take softmax as a soft-label.
        '''
        # when using InfoNCE-like loss
        image_loss = self._soft_xent_loss(logits_per_img, F.softmax(soft_label,1))
        caption_loss = self._soft_xent_loss(logits_per_img.T, F.softmax(soft_label.T,1))
        return (image_loss + caption_loss) / 2

        # when using multilabel bce loss
        # image_loss = self._soft_bce_loss(logits_per_img, soft_label)
        # return image_loss

    def _soft_xent_loss(self, input, target):
        logprobs = torch.nn.functional.log_softmax(input, dim = 1)
        return  -(target * logprobs).sum() / input.shape[0]

    def _soft_bce_loss(self, input, target):
        return nn.functional.binary_cross_entropy_with_logits(input, target)

In [None]:
class MedCLIPModel(nn.Module):
    def __init__(self,
        vision_cls=MedCLIPVisionModel,
        checkpoint=None,
        vision_checkpoint=None,
        logit_scale_init_value=0.07,
        ) -> None:
        super().__init__()
        text_proj_bias = False
        assert vision_cls in [MedCLIPVisionModel, MedCLIPVisionModelViT], 'vision_cls should be one of [MedCLIPVisionModel, MedCLIPVisionModelViT]'

        self.vision_model = vision_cls(checkpoint=vision_checkpoint)
        self.text_model = MedCLIPTextModel(proj_bias=False)

        # learnable temperature for contrastive loss
        self.logit_scale = nn.Parameter(torch.log(torch.tensor(1/logit_scale_init_value)))

        if checkpoint is not None:
            state_dict = torch.load(os.path.join(checkpoint, constants.WEIGHTS_NAME))
            self.load_state_dict(state_dict)
            print('load model weight from:', checkpoint)

    def from_pretrained(self, input_dir=None):
        '''
        If input_dir is None, download pretrained weight from google cloud and load.
        '''
        import wget
        import zipfile
        pretrained_url = None
        if isinstance(self.vision_model, MedCLIPVisionModel):
            # resnet
            pretrained_url = constants.PRETRAINED_URL_MEDCLIP_RESNET
            if input_dir is None:
                input_dir = './pretrained/medclip-resnet'
        elif isinstance(self.vision_model, MedCLIPVisionModelViT):
            # ViT
            pretrained_url = constants.PRETRAINED_URL_MEDCLIP_VIT
            if input_dir is None:
                input_dir = './pretrained/medclip-vit'
        else:
            raise ValueError(f'We only have pretrained weight for MedCLIP-ViT or MedCLIP-ResNet, get {type(self.vision_model)} instead.')

        if not os.path.exists(input_dir):
            os.makedirs(input_dir)

            # download url link
            pretrained_url = requests.get(pretrained_url).text
            filename = wget.download(pretrained_url, input_dir)

            # unzip
            zipf = zipfile.ZipFile(filename)
            zipf.extractall(input_dir)
            zipf.close()
            print('\n Download pretrained model from:', pretrained_url)
        
        state_dict = torch.load(os.path.join(input_dir, constants.WEIGHTS_NAME))
        self.load_state_dict(state_dict)
        print('load model weight from:', input_dir)

    def encode_text(self, input_ids=None, attention_mask=None):
        input_ids = input_ids.cuda()
        if attention_mask is not None:
            attention_mask = attention_mask.cuda()
        text_embeds = self.text_model(input_ids, attention_mask)
        text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
        return text_embeds

    def encode_image(self, pixel_values=None):
        # image encoder
        vision_output = self.vision_model(pixel_values=pixel_values)
        img_embeds = vision_output / vision_output.norm(dim=-1, keepdim=True)
        return img_embeds
    def forward(self,
            input_ids=None,
            pixel_values=None,
            attention_mask=None,
            return_loss=None,
            **kwargs,
            ):
            input_ids = input_ids.cuda()
            if attention_mask is not None:
                attention_mask = attention_mask.cuda()
            pixel_values = pixel_values.cuda()
    
            img_embeds = self.encode_image(pixel_values)
            text_embeds = self.encode_text(input_ids, attention_mask)
    
            logits_per_image = self.compute_logits(img_embeds, text_embeds)
            logits_per_text = logits_per_image.t()
    
            if return_loss:
                loss = self.clip_loss(logits_per_text)
            else:
                loss = None
    
            return {'img_embeds':img_embeds, 'text_embeds':text_embeds,
                'logits':logits_per_image, 'loss_value':loss, 'logits_per_text':logits_per_text}

    def compute_logits(self, img_emb, text_emb):
        self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, 4.6052)
        logit_scale = self.logit_scale.exp()
        logits_per_text = torch.matmul(text_emb, img_emb.t()) * logit_scale
        return logits_per_text.t()

    def clip_loss(self, similarity: torch.Tensor) -> torch.Tensor:
        caption_loss = self.contrastive_loss(similarity)
        image_loss = self.contrastive_loss(similarity.T)
        return (caption_loss + image_loss) / 2.0

    def contrastive_loss(self, logits: torch.Tensor) -> torch.Tensor:
        return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))

In [None]:
# Sample phenotyping label 
# stay is 14851532_episode3_timeseries.csv
# stay_id is 30681041
# image name is 721e19bf-893cd83c-ea610180-ee56a931-b0b7c146
# study_id is 59956491
# phenotyping classes
# pheno classes: Acute and unspecified renal failure,Acute cerebrovascular disease,Acute myocardial infarction,Cardiac dysrhythmias,Chronic kidney disease,Chronic obstructive pulmonary disease and bronchiectasis,Complications of surgical procedures or medical care,Conduction disorders,Congestive heart failure; nonhypertensive,Coronary atherosclerosis and other heart disease,Diabetes mellitus with complications,Diabetes mellitus without complication,Disorders of lipid metabolism,Essential hypertension,Fluid and electrolyte disorders,Gastrointestinal hemorrhage,Hypertension with complications and secondary hypertension,Other liver diseases,Other lower respiratory disease,Other upper respiratory disease,Pleurisy; pneumothorax; pulmonary collapse,Pneumonia (except that caused by tuberculosis or sexually transmitted disease),Respiratory failure; insufficiency; arrest (adult),Septicemia (except in labor),Shock
# image classes: Atelectasis,Cardiomegaly,Consolidation,Edema,Enlarged Cardiomediastinum,Fracture,Lung Lesion,Lung Opacity,No Finding,Pleural Effusion,Pleural Other,Pneumonia,Pneumothorax,Support Devices
pheno_label = [0,0,0,0,0,1,1,1,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,1,0] # 25 classes
img_label = [1.0,1.0,0,0,0,0,0,1.0,0,-1.0,0,0,0,1.0] # 14 classes 

# To get the soft targets 
def _soft_clip_loss(self, logits_per_img, soft_label):
    '''take labels of images and sentences as a softlabel
    e.g., image_label = [1, 0, 1, -1], sentence_label = [0, 0, 1, -1]
    this pair has similarity as: 1 * 0 + 0 * 0 + 1 * 1 + -1 * -1 = 2.
    We will clamp the similarity into [-1,1], and take softmax as a soft-label.
    '''
    # when using InfoNCE-like loss
    image_loss = self._soft_xent_loss(logits_per_img, F.softmax(soft_label,1))
    caption_loss = self._soft_xent_loss(logits_per_img.T, F.softmax(soft_label.T,1))
    return (image_loss + caption_loss) / 2

    # when using multilabel bce loss
    # image_loss = self._soft_bce_loss(logits_per_img, soft_label)
    # return image_loss

def _soft_xent_loss(self, input, target):
    logprobs = torch.nn.functional.log_softmax(input, dim = 1)
    return  -(target * logprobs).sum() / input.shape[0]



