In [29]:
import sys
sys.path.append("../")

from transformers import AutoTokenizer
from models.models import CLIPTextTransformer, CLIPVisionTransformer, CLIPModel

In [52]:
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

caption = "Supernova (SN) 2019ehk in M100 is the closest known Calcium-rich (Ca-rich) transient and the only object in this class with an X-ray detection. Prompt, high-cadence follow-up of this transient across the EM spectrum, in addition to pre-explosion HST imaging, has indicated that the progenitor star was likely low mass and surrounded by dense circumstellar material (CSM) whose geometry/density was capable of producing luminous X-ray emission as well as a double-peaked light curve. The close proximity of SN 2019ehk provides the first opportunity to track the photometric evolution of a Ca-rich transient at late phases (>300 days) when the SN luminosity is governed by radioactive decay and/or additional power sources e.g., CSM, and is too faint for ground-based observatories. These objects typically decrease in magnitude rapidly and thus their late-time decline rate and power source is unknown. Here we propose multi-color imaging of SN 2019ehk in order to understand its late-time bolometric behavior and to constrain the total mass of Ni-56 synthesized in the explosion. This will allow us to test whether SN 2019ehk is powered solely by radioactive decay or by additional CSM at large distances from the progenitor system."

txt_inputs = tokenizer(
                    4 * [caption],
                    padding="max_length",
                    truncation=True,
                    max_length=300,
                    return_tensors="np",
                )
txt_inputs = {k: txt_inputs[k] for k in ["input_ids", "attention_mask"]}

In [53]:
text_config = {
    "dtype": "float32",
    "activations": ["gelu"],
    "use_bias": False,
    "force_scale": False,
    "attention_dropout": 0.0,
    "mlp_dropout_rate": 0.0,
    "unroll": 100,
    "gradient_checkpointing": False,
    "eos_token_id": 49407,
    "vocab_size": 50000,
    "hidden_size": 512,
    "max_length": 300,
    "num_layers": 5,
    "use_rmsnorm": True,
    "ln_type": "normformer",
    "num_heads": 8,
    "position_embedding_type": "rotary",
    "use_causal_mask": False,
    "mlp_dim": 1024
  }

vision_config ={
  "position_embedding_type": "sincos2d",
  "dtype": "float32",
  "activations": ["gelu"],
  "use_bias": False,
  "force_scale": False,
  "attention_dropout": 0.0,
  "mlp_dropout_rate": 0.0,
  "use_cls_token": False,
  "unroll": 100,
  "gradient_checkpointing": True,
  "image_size": 256,
  "hidden_size": 512,
  "patch_size": 16,
  "num_layers": 5,
  "use_rmsnorm": True,
  "ln_type": "normformer",
  "num_heads": 8,
  "use_causal_mask": False,
  "mlp_dim": 1024
}

In [54]:

transformer = CLIPTextTransformer(**text_config)

In [55]:
txt_inputs['input_ids'].shape, txt_inputs['attention_mask'].shape

((4, 300), (4, 300))

In [56]:
import jax

key = jax.random.PRNGKey(0)
transformer.init_with_output(key, txt_inputs['input_ids'], txt_inputs['attention_mask']);

In [57]:
vit = CLIPVisionTransformer(**vision_config)

In [58]:
from PIL import Image
import jax.numpy as np
import matplotlib.pyplot as plt

img = Image.open("../data/observations/proposal_15727/hst_15727_01_wfc3_ir_total_ie0w01_drz.jpg")

# Convert to RGB and to tensor
img = img.convert('RGB')
img = np.array(img)

vit.init_with_output(key, np.array(4 * [img]));



In [60]:
clip = CLIPModel(text_config=text_config, vision_config=vision_config, projection_dim=256)
outputs, params = clip.init_with_output(key, txt_inputs['input_ids'], np.array(4 * [img]), txt_inputs['attention_mask'])



In [69]:
def mini_batch_sigmoid_loss(text_embeds, image_embeds, logit_scale, logit_bias, negative_samples):
    """Positive samples are on the diagonal"""
    bs = text_embeds.shape[0]
    if negative_samples:
        labels = -np.ones((bs, bs))
    else:
        labels = 2 * np.eye(bs) - np.ones((bs, bs))
    logits = np.matmul(text_embeds, image_embeds.T) * logit_scale + logit_bias
    return -np.mean(np.log(1 + np.exp(-labels * logits)))

def sigmoid_loss(outputs):
    text_embeds = outputs["text_embeds"]
    image_embeds = outputs["image_embeds"]
    logit_scale = outputs["logit_scale"]
    logit_bias = outputs["logit_bias"]
    
    bs = text_embeds.shape[0]

    # Compute the positive samples loss
    loss = mini_batch_sigmoid_loss(text_embeds, image_embeds, logit_scale, logit_bias, negative_samples=False)

    # Create a tensor of all shifted versions of image embeddings
    shifted_image_embeds = np.stack([np.roll(image_embeds, shift=-i, axis=0) for i in range(1, bs)])

    # Compute the negative samples logits using einsum
    all_neg_logits = np.einsum('bi,aji->abj', text_embeds, shifted_image_embeds)
    all_neg_logits = all_neg_logits * logit_scale + logit_bias

    neg_labels = -np.ones(all_neg_logits.shape)
    neg_loss = -np.mean(np.log(1 + np.exp(-neg_labels * all_neg_logits)))

    loss = (loss + (bs - 1) * neg_loss) / bs

    return loss

sigmoid_loss(outputs)

Array(-1.405358, dtype=float32)