In [None]:
!pip install accelerate

In [None]:
import os
import tarfile
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
import piexif
from accelerate import Accelerator
import torchvision.transforms.functional as TVF

# 绝对路径
MODEL_NAME = "fancyfeast/llama-joycaption-alpha-two-hf-llava"
TAR_INPUT_PATH = ""  # 输入tar文件的绝对路径
TAR_OUTPUT_PATH = ""    # 输出tar文件的绝对路径
TEMP_DIR = ""       # 临时解压目录的绝对路径

prompt = f'''Write a stable diffusion prompt for this image within 150 words. Focus on the subjects and their physical traits and physical interactions. 
Do NOT name series or character names. 
Do NOT use any ambiguous language. 
Do NOT use tags or lists. 
Do NOT describe the mood or atmosphere. 
Do NOT describe emotions. 
If applicable, describe sexual scenes or nudity with informal, sexual slang. 
If a vagina (pussy), penis, butt (ass), nipples or sexual fluid (cum) is present, describe them with informal, sexual slang. 
If characters are engaging in sex (fucking), describe the position and type of sex with informal, sexual slang. 
If there are multiple named characters in the image, differentiate each character based on its features, and use character names instead of personal pronouns.
'''
if not os.path.exists(TEMP_DIR): os.makedirs(TEMP_DIR)
accelerator = Accelerator()
processor = AutoProcessor.from_pretrained(MODEL_NAME)
device_map = "auto" if torch.cuda.is_available() else None
llava_model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16, device_map=device_map)
llava_model.eval()
llava_model, processor = accelerator.prepare(llava_model, processor)
with tarfile.open(TAR_INPUT_PATH, 'r') as tar: tar.extractall(path=TEMP_DIR)
image_files = [os.path.join(TEMP_DIR, f) for f in os.listdir(TEMP_DIR) if f.lower().endswith(('png', 'jpg', 'jpeg'))]
with tarfile.open(TAR_OUTPUT_PATH, 'w') as tar_out:
    with torch.no_grad():
        for image_path in image_files:
            try:
                print(f"Processing {image_path}")
                image = Image.open(image_path).convert("RGB")
                if image.size != (384, 384):
                    image = image.resize((384, 384), Image.LANCZOS)
                pixel_values = TVF.pil_to_tensor(image)
                pixel_values = pixel_values / 255.0
                pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
                pixel_values = pixel_values.to(torch.bfloat16).unsqueeze(0).to(accelerator.device)
                convo = [
                    {"role": "system", "content": "You are a helpful image captioner."},
                    {"role": "user", "content": prompt}
                ]
                convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
                assert isinstance(convo_string, str)
                inputs = processor(text=[convo_string], images=[image], return_tensors="pt")
                inputs = {k: v.to(accelerator.device) for k, v in inputs.items()}
                inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
                generate_ids = llava_model.generate(
                    **inputs,
                    max_new_tokens=300,
                    do_sample=True,
                    suppress_tokens=None,
                    use_cache=True,
                    temperature=0.6,
                    top_k=None,
                    top_p=0.9,
                )[0]
                generate_ids = generate_ids[inputs['input_ids'].shape[1]:]
                caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False).strip()
                print(f"Generated Caption for {image_path}: {caption}")
                exif_dict = piexif.load(image.info.get("exif", b""))
                user_comment = caption.encode("utf-8")
                exif_dict["Exif"][piexif.ExifIFD.UserComment] = piexif.helper.UserComment.dump(user_comment, encoding="unicode")
                exif_bytes = piexif.dump(exif_dict)
                output_image_path = os.path.join(TEMP_DIR, os.path.basename(image_path))
                image.save(output_image_path, "jpeg", exif=exif_bytes)
                tar_out.add(output_image_path, arcname=os.path.basename(image_path))
            except Exception as e:
                print(f"Failed to process {image_path}: {e}")
print(f"Processed images saved to {TAR_OUTPUT_PATH}")
for file in os.listdir(TEMP_DIR):
    file_path = os.path.join(TEMP_DIR, file)
    if os.path.isfile(file_path): os.remove(file_path)
os.rmdir(TEMP_DIR)