In [None]:
!git clone https://github.com/eujhwang/meme-cap.git

Cloning into 'meme-cap'...
remote: Enumerating objects: 52, done.[K
remote: Counting objects: 100% (52/52), done.[K
remote: Compressing objects: 100% (45/45), done.[K
remote: Total 52 (delta 10), reused 13 (delta 0), pack-reused 0[K
Receiving objects: 100% (52/52), 844.03 KiB | 3.34 MiB/s, done.
Resolving deltas: 100% (10/10), done.


In [None]:
%cd meme-cap/data

/content/meme-cap/data


In [None]:
import os
import json
import requests
from tqdm import tqdm
import uuid

In [None]:
def download_image(url, img_path):
    response = requests.get(url)

    if response.status_code == 200:
        with open(img_path, 'wb') as f:
            f.write(response.content)
        return True

    return False

In [None]:
'''
./imgs
    /trainval
        abcd.png
        efgh.jpg
        ...
    /test
        lmno.png
        wxyz.jpg
        ...
'''
def prepare_data(phase):
    root_folder = "imgs"
    os.makedirs(root_folder, exist_ok=True)

    # Load train-val data
    with open(f'memes-{phase}.json', 'r') as f:
        data = json.load(f)

    output_folder = os.path.join(root_folder, phase)
    os.makedirs(output_folder, exist_ok=True)

    dataset = []

    # Download stat
    tot_count = len(data)
    suc_count = 0
    fail_count = 0

    try:
        with tqdm(total=tot_count, desc="Downloading Images", ncols=80) as pbar:
            for d in data:
                img_url = d['url']
                img_fname = d['img_fname']
                img_path = os.path.join(output_folder, img_fname)

                # Download the image
                if download_image(img_url, img_path):
                    meme_id = str(uuid.uuid4())

                    meme_data = {
                        "id": meme_id,
                        "image": img_path,
                        "conversations": [
                            {
                                "from": "human",
                                "value": "<image>\nWhat is in this meme?"
                            },
                            {
                                "from": "gpt",
                                "value": " ".join(d['img_captions'])
                            },
                            {
                                "from": "human",
                                "value": "<image>\nWhat metaphor is this meme trying to convey?"
                            },
                            {
                                "from": "gpt",
                                "value": " ".join(d['meme_captions'])
                            }
                        ]
                    }
                    dataset.append(meme_data)

                    suc_count += 1
                else:
                    fail_count += 1

                # Update progress bar
                pbar.update(1)

        with open(f'llava-{phase}.json', 'w') as out:
            json.dump(dataset, out, indent=4)

        print(f"{tot_count} images: {suc_count} downloaded, {fail_count} failed")

    except Exception as e:
        print(f"Error when preparing data: {e}")

In [None]:
prepare_data("trainval")

Downloading Images: 100%|███████████████████| 5823/5823 [24:20<00:00,  3.99it/s]


5823 images: 5341 downloaded, 482 failed


In [None]:
prepare_data("test")

Downloading Images: 100%|█████████████████████| 559/559 [02:23<00:00,  3.91it/s]

559 images: 518 downloaded, 41 failed



