In [21]:
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure
from bson.objectid import ObjectId
import requests
from PIL import Image
import pandas as pd
from tqdm import tqdm
import os
from dotenv import load_dotenv
import matplotlib.pyplot as plt

In [22]:
load_dotenv()
client = MongoClient(os.environ.get("MONGODB_URI"))

In [34]:
def fetchData(client):
  result  = client.reddit.images.metadata.find(
      {"category_human": {"$exists": True}},
      {"url":1,"phash":1,"category_human":1}
      )
  return result

def retrieveImages(data,n):
  tmp = data[0:n]
  for entry in tqdm(data[0:n]):
    image = Image.open(requests.get(entry["url"], stream=True).raw)
    entry["image"]=image
  return tmp

data = list(fetchData(client))
data=retrieveImages(data,500)


In [24]:
import torch
from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

In [25]:
from transformers import pipeline

model_id = "llava-hf/llava-1.5-7b-hf"

pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})


Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.03s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [30]:
prompt = "USER: <image>\n Please categorize the image into one of the following categories: 'Reddit', 'Instagram', 'Facebook', 'Discord', 'TikTok', 'Messaging', 'News', 'Meme', 'Tweet', or 'Other'. Focus on visual and textual clues that might indicate the source platform or content type. \nASSISTANT:"

In [35]:
def llava_label(data):
    for entry in tqdm(data):
        outputs = pipe(entry["image"], prompt=prompt, generate_kwargs={"max_new_tokens": 1000})
        entry["category_llava"] = outputs[0]["generated_text"].split("ASSISTANT: ")[1]

llava_label(data)

100%|██████████| 100/100 [00:32<00:00,  3.08it/s]


In [None]:
df = pd.DataFrame(data)[['category_human', 'category_llava', 'url']]

In [38]:
print(df[0:20])

   category_human category_llava                                  url
0           Tweet       Facebook  https://i.redd.it/q49trvdqfa3a1.jpg
1           Tweet         Reddit  https://i.redd.it/huab7xr7ed3a1.jpg
2            Meme      Instagram  https://i.redd.it/g5picrl8ie3a1.jpg
3           Other          Other  https://i.redd.it/gm6a3y4aif3a1.jpg
4          Reddit         Reddit  https://i.redd.it/t8b1v46hjg3a1.jpg
5           Tweet        Twitter  https://i.redd.it/5vbryiwc5f3a1.png
6           Tweet      Instagram  https://i.redd.it/dmzg6wi10g3a1.jpg
7            Meme           Meme  https://i.redd.it/o6sljqhq0i3a1.jpg
8            Meme           Meme  https://i.redd.it/bihyeo5imh3a1.jpg
9       Messaging         Reddit  https://i.redd.it/e12o69tm1i3a1.jpg
10          Other          Other  https://i.redd.it/qpqlgn4p5g3a1.jpg
11          Other          Other  https://i.redd.it/lhmgbzi1a83a1.jpg
12           Meme           Meme  https://i.redd.it/oaf8fa75fl3a1.jpg
13          Other   

In [37]:
matches = (df['category_human'] == df['category_llava']).sum()
print("Accuracy: %s%%" % ((matches/len(data))*100))

Accuracy: 42.0%


In [None]:
df.to_excel("output.xlsx", index=False, engine='openpyxl')