In [None]:
import os
import json
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration

from google.colab import drive
drive.mount('/content/drive')

BASE_DIR = "/content/drive/MyDrive/CSE895LLMs"
JSON_SOURCE = os.path.join(BASE_DIR, "Web2Code.json")
OUTPUT_JSON = os.path.join(BASE_DIR, "wanted_outputs.json")
model_id = "llava-hf/llava-1.5-7b-hf"

wanted_images = {
    "image_0_1.png",
    "image_0_2.png",
    "image_0_3.png",
    "image_0_4.png",
    "image_0_5.png",
    "image_0_6.png",
    "image_0_7.png",
    "image_0_8.png",
    "image_0_9.png",
    "image_0_10.png",
    "image_0_11.png",
    "image_0_12.png",
    "image_0_13.png",
    "image_0_15.png",
    "image_0_16.png",
    "image_0_18.png",
    "image_0_19.png",
    "image_0_20.png",
}

model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto"
)
processor = AutoProcessor.from_pretrained(model_id)

def generate_web2code(image_path: str, user_prompt: str) -> str:
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": user_prompt}
            ],
        }
    ]
    conv_str = processor.apply_chat_template(conversation, add_generation_prompt=True)

    with Image.open(image_path).convert("RGB") as img:
        inputs = processor(images=img, text=conv_str, return_tensors="pt").to(model.device, torch.float16)

    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=1024,
            temperature=0.7,
            top_p=0.9
        )

    outputs = processor.batch_decode(generated_ids, skip_special_tokens=True)
    return outputs[0].strip()

with open(JSON_SOURCE, "r", encoding="utf-8") as f_in:
    data = json.load(f_in)  # e.g., a list of objects

results = []
count = 0

for item in data:
    image_filename = os.path.basename(item["image"])

    if image_filename not in wanted_images:
        continue

    user_prompt = None
    for c in item["conversations"]:
        if c["from"] == "human":
            user_prompt = c["value"].replace("<image>\n", "").strip()
            break

    if not user_prompt:
        print(f"WARNING: no human message for item, skipping.")
        continue

    image_path = os.path.join(BASE_DIR, "Web2Code_image/WebSight_images_new", image_filename)
    if not os.path.exists(image_path):
        print(f"WARNING: image not found => {image_path}")
        continue

    print(f"\nProcessing {image_filename} ...")
    output_text = generate_web2code(image_path, user_prompt)

    out_obj = {
        "image": image_filename,
        "prompt": user_prompt,
        "output": output_text
    }
    results.append(out_obj)
    count += 1

print(f"\nAll done! Processed {count} images from wanted_images. Saving to {OUTPUT_JSON}")

with open(OUTPUT_JSON, "w", encoding="utf-8") as f_out:
    json.dump(results, f_out, ensure_ascii=False, indent=2)



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]


Processing image_0_1.png ...





Processing image_0_18.png ...

Processing image_0_13.png ...

Processing image_0_10.png ...

Processing image_0_4.png ...

Processing image_0_16.png ...

Processing image_0_9.png ...

Processing image_0_5.png ...

Processing image_0_7.png ...

Processing image_0_6.png ...

Processing image_0_3.png ...

Processing image_0_2.png ...

Processing image_0_12.png ...

Processing image_0_15.png ...

Processing image_0_8.png ...

Processing image_0_20.png ...

All done! Processed 16 images from wanted_images. Saving to /content/drive/MyDrive/CSE895LLMs/wanted_outputs.json
