# Convert images and text in CLIP embeddings

Make sure in your directory are present hmc_info.csv and a folder named HMC with all the images

In [None]:
import json
import os
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import Resize, ToTensor
import clip
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import csv
import torch.multiprocessing as mp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

In [None]:
img_folder = 'HMC/'                         # Folder containing the images
num_images = len(os.listdir(img_folder))
print("Number of elements in the image folder:", num_images)

jsonl_file = os.path.join(os.getcwd(), 'hmc_info.csv')

In [None]:
class MultimodalDataset(Dataset):
    def __init__(self, jsonl_file, img_folder=None, model_name="ViT-L/14", device='cpu', load_text=True, load_images=True, split='train'):
        self.device = device
        self.load_text = load_text
        self.load_images = load_images
        self.model_name = model_name
        self.data = self._load_data(jsonl_file, split)
        self.split = split

        # Load model
        if self.load_text or self.load_images:
            self.clip_model = clip.load(model_name, device=device)[0]  
            self.clip_processor = clip.tokenize
        
        if self.load_images:
            assert img_folder is not None, "Image folder must be provided to load images."
            self.img_folder = img_folder
            self.image_size = 224
            self.image_transform = Resize((self.image_size, self.image_size))  
            self.to_tensor = ToTensor()  

    def _load_data(self, file_path, split):
        
        """
        Loads data from a CSV file.
        """
        data = []
        with open(file_path, 'r', encoding = 'utf-8') as file:
            reader = csv.DictReader(file)
            for row in reader:
                if row['split'] == split:
                    data.append(row)
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item_data = self.data[idx]
        result = item_data.copy() 
        
        if self.load_text:
            if len(item_data['text'])>250:
                item_data['text'] = item_data['text'][:250]
                #print(1)
            text_tokens = self.clip_processor([item_data['text']]).to(self.device)
            with torch.no_grad():
                text_features = self.clip_model.encode_text(text_tokens)
            result['text_embedding'] = text_features.cpu()

        if self.load_images and 'img' in item_data:
            image_path = os.path.join(self.img_folder, item_data['img'].replace('img/', ''))
            if os.path.exists(image_path):
                image = Image.open(image_path).convert('RGB')
                image = self.image_transform(image)  # Apply resizing
                image = self.to_tensor(image)  # Convert image to tensor
                image_inputs = image.unsqueeze(0).to(self.device)
                with torch.no_grad():
                    image_features = self.clip_model.encode_image(image_inputs)
                result['image_embedding'] = image_features.cpu()

        return result


In [None]:
mp.set_start_method('spawn')

splits = ['train', 'dev_seen','dev_unseen', 'test_seen', 'test_unseen']

for split in splits:
    dataset = MultimodalDataset(
            jsonl_file=jsonl_file,
            img_folder=img_folder,
            device=device,
            load_text= True,
            load_images= True,
            split=split
    )
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, pin_memory=True, num_workers=4)
    total_batches = len(dataloader)
    progress_bar = tqdm(total=total_batches, desc='Processing Batches')

    for batch in dataloader:
        progress_bar.update()
    progress_bar.close()
    datalist=[]
    for data in dataloader.dataset:
            datalist.append(data)

    # Remove unnecessary data
    for d in datalist:
        d.pop('text')
        d.pop('img')
        d['text'] = d['text_embedding'].to(dtype=torch.float32)
        d['image'] = d['image_embedding'].to(dtype=torch.float32)
        d.pop('text_embedding')
        d.pop('image_embedding')
        d.pop('split')

    data_dir = 'data'

    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    if not os.path.exists(f'{data_dir}/hateful_memes'):
        os.makedirs(f'{data_dir}/hateful_memes')
    torch.save(datalist, f'data/hateful_memes/{split}.pth')