In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
import requests
from PIL import Image
from transformers import BlipProcessor, BlipForImageTextRetrieval

processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")

model = model.cuda()
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
train_ds = load_dataset("zai-org/VisionRewardDB-Image", split='train[:40000]')
test_ds = load_dataset("zai-org/VisionRewardDB-Image", split='train[40000:]')

import io, math, random
import numpy as np
from PIL import Image, ImageFilter, ImageEnhance
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import pandas as pd

df = pd.read_csv("rules.csv")

import pandas as pd
import re

df.columns = df.columns.str.strip()
df['Dimension'] = df['Dimension'].ffill()

df['dim_key'] = df['Dimension'].apply(lambda x: re.search(r'\((.*?)\)', x).group(1) if re.search(r'\((.*?)\)', x) else x)

guide = {
    dim_key: {
        int(row['Score']): row['Option'] + ": " +str(row['Description']).strip()
        for _, row in group.iterrows()
    }
    for dim_key, group in df.groupby('dim_key')
}

dims = {k: v for k, v in guide.items() if k not in ["unsafe type", "hands", "face", "body", "safety", "lighting aesthetic"]}.keys()
dims = list(dims)
dim_min = {i:min(guide[i].keys()) for i in guide.keys()}

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading dataset shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [74]:
import json
with open("prompts.json", "r") as f:
    prompt_dict = json.load(f)


In [62]:
prompt_dict

{'background': "The images's background is bad, there is no background or the background is ugly",
 'clarity': "The images's clarity is blurry, the image is noticeably blurry, as though noise or distortion is present.",
 'color aesthetic': "The images's color aesthetic is bad, there are ugly colors, the colors are unpleasant, often due to lack of harmony, abruptness, overexposure, or poor color choice.",
 'color brightness': "The images's color brightness is dark: Purely judging the color, it is dark.",
 'detail realism': "The images's detail realism is bad, inauthentic; The image is inconsistent with reality, featuring heavy object distortion or deformation, even when scaled down, the flaws in realism are apparent.",
 'detail refinement': "Important elements appear noticeably rough on closer inspection, or the overall image feels rough. Important details are blatantly rough or unfinished. So rough that it is impossible to discern objects in the image. The image isn't just indecipherab

In [75]:
def format_data(sample):
    images = []
    dims_selected = []
    # print(len(sample["image"]), len(sample["annotation"]))
    for image in range(len(sample['image'])):
        images.append(sample['image'][image])
        try:
            if random.random()>0.5:
                # sample a dim with score>=0 
                dims_selected.append(random.choice(list([i for i in dims if sample['annotation'][image][i]>=0])))
            else:
                # sample a dim with score<0
                dims_selected.append(random.choice(list([i for i in dims if sample['annotation'][image][i]<0])))
        except IndexError:
            dims_selected.append(random.choice(dims))
    
    prompts = [prompt_dict[dim] for i, dim in enumerate(dims_selected)]
    images = list(sample['image'])
    n_images = len(images)
    n_prompts = len(prompts) 
    # prompts = [[prompt] * len(images) for prompt in prompts] # n prompt x n image
    # prompts = [item for sublist in prompts for item in sublist]
    # images = images * n_prompts # n prompt x n images
    inputs = processor(images=images, text=prompts, return_tensors="pt", padding=True)
    answers = [1 if i[dim]<0 else 0 for i, dim in zip(sample["annotation"], dims_selected)]
    # answers = [(-np.sign(i[dim])+1)/2 for i in sample["annotation"]] 
    answers = [(1-i, i) for i in answers]
    labels = torch.tensor(answers)
    inputs['labels'] = labels
    inputs['n_images'] = [n_images] * len(inputs['input_ids'])
    return {
        'pixel_values': inputs['pixel_values'],
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        'labels': inputs['labels'], 
        'dims': dims_selected,
        'n_images': inputs['n_images'],
        "prompts": prompts,
        "annotation": [i[dim] for i, dim in zip(sample["annotation"], dims_selected)],
    }


In [None]:
# sorted_items = [sorted(i.items()) for i in [guide[dim] for dim in dims]]

# import json
# with open("prompts.json", "w") as f:
#     json.dump([f"The images's {dim} is bad: {sorted_items[i][0][1]}" for i, dim in enumerate(dims)], f, indent=4)

In [64]:
[f"{dim}" for i, dim in enumerate(dims)]

['background',
 'clarity',
 'color aesthetic',
 'color brightness',
 'detail realism',
 'detail refinement',
 'emotion',
 'lighting distinction',
 'main object',
 'object pairing',
 'richness',
 'symmetry']

In [76]:
train_ds = train_ds.with_transform(format_data)
test_ds = test_ds.with_transform(format_data)

inputs = train_ds[:18]  
# print(inputs["labels"])
# [len(i) for i in train_ds[:8].values()] 

In [77]:
inputs["prompts"]

["The images's shows negative emotions. The image invokes negative emotions, such as creepiness, eeriness, or hostility. The image elicits negative emotions, such as horror or disgust.",
 "The images's symmetry is bad: Asymmetrical: Choose asymmetrical only if the following conditions are met:1. The object is naturally symmetrical in real life but appears clearly asymmetrical in the image.2. The object itself is asymmetrical due to reasons other than perspective.",
 "The images's symmetry is bad: Asymmetrical: Choose asymmetrical only if the following conditions are met:1. The object is naturally symmetrical in real life but appears clearly asymmetrical in the image.2. The object itself is asymmetrical due to reasons other than perspective.",
 "The images's color brightness is dark: Purely judging the color, it is dark.",
 "The images's symmetry is bad: Asymmetrical: Choose asymmetrical only if the following conditions are met:1. The object is naturally symmetrical in real life but app

In [4]:
# print(inputs["annotation"]) 


In [5]:
# train_ds[:10]["labels"] 

In [6]:
# [len(i) for i in train_ds[:8].values()]


import wandb

import torch
from transformers import PreTrainedModel, PretrainedConfig

class Rater(PreTrainedModel):
    def __init__(self, backbone):
      super().__init__(PretrainedConfig())
      self.backbone = backbone
      # self.t_score = 0.2
      # self.t_ce = 0.2
      self.t = torch.nn.Parameter(torch.tensor(0.2))

    def forward(self, pixel_values, input_ids, attention_mask, dim, n_images, labels=None):
      dim = dim[0]
      n_images = n_images[0]
      outputs = self.backbone(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
      itm_scores = outputs[0]

      if labels is not None:
        assert itm_scores.shape[0] == labels.shape[0] == n_images, f"{itm_scores.shape[0]} {labels.shape[0]} {n_images}"
        assert itm_scores.shape[1] == labels.shape[1] == 2
        bce_loss = torch.nn.functional.cross_entropy(itm_scores, labels)
        loss = bce_loss

        assert itm_scores.argmax(-1).shape == labels.argmax(-1).shape
        try:
          wandb.log({"bce_loss": bce_loss, "acc": (itm_scores.argmax(-1) == labels.argmax(-1)).float().mean()})
        except:
          pass
        outputs['loss'] = loss

      return outputs

my_rater = Rater(model).cuda()

from transformers import TrainingArguments
import os
training_args = TrainingArguments(
    output_dir="BLIP-Reward",
    learning_rate=1e-4,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    gradient_accumulation_steps=4,
    num_train_epochs=7,
    weight_decay=0.0001,
    eval_strategy="steps",
    save_strategy="steps",
    eval_steps=500,
    save_steps=500,
    logging_steps=1,
    load_best_model_at_end=True,
    push_to_hub=True,
    max_grad_norm=1.0,
    remove_unused_columns=False,
    dataloader_num_workers=os.cpu_count(),
    fp16=True,
    warmup_ratio=0.01,
    lr_scheduler_type="warmup_stable_decay",
    lr_scheduler_kwargs={"num_decay_steps": 500},
    report_to="wandb"
)

from transformers import Trainer

trainer = Trainer(
    model=my_rater,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    processing_class=processor,
)

In [None]:
trainer.train() 



[34m[1mwandb[0m: Currently logged in as: [33mwguo6358[0m ([33m3dsmile[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


64 6464 64

646464 64 
64 6464
 
646464646464
     6464646464
 

6464 64

64
6464646464
 6464646464   
 64 6464 64646464    6464
646464646464 

 646464
64 

 64

64 

  6464 646464646464 64
  64 64
64
64
646464
6464
6464  64

 646464646464  
 64
 64
 64646464
646464
6464


 

 646464  64 64 64646464 64 64
 64 6464 64646464

64 64 64
 6464
 646464
 
64
 
64 
  

 
 646464
64 
 6464 64646464 6464646464 
64 



 64646464

64646464
 646464 

64
64

  6464
 
6464
6464 
6464 64
64
64
64 64 64
   6464646464 6464 
 64


6464
64 

6464 64 64

 6464
64


64
64  6464

6464 64  64 64646464 64
 6464

64
6464  
6464
64
64  6464

64
 6464
 6464 
6464 64
64
 64 646464

 64
64 64
