In [18]:
from transformers import PretrainedConfig, BertConfig
import json

In [16]:
import yaml
cosine_config = yaml.load(open('ALBEF/configs/Cosine-Retrieval.yaml', 'r'), Loader=yaml.Loader)
itm_config = yaml.load(open('ALBEF/configs/ITM.yaml', 'r'), Loader=yaml.Loader)
bert_config = json.load(open('ALBEF/configs/config_bert.json', 'r'))

In [21]:
class CosineConfig(PretrainedConfig): 
    model_type = 'XREM-Cosine'
    default_bert_config = {'architectures': ['BertForMaskedLM'],
                             'attention_probs_dropout_prob': 0.1,
                             'hidden_act': 'gelu',
                             'hidden_dropout_prob': 0.1,
                             'hidden_size': 768,
                             'initializer_range': 0.02,
                             'intermediate_size': 3072,
                             'layer_norm_eps': 1e-12,
                             'max_position_embeddings': 512,
                             'model_type': 'bert',
                             'num_attention_heads': 12,
                             'num_hidden_layers': 12,
                             'pad_token_id': 0,
                             'type_vocab_size': 2,
                             'vocab_size': 30522,
                             'fusion_layer': 6,
                             'encoder_width': 768} 
    default_cosine_config = {
                             'image_res': 256,
                             'batch_size_train': 32,
                             'batch_size_test': 64,
                             'queue_size': 65536,
                             'momentum': 0.995,
                             'vision_width': 768,
                             'embed_dim': 256,
                             'temp': 0.07,
                             'k_test': 128,
                             'alpha': 0.4,
                             'distill': True,
                             'warm_up': True,
                             'optimizer': {'opt': 'adamW', 'lr': '1e-5', 'weight_decay': 0.02},
                             'schedular': {'sched': 'cosine',
                              'lr': '1e-5',
                              'epochs': 10,
                              'min_lr': '1e-6',
                              'decay_rate': 1,
                              'warmup_lr': '1e-5',
                              'warmup_epochs': 1,
                              'cooldown_epochs': 0}}
    def __init__(
        self, 
        bert_config=default_bert_config, 
        cosine_config=default_cosine_config,
        **kwargs
    ): 
        self.bert_config=bert_config
        super().__init__(**kwargs, **cosine_config)
        

In [27]:
class ITMConfig(PretrainedConfig): 
    model_type = 'XREM-ITM'
    default_bert_config = {'architectures': ['BertForMaskedLM'],
                             'attention_probs_dropout_prob': 0.1,
                             'hidden_act': 'gelu',
                             'hidden_dropout_prob': 0.1,
                             'hidden_size': 768,
                             'initializer_range': 0.02,
                             'intermediate_size': 3072,
                             'layer_norm_eps': 1e-12,
                             'max_position_embeddings': 512,
                             'model_type': 'bert',
                             'num_attention_heads': 12,
                             'num_hidden_layers': 12,
                             'pad_token_id': 0,
                             'type_vocab_size': 2,
                             'vocab_size': 30522,
                             'fusion_layer': 6,
                             'encoder_width': 768} 
    default_itm_config = {
                             'image_res': 384,
                             'batch_size_train': 32,
                             'batch_size_test': 64,
                             'alpha': 0.4,
                             'distill': True,
                             'warm_up': False,
                             'optimizer': {'opt': 'adamW', 'lr': '2e-5', 'weight_decay': 0.02},
                             'schedular': {'sched': 'cosine',
                              'lr': '2e-5',
                              'epochs': 5,
                              'min_lr': '1e-6',
                              'decay_rate': 1,
                              'warmup_lr': '1e-5',
                              'warmup_epochs': 1,
                              'cooldown_epochs': 0}}
    def __init__(
        self, 
        bert_config=default_bert_config, 
        itm_config=default_itm_config,
        **kwargs
    ): 
        self.bert_config=bert_config
        super().__init__(**kwargs, **itm_config)

In [1]:
import numpy as np
import torch
from PIL import Image
import h5py
from torch.utils import data
from torchvision import transforms

#Adapted cxr-repair
#input: .h5 file containing the images
class CXRTestDataset_h5(data.Dataset):
    def __init__(self, img_path, input_resolution):
        super().__init__()
        self.img_dset = h5py.File(img_path, 'r')['cxr']
        self.transform = transforms.Compose([
                                            transforms.Resize((input_resolution,input_resolution),interpolation=Image.BICUBIC),
                                            Normalize((101.48761, 101.48761, 101.48761), (83.43944, 83.43944, 83.43944))
                                        ])

        
    def __len__(self):
        return len(self.img_dset)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img = self.img_dset[idx]
        img = np.expand_dims(img, axis=0)
        img = np.repeat(img, 3, axis=0)
        img = torch.from_numpy(img)
        if self.transform:
            img = self.transform(img)
        
        return img

#Adapted cxr-repair
#input: files containing paths to the image files
class CXRTestDataset(data.Dataset):
    def __init__(self, target_files, input_resolution):
        super().__init__()
        self.files = target_files
        self.transform = transforms.Compose([
                                            transforms.Resize((input_resolution,input_resolution),interpolation=Image.BICUBIC),
                                            Normalize((101.48761, 101.48761, 101.48761), (83.43944, 83.43944, 83.43944))
                                        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        fpath = self.files[idx]
        desired_size = 320
        img = Image.open(fpath)
        old_size = img.size
        ratio = float(desired_size)/max(old_size)
        new_size = tuple([int(x*ratio) for x in old_size])
        img = img.resize(new_size, Image.ANTIALIAS)
        new_img = Image.new('L', (desired_size, desired_size))
        new_img.paste(img, ((desired_size-new_size[0])//2,
                            (desired_size-new_size[1])//2))
        img = np.asarray(new_img, np.float64)
        img = np.expand_dims(img, axis=0)
        img = np.repeat(img, 3, axis=0)
        img = torch.from_numpy(img)
        if self.transform:
            img = self.transform(img)
        return img

In [3]:
import argparse
import yaml
import numpy as np
import random
import pandas as pd
from tqdm import tqdm
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils import data

import models
from models.model_itm import ALBEF as ALBEF_itm
from models.model_retrieval import ALBEF as ALBEF_retrieval
from models.vit import interpolate_pos_embed
from models.tokenization_bert import BertTokenizer
from torchvision.transforms import Compose, Normalize, Resize, InterpolationMode
import utils
from PIL import Image
from XREM_dataset import CXRTestDataset, CXRTestDataset_h5



class RETRIEVAL_MODULE:

    def __init__(self, 
                mode, 
                config, 
                checkpoint, 
                topk, 
                input_resolution, 
                delimiter, 
                max_token_len):
                
        self.mode = mode
        assert mode == 'cosine-sim' or mode == 'image-text-matching', 'mode should be cosine-sim or image-text-matching'
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 
        self.config = yaml.load(open(config, 'r'), Loader=yaml.Loader)
        self.input_resolution = input_resolution
        self.topk = topk
        self.max_token_len = max_token_len
        #self.dset = CXRTestDataset_h5(transform=self.transform, img_path=img_path)  
        self.delimiter = delimiter
        self.itm_labels = {'negative':0,  'positive':2}

        if mode == 'cosine-sim':
            self.load_albef_retrieval(checkpoint)
        else:
            self.load_albef_itm(checkpoint)


    #adapted albef codebase
    #For Image-Text Matching, we use ALBEF fine-tuned on visual entailmet to perform binary classification (entail/nonentail) 
    def load_albef_itm(self,checkpoint_path):
        model = ALBEF_itm(config=self.config, 
                         text_encoder='bert-base-uncased', 
                         tokenizer=self.tokenizer
                         ).to(self.device)  
        checkpoint = torch.load(checkpoint_path, map_location='cpu') 
        state_dict = checkpoint['model']
        pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)         
        state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped
        msg = model.load_state_dict(state_dict,strict=False)
        model = model.eval()
        self.model = model

    #adapted albef codebase
    def load_albef_retrieval(self, checkpoint_path):
        model = ALBEF_retrieval(config=self.config, 
                                text_encoder='bert-base-uncased', 
                                tokenizer=self.tokenizer
                                ).to(device = self.device)
        checkpoint = torch.load(checkpoint_path, map_location='cpu') 
        state_dict = checkpoint['model']
        pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)         
        state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped
        m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],model.visual_encoder_m)   
        state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped 
        for key in list(state_dict.keys()):
            if 'bert' in key:
                encoder_key = key.replace('bert.','')         
                state_dict[encoder_key] = state_dict[key] 
                del state_dict[key]                
        msg = model.load_state_dict(state_dict,strict=False)  
        model = model.eval()
        self.model = model

    def __forward__(self, images_dataset, reports):
        if self.mode == 'cosine-sim':
            embeddings = self.generate_embeddings(reports)
            return self.cosine_sim_predict(images_dataset, reports, embeddings)
        else: 
            return self.itm_predict(images_dataset, reports) 

    #adapted cxr-repair codebase
    def generate_embeddings(self, reports, batch_size=2000):
        #adapted albef codebase
        def _embed_text(report):
            with torch.no_grad():
                text_input = self.tokenizer(report, 
                                            padding='max_length', 
                                            truncation=True, 
                                            max_length=self.max_token_len, 
                                            return_tensors="pt").to(self.device) 
                text_output = self.model.text_encoder(text_input.input_ids, 
                                                        attention_mask = text_input.attention_mask, 
                                                        mode='text')  
                text_feat = text_output.last_hidden_state
                text_embed = F.normalize(self.model.text_proj(text_feat[:,0,:]))
                text_embed /= text_embed.norm(dim=-1, keepdim=True)
            return text_embed
        num_batches = reports.shape[0] // batch_size
        tensors = []
        for i in tqdm(range(num_batches + 1)):
            batch = list(reports[batch_size*i:min(batch_size*(i+1), len(self.reports))])
            weights = _embed_text(batch)
            tensors.append(weights)
        embeddings = torch.cat(tensors)
        return embeddings

    #adapted cxr-repair codebase
    def select_reports(self, reports, y_pred):      
        reports_list = []
        for i, simscores in tqdm(enumerate(y_pred)):
            idxes = np.argsort(np.array(simscores))[-1 * self.topk:]
            idxes = np.flip(idxes)
            report = ""
            for j in idxes: 
                if self.mode == 'cosine-sim':
                    cand = reports[j]
                else:
                    cand = reports[i][j]
                report += cand + self.delimiter
            reports_list.append(report)
        return reports_list

    #adapted albef codebase
    def itm_predict(self, images_dataset, reports):
        y_preds = []
        bs = 100
        for i in tqdm(range(len(images_dataset))):
            image = images_dataset[i].to(self.device, dtype = torch.float)
            image = torch.unsqueeze(image, axis = 0)
            image_embeds = self.model.visual_encoder(image)
            image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
            preds = torch.Tensor([]).to(self.device)
            local_reports = reports[i]
            for idx in range(0, len(local_reports), bs):
                try:
                    text = self.tokenizer(local_reports[idx:idx + bs], 
                                          padding='longest', 
                                          return_tensors="pt").to(self.device) 
                    output = self.model.text_encoder(text.input_ids, 
                                attention_mask = text.attention_mask, 
                                encoder_hidden_states = image_embeds,
                                encoder_attention_mask = image_atts,        
                                return_dict = True
                                )    
                    prediction = self.model.cls_head(output.last_hidden_state[:,0,:])
                    positive_score = prediction[:, self.itm_labels['positive']]
                except:
                    positive_score = torch.Tensor([0]).cuda()

                preds = torch.cat([preds, positive_score])
            idxes = torch.squeeze(preds).detach().cpu().numpy()
            y_preds.append(idxes)
            
        df = pd.DataFrame(self.select_reports(reports, y_preds))
        df.columns = [ "Report Impression"]
        return df

    #adapted cxr-repair codebase
    def cosine_sim_predict(self, images_dataset, reports, embeddings): 
        def softmax(x):
            return np.exp(x)/sum(np.exp(x))
        def embed_img(images):
            images = images.to(self.device, dtype = torch.float)
            image_features = self.model.visual_encoder(images)        
            image_features = self.model.vision_proj(image_features[:,0,:])            
            image_features = F.normalize(image_features,dim=-1) 
            return image_features

        y_pred = []
        image_loader = torch.utils.data.DataLoader(images_dataset, shuffle=False)
        with torch.no_grad():
            for image in tqdm(image_loader):
                image_features = embed_img(image)
                logits = image_features @ embeddings.T
                logits = np.squeeze(logits.to('cpu').numpy(), axis=0).astype('float64')
                norm_logits = (logits - logits.mean()) / (logits.std())
                probs = softmax(norm_logits)
                y_pred.append(probs)
                
        y_pred = np.array(np.array(y_pred))
        df = pd.DataFrame(self.select_reports(reports, y_pred))
        df.columns = ["Report Impression"]
        return df

        


In [4]:
from transformers import PreTrainedModel
class XREM (PreTrainedModel):
#     config_class = XREMConfig
    def __init__(self, config): 
        super().__init__(config)
        
#         df = pd.read_csv(args.impressions_path)
#         impressions = df["report"].drop_duplicates().dropna().reset_index(drop = True)
        self.cosine_sim_module = RETRIEVAL_MODULE(mode='cosine-sim', 
                                             config=config.albef_retrieval_config, 
                                             checkpoint=config.albef_retrieval_ckpt, 
                                             topk=config.albef_retrieval_top_k,
                                             input_resolution=256, 
                                             delimiter=config.albef_retrieval_delimiter, 
                                             max_token_len = 25)
        

        self.itm_module = RETRIEVAL_MODULE(impressions=new_impressions, 
                                                mode='image-text-matching', 
                                                config=config.albef_itm_config, 
                                                checkpoint=config.albef_itm_ckpt,
                                                topk=config.albef_itm_top_k,
                                                input_resolution=384, 
                                                img_path=config.img_path, 
                                                delimiter=config.albef_itm_delimiter, 
                                                max_token_len=30)
        
        
    def __forward__(self, image_dataset, reports): 
        output = cosine_sim_module(images_dataset, reports)
        if self.albef_itm_top_k > 0: 
            reports = [report.split(self.albef_retrieval_delimiter) for report in output['Report Impression']]
            output = itm_module.predict()