In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from torch.utils.data import Dataset, DataLoader
from enum import Enum
from transformers import (
    GPT2Tokenizer,
    GPT2LMHeadModel,
    CLIPProcessor,
    CLIPModel,
    AdamW,
    get_linear_schedule_with_warmup,
)
from tqdm import tqdm
import os
import pickle
import sys
import argparse
import json
from typing import Tuple, Optional, Union
from datasets import load_dataset

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("MPS available.")
else:
    if not torch.backends.mps.is_built():
        print(
            "MPS not available because the current PyTorch install was not "
            "built with MPS enabled."
        )
    else:
        print(
            "MPS not available because the current MacOS version is not 12.3+ "
            "and/or you do not have an MPS-enabled device on this machine."
        )



MPS available.


In [2]:

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(DEVICE)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [3]:
dataset = load_dataset("arampacha/rsicd")  # , split="train")


In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 8734
    })
    test: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 1093
    })
    valid: Dataset({
        features: ['filename', 'captions', 'image'],
        num_rows: 1094
    })
})

In [5]:
data_path = "../data/"

if not os.path.exists(data_path):
    os.makedirs(data_path)

file_name = "rsicd_dataset.json"

with open(data_path + file_name, "w") as f:
    id = 0
    for split in ["train", "valid", "test"]:
        progress = tqdm(dataset[split])
        progress.set_description(f"Processing {split} dataset")
        for i, row in enumerate(dataset[split]):
            caption = row["captions"][0]
            caption_tokens = torch.tensor(tokenizer.encode(caption), dtype=torch.int64).to("cpu")
            
            clip_inputs = clip_processor(images=row["image"], return_tensors="pt").to(DEVICE)
            clip_embedding = clip_model.get_image_features(**clip_inputs).squeeze(0).to("cpu") # Shape: (512,)
            # free memory
            del clip_inputs
            torch.mps.empty_cache()

            record = {
                "caption": caption,
                "caption_tokens": caption_tokens.tolist(),
                "clip_embedding": clip_embedding.tolist(),
                "split": split,
                "id": i,

            }
            # write to file
            f.write(json.dumps(record) + "\n")
            id += 1
            progress.update()
        progress.close()


Processing train dataset: 100%|██████████| 8734/8734 [05:47<00:00, 25.11it/s]
Processing valid dataset: 100%|██████████| 1094/1094 [00:49<00:00, 22.30it/s]
Processing test dataset: 100%|██████████| 1093/1093 [00:49<00:00, 22.19it/s]


In [12]:
clip_rsicd_dataset = load_dataset("json", data_files=data_path + file_name, split="train")

In [13]:
clip_rsicd_dataset

Dataset({
    features: ['caption', 'caption_tokens', 'clip_embedding', 'split', 'id'],
    num_rows: 10921
})

In [16]:
clip_rsicd_dataset[5]['caption_tokens']

389