In [1]:
import sys
sys.path.append("../")

import jax
import jax.numpy as np
from transformers import FlaxCLIPModel, AutoProcessor, AutoTokenizer

%matplotlib inline
%load_ext autoreload
%autoreload 2

2024-02-12 18:01:07.414616: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-12 18:01:07.414658: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-12 18:01:07.415937: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch16")
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch16")
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch16")

In [3]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab

from matplotlib import cm
cmap = matplotlib.colormaps.get_cmap('viridis_r')

# Ignore warning
import warnings
import logging

logging.getLogger('matplotlib').setLevel(logging.ERROR)
warnings.filterwarnings("ignore",category=matplotlib.MatplotlibDeprecationWarning)

# Get plot params

from plot_params import params
pylab.rcParams.update(params)

# Set default colors to load at will
cols_default = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [4]:
import optax 
from flax.training import train_state
import flax
import orbax

replicate = flax.jax_utils.replicate
unreplicate = flax.jax_utils.unreplicate

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=1e-4,
    warmup_steps=5_000,
    decay_steps=100_000,
)

tx = optax.adamw(learning_rate=schedule, weight_decay=1e-4)
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=tx)

In [5]:
import yaml
from ml_collections.config_dict import ConfigDict

logging_dir = '../logging/proposals/'
run_name = 'ancient-pine-88'

config_file = "{}/{}/config.yaml".format(logging_dir, run_name)

with open(config_file, 'r') as file:
    config = yaml.safe_load(file)
    
config = ConfigDict(config)

In [1]:

# ckpt_dir = "{}/{}".format(logging_dir, run_name)  # Load SLURM run

# best_fn = lambda metrics: metrics[f"val/top_10_accuracy"]

# mgr_options = orbax.checkpoint.CheckpointManagerOptions(step_prefix=f'step', best_fn=best_fn, best_mode='min', create=False)
# ckpt_mgr = orbax.checkpoint.CheckpointManager(f"{ckpt_dir}/ckpts/", orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options)

# restore_args = flax.training.orbax_utils.restore_args_from_target(state, mesh=None)
# restored_state = ckpt_mgr.restore(ckpt_mgr.latest_step(), items=state, restore_kwargs={'restore_args': restore_args})

# if state is restored_state:
#     raise FileNotFoundError(f"Did not load checkpoint correctly")

## Multiple runs

In [22]:
from tqdm import tqdm
import tensorflow as tf
from utils.dataset_utils import make_dataloader, create_input_iter
from utils.text_utils import process_truncate_captions
from dm_pix import center_crop, random_crop, rotate, random_flip_up_down, random_flip_left_right

In [24]:
run_labels = ['ancient-pine-88',]
run_legends = ["Fine-tune (abstracts)"]

data_type = ["abstract"]
use_sum1 = [False]

In [25]:
def retrieval_eval_metric(text_embeds, image_embeds, k=[1, 5, 10, 20]):
    """ Compute the top-k retrieval accuracy.
    """

    # Get shapes
    bs = text_embeds.shape[0]
    axis_size = jax.lax.psum(1, axis_name="batch")

    # Gather the embeddings from all devices
    all_text_embeds = jax.lax.all_gather(text_embeds, axis_name="batch").reshape(-1, text_embeds.shape[-1])
    all_image_embeds = jax.lax.all_gather(image_embeds, axis_name="batch").reshape(-1, image_embeds.shape[-1])

    # Compute the full matrix of logitseval
    all_logits = np.matmul(all_text_embeds, all_image_embeds.T)

    # Compute the global top-k indices for the maximum k value
    max_k = max(k)
    top_k_indices = np.argsort(all_logits, axis=-1)[:, -max_k:]

    # Compute the correct indices for each row
    correct_indices = np.arange(bs * axis_size)[:, None]

    metrics = {}
    for current_k in k:
        # Check if the correct image (diagonal) is in the current top-k for each text embedding
        correct_in_top_k = np.any(top_k_indices[:, -current_k:] == correct_indices, axis=-1)
        accuracy = np.mean(correct_in_top_k.astype(np.float32))
        metrics[f"top_{current_k}_accuracy"] = accuracy

    return metrics

In [26]:
from functools import partial
from einops import rearrange
import numpy as onp
from models.losses import softmax_loss

@partial(jax.pmap, axis_name="batch")
def get_features(state, input_ids, pixel_values, attention_mask):

    # captions_feat = model.get_text_features(input_ids, attention_mask, params=state.params)
    # images_feat = model.get_image_features(pixel_values, params=state.params)

    outputs = state.apply_fn(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, params=state.params)
    captions_feat, images_feat = outputs["text_embeds"], outputs["image_embeds"]

    outputs['logit_scale'] = state.params['logit_scale']
    outputs['logit_bias'] = state.params.get('logit_bias', 0.)

    loss = softmax_loss(outputs)
    retrieval_metrics = retrieval_eval_metric(captions_feat, images_feat)

    metrics = {"loss": loss}
    for key, value in retrieval_metrics.items():
        metrics[key] = value
    
    return metrics
    
def get_features_ds(state, ds, truncate=False):

    batches = iter(ds)
    
    num_local_devices = jax.local_device_count()
    replicate = flax.jax_utils.replicate
    
    total_batches = sum(1 for _ in ds) - 1
    current_batch = 0

    retrieval_eval_metrics = []

    for (images, captions) in tqdm(batches, total=total_batches):
        if current_batch == total_batches - 1:
            break
    
        images = np.array(images)

        if truncate:
            captions = process_truncate_captions(captions, jax.random.PRNGKey(onp.random.randint(99999)), max_length_words=config.data.max_length_words)
        else:
            captions = captions.numpy().tolist()
            captions = [c.decode('utf-8') for c in captions]

        rng_eval = jax.random.PRNGKey(onp.random.randint(99999))
        
        # Rotations
        rng_eval, _ = jax.random.split(rng_eval)
        rotation_angles = jax.random.uniform(rng_eval, shape=(images.shape[0],)) * 2 * np.pi  # Angles in radians
        images = jax.vmap(partial(rotate, mode='constant', cval=1.))(images, rotation_angles)
        
        # Flips
        rng_eval, _ = jax.random.split(rng_eval)
        images = jax.vmap(partial(random_flip_up_down, key=rng_eval))(image=images)

        rng_eval, _ = jax.random.split(rng_eval)
        images = jax.vmap(partial(random_flip_left_right, key=rng_eval))(image=images)

        images = jax.vmap(random_crop, in_axes=(None,0,None))(rng_eval, images, (model.config.vision_config.image_size, model.config.vision_config.image_size, 3))

        input = processor(text=captions, images=(images * 255.).astype(np.uint8), return_tensors="np", padding="max_length", truncation=True, max_length=77)
    
        batch = jax.tree_map(lambda x: np.split(x, num_local_devices, axis=0), input.data)
        batch = jax.tree_map(lambda x: np.array(x, dtype=np.float32), batch)

        metrics = get_features(replicate(state), np.array(batch["input_ids"]), np.array(batch["pixel_values"]), np.array(batch["attention_mask"]))

        retrieval_eval_metrics.append(metrics)
        
        current_batch += 1

    return retrieval_eval_metrics

In [2]:
accuracy_lists = []
for idx, run_name in enumerate(tqdm(run_labels[:])):

    files = tf.io.gfile.glob(f"/n/holyscratch01/iaifi_lab/smsharma/hubble_data/tfrecords_v5/*val*.tfrecord")
    ds = make_dataloader(files, batch_size=100, seed=42, split="val", shuffle=True, caption_type=data_type[idx])
    
    ckpt_dir = "{}/{}".format(logging_dir, run_name)  # Load SLURM run
    
    best_fn = lambda metrics: metrics[f"val/loss"]
    
    mgr_options = orbax.checkpoint.CheckpointManagerOptions(step_prefix=f'step', best_fn=best_fn, best_mode='min', create=False)
    ckpt_mgr = orbax.checkpoint.CheckpointManager(f"{ckpt_dir}/ckpts/", orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options)
    
    restore_args = flax.training.orbax_utils.restore_args_from_target(state, mesh=None)
    restored_state = ckpt_mgr.restore(ckpt_mgr.latest_step(), items=state, restore_kwargs={'restore_args': restore_args})
    
    if state is restored_state:
        raise FileNotFoundError(f"Did not load checkpoint correctly")

    retrieval_eval_metrics = get_features_ds(restored_state, ds, truncate=data_type[idx] == "abstract",)
    accuracy_lists.append(retrieval_eval_metrics)

In [29]:
from flax.training import common_utils

for idx, metric in enumerate(accuracy_lists):
    val_metrics = common_utils.get_metrics(metric)
    print(run_legends[idx], {f"val/{k}": v for k, v in jax.tree_map(lambda x: x.mean(), val_metrics).items()})

Fine-tune (abstracts) {'val/loss': 3.2589102, 'val/top_10_accuracy': 0.66666675, 'val/top_1_accuracy': 0.20033331, 'val/top_20_accuracy': 0.79533327, 'val/top_5_accuracy': 0.522}
Fine-tune (summaries) {'val/loss': 3.3698084, 'val/top_10_accuracy': 0.625, 'val/top_1_accuracy': 0.22099997, 'val/top_20_accuracy': 0.74966675, 'val/top_5_accuracy': 0.4993333}
