In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%pip install multilingual-clip torch

In [None]:
import requests
import torch
import pandas as pd
import numpy as np
import transformers
from multilingual_clip import pt_multilingual_clip
from sentence_transformers import SentenceTransformer, util
from PIL import Image, ImageFile
from scipy.stats import spearmanr
from sklearn.metrics.pairwise import cosine_similarity

#Extract CLIP embeddings

In [None]:
main_folder = "/content/drive/MyDrive/SemEval2025/task1/"

clip_model_names = [('M-CLIP/XLM-Roberta-Large-Vit-L-14', "clip-ViT-L-14"),
                    ('M-CLIP/XLM-Roberta-Large-Vit-B-32', "clip-ViT-B-32"),
                    ('M-CLIP/LABSE-Vit-L-14', "clip-ViT-L-14")
                    ]

train_folders = [main_folder + "dataset results gpt4-4o/AdMIRe Subtask A Train/",
                 main_folder + "dataset results gpt4-4o/AdMIRe Subtask A PT Train/",
                 main_folder + "dataset results gpt4-4o/AdMIRe Subtask A Test/",
                 main_folder + "dataset results gpt4-4o/AdMIRe Subtask A PT Test/",
                 main_folder + "dataset results gpt4-4o/AdMIRe Subtask A Extended Evaluation/",
                 main_folder + "dataset results gpt4-4o/AdMIRe Subtask A PT Extended Evaluation/"]

llm_names = ["gpt-3.5", "gpt-4", "gpt-4o"]

In [None]:
def load_image(url_or_path):
  return Image.open(url_or_path)

def get_caption_embeddings(text_model, tokenizer, captions):
  truncated_captions = []
  for caption in captions:
    tokens = tokenizer.encode(caption, max_length=512, truncation=True)
    truncated_caption = tokenizer.decode(tokens, skip_special_tokens=True)
    truncated_captions.append(truncated_caption)

  cap_embeddings = [text_model.forward(caption, tokenizer).detach().numpy().squeeze()  for caption in truncated_captions]
  return cap_embeddings

for clip_model_name in clip_model_names:
  text_model_name  = clip_model_name[0]
  img_model_name = clip_model_name[1]

  text_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(text_model_name)
  tokenizer = transformers.AutoTokenizer.from_pretrained(text_model_name)

  img_model = SentenceTransformer(img_model_name)

  for llm_name in llm_names:
    for train_folder in train_folders:
      if "Train" in train_folder:
        train_df_file_path = train_folder + f"subtask_a_train_{llm_name}_meanings.tsv"
      elif "Test" in train_folder:
        train_df_file_path = train_folder + f"subtask_a_test_{llm_name}_meanings.tsv"
      elif "Dev" in train_folder:
        train_df_file_path = train_folder + f"subtask_a_dev_{llm_name}_meanings.tsv"
      elif "Extended" in train_folder:
        if "PT" in train_folder:
          train_df_file_path = train_folder + f"subtask_a_xp_{llm_name}_meanings.tsv"
        else:
          train_df_file_path = train_folder + f"subtask_a_xe_{llm_name}_meanings.tsv"
      else:
        raise NotImplementedError

      #text_augment_bt_file_path = main_folder+ "text_augment_train_A/output_back_translated.tsv"
      #text_augment_pr_file_path = main_folder+ "text_augment_train_A/output_paraphrased.tsv"
      #image_augment_folder      = main_folder+ "image_augment_train_A/augmented_train"

      text_augment_bt_file_path = None
      text_augment_pr_file_path = None
      image_augment_folder      = None

      df = pd.read_csv(train_df_file_path, sep='\t')

      all_text_embeddings = []
      all_img_embeddings = []
      all_cap_embeddings = []
      all_bt_cap_embeddings = []
      all_pr_cap_embeddings = []
      all_aug_img_embeddings = []

      for _, row in df.iterrows():

        #Get data
        compound = row['compound']
        img_names = []
        captions = []
        bt_captions = []
        pr_captions = []
        aug_img_names = []
        for i in range(1,6):
          image_name = row[f"image{i}_name"]
          img_names.append(image_name)

          caption = row[f"image{i}_caption"]
          captions.append(caption)

          if text_augment_bt_file_path is not None:
            bt_caption = row[f"image{i}_caption_bt"]
            bt_captions.append(bt_caption)

          if text_augment_pr_file_path is not None:
            pr_caption = row[f"image{i}_caption_para"]
            pr_captions.append(pr_caption)

          if image_augment_folder is not None:
            aug_image_name = row[f"image{i}_name"].replace(".png", "_aug1.png")
            aug_img_names.append(aug_image_name)

        print(row['compound'], row["sent_type_predicted"], row["sentence_type"])

        # Get text (compound) embedding
        if row["sent_type_predicted"] == 'literal':
          text_embedding = text_model.forward(compound, tokenizer).detach().numpy()
        elif row["sent_type_predicted"] == 'idiomatic':
          meaning = row["meaning"]
          text_embedding = text_model.forward(meaning, tokenizer).detach().numpy()
        else:
          raise NotImplementedError

        # Get image embeddings
        images     = [load_image(f"""{train_folder}/{compound.replace("'", "_")}/{image_name}""") for image_name in img_names]
        img_embeddings     = img_model.encode(images)

        # Get caption embeddings
        cap_embeddings    =  get_caption_embeddings(text_model, tokenizer, captions)

        # Save embeddings
        all_text_embeddings.append(text_embedding)
        all_img_embeddings.append(np.array(img_embeddings))
        all_cap_embeddings.append(np.array([cap_embedding.squeeze() for cap_embedding in cap_embeddings]))

          # Get embedding of augmented data
        if image_augment_folder is not None:
          aug_images = [load_image(f"""{image_augment_folder}/{compound.replace("'", "_")}/{image_name}""") for image_name in aug_img_names]
          aug_img_embeddings = img_model.encode(aug_images)
          all_aug_img_embeddings.append(np.array(aug_img_embeddings))

        if text_augment_bt_file_path is not None:
          bt_cap_embeddings =  get_caption_embeddings(text_model, tokenizer, bt_captions)
          all_bt_cap_embeddings.append(np.array([bt_cap_embedding.squeeze() for bt_cap_embedding in bt_cap_embeddings]))

        if text_augment_pr_file_path is not None:
          pr_cap_embeddings =  get_caption_embeddings(text_model, tokenizer, pr_captions)
          all_pr_cap_embeddings.append(np.array([pr_cap_embedding.squeeze() for pr_cap_embedding in pr_cap_embeddings]))


      stacked_text_embeddings = np.vstack(all_text_embeddings)
      stacked_img_embeddings = np.stack(all_img_embeddings)
      stacked_cap_embeddings = np.stack(all_cap_embeddings)
      print("Shape compound embeddings:", stacked_text_embeddings.shape)
      print("Shape image embeddings:", stacked_img_embeddings.shape)
      print("Shape caption embeddings:", stacked_cap_embeddings.shape)
      np.save(train_folder + f"baseline_text_embeddings_{llm_name}_{text_model_name.replace('/','-')}.npy", stacked_text_embeddings)
      np.save(train_folder + f"baseline_img_embeddings_{llm_name}_{text_model_name.replace('/','-')}.npy", stacked_img_embeddings)
      np.save(train_folder + f"baseline_cap_embeddings_{llm_name}_{text_model_name.replace('/','-')}.npy", stacked_cap_embeddings)

      if text_augment_bt_file_path is not None:
        stacked_bt_cap_embeddings = np.stack(all_bt_cap_embeddings)
        print("Shape augmented caption (back-translated) embeddings:", stacked_bt_cap_embeddings.shape)
        np.save(train_folder + f"bt_cap_embeddings_{llm_name}_{text_model_name.replace('/','-')}.npy", stacked_bt_cap_embeddings)

      if text_augment_pr_file_path is not None:
        stacked_pr_cap_embeddings = np.stack(all_pr_cap_embeddings)
        print("Shape augmented caption (paraphrased) embeddings:", stacked_pr_cap_embeddings.shape)
        np.save(train_folder + f"pr_cap_embeddings_{llm_name}_{text_model_name.replace('/','-')}.npy", stacked_pr_cap_embeddings)

      if image_augment_folder is not None:
        stacked_aug_img_embeddings = np.stack(all_aug_img_embeddings)
        print("Shape augmented image embeddings:", stacked_aug_img_embeddings.shape)
        np.save(train_folder + f"aug_img_embeddings_{llm_name}_{text_model_name.replace('/','-')}.npy", stacked_img_embeddings)