In [1]:
import jax.numpy as np
import jax
from transformers import FlaxCLIPModel, AutoProcessor

import sys
sys.path.append("../")

from models.dataset_utils import make_dataloader, create_input_iter
from models.text_utils import process_truncate_captions, tokenize_captions

  from .autonotebook import tqdm as notebook_tqdm
2023-12-03 22:51:33.187798: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-03 22:51:33.187844: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-03 22:51:33.188997: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
batch_size = 8

files = ['../data/tfrecords_v3/observations_train_1.tfrecord','../data/tfrecords_v3/observations_train_2.tfrecord',]
ds = make_dataloader(files, batch_size=batch_size, seed=42)
batches = create_input_iter(ds)

In [3]:
model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [4]:
rng = jax.random.PRNGKey(42)
rng, rng_aug = jax.random.split(rng)

In [5]:
images, captions = next(iter(ds))
images = np.array(images)

captions = process_truncate_captions(captions, rng_aug, max_length_words=77)
inputs = processor(text=captions, images=images * 255.,  return_tensors="np", padding="max_length", truncation=True, max_length=77)
batch = inputs.data

In [6]:
outputs = model(**batch, params=model.params)

In [7]:
params_init = model.module.init(rng, input_ids=np.zeros((1, 77)), 
                   attention_mask=np.zeros((1, 77)),
                   pixel_values=np.zeros((1, 224, 224, 3)),
                   position_ids=np.zeros((1, 77)))

In [8]:
model.params.keys()

dict_keys(['logit_scale', 'text_model', 'text_projection', 'vision_model', 'visual_projection'])

In [21]:
model.params['vision_model'] = params_init['params']['vision_model']
model.params['visual_projection'] = params_init['params']['visual_projection']

In [22]:
# model.params['vision_model']

In [23]:
import optax
from flax.core import FrozenDict
from flax.training import train_state

tx = optax.adam(1e-3)
state = train_state.TrainState.create(apply_fn=model.__call__, params=FrozenDict(model.params), tx=tx)

In [46]:
import orbax
from flax.training import orbax_utils

CKPT_DIR = '/n/holystore01/LABS/iaifi_lab/Users/smsharma/multimodal-data/notebooks/tmp/logging/'


def best_fn():
    return best_fn['
# At the top level
mgr_options = orbax.checkpoint.CheckpointManagerOptions(
      create=True, step_prefix='step',
      max_to_keep=2, best_fn=None, best_mode='max')

ckpt_mgr = orbax.checkpoint.CheckpointManager(
  CKPT_DIR,
  orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options)

# Inside your training loop
for step in range(10):
  # do training
  save_args = orbax_utils.save_args_from_target(state)
  ckpt_mgr.save(step, state, save_kwargs={'save_args': save_args}, metrics={})