In [1]:
import os
import pandas as pd

from collections import defaultdict
from torch.utils.data import DataLoader
from sklearn.neighbors import KNeighborsClassifier
from configs import configs
from dataset import ChestXRayCaptionDataset
from nltk.translate.bleu_score import corpus_bleu
import torch
import numpy as np
from model import Chexnet
from tqdm import tqdm
from utils import train_transform, evaluate_transform, quantize_probs
from tokenizer import create_tokenizer
from test import evaluation_matrix
from chexpert import chexpert
import time
import faiss
from nlgeval import NLGEval

2091lines [00:00, 174255.71lines/s]


In [2]:
metrics_to_omit = [
    'METEOR', 
    'SkipThoughtCS', 
    'EmbeddingAverageCosineSimilarity', 
    'VectorExtremaCosineSimilarity', 
    'GreedyMatchingScore', 
    'EmbeddingAverageCosineSimilairty',
]

nlgeval = NLGEval(metrics_to_omit=metrics_to_omit)  # loads the models

LOAD_RANDOM_PROJECTION_DATA = False
BUILD_CACHED_MAPS = False
USE_CACHED_MAPS = True
SAVE_PROJECTED = True
time_file_path = 'results/inference_time.csv'
nlg_file_path = 'results/results_nlg.csv'


def filename(k, seed, project_dim):
    d = {
        'train_x': configs['mimic_dir'] + 'baseline_data/' +  f'train_image_embeddings_{project_dim}_{seed}.npy',
        'train_y': configs['mimic_dir'] + 'baseline_data/' +  f'train_captions.npy',
        'val_x': configs['mimic_dir'] +'baseline_data/' +  f'val_image_embeddings_{project_dim}_{seed}.npy',
        'val_y': configs['mimic_dir'] +'baseline_data/' +  f'val_captions.npy',
        'test_x': configs['mimic_dir'] +'baseline_data/' +  f'test_image_embeddings_{project_dim}_{seed}.npy',
        'test_y': configs['mimic_dir'] +'baseline_data/' +  f'test_captions.npy',
    }

    return d[k]


In [3]:
tokenizer = create_tokenizer()
checkpoint = torch.load('weights/pretrained_encoder/pretrained_enc_epoch_5_2022-03-08_15-43-47.540586.pth.tar')
print(f"loaded epoch {checkpoint['epoch']+1} model, val_loss: {checkpoint['val_loss']}")
encoder = checkpoint['encoder'].cuda()

train_probs_quantized = np.load(configs['mimic_dir'] + 'baseline_data/train_probs_quantized.npy')
val_probs_quantized = np.load(configs['mimic_dir'] + 'baseline_data/val_probs_quantized.npy')
test_probs_quantized = np.load(configs['mimic_dir'] + 'baseline_data/test_probs_quantized.npy')


train_loader = DataLoader(
    ChestXRayCaptionDataset('train', transform=train_transform),
    batch_size=16,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

val_loader = DataLoader(
    ChestXRayCaptionDataset('val', transform=evaluate_transform),
    batch_size=16,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
)

test_loader = DataLoader(
    ChestXRayCaptionDataset('test', transform=evaluate_transform),
    batch_size=16,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
)

len_train_loader = 16740
len_val_loader = 131
len_test_loader = 229

2091lines [00:00, 160863.71lines/s]


loaded epoch 5 model, val_loss: 0.28202417492866516


In [4]:
def generate_image_embeddings_random(encoder, data_loader, projection_matrix, project_every=2):
    # With random projection
    encoder.eval()
    image_embeddings = []
    captions = []
    batch = []
    with torch.no_grad():
        for i, (img, caption, _) in enumerate(tqdm(data_loader)):
            img = img.cuda()
            encoded_img, _ = encoder(img)
            batch.append(encoded_img.cpu())
            captions.append(caption.cpu())
            if ((i+1) % project_every) == 0 or (i+1) == len(data_loader):
                batch = torch.cat(batch).reshape(-1, 1024*8*8).numpy()
                batch = np.matmul(batch, projection_matrix)
                image_embeddings.append(batch)
                batch = []

    image_embeddings = np.vstack(image_embeddings)
    captions = torch.cat(captions).numpy()
    return image_embeddings, captions

def generate_image_embeddings_save_every(encoder, data_split, data_loader, save_every=1024):
    encoder.eval()
    image_embeddings = []
    captions = []
    file_index = 0

    with torch.no_grad():
        for i, (img, caption, _) in enumerate(tqdm(data_loader)):
            img = img.cuda()
            encoded_img, _ = encoder(img)
            image_embeddings.append(encoded_img.cpu())
            captions.append(caption.cpu())

            if ((i+1) % save_every) == 0 or (i+1) == len(data_loader):
                # stack
                image_embeddings = torch.cat(image_embeddings).reshape(-1, 1024*8*8).numpy()
                captions = torch.cat(captions).numpy()

                # save
                np.save(configs['mimic_dir'] + f'raw_embeddings/{data_split}/feature_maps_{file_index}.npy', image_embeddings)
                np.save(configs['mimic_dir'] + f'raw_embeddings/{data_split}/captions_{file_index}.npy', captions)

                # clear and update
                image_embeddings = []
                captions = []
                file_index += 1

def generate_probs_save_every(encoder, data_split, data_loader, save_every=1024):
    encoder.eval()
    probs = []
    file_index = 0

    with torch.no_grad():
        for i, (img, _, _) in enumerate(tqdm(data_loader)):
            img = img.cuda()
            _, prob = encoder(img)
            probs.append(prob.cpu())

            if ((i+1) % save_every) == 0 or (i+1) == len(data_loader):
                # stack
                probs = torch.cat(probs).numpy()

                # save
                np.save(configs['mimic_dir'] + f'raw_embeddings/{data_split}/probs_{file_index}.npy', probs)

                # clear and update
                probs = []
                file_index += 1

def random_project(len_data_loader, data_split, projection_matrix, save_every=1024, project_every=256):
    n_split = len_data_loader // save_every + 1
    print(f"{n_split=}")
    projected_image_embeddings = []
    captions = []
    
    for file_index in tqdm(range(n_split)):
        feat_maps = np.load(configs['mimic_dir'] + f'raw_embeddings/{data_split}/feature_maps_{file_index}.npy')
        caps = np.load(configs['mimic_dir'] + f'raw_embeddings/{data_split}/captions_{file_index}.npy')
        captions.append(caps)

        # project
        feat_maps = np.array_split(feat_maps, project_every)
        for batch in feat_maps:
            proj = np.matmul(batch, projection_matrix)
            projected_image_embeddings.append(proj)

    projected_image_embeddings = np.vstack(projected_image_embeddings)
    captions = np.vstack(captions)
    
    return projected_image_embeddings, captions

In [5]:
# generate_probs_save_every(encoder, 'train', train_loader, save_every=1024)
# generate_probs_save_every(encoder, 'val', val_loader, save_every=1024)
# generate_probs_save_every(encoder, 'test', test_loader, save_every=1024)

In [6]:
def get_vectors(SEED, RANDOM_PROJECT_DIM, load=False):
    """
    End to end train, val, test projected vectors function
    Input:
        SEED (int): Seed of the random projection matrix
        RANDOM_PROJECT_DIM: Dimension of the random projection matrix
    Output:
        train_x, train_y, val_x, val_y, test_x, test_y
    """

    """
    Load Cached Vectors
    """
    if load:
        
        file_suffix = f"_{RANDOM_PROJECT_DIM}_{SEED}.npy"
        file_prefix = configs['mimic_dir'] + 'baseline_data/'
        file_exists = os.path.exists(file_prefix + 'train_image_embeddings' + file_suffix)

        if file_exists:
            train_image_embeddings = np.load(file_prefix + 'train_image_embeddings' + file_suffix)
            train_captions = np.load(file_prefix + 'train_captions.npy')
            val_image_embeddings = np.load(file_prefix + 'val_image_embeddings' + file_suffix)
            val_captions = np.load(file_prefix + 'val_captions.npy')
            test_image_embeddings = np.load(file_prefix + 'test_image_embeddings' + file_suffix)
            test_captions = np.load(file_prefix + 'test_captions.npy')
            return train_image_embeddings, train_captions, val_image_embeddings, val_captions, test_image_embeddings, test_captions
        else:
            print(f"Attempting to load file with seed={SEED} and project_dim={RANDOM_PROJECT_DIM} but doesn't exist")
    
    print(f"Random projecting with {SEED=}, {RANDOM_PROJECT_DIM=}")
    # Create a whole new projection
    rng = np.random.RandomState(SEED)
    # Gaussian random projection
    projection_matrix = rng.normal(0.0, 1/RANDOM_PROJECT_DIM, (65536, RANDOM_PROJECT_DIM))

    """
    Train vectors
    """

    print("Projecting train vectors...")

    if USE_CACHED_MAPS:
        # New method: Predict first, cache them, then project
        if BUILD_CACHED_MAPS:
            generate_image_embeddings_save_every(encoder, 'train', train_loader, save_every=1024)

        train_image_embeddings, train_captions = random_project(len_train_loader, 'train', projection_matrix, save_every=1024, project_every=64)
        print(train_image_embeddings.shape)
        print(train_captions.shape)
        if SAVE_PROJECTED:
            np.save(filename('train_x', seed=SEED, project_dim=RANDOM_PROJECT_DIM), train_image_embeddings)
            np.save(filename('train_y', seed=SEED, project_dim=RANDOM_PROJECT_DIM), train_captions)
    else:
        # Old method: Project as we predict
        if LOAD_RANDOM_PROJECTION_DATA:
            # Use cached projection
            # train_image_embeddings = np.load(filename['train_x'])
            # train_captions = np.load(filename['train_y'])
            print(train_image_embeddings.shape)
            print(train_captions.shape)
        else:
            project_every = 256
            train_image_embeddings, train_captions = generate_image_embeddings_random(encoder, train_loader, projection_matrix, project_every=project_every)
            print(train_image_embeddings.shape)
            print(train_captions.shape)
            # np.save(filename['train_x'], train_image_embeddings)
            # np.save(filename['train_y'], train_captions)
    
    """
    Val & Test vectors
    """

    print("Projecting val & test vectors...")

    if USE_CACHED_MAPS:
        # New method: Predict first, cache them, then project
        if BUILD_CACHED_MAPS:
            generate_image_embeddings_save_every(encoder, 'val', val_loader, save_every=1024)
            generate_image_embeddings_save_every(encoder, 'test', test_loader, save_every=1024)

        val_image_embeddings, val_captions = random_project(len_val_loader, 'val', projection_matrix, save_every=1024, project_every=64)
        print(val_image_embeddings.shape)
        print(val_captions.shape)
        if SAVE_PROJECTED:
            np.save(filename('val_x', seed=SEED, project_dim=RANDOM_PROJECT_DIM), val_image_embeddings)
            np.save(filename('val_y', seed=SEED, project_dim=RANDOM_PROJECT_DIM), val_captions)
        test_image_embeddings, test_captions = random_project(len_test_loader, 'test', projection_matrix, save_every=1024, project_every=64)
        print(test_image_embeddings.shape)
        print(test_captions.shape)
        if SAVE_PROJECTED:
            np.save(filename('test_x', seed=SEED, project_dim=RANDOM_PROJECT_DIM), test_image_embeddings)
            np.save(filename('test_y', seed=SEED, project_dim=RANDOM_PROJECT_DIM), test_captions)

    else:
        if LOAD_RANDOM_PROJECTION_DATA:
            # val_image_embeddings = np.load(filename['val_x'])
            # val_captions = np.load(filename['val_y'])
            # test_image_embeddings = np.load(filename['test_x'])
            # test_captions = np.load(filename['test_y'])
            print(val_image_embeddings.shape)
            print(val_captions.shape)
            print(test_image_embeddings.shape)
            print(test_captions.shape)
        else:
            project_every = 256
            val_image_embeddings, val_captions = generate_image_embeddings_random(encoder, val_loader, projection_matrix, project_every=project_every)
            print(val_image_embeddings.shape)
            print(val_captions.shape)
            test_image_embeddings, test_captions = generate_image_embeddings_random(encoder, test_loader, projection_matrix, project_every=project_every)
            print(test_image_embeddings.shape)
            print(test_captions.shape)
            # np.save(filename['val_x'], val_image_embeddings)
            # np.save(filename['val_y'], val_captions)
            # np.save(filename['test_x'], test_image_embeddings)
            # np.save(filename['test_y'], test_captions)
    
    return train_image_embeddings, train_captions, val_image_embeddings, val_captions, test_image_embeddings, test_captions

In [7]:
class SimilaritySearch:
    def fit(self, xb, yb):
        pass

    def predict(self, xq):
        pass

class OneNearestNeighbor(SimilaritySearch):
    def __init__(self):
        self.yb = None
        self.knn = KNeighborsClassifier(n_neighbors=1)

    def fit(self, xb, yb):
        indices = [*range(xb.shape[0])]
        self.knn.fit(xb.astype(np.float32), indices)
        self.yb = yb.astype(np.float32)

    def predict(self, xq):
        dists, indices = self.knn.kneighbors(xq.astype(np.float32))
        yq = np.array([self.yb[i] for i in indices])
        yq = yq.reshape(xq.shape[0], self.yb.shape[1])
        return yq

class FaissFlatIndexL2CPU(SimilaritySearch):
    def __init__(self):
        self.yb = None
        self.index = None

    def fit(self, xb, yb):
        dim = xb.shape[1]
        self.index = faiss.IndexFlatL2(dim)
        self.index.add(xb.astype(np.float32))
        self.yb = yb.astype(np.float32)

    def predict(self, xq):
        dists, indices = self.index.search(xq.astype(np.float32), 1)
        yq = np.array([self.yb[i] for i in indices])
        yq = yq.reshape(xq.shape[0], self.yb.shape[1])
        return yq

class FaissFlatIndexL2GPU(SimilaritySearch):
    def __init__(self):
        self.res = faiss.StandardGpuResources()
        self.yb = None
        self.gpu_index = None

    def fit(self, xb, yb):
        dim = xb.shape[1]
        self.gpu_index = faiss.index_cpu_to_gpu(self.res, 0, faiss.IndexFlatL2(dim))
        self.gpu_index.add(xb.astype(np.float32))
        self.yb = yb.astype(np.float32)

    def predict(self, xq):
        dists, indices = self.gpu_index.search(xq.astype(np.float32), 1)
        yq = np.array([self.yb[i] for i in indices])
        yq = yq.reshape(xq.shape[0], self.yb.shape[1])
        return yq

class FaissHNSW32(SimilaritySearch):
    def __init__(self):
        self.yb = None
        self.index = None

    def fit(self, xb, yb):
        dim = xb.shape[1]
        self.index = faiss.IndexHNSWFlat(dim, 32)
        self.index.add(xb.astype(np.float32))
        self.yb = yb.astype(np.float32)

    def predict(self, xq):
        dists, indices = self.index.search(xq.astype(np.float32), 1)
        yq = np.array([self.yb[i] for i in indices])
        yq = yq.reshape(xq.shape[0], self.yb.shape[1])
        return yq

class FaissLSH32(SimilaritySearch):
    def __init__(self):
        self.yb = None
        self.index = None

    def fit(self, xb, yb):
        dim = xb.shape[1]
        self.index = faiss.IndexLSH(dim, 32)
        self.index.add(xb.astype(np.float32))
        self.yb = yb.astype(np.float32)

    def predict(self, xq):
        dists, indices = self.index.search(xq.astype(np.float32), 1)
        yq = np.array([self.yb[i] for i in indices])
        yq = yq.reshape(xq.shape[0], self.yb.shape[1])
        return yq

In [8]:
class SimilaritySearchCoarse2Fine(SimilaritySearch):
    def assign_labels(self, train_labels):
        self.train_labels = train_labels

    def assign_encoder(self, encoder):
        self.encoder = encoder

class OneNearestNeighborCoarse2Fine(SimilaritySearchCoarse2Fine):
    def __init__(self):
        # Dict label -> feature maps
        self.map = defaultdict(lambda: [])
        self.reports = defaultdict(lambda: [])
        self.knns = dict()
        self.assign_labels(train_probs_quantized)
    
    def fit(self, xb, yb):
        binary_strings = [''.join(label.astype(str)) for label in self.train_labels]
        # for label, feat_maps in tqdm(zip(binary_strings, xb), total=len(binary_strings)):
        for i in range(len(binary_strings)):
            label = binary_strings[i]
            feat_maps = xb[i]
            report = yb[i]
            self.map[label].append(feat_maps)
            self.reports[label].append(report)

        for label, feat_maps_list in self.map.items():
            indices = [*range(len(feat_maps_list))]
            self.knns[label] = KNeighborsClassifier(n_neighbors=1)
            self.knns[label].fit(np.array(feat_maps_list).astype(np.float32), indices)

    def predict(self, xq, x_image):
        _, probs = encoder(x_image)
        labels = quantize_probs(probs.detach().cpu().numpy())
        labels = [''.join(label.astype(str)) for label in labels]
        results = []
        no_label_count = 0
        for label, feat_map in zip(labels, xq):
            # Might not found exact label in training
            # Might need to search in near edit distance
            if label in self.knns.keys():
                dists, index = self.knns[label].kneighbors(feat_map.astype(np.float32).reshape(1, -1))
                yq = np.array(self.reports[label][index[0][0]]).reshape(-1)
            else:
                # Handle no label
                similar_labels_list = self.get_similar_binaries(label)
                similar_feature_maps = []
                similar_reports = []
                for sim_label in similar_labels_list:
                    if sim_label in self.map.keys():
                        # Need to make sure the sim label exists
                        similar_feature_maps.extend(self.map[sim_label])
                        similar_reports.extend(self.reports[sim_label])
                
                # Create temp knn
                temp_knn = KNeighborsClassifier(n_neighbors=1)
                indices = [*range(len(similar_feature_maps))]
                temp_knn.fit(similar_feature_maps, indices)
                dists, index = temp_knn.kneighbors(feat_map.astype(np.float32).reshape(1, -1))
                yq = np.array(similar_reports[index[0][0]]).reshape(-1)
                no_label_count += 1 # For monitoring purpose

            results.append(yq)

        # if no_label_count > 0:
        #     print(f"No label count: {no_label_count}")
        return np.array(results)

    def get_similar_binaries(self, binary_str: str):
        """
        Returns a list of similar binary string with edit distance = 1
        """
        res = []
        for i in range(len(binary_str)):
            res.append(binary_str[:i] + str(int(not(bool(int(binary_str[i]))))) + binary_str[i+1:])
        return res

In [9]:
def predict(model, embeddings, batch_size=64, image_loader=None):
    """
    Predict captions given a similarity search model and image embeddings.

    If the model is Coarse2Fine, an image_loader must be given for encoder
    For predicting diseases in coarse searching step.

    Parameters:
        model (SimilaritySearch):                   Similarity search model to evaluate
        embeddings (numpy.array):                   Image embeddings of size (sample_size, project_dim)
        decode (bool):                              Whether to decode the report or not
        batch_size (int):                           How many samples to do sim search per iteration
        image_loader (torch.utils.data.DataLoader): DataLoader that can load images for encoder

    Returns:
        Predicted captions (numpy.array)
        Total time to predict (float)
    """

    captions = []
    total_time = 0.0

    # Check if model is Coarse2Fine
    # image_loader must be given if so
    assert isinstance(model, SimilaritySearchCoarse2Fine) == (image_loader != None), (
        f"isinstance={isinstance(model, SimilaritySearchCoarse2Fine)} but image_loader != None is {image_loader != None}"
    )

    if isinstance(model, SimilaritySearchCoarse2Fine) and image_loader != None:
        print(f"forced batch size: {image_loader.batch_size}")
        data_loader = DataLoader(
            embeddings,
            batch_size=image_loader.batch_size,
            num_workers=0,
            pin_memory=True,
        )

        assert len(data_loader) == len(image_loader), (
            f"{len(data_loader)=}, {len(image_loader)=}"
        )

        for j, (batch, (image, _, _)) in enumerate(tqdm(zip(data_loader, image_loader), total=len(data_loader))):
            start = time.time()
            yq = model.predict(batch.numpy(), image.cuda())
            total_time += time.time() - start
            captions.extend(yq)
    else:
        data_loader = DataLoader(
            embeddings,
            batch_size=batch_size,
            num_workers=0,
            pin_memory=True,
        )
        for j, batch in enumerate(tqdm(data_loader)):
            start = time.time()
            yq = model.predict(batch.numpy())
            total_time += time.time() - start
            captions.extend(yq)

    captions = np.array(captions).reshape(embeddings.shape[0], -1)

    return captions, total_time

def evaluate_clinical(true_captions, pred_captions, batch_size):
    """
    Evaluate clinical accuracy of predicted reports vs ground truth
    Using VisualCheXbert to compare between ground truth and predicted reports

    Parameters:
        true_captions (Iterable): Ground truth
        pred_captions (Iterable): Predicted reports
        batch_size (int): Size of the batch for VisualCheXbert to predict each iteration

    Returns:
        Evaluation matrix (Precision, recall, F1) of each disease and micro/macro avg (pandas.DataFrame)
    """

    true_df = []
    pred_df = []
    print(true_captions.shape[0], pred_captions.shape[0])
    true_loader = DataLoader(
        true_captions, 
        batch_size=batch_size,
        num_workers=0,
        pin_memory=True
    )
    pred_loader = DataLoader(
        pred_captions, 
        batch_size=batch_size,
        num_workers=0,
        pin_memory=True
    )

    for t in tqdm(true_loader):
        labels = chexpert(t, tokenizer)
        true_df.append(labels)

    for p in tqdm(pred_loader):
        labels = chexpert(p, tokenizer)
        pred_df.append(labels)
    
    true_df = pd.concat(true_df).reset_index(drop=True)
    pred_df = pd.concat(pred_df).reset_index(drop=True)
    return evaluation_matrix(true_df, pred_df)

# def calculate_bleu_scores(true_captions, pred_captions):
#     """
#     Calculates BLEU 1-4 scores based on NLTK functionality

#     Parameters:
#         true_captions: List of reference sentences
#         pred_captions: List of generated sentences

#     Returns:
#         bleu_1, bleu_2, bleu_3, bleu_4: BLEU scores

#     """
#     # Put each sentence in references in a list
#     # Because nltk accepts list of possible references for each sample
#     true_captions = [[e.split()] for e in true_captions]
#     pred_captions = [e.split() for e in pred_captions]

#     bleu_1 = np.round(corpus_bleu(true_captions, pred_captions, weights=(1.0, 0., 0., 0.)), decimals=4)
#     bleu_2 = np.round(corpus_bleu(true_captions, pred_captions, weights=(0.50, 0.50, 0., 0.)), decimals=4)
#     bleu_3 = np.round(corpus_bleu(true_captions, pred_captions, weights=(0.33, 0.33, 0.33, 0.)), decimals=4)
#     bleu_4 = np.round(corpus_bleu(true_captions, pred_captions, weights=(0.25, 0.25, 0.25, 0.25)), decimals=4)
#     return bleu_1, bleu_2, bleu_3, bleu_4 

def calculate_nlg_metrics(true_captions, pred_captions):
    """
    Calculate BLEU 1-4, ROGUE_L, and CIDEr score using nlg-eval library
    Parameters:
        true_captions: List of reference sentences
        pred_captions: List of generated sentences

    Returns:
        metrics_dict (dictionary): Dictionary containing all metrics above
    """

    # nlg-eval requires user to do this to references
    true_captions = [true_captions]
    metrics_dict = nlgeval.compute_metrics(true_captions, pred_captions)
    return metrics_dict

def write_time_file(model, project_dim, seed, val_time, test_time):
    with open(time_file_path, 'a') as f:
        model_type = type(model).__name__
        f.write(f"{model_type},{pd.Timestamp.now()},{project_dim},{seed},{val_time},{test_time}\n")

def write_nlg_file(model, project_dim, seed, split, metrics_dict):
    with open(nlg_file_path, 'a') as f:
        model_type = type(model).__name__
        ls = [
            model_type, pd.Timestamp.now(), project_dim, seed, split,
            metrics_dict['Bleu_1'], metrics_dict['Bleu_2'], metrics_dict['Bleu_3'],
            metrics_dict['Bleu_4'],metrics_dict['ROUGE_L'], metrics_dict['CIDEr'],
        ]
        # f.write((f"{model_type},{pd.Timestamp.now()},{project_dim},{seed},{split},"
        # f"{metrics_dict['Bleu_1']},{metrics_dict['Bleu_2']},{metrics_dict['Bleu_3']},"
        # f"{metrics_dict['Bleu_4']},{metrics_dict['ROGUE_L']},{metrics_dict['CIDEr']}\n"
        # ))
        print(','.join([str(e) for e in ls]))
        f.write(','.join([str(e) for e in ls]) + '\n')

def evaluate_all(model, val_image_embeddings, val_captions, test_image_embeddings, test_captions, seed, project_dim):
    """
    Evaluate the model with val and test set. Will write results in directory results/.
    Also will log time taken to predict with function write_time_file.

    Metrics evaluated: 
        Clinical accuracy, 
        TODO: BLEU score

    Parameters:
        model (SimilaritySearch):           Similarity search model to evaluate
        val_image_embeddings (numpy.array): Image embeddings of size (sample_size, project_dim)
        val_captions (numpy.array):         Encoded captions of size (sample_size, max_caption_len)
        test_image_embeddings (numpy.array): Image embeddings of size (sample_size, project_dim)
        test_captions (numpy.array):        Encoded captions of size (sample_size, max_caption_len)
        seed (int):                         (For logging purpose only) Seed used to random project
        project_dim (int):                  Random project dimension
    """

    if isinstance(model, SimilaritySearchCoarse2Fine):
        # Coarse2Fine models needs an image loader for the encoder to predict 14 diseases
        predicted_reports, val_time = predict(model, val_image_embeddings, batch_size=1024, image_loader=val_loader)
    else:
        predicted_reports, val_time = predict(model, val_image_embeddings, batch_size=1024)
    print(f"Time taken to predict val: {val_time:.3f} seconds")

    # NLG Metrics
    decoded_predicted_reports = tokenizer.decode(predicted_reports)
    decoded_true_reports = tokenizer.decode(val_captions)
    val_metrics_dict = calculate_nlg_metrics(decoded_true_reports, decoded_predicted_reports)
    print(val_metrics_dict)
    write_nlg_file(model, project_dim, seed, 'val', val_metrics_dict)

    # Clinical Accuracy
    # val_eval_matrix = evaluate_clinical(val_captions, predicted_reports, batch_size=8)

    # Save Results to csv
    # val_eval_matrix.to_csv(
    #     f'results/{type(model).__name__}_val_results_{project_dim}_{seed}.csv', index=False
    # )
    # print(val_eval_matrix)

    if isinstance(model, SimilaritySearchCoarse2Fine):
        # Coarse2Fine models needs an image loader for the encoder to predict 14 diseases
        predicted_reports, test_time = predict(model, test_image_embeddings, batch_size=1024, image_loader=test_loader)
    else:
        predicted_reports, test_time = predict(model, test_image_embeddings, batch_size=1024)
    print(f"Time taken to predict test: {test_time:.3f} seconds")

    # NLG Metrics
    decoded_predicted_reports = tokenizer.decode(predicted_reports)
    decoded_true_reports = tokenizer.decode(test_captions)
    test_metrics_dict = calculate_nlg_metrics(decoded_true_reports, decoded_predicted_reports)
    print(test_metrics_dict)
    write_nlg_file(model, project_dim, seed, 'test', test_metrics_dict)

    # Clinical Accuracy
    # test_eval_matrix = evaluate_clinical(test_captions, predicted_reports, batch_size=8)

    # Save Results to csv
    # test_eval_matrix.to_csv(
    #     f'results/{type(model).__name__}_test_results_{project_dim}_{seed}.csv', index=False
    # )
    # print(test_eval_matrix)

    # Save time taken to csv
    # write_time_file(model, project_dim, seed, val_time, test_time)

In [10]:
def e2e_benchmark(model_class, seed, project_dim, load=False):
    """
    Perform vector acquisition, 1-NN, and evaluate on val and test set, and save results into files
    """

    # Get vector
    train_image_embeddings, train_captions, val_image_embeddings, val_captions, test_image_embeddings, test_captions = get_vectors(seed, project_dim, load=load)

    # Model selection
    model = model_class()
    if isinstance(model, SimilaritySearchCoarse2Fine):
        model.assign_encoder(encoder)

    # Train
    model.fit(train_image_embeddings, train_captions)

    # Evaluate
    evaluate_all(model, val_image_embeddings, val_captions, test_image_embeddings, test_captions, seed, project_dim)

In [15]:
models = [OneNearestNeighbor, OneNearestNeighborCoarse2Fine]
seeds = [2000, 3000, 4000]
dims = [128, 256, 512, 1024, 2048,]

for model_class in models:
    for dim in dims:
        for seed in seeds:
            # path_to_check = f'results/{model_class.__name__}_test_results_{dim}_{seed}.csv'
            # exist = os.path.exists(path_to_check)
            # if not exist:
            print(model_class.__name__, dim, seed)
            e2e_benchmark(model_class, seed=seed, project_dim=dim, load=True)

OneNearestNeighbor 128 2000


100%|██████████| 3/3 [00:10<00:00,  3.64s/it]


Time taken to predict val: 10.876 seconds
{'Bleu_1': 0.3851666246581032, 'Bleu_2': 0.22431177600505312, 'Bleu_3': 0.14294507079829166, 'Bleu_4': 0.09649506046901199, 'ROUGE_L': 0.271416630620714, 'CIDEr': 0.1281270202316191}
OneNearestNeighbor,2022-05-02 18:46:13.582486,128,2000,val,0.3851666246581032,0.22431177600505312,0.14294507079829166,0.09649506046901199,0.271416630620714,0.1281270202316191


100%|██████████| 4/4 [00:17<00:00,  4.30s/it]


Time taken to predict test: 17.150 seconds
{'Bleu_1': 0.34100706701727673, 'Bleu_2': 0.18396719381405655, 'Bleu_3': 0.1073921312427076, 'Bleu_4': 0.06729205252292166, 'ROUGE_L': 0.23272179826916842, 'CIDEr': 0.0697661325315319}
OneNearestNeighbor,2022-05-02 18:46:49.021982,128,2000,test,0.34100706701727673,0.18396719381405655,0.1073921312427076,0.06729205252292166,0.23272179826916842,0.0697661325315319
OneNearestNeighbor 128 3000


100%|██████████| 3/3 [00:09<00:00,  3.18s/it]


Time taken to predict val: 9.536 seconds
{'Bleu_1': 0.38165138741950333, 'Bleu_2': 0.2216532610022771, 'Bleu_3': 0.14002622831695125, 'Bleu_4': 0.09372589719497788, 'ROUGE_L': 0.27100540965478476, 'CIDEr': 0.12629952141326783}
OneNearestNeighbor,2022-05-02 18:47:08.381862,128,3000,val,0.38165138741950333,0.2216532610022771,0.14002622831695125,0.09372589719497788,0.27100540965478476,0.12629952141326783


100%|██████████| 4/4 [00:17<00:00,  4.47s/it]


Time taken to predict test: 17.849 seconds
{'Bleu_1': 0.33894636307856185, 'Bleu_2': 0.18358649281528508, 'Bleu_3': 0.10764939891962928, 'Bleu_4': 0.06745009842237873, 'ROUGE_L': 0.233474916946053, 'CIDEr': 0.07226210406848409}
OneNearestNeighbor,2022-05-02 18:47:44.877925,128,3000,test,0.33894636307856185,0.18358649281528508,0.10764939891962928,0.06745009842237873,0.233474916946053,0.07226210406848409
OneNearestNeighbor 128 4000


100%|██████████| 3/3 [00:09<00:00,  3.16s/it]


Time taken to predict val: 9.469 seconds
{'Bleu_1': 0.3773189706762382, 'Bleu_2': 0.2180727898581333, 'Bleu_3': 0.1376134180106751, 'Bleu_4': 0.0928207057626722, 'ROUGE_L': 0.26774046967359794, 'CIDEr': 0.12797976259636373}
OneNearestNeighbor,2022-05-02 18:48:03.883438,128,4000,val,0.3773189706762382,0.2180727898581333,0.1376134180106751,0.0928207057626722,0.26774046967359794,0.12797976259636373


100%|██████████| 4/4 [00:16<00:00,  4.11s/it]


Time taken to predict test: 16.427 seconds
{'Bleu_1': 0.33887426496700607, 'Bleu_2': 0.18214606603080472, 'Bleu_3': 0.10585270359020955, 'Bleu_4': 0.06577013748799636, 'ROUGE_L': 0.23150700739956637, 'CIDEr': 0.07492961406103807}
OneNearestNeighbor,2022-05-02 18:48:37.987057,128,4000,test,0.33887426496700607,0.18214606603080472,0.10585270359020955,0.06577013748799636,0.23150700739956637,0.07492961406103807
OneNearestNeighbor 256 2000


100%|██████████| 3/3 [00:11<00:00,  3.80s/it]


Time taken to predict val: 11.390 seconds
{'Bleu_1': 0.38669479422389047, 'Bleu_2': 0.22564265104755848, 'Bleu_3': 0.14363864654736716, 'Bleu_4': 0.09736466730627144, 'ROUGE_L': 0.2747899060951425, 'CIDEr': 0.15402486510029487}
OneNearestNeighbor,2022-05-02 18:48:59.222174,256,2000,val,0.38669479422389047,0.22564265104755848,0.14363864654736716,0.09736466730627144,0.2747899060951425,0.15402486510029487


100%|██████████| 4/4 [00:19<00:00,  4.83s/it]


Time taken to predict test: 19.288 seconds
{'Bleu_1': 0.34201100721765865, 'Bleu_2': 0.185108074983537, 'Bleu_3': 0.10839928303109103, 'Bleu_4': 0.06774496438127756, 'ROUGE_L': 0.23451050915161073, 'CIDEr': 0.0793206512724363}
OneNearestNeighbor,2022-05-02 18:49:35.945526,256,2000,test,0.34201100721765865,0.185108074983537,0.10839928303109103,0.06774496438127756,0.23451050915161073,0.0793206512724363
OneNearestNeighbor 256 3000


100%|██████████| 3/3 [00:11<00:00,  3.75s/it]


Time taken to predict val: 11.221 seconds
{'Bleu_1': 0.3828955559872012, 'Bleu_2': 0.225259733191817, 'Bleu_3': 0.1441249901449956, 'Bleu_4': 0.09777121869976654, 'ROUGE_L': 0.27501430050432396, 'CIDEr': 0.14586516616667644}
OneNearestNeighbor,2022-05-02 18:49:57.187977,256,3000,val,0.3828955559872012,0.225259733191817,0.1441249901449956,0.09777121869976654,0.27501430050432396,0.14586516616667644


100%|██████████| 4/4 [00:19<00:00,  4.85s/it]


Time taken to predict test: 19.363 seconds
{'Bleu_1': 0.33547332671171703, 'Bleu_2': 0.17984331743040521, 'Bleu_3': 0.10375664188228957, 'Bleu_4': 0.06346153769176875, 'ROUGE_L': 0.2305317865431016, 'CIDEr': 0.06733790778984762}
OneNearestNeighbor,2022-05-02 18:50:33.877080,256,3000,test,0.33547332671171703,0.17984331743040521,0.10375664188228957,0.06346153769176875,0.2305317865431016,0.06733790778984762
OneNearestNeighbor 256 4000


100%|██████████| 3/3 [00:11<00:00,  3.78s/it]


Time taken to predict val: 11.309 seconds
{'Bleu_1': 0.385325317251143, 'Bleu_2': 0.22554866752296315, 'Bleu_3': 0.14391753672772556, 'Bleu_4': 0.09809792589119896, 'ROUGE_L': 0.2747302161353351, 'CIDEr': 0.157531823997193}
OneNearestNeighbor,2022-05-02 18:50:54.990818,256,4000,val,0.385325317251143,0.22554866752296315,0.14391753672772556,0.09809792589119896,0.2747302161353351,0.157531823997193


100%|██████████| 4/4 [00:19<00:00,  4.90s/it]


Time taken to predict test: 19.561 seconds
{'Bleu_1': 0.3443989109404857, 'Bleu_2': 0.18771114923220097, 'Bleu_3': 0.11068636083302628, 'Bleu_4': 0.0696434266120027, 'ROUGE_L': 0.2351346437674789, 'CIDEr': 0.07829656636923034}
OneNearestNeighbor,2022-05-02 18:51:32.391410,256,4000,test,0.3443989109404857,0.18771114923220097,0.11068636083302628,0.0696434266120027,0.2351346437674789,0.07829656636923034
OneNearestNeighbor 512 2000


100%|██████████| 3/3 [00:15<00:00,  5.05s/it]


Time taken to predict val: 15.133 seconds
{'Bleu_1': 0.38642374369720667, 'Bleu_2': 0.22710088451559476, 'Bleu_3': 0.14617565845558117, 'Bleu_4': 0.10038664328642324, 'ROUGE_L': 0.27635580543055255, 'CIDEr': 0.16229649182041483}
OneNearestNeighbor,2022-05-02 18:51:58.056073,512,2000,val,0.38642374369720667,0.22710088451559476,0.14617565845558117,0.10038664328642324,0.27635580543055255,0.16229649182041483


100%|██████████| 4/4 [00:25<00:00,  6.27s/it]


Time taken to predict test: 25.036 seconds
{'Bleu_1': 0.3416901004190246, 'Bleu_2': 0.1857929190286701, 'Bleu_3': 0.1101022784324904, 'Bleu_4': 0.06978758001458982, 'ROUGE_L': 0.23464368891282436, 'CIDEr': 0.08071289967998382}
OneNearestNeighbor,2022-05-02 18:52:40.642773,512,2000,test,0.3416901004190246,0.1857929190286701,0.1101022784324904,0.06978758001458982,0.23464368891282436,0.08071289967998382
OneNearestNeighbor 512 3000


100%|██████████| 3/3 [00:16<00:00,  5.38s/it]


Time taken to predict val: 16.108 seconds
{'Bleu_1': 0.3818267289299594, 'Bleu_2': 0.2236015975133133, 'Bleu_3': 0.14366263654488043, 'Bleu_4': 0.09830654719644734, 'ROUGE_L': 0.274369035355379, 'CIDEr': 0.14486543633748897}
OneNearestNeighbor,2022-05-02 18:53:07.561045,512,3000,val,0.3818267289299594,0.2236015975133133,0.14366263654488043,0.09830654719644734,0.274369035355379,0.14486543633748897


100%|██████████| 4/4 [00:26<00:00,  6.57s/it]


Time taken to predict test: 26.261 seconds
{'Bleu_1': 0.34322912125865035, 'Bleu_2': 0.18663465765824186, 'Bleu_3': 0.10976878668078784, 'Bleu_4': 0.06879634851883272, 'ROUGE_L': 0.23408144582189097, 'CIDEr': 0.07725186574216192}
OneNearestNeighbor,2022-05-02 18:53:51.644938,512,3000,test,0.34322912125865035,0.18663465765824186,0.10976878668078784,0.06879634851883272,0.23408144582189097,0.07725186574216192
OneNearestNeighbor 512 4000


100%|██████████| 3/3 [00:15<00:00,  5.16s/it]


Time taken to predict val: 15.464 seconds
{'Bleu_1': 0.3817403325838768, 'Bleu_2': 0.22351731985821144, 'Bleu_3': 0.14343743434331366, 'Bleu_4': 0.09787013621346655, 'ROUGE_L': 0.2743951186975061, 'CIDEr': 0.14308431868936775}
OneNearestNeighbor,2022-05-02 18:54:17.667972,512,4000,val,0.3817403325838768,0.22351731985821144,0.14343743434331366,0.09787013621346655,0.2743951186975061,0.14308431868936775


100%|██████████| 4/4 [00:26<00:00,  6.54s/it]


Time taken to predict test: 26.096 seconds
{'Bleu_1': 0.3436262538630122, 'Bleu_2': 0.18601949464672632, 'Bleu_3': 0.10963293874236098, 'Bleu_4': 0.06916439967143259, 'ROUGE_L': 0.23427576826999078, 'CIDEr': 0.07541453025980668}
OneNearestNeighbor,2022-05-02 18:55:01.388706,512,4000,test,0.3436262538630122,0.18601949464672632,0.10963293874236098,0.06916439967143259,0.23427576826999078,0.07541453025980668
OneNearestNeighbor 1024 2000


100%|██████████| 3/3 [00:22<00:00,  7.64s/it]


Time taken to predict val: 22.906 seconds
{'Bleu_1': 0.3875724487032092, 'Bleu_2': 0.22904047072745548, 'Bleu_3': 0.14806252203933593, 'Bleu_4': 0.10156013781022297, 'ROUGE_L': 0.2768749579044496, 'CIDEr': 0.1515306340726452}
OneNearestNeighbor,2022-05-02 18:55:36.450311,1024,2000,val,0.3875724487032092,0.22904047072745548,0.14806252203933593,0.10156013781022297,0.2768749579044496,0.1515306340726452


100%|██████████| 4/4 [00:38<00:00,  9.67s/it]


Time taken to predict test: 38.658 seconds
{'Bleu_1': 0.3418744749089257, 'Bleu_2': 0.18526110280139949, 'Bleu_3': 0.10900685001905178, 'Bleu_4': 0.06846551135897302, 'ROUGE_L': 0.23449138935461508, 'CIDEr': 0.07819480548473798}
OneNearestNeighbor,2022-05-02 18:56:32.568963,1024,2000,test,0.3418744749089257,0.18526110280139949,0.10900685001905178,0.06846551135897302,0.23449138935461508,0.07819480548473798
OneNearestNeighbor 1024 3000


100%|██████████| 3/3 [00:22<00:00,  7.64s/it]


Time taken to predict val: 22.884 seconds
{'Bleu_1': 0.3855432979351549, 'Bleu_2': 0.22621197096275159, 'Bleu_3': 0.14461269801595483, 'Bleu_4': 0.09774148963947935, 'ROUGE_L': 0.2741198823049274, 'CIDEr': 0.14208014927071086}
OneNearestNeighbor,2022-05-02 18:57:07.565651,1024,3000,val,0.3855432979351549,0.22621197096275159,0.14461269801595483,0.09774148963947935,0.2741198823049274,0.14208014927071086


100%|██████████| 4/4 [00:38<00:00,  9.69s/it]


Time taken to predict test: 38.741 seconds
{'Bleu_1': 0.3440515534043194, 'Bleu_2': 0.18736751091326742, 'Bleu_3': 0.11065406990732024, 'Bleu_4': 0.06987331433688984, 'ROUGE_L': 0.2348759988909057, 'CIDEr': 0.08013059117516382}
OneNearestNeighbor,2022-05-02 18:58:04.268722,1024,3000,test,0.3440515534043194,0.18736751091326742,0.11065406990732024,0.06987331433688984,0.2348759988909057,0.08013059117516382
OneNearestNeighbor 1024 4000


100%|██████████| 3/3 [00:22<00:00,  7.66s/it]


Time taken to predict val: 22.946 seconds
{'Bleu_1': 0.38187047689738, 'Bleu_2': 0.2219842106463303, 'Bleu_3': 0.1408398258697893, 'Bleu_4': 0.09453120894891902, 'ROUGE_L': 0.27230570225202017, 'CIDEr': 0.14074895719170197}
OneNearestNeighbor,2022-05-02 18:58:39.470098,1024,4000,val,0.38187047689738,0.2219842106463303,0.1408398258697893,0.09453120894891902,0.27230570225202017,0.14074895719170197


100%|██████████| 4/4 [00:38<00:00,  9.70s/it]


Time taken to predict test: 38.769 seconds
{'Bleu_1': 0.3436602943554784, 'Bleu_2': 0.18769387806305132, 'Bleu_3': 0.11099767124368434, 'Bleu_4': 0.07044037138525842, 'ROUGE_L': 0.23555435162370494, 'CIDEr': 0.07433691805224799}
OneNearestNeighbor,2022-05-02 18:59:35.787357,1024,4000,test,0.3436602943554784,0.18769387806305132,0.11099767124368434,0.07044037138525842,0.23555435162370494,0.07433691805224799
OneNearestNeighbor 2048 2000


100%|██████████| 3/3 [00:44<00:00, 14.79s/it]


Time taken to predict val: 44.342 seconds
{'Bleu_1': 0.38900598600454905, 'Bleu_2': 0.22881723600754794, 'Bleu_3': 0.147169275108359, 'Bleu_4': 0.10064597260204772, 'ROUGE_L': 0.2762343984993211, 'CIDEr': 0.15314477592471193}
OneNearestNeighbor,2022-05-02 19:00:35.856371,2048,2000,val,0.38900598600454905,0.22881723600754794,0.147169275108359,0.10064597260204772,0.2762343984993211,0.15314477592471193


100%|██████████| 4/4 [01:11<00:00, 17.85s/it]


Time taken to predict test: 71.354 seconds
{'Bleu_1': 0.34699881912844205, 'Bleu_2': 0.18883834787385043, 'Bleu_3': 0.11140620165303977, 'Bleu_4': 0.07001457727845664, 'ROUGE_L': 0.23500773618372237, 'CIDEr': 0.08380085753433965}
OneNearestNeighbor,2022-05-02 19:02:05.037201,2048,2000,test,0.34699881912844205,0.18883834787385043,0.11140620165303977,0.07001457727845664,0.23500773618372237,0.08380085753433965
OneNearestNeighbor 2048 3000


100%|██████████| 3/3 [00:43<00:00, 14.41s/it]


Time taken to predict val: 43.188 seconds
{'Bleu_1': 0.38771566712269123, 'Bleu_2': 0.22794212158547325, 'Bleu_3': 0.1460835210247413, 'Bleu_4': 0.09892184254509868, 'ROUGE_L': 0.2755242267479713, 'CIDEr': 0.1434520418452554}
OneNearestNeighbor,2022-05-02 19:03:03.501158,2048,3000,val,0.38771566712269123,0.22794212158547325,0.1460835210247413,0.09892184254509868,0.2755242267479713,0.1434520418452554


100%|██████████| 4/4 [01:11<00:00, 17.96s/it]


Time taken to predict test: 71.761 seconds
{'Bleu_1': 0.3438537435489242, 'Bleu_2': 0.18679584625250464, 'Bleu_3': 0.11012675589754968, 'Bleu_4': 0.06930731828139797, 'ROUGE_L': 0.235368545564509, 'CIDEr': 0.0787850948911098}
OneNearestNeighbor,2022-05-02 19:04:32.877138,2048,3000,test,0.3438537435489242,0.18679584625250464,0.11012675589754968,0.06930731828139797,0.235368545564509,0.0787850948911098
OneNearestNeighbor 2048 4000


100%|██████████| 3/3 [00:42<00:00, 14.18s/it]


Time taken to predict val: 42.520 seconds
{'Bleu_1': 0.3868340453870168, 'Bleu_2': 0.22711265581077705, 'Bleu_3': 0.1452446701886665, 'Bleu_4': 0.09853515959163726, 'ROUGE_L': 0.2752621557077365, 'CIDEr': 0.1470349293679516}
OneNearestNeighbor,2022-05-02 19:05:30.814665,2048,4000,val,0.3868340453870168,0.22711265581077705,0.1452446701886665,0.09853515959163726,0.2752621557077365,0.1470349293679516


100%|██████████| 4/4 [01:11<00:00, 17.79s/it]


Time taken to predict test: 71.110 seconds
{'Bleu_1': 0.3441497520485206, 'Bleu_2': 0.18767617816169713, 'Bleu_3': 0.11063535179574789, 'Bleu_4': 0.0696093989557363, 'ROUGE_L': 0.23437415916994386, 'CIDEr': 0.08026117768589609}
OneNearestNeighbor,2022-05-02 19:06:59.622474,2048,4000,test,0.3441497520485206,0.18767617816169713,0.11063535179574789,0.0696093989557363,0.23437415916994386,0.08026117768589609
OneNearestNeighborCoarse2Fine 128 2000
forced batch size: 16


100%|██████████| 131/131 [01:02<00:00,  2.09it/s]


Time taken to predict val: 55.580 seconds
{'Bleu_1': 0.38299645178362485, 'Bleu_2': 0.22291360353388726, 'Bleu_3': 0.14138369331044728, 'Bleu_4': 0.09535268585093719, 'ROUGE_L': 0.27124081292679403, 'CIDEr': 0.13350435525591098}
OneNearestNeighborCoarse2Fine,2022-05-02 19:08:17.204674,128,2000,val,0.38299645178362485,0.22291360353388726,0.14138369331044728,0.09535268585093719,0.27124081292679403,0.13350435525591098
forced batch size: 16


100%|██████████| 229/229 [00:56<00:00,  4.02it/s]


Time taken to predict test: 45.532 seconds
{'Bleu_1': 0.34089702921135673, 'Bleu_2': 0.18422519699858275, 'Bleu_3': 0.10766278080617825, 'Bleu_4': 0.06740032986443346, 'ROUGE_L': 0.23415012321453937, 'CIDEr': 0.07064426397473428}
OneNearestNeighborCoarse2Fine,2022-05-02 19:09:28.436003,128,2000,test,0.34089702921135673,0.18422519699858275,0.10766278080617825,0.06740032986443346,0.23415012321453937,0.07064426397473428
OneNearestNeighborCoarse2Fine 128 3000
forced batch size: 16


100%|██████████| 131/131 [00:56<00:00,  2.31it/s]


Time taken to predict val: 52.121 seconds
{'Bleu_1': 0.3831598102722573, 'Bleu_2': 0.22331888535865033, 'Bleu_3': 0.14173756269105883, 'Bleu_4': 0.09538797553035011, 'ROUGE_L': 0.2716896400668826, 'CIDEr': 0.13017003837993976}
OneNearestNeighborCoarse2Fine,2022-05-02 19:10:36.892858,128,3000,val,0.3831598102722573,0.22331888535865033,0.14173756269105883,0.09538797553035011,0.2716896400668826,0.13017003837993976
forced batch size: 16


100%|██████████| 229/229 [00:52<00:00,  4.37it/s]


Time taken to predict test: 44.494 seconds
{'Bleu_1': 0.3444465994182055, 'Bleu_2': 0.18765287268548075, 'Bleu_3': 0.11095112443373517, 'Bleu_4': 0.07041071580420676, 'ROUGE_L': 0.2352278939065356, 'CIDEr': 0.07089528142774225}
OneNearestNeighborCoarse2Fine,2022-05-02 19:11:43.603673,128,3000,test,0.3444465994182055,0.18765287268548075,0.11095112443373517,0.07041071580420676,0.2352278939065356,0.07089528142774225
OneNearestNeighborCoarse2Fine 128 4000
forced batch size: 16


100%|██████████| 131/131 [00:56<00:00,  2.32it/s]


Time taken to predict val: 52.003 seconds
{'Bleu_1': 0.3830478913795432, 'Bleu_2': 0.22389251347787736, 'Bleu_3': 0.14301648134180991, 'Bleu_4': 0.0974029575982961, 'ROUGE_L': 0.27334878328030726, 'CIDEr': 0.14781898997946183}
OneNearestNeighborCoarse2Fine,2022-05-02 19:12:51.715110,128,4000,val,0.3830478913795432,0.22389251347787736,0.14301648134180991,0.0974029575982961,0.27334878328030726,0.14781898997946183
forced batch size: 16


100%|██████████| 229/229 [00:52<00:00,  4.37it/s]


Time taken to predict test: 44.464 seconds
{'Bleu_1': 0.341504249373113, 'Bleu_2': 0.18511299078015228, 'Bleu_3': 0.10835649662585373, 'Bleu_4': 0.06799057186271243, 'ROUGE_L': 0.23377168948944993, 'CIDEr': 0.07539117433022936}
OneNearestNeighborCoarse2Fine,2022-05-02 19:13:58.227216,128,4000,test,0.341504249373113,0.18511299078015228,0.10835649662585373,0.06799057186271243,0.23377168948944993,0.07539117433022936
OneNearestNeighborCoarse2Fine 256 2000
forced batch size: 16


100%|██████████| 131/131 [01:35<00:00,  1.38it/s]


Time taken to predict val: 90.530 seconds
{'Bleu_1': 0.383464081943943, 'Bleu_2': 0.22467617136822315, 'Bleu_3': 0.143095504069328, 'Bleu_4': 0.09690253328785003, 'ROUGE_L': 0.273898096970324, 'CIDEr': 0.15357015993692108}
OneNearestNeighborCoarse2Fine,2022-05-02 19:15:45.106631,256,2000,val,0.383464081943943,0.22467617136822315,0.143095504069328,0.09690253328785003,0.273898096970324,0.15357015993692108
forced batch size: 16


100%|██████████| 229/229 [01:15<00:00,  3.04it/s]


Time taken to predict test: 67.530 seconds
{'Bleu_1': 0.3421722617788217, 'Bleu_2': 0.1844057391861504, 'Bleu_3': 0.10752914186035067, 'Bleu_4': 0.0673415776568895, 'ROUGE_L': 0.2322441666666971, 'CIDEr': 0.07148363979694566}
OneNearestNeighborCoarse2Fine,2022-05-02 19:17:14.817411,256,2000,test,0.3421722617788217,0.1844057391861504,0.10752914186035067,0.0673415776568895,0.2322441666666971,0.07148363979694566
OneNearestNeighborCoarse2Fine 256 3000
forced batch size: 16


100%|██████████| 131/131 [01:33<00:00,  1.40it/s]


Time taken to predict val: 89.234 seconds
{'Bleu_1': 0.3823095731377234, 'Bleu_2': 0.223123379934912, 'Bleu_3': 0.1419996843125351, 'Bleu_4': 0.09616778886372021, 'ROUGE_L': 0.2735227320353744, 'CIDEr': 0.14483679870045182}
OneNearestNeighborCoarse2Fine,2022-05-02 19:19:00.260921,256,3000,val,0.3823095731377234,0.223123379934912,0.1419996843125351,0.09616778886372021,0.2735227320353744,0.14483679870045182
forced batch size: 16


100%|██████████| 229/229 [01:15<00:00,  3.02it/s]


Time taken to predict test: 67.768 seconds
{'Bleu_1': 0.3412224251618969, 'Bleu_2': 0.18423503824392226, 'Bleu_3': 0.10726253994925358, 'Bleu_4': 0.06644384992775854, 'ROUGE_L': 0.23300343047792899, 'CIDEr': 0.07216851010858212}
OneNearestNeighborCoarse2Fine,2022-05-02 19:20:30.143720,256,3000,test,0.3412224251618969,0.18423503824392226,0.10726253994925358,0.06644384992775854,0.23300343047792899,0.07216851010858212
OneNearestNeighborCoarse2Fine 256 4000
forced batch size: 16


100%|██████████| 131/131 [01:34<00:00,  1.39it/s]


Time taken to predict val: 89.986 seconds
{'Bleu_1': 0.38683772485335255, 'Bleu_2': 0.22751346926043145, 'Bleu_3': 0.14561618395621478, 'Bleu_4': 0.09954590147084239, 'ROUGE_L': 0.27742419928944667, 'CIDEr': 0.15510117829514525}
OneNearestNeighborCoarse2Fine,2022-05-02 19:22:16.622713,256,4000,val,0.38683772485335255,0.22751346926043145,0.14561618395621478,0.09954590147084239,0.27742419928944667,0.15510117829514525
forced batch size: 16


100%|██████████| 229/229 [01:16<00:00,  3.01it/s]


Time taken to predict test: 68.148 seconds
{'Bleu_1': 0.3418091676980154, 'Bleu_2': 0.18467825363276338, 'Bleu_3': 0.10808888181993037, 'Bleu_4': 0.0676465755790161, 'ROUGE_L': 0.2341807713127488, 'CIDEr': 0.07032767723035778}
OneNearestNeighborCoarse2Fine,2022-05-02 19:23:46.877215,256,4000,test,0.3418091676980154,0.18467825363276338,0.10808888181993037,0.0676465755790161,0.2341807713127488,0.07032767723035778
OneNearestNeighborCoarse2Fine 512 2000
forced batch size: 16


100%|██████████| 131/131 [02:52<00:00,  1.32s/it]


Time taken to predict val: 167.992 seconds
{'Bleu_1': 0.38746534176100494, 'Bleu_2': 0.22773143295767762, 'Bleu_3': 0.1467592041458735, 'Bleu_4': 0.10081739173369496, 'ROUGE_L': 0.27823225832787796, 'CIDEr': 0.1617172419781009}
OneNearestNeighborCoarse2Fine,2022-05-02 19:26:52.025149,512,2000,val,0.38746534176100494,0.22773143295767762,0.1467592041458735,0.10081739173369496,0.27823225832787796,0.1617172419781009
forced batch size: 16


100%|██████████| 229/229 [02:03<00:00,  1.86it/s]


Time taken to predict test: 115.376 seconds
{'Bleu_1': 0.3386828496998539, 'Bleu_2': 0.18279330722174236, 'Bleu_3': 0.10697683906503591, 'Bleu_4': 0.06713726210929288, 'ROUGE_L': 0.233063857431398, 'CIDEr': 0.06459605688708628}
OneNearestNeighborCoarse2Fine,2022-05-02 19:29:09.419578,512,2000,test,0.3386828496998539,0.18279330722174236,0.10697683906503591,0.06713726210929288,0.233063857431398,0.06459605688708628
OneNearestNeighborCoarse2Fine 512 3000
forced batch size: 16


100%|██████████| 131/131 [02:52<00:00,  1.32s/it]


Time taken to predict val: 168.042 seconds
{'Bleu_1': 0.38443568106114073, 'Bleu_2': 0.22568634376425734, 'Bleu_3': 0.14441077435562888, 'Bleu_4': 0.09808178848870326, 'ROUGE_L': 0.27665217601057646, 'CIDEr': 0.1483562736604661}
OneNearestNeighborCoarse2Fine,2022-05-02 19:32:14.524852,512,3000,val,0.38443568106114073,0.22568634376425734,0.14441077435562888,0.09808178848870326,0.27665217601057646,0.1483562736604661
forced batch size: 16


100%|██████████| 229/229 [02:04<00:00,  1.84it/s]


Time taken to predict test: 116.200 seconds
{'Bleu_1': 0.34128882441499986, 'Bleu_2': 0.18399846495264455, 'Bleu_3': 0.10684625941741388, 'Bleu_4': 0.06640543936691712, 'ROUGE_L': 0.23175707093714612, 'CIDEr': 0.06937802527785739}
OneNearestNeighborCoarse2Fine,2022-05-02 19:34:32.893729,512,3000,test,0.34128882441499986,0.18399846495264455,0.10684625941741388,0.06640543936691712,0.23175707093714612,0.06937802527785739
OneNearestNeighborCoarse2Fine 512 4000
forced batch size: 16


100%|██████████| 131/131 [02:53<00:00,  1.33s/it]


Time taken to predict val: 169.167 seconds
{'Bleu_1': 0.3839033240122222, 'Bleu_2': 0.22390344231189577, 'Bleu_3': 0.14286818348677155, 'Bleu_4': 0.09684164760544983, 'ROUGE_L': 0.2744398578033106, 'CIDEr': 0.1378280429495642}
OneNearestNeighborCoarse2Fine,2022-05-02 19:37:39.403630,512,4000,val,0.3839033240122222,0.22390344231189577,0.14286818348677155,0.09684164760544983,0.2744398578033106,0.1378280429495642
forced batch size: 16


100%|██████████| 229/229 [02:04<00:00,  1.83it/s]


Time taken to predict test: 116.802 seconds
{'Bleu_1': 0.34139962640203797, 'Bleu_2': 0.1844511720776192, 'Bleu_3': 0.10789486882470538, 'Bleu_4': 0.06772538390758015, 'ROUGE_L': 0.23319955294762926, 'CIDEr': 0.07122056145614389}
OneNearestNeighborCoarse2Fine,2022-05-02 19:39:58.297570,512,4000,test,0.34139962640203797,0.1844511720776192,0.10789486882470538,0.06772538390758015,0.23319955294762926,0.07122056145614389
OneNearestNeighborCoarse2Fine 1024 2000
forced batch size: 16


100%|██████████| 131/131 [05:37<00:00,  2.57s/it]


Time taken to predict val: 332.535 seconds
{'Bleu_1': 0.38893213442547553, 'Bleu_2': 0.23019395176310525, 'Bleu_3': 0.1485943767862617, 'Bleu_4': 0.10173034464726048, 'ROUGE_L': 0.27794234478451113, 'CIDEr': 0.15814639628867888}
OneNearestNeighborCoarse2Fine,2022-05-02 19:45:50.070707,1024,2000,val,0.38893213442547553,0.23019395176310525,0.1485943767862617,0.10173034464726048,0.27794234478451113,0.15814639628867888
forced batch size: 16


100%|██████████| 229/229 [03:42<00:00,  1.03it/s]


Time taken to predict test: 214.455 seconds
{'Bleu_1': 0.33978287446828065, 'Bleu_2': 0.18304016360905273, 'Bleu_3': 0.10666009793667154, 'Bleu_4': 0.06673419077878213, 'ROUGE_L': 0.23371633798652716, 'CIDEr': 0.07253752359169491}
OneNearestNeighborCoarse2Fine,2022-05-02 19:49:46.774374,1024,2000,test,0.33978287446828065,0.18304016360905273,0.10666009793667154,0.06673419077878213,0.23371633798652716,0.07253752359169491
OneNearestNeighborCoarse2Fine 1024 3000
forced batch size: 16


100%|██████████| 131/131 [05:37<00:00,  2.57s/it]


Time taken to predict val: 332.531 seconds
{'Bleu_1': 0.38466414363428586, 'Bleu_2': 0.22665316951605882, 'Bleu_3': 0.14565077446534846, 'Bleu_4': 0.09952345989814733, 'ROUGE_L': 0.27586650345643016, 'CIDEr': 0.15208447263987115}
OneNearestNeighborCoarse2Fine,2022-05-02 19:55:38.386875,1024,3000,val,0.38466414363428586,0.22665316951605882,0.14565077446534846,0.09952345989814733,0.27586650345643016,0.15208447263987115
forced batch size: 16


100%|██████████| 229/229 [03:41<00:00,  1.03it/s]


Time taken to predict test: 213.776 seconds
{'Bleu_1': 0.3414315332367484, 'Bleu_2': 0.18503868344234625, 'Bleu_3': 0.10894787172260996, 'Bleu_4': 0.06874662320203524, 'ROUGE_L': 0.2329991440060694, 'CIDEr': 0.07266256542045374}
OneNearestNeighborCoarse2Fine,2022-05-02 19:59:34.332107,1024,3000,test,0.3414315332367484,0.18503868344234625,0.10894787172260996,0.06874662320203524,0.2329991440060694,0.07266256542045374
OneNearestNeighborCoarse2Fine 1024 4000
forced batch size: 16


100%|██████████| 131/131 [05:39<00:00,  2.59s/it]


Time taken to predict val: 334.728 seconds
{'Bleu_1': 0.3847586000853468, 'Bleu_2': 0.22487314316621626, 'Bleu_3': 0.14371881366302686, 'Bleu_4': 0.09719225259885546, 'ROUGE_L': 0.2734702076772536, 'CIDEr': 0.1402689180905492}
OneNearestNeighborCoarse2Fine,2022-05-02 20:05:28.356463,1024,4000,val,0.3847586000853468,0.22487314316621626,0.14371881366302686,0.09719225259885546,0.2734702076772536,0.1402689180905492
forced batch size: 16


100%|██████████| 229/229 [03:43<00:00,  1.03it/s]


Time taken to predict test: 215.098 seconds
{'Bleu_1': 0.34183629774504315, 'Bleu_2': 0.18539784089977743, 'Bleu_3': 0.10884708802266972, 'Bleu_4': 0.0687844967693281, 'ROUGE_L': 0.23404579048496868, 'CIDEr': 0.07249214478557643}
OneNearestNeighborCoarse2Fine,2022-05-02 20:09:25.631331,1024,4000,test,0.34183629774504315,0.18539784089977743,0.10884708802266972,0.0687844967693281,0.23404579048496868,0.07249214478557643
OneNearestNeighborCoarse2Fine 2048 2000
forced batch size: 16


100%|██████████| 131/131 [11:09<00:00,  5.11s/it]


Time taken to predict val: 663.155 seconds
{'Bleu_1': 0.3872947272111692, 'Bleu_2': 0.22740952312875445, 'Bleu_3': 0.1458779376647093, 'Bleu_4': 0.09951423421100498, 'ROUGE_L': 0.27618053100766915, 'CIDEr': 0.1475973141616982}
OneNearestNeighborCoarse2Fine,2022-05-02 20:20:53.727879,2048,2000,val,0.3872947272111692,0.22740952312875445,0.1458779376647093,0.09951423421100498,0.27618053100766915,0.1475973141616982
forced batch size: 16


100%|██████████| 229/229 [06:56<00:00,  1.82s/it]


Time taken to predict test: 405.833 seconds
{'Bleu_1': 0.3415027952160681, 'Bleu_2': 0.1850045052849492, 'Bleu_3': 0.10868861232731189, 'Bleu_4': 0.06841548589888627, 'ROUGE_L': 0.23362773233872947, 'CIDEr': 0.07265239908288947}
OneNearestNeighborCoarse2Fine,2022-05-02 20:28:04.244066,2048,2000,test,0.3415027952160681,0.1850045052849492,0.10868861232731189,0.06841548589888627,0.23362773233872947,0.07265239908288947
OneNearestNeighborCoarse2Fine 2048 3000
forced batch size: 16


100%|██████████| 131/131 [11:04<00:00,  5.07s/it]


Time taken to predict val: 658.744 seconds
{'Bleu_1': 0.38531075838351664, 'Bleu_2': 0.22665312775039692, 'Bleu_3': 0.14546795113409094, 'Bleu_4': 0.09918486411134457, 'ROUGE_L': 0.2769879028810313, 'CIDEr': 0.1499924053553085}
OneNearestNeighborCoarse2Fine,2022-05-02 20:39:26.944314,2048,3000,val,0.38531075838351664,0.22665312775039692,0.14546795113409094,0.09918486411134457,0.2769879028810313,0.1499924053553085
forced batch size: 16


100%|██████████| 229/229 [06:58<00:00,  1.83s/it]


Time taken to predict test: 408.804 seconds
{'Bleu_1': 0.342281023310038, 'Bleu_2': 0.18618295092497203, 'Bleu_3': 0.1097582162174425, 'Bleu_4': 0.06959025506714356, 'ROUGE_L': 0.23498385443420397, 'CIDEr': 0.07190956914805945}
OneNearestNeighborCoarse2Fine,2022-05-02 20:46:39.692846,2048,3000,test,0.342281023310038,0.18618295092497203,0.1097582162174425,0.06959025506714356,0.23498385443420397,0.07190956914805945
OneNearestNeighborCoarse2Fine 2048 4000
forced batch size: 16


100%|██████████| 131/131 [10:59<00:00,  5.04s/it]


Time taken to predict val: 654.330 seconds
{'Bleu_1': 0.3853319568186967, 'Bleu_2': 0.2267026927890868, 'Bleu_3': 0.14551311759812988, 'Bleu_4': 0.09908181609473281, 'ROUGE_L': 0.276378986867147, 'CIDEr': 0.15593824569082942}
OneNearestNeighborCoarse2Fine,2022-05-02 20:57:58.374450,2048,4000,val,0.3853319568186967,0.2267026927890868,0.14551311759812988,0.09908181609473281,0.276378986867147,0.15593824569082942
forced batch size: 16


100%|██████████| 229/229 [06:59<00:00,  1.83s/it]


Time taken to predict test: 410.099 seconds
{'Bleu_1': 0.3405762663095114, 'Bleu_2': 0.18504360921440205, 'Bleu_3': 0.10910053944152859, 'Bleu_4': 0.06915633653176438, 'ROUGE_L': 0.23434117316595268, 'CIDEr': 0.07346506694150733}
OneNearestNeighborCoarse2Fine,2022-05-02 21:05:11.943960,2048,4000,test,0.3405762663095114,0.18504360921440205,0.10910053944152859,0.06915633653176438,0.23434117316595268,0.07346506694150733


In [16]:
models = [OneNearestNeighbor, OneNearestNeighborCoarse2Fine]
seeds = [0]
dims = [4096, 8192]

for model_class in models:
    for dim in dims:
        for seed in seeds:
            # path_to_check = f'results/{model_class.__name__}_test_results_{dim}_{seed}.csv'
            # exist = os.path.exists(path_to_check)
            # if not exist:
            print(model_class.__name__, dim, seed)
            e2e_benchmark(model_class, seed=seed, project_dim=dim, load=True)

OneNearestNeighbor 4096 0


  0%|          | 0/3 [00:00<?, ?it/s]

In [12]:
def chexpertify(captions, batch_size):
    """
    Returns a dataframe with labels and captions alongside
    Example usage:

        true_df = chexpertify(test_captions, batch_size=12)
        pred_df = chexpertify(test_predicted_reports, batch_size=12)

    """
    df = []
    data_loader = DataLoader(
        captions, 
        batch_size=batch_size,
        num_workers=0,
        pin_memory=True
    )

    for batch in tqdm(data_loader):
        labels = chexpert(batch, tokenizer)
        df.append(labels)


    df = pd.concat(df).reset_index(drop=True)
    df['captions'] = tokenizer.decode(captions)
    return df

def e2e_chexpertify(model_class, seed, project_dim, load=False):
    """
    Perform vector acquisition, 1-NN, and evaluate on val and test set, and save results into files
    """

    # Get vector
    train_image_embeddings, train_captions, val_image_embeddings, val_captions, test_image_embeddings, test_captions = get_vectors(seed, project_dim, load=load)

    # Model selection
    model = model_class()
    if isinstance(model, SimilaritySearchCoarse2Fine):
        model.assign_encoder(encoder)

    # Train
    model.fit(train_image_embeddings, train_captions)

    # Get val captions
    if isinstance(model, SimilaritySearchCoarse2Fine):
        # Coarse2Fine models needs an image loader for the encoder to predict 14 diseases
        predicted_reports, val_time = predict(model, val_image_embeddings, batch_size=1024, image_loader=val_loader)
    else:
        predicted_reports, val_time = predict(model, val_image_embeddings, batch_size=1024)
    print(f"Time taken to predict val: {val_time:.3f} seconds")

    val_pred_df = chexpertify(predicted_reports, batch_size=12)
    val_true_df = chexpertify(val_captions, batch_size=12)

    # Save Results to csv
    val_pred_df.to_csv(
        f'chexpertify/{type(model).__name__}_val_pred_df_{project_dim}_{seed}.csv', index=False
    )
    val_true_df.to_csv(
        f'chexpertify/{type(model).__name__}_val_true_df_{project_dim}_{seed}.csv', index=False
    )
    print(val_pred_df.head())
    print(val_true_df.head())

    # Get test captions
    if isinstance(model, SimilaritySearchCoarse2Fine):
        # Coarse2Fine models needs an image loader for the encoder to predict 14 diseases
        predicted_reports, test_time = predict(model, test_image_embeddings, batch_size=1024, image_loader=test_loader)
    else:
        predicted_reports, test_time = predict(model, test_image_embeddings, batch_size=1024)
    print(f"Time taken to predict test: {test_time:.3f} seconds")

    test_pred_df = chexpertify(predicted_reports, batch_size=12)
    test_true_df = chexpertify(test_captions, batch_size=12)

    # Save Results to csv
    test_pred_df.to_csv(
        f'chexpertify/{type(model).__name__}_test_pred_df_{project_dim}_{seed}.csv', index=False
    )
    test_true_df.to_csv(
        f'chexpertify/{type(model).__name__}_test_true_df_{project_dim}_{seed}.csv', index=False
    )
    print(test_pred_df.head())
    print(test_true_df.head())

def load_chexpertify_results(model_class, seed, project_dim):
    val_pred_df = pd.read_csv(f'chexpertify/{type(model_class()).__name__}_val_pred_df_{project_dim}_{seed}.csv')
    val_true_df = pd.read_csv(f'chexpertify/{type(model_class()).__name__}_val_true_df_{project_dim}_{seed}.csv')
    test_pred_df = pd.read_csv(f'chexpertify/{type(model_class()).__name__}_test_pred_df_{project_dim}_{seed}.csv')
    test_true_df = pd.read_csv(f'chexpertify/{type(model_class()).__name__}_test_true_df_{project_dim}_{seed}.csv')
    return val_pred_df, val_true_df, test_pred_df, test_true_df

In [13]:
e2e_chexpertify(OneNearestNeighbor, 0, 128, load=True)
e2e_chexpertify(OneNearestNeighbor, 0, 256, load=True)

100%|██████████| 3/3 [00:09<00:00,  3.13s/it]


Time taken to predict val: 9.380 seconds


100%|██████████| 174/174 [00:28<00:00,  6.01it/s]
100%|██████████| 174/174 [00:25<00:00,  6.92it/s]


   Enlarged Cardiomediastinum  Cardiomegaly  Lung Opacity  Lung Lesion  Edema  \
0                         1.0           1.0           0.0          0.0    0.0   
1                         0.0           0.0           0.0          0.0    0.0   
2                         1.0           1.0           1.0          0.0    0.0   
3                         1.0           1.0           1.0          0.0    1.0   
4                         1.0           1.0           1.0          0.0    1.0   

   Consolidation  Pneumonia  Atelectasis  Pneumothorax  Pleural Effusion  \
0            0.0        0.0          0.0           0.0               0.0   
1            0.0        0.0          0.0           0.0               0.0   
2            1.0        1.0          0.0           0.0               1.0   
3            1.0        1.0          1.0           0.0               1.0   
4            1.0        1.0          1.0           0.0               1.0   

   Pleural Other  Fracture  Support Devices  No Finding 

100%|██████████| 4/4 [00:16<00:00,  4.19s/it]


Time taken to predict test: 16.728 seconds


100%|██████████| 305/305 [00:56<00:00,  5.40it/s]
100%|██████████| 305/305 [00:50<00:00,  6.05it/s]


   Enlarged Cardiomediastinum  Cardiomegaly  Lung Opacity  Lung Lesion  Edema  \
0                         0.0           0.0           1.0          0.0    0.0   
1                         1.0           1.0           1.0          0.0    0.0   
2                         0.0           0.0           0.0          0.0    0.0   
3                         0.0           0.0           0.0          0.0    0.0   
4                         0.0           0.0           0.0          0.0    0.0   

   Consolidation  Pneumonia  Atelectasis  Pneumothorax  Pleural Effusion  \
0            0.0        0.0          1.0           0.0               0.0   
1            0.0        0.0          1.0           0.0               0.0   
2            0.0        0.0          0.0           0.0               0.0   
3            0.0        0.0          0.0           0.0               0.0   
4            0.0        0.0          0.0           0.0               0.0   

   Pleural Other  Fracture  Support Devices  No Finding 

100%|██████████| 3/3 [00:11<00:00,  3.78s/it]


Time taken to predict val: 11.335 seconds


100%|██████████| 174/174 [00:29<00:00,  5.80it/s]
100%|██████████| 174/174 [00:25<00:00,  6.81it/s]


   Enlarged Cardiomediastinum  Cardiomegaly  Lung Opacity  Lung Lesion  Edema  \
0                         1.0           1.0           1.0          0.0    1.0   
1                         1.0           1.0           1.0          0.0    0.0   
2                         1.0           1.0           1.0          0.0    1.0   
3                         1.0           1.0           1.0          0.0    0.0   
4                         1.0           1.0           1.0          0.0    1.0   

   Consolidation  Pneumonia  Atelectasis  Pneumothorax  Pleural Effusion  \
0            0.0        0.0          0.0           0.0               0.0   
1            1.0        1.0          1.0           0.0               1.0   
2            1.0        1.0          1.0           0.0               1.0   
3            0.0        0.0          1.0           0.0               1.0   
4            1.0        1.0          1.0           0.0               1.0   

   Pleural Other  Fracture  Support Devices  No Finding 

100%|██████████| 4/4 [00:19<00:00,  4.95s/it]


Time taken to predict test: 19.770 seconds


100%|██████████| 305/305 [00:55<00:00,  5.47it/s]
100%|██████████| 305/305 [00:50<00:00,  6.07it/s]


   Enlarged Cardiomediastinum  Cardiomegaly  Lung Opacity  Lung Lesion  Edema  \
0                         0.0           0.0           0.0          0.0    0.0   
1                         0.0           0.0           0.0          0.0    0.0   
2                         1.0           1.0           1.0          1.0    1.0   
3                         0.0           0.0           0.0          0.0    0.0   
4                         0.0           0.0           0.0          0.0    0.0   

   Consolidation  Pneumonia  Atelectasis  Pneumothorax  Pleural Effusion  \
0            0.0        0.0          0.0           0.0               0.0   
1            0.0        0.0          0.0           0.0               0.0   
2            1.0        1.0          1.0           0.0               1.0   
3            0.0        1.0          0.0           0.0               0.0   
4            0.0        0.0          0.0           0.0               0.0   

   Pleural Other  Fracture  Support Devices  No Finding 

In [31]:
val_pred_df, val_true_df, test_pred_df, test_true_df = load_chexpertify_results(OneNearestNeighborCoarse2Fine, 0, 128)

In [32]:
val_pred_df

Unnamed: 0,Enlarged Cardiomediastinum,Cardiomegaly,Lung Opacity,Lung Lesion,Edema,Consolidation,Pneumonia,Atelectasis,Pneumothorax,Pleural Effusion,Pleural Other,Fracture,Support Devices,No Finding,captions
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,rightsided terminates in the low svc without e...
1,1.0,1.0,1.0,0.0,0.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0,0.0,the cardiac silhouette is mildly enlarged but ...
2,1.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,compared with prior there has been no signific...
3,1.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,blunting at the left costophrenic may represen...
4,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,there is no change . and rightsided chest tube...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2080,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,multiple bilateral focal concerning for pneumo...
2081,0.0,0.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,there is a persistent lower lingular opacifica...
2082,0.0,0.0,1.0,1.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,the patient is status post right upper lobe re...
2083,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,frontal and lateral views of the chest were ob...


### Test from remote run on server

Requirements: `predicted_val_captions.npy` and `predicted_test_captions.npy`

In [8]:
pred_captions = np.load('remote_server/predicted_val_captions.npy')
until = pred_captions.shape[0]
true_captions = np.load('mimic_cxr/raw_embeddings/val/captions_0.npy')[:until]
evaluate_clinical(true_captions, pred_captions, batch_size=12)

100%|██████████| 174/174 [00:52<00:00,  3.29it/s]
100%|██████████| 174/174 [01:05<00:00,  2.64it/s]


Metrics,Recall,Precision,F1
Enlarged Cardiomediastinum,0.729483,0.658436,0.692141
Cardiomegaly,0.657244,0.613861,0.634812
Lung Opacity,0.744152,0.712885,0.728183
Lung Lesion,0.266355,0.360759,0.306452
Edema,0.603448,0.550562,0.575793
Consolidation,0.620123,0.683258,0.650161
Pneumonia,0.438119,0.517544,0.474531
Atelectasis,0.624299,0.660079,0.641691
Pneumothorax,0.148148,0.171429,0.15894
Pleural Effusion,0.640426,0.650108,0.64523


In [12]:
pred_captions = np.load('remote_server/predicted_test_captions.npy')
until = pred_captions.shape[0]
true_captions = np.load('mimic_cxr/raw_embeddings/test/captions_0.npy')[:until]
evaluate_clinical(true_captions, pred_captions, batch_size=12)

100%|██████████| 42/42 [00:15<00:00,  2.72it/s]
100%|██████████| 42/42 [00:09<00:00,  4.50it/s]


Metrics,Recall,Precision,F1
Enlarged Cardiomediastinum,0.746429,0.741135,0.743772
Cardiomegaly,0.691304,0.679487,0.685345
Lung Opacity,0.784722,0.733766,0.758389
Lung Lesion,0.2,0.367647,0.259067
Edema,0.565934,0.559783,0.562842
Consolidation,0.480952,0.554945,0.515306
Pneumonia,0.413408,0.544118,0.469841
Atelectasis,0.610465,0.486111,0.541237
Pneumothorax,0.071429,0.142857,0.095238
Pleural Effusion,0.527607,0.530864,0.529231
