In [None]:
# magic/dataset.py

import json
import logging
import os

from PIL import Image
from torch.utils.data import Dataset


class MAGICDataset(Dataset):
    """_summary_

    :param _type_ Dataset: MAGICDataset for ImageCLEF 2024 challenge
    """
    def __init__(self, file_path:str="data/", split:str="train"):
        """
        :param split: which dataset should be chosen
        :param file_path: main path with data
        """
        self.json_file = file_path + split + "_downloaded.json"
        self.folder_path = file_path + "images/" + split
        self.data = self._get_preprocessed_data()

    def _get_preprocessed_data(self):
        with open(self.json_file, encoding="utf8") as f :
            json_data = json.load(f)
        temp_data = []
        for sample in json_data:
            if len(sample["image_ids"]) != 1 :
                logging.warning(f'Different number of images ({len(sample["image_ids"])}) for question than 1')
            image_path = self.folder_path + '/' + sample["image_ids"][0] + '.jpg'
            if not os.path.exists(image_path):
                image_path = self.folder_path + '/' + sample["image_ids"][0] + '.png'
                if not os.path.exists(image_path):
                    logging.warning(f"Couldn't find path {image_path}")
                    continue
            temp_data.append({
                "image" : image_path,
                "description" : sample["query_title_en"],
                "answer" : sample["responses"][0]["content_en"],
                "encounter_id" : sample["encounter_id"]
            })
        return temp_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        prompt = (
            "This is additional information about the dermatology issue on the image:"
            + sample["description"]
            + "What dermatological disease is on the image and how can it be treated?"
        )
        return {
            "image": Image.open(sample["image"]),  # Should be a PIL image
            "qa": [
                {
                    "question": prompt,
                    "answer": sample["answer"],
                }
            ], ## Why array?
            "encounter_id": sample['encounter_id']
        }

In [None]:
# tiny-llava-v1-hf/inference.py

import requests
import torch
# from magic.dataset import MAGICDataset
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
#from llava.model import LlavaLlamaForCausalLM
import torch
from PIL import Image
from torch.utils.data import Dataset
import json
import logging
import os

print("START")
model_id = 'bczhou/tiny-llava-v1-hf'
dataset = MAGICDataset("/content/drive/MyDrive/reddit/", "valid")

model = LlavaForConditionalGeneration.from_pretrained(model_id)

processor = AutoProcessor.from_pretrained(model_id)


response = []
for sample in dataset:
    text = f"USER: <image>\n {sample['qa'][0]['question']} ASSISTANT:"
    inputs = processor(text, sample['image'], return_tensors='pt')
    output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
    decoded_response = processor.decode(output[0][2:], skip_special_tokens=True)
    print(decoded_response)
    result = {
        "encounter_id": sample["encounter_id"],
        "responses": [{
            "content_en": decoded_response.split("ASSISTANT:")[1]
        }]
    }
    response.append(result)
    with open('/content/drive/MyDrive/reddit/valid_data.json', 'w') as f:
        json.dump(response, f)