In [1]:
import os
from tqdm import tqdm
from data_handler import DataHandler
from evaluation import get_evaluation
from image_processor import ImageProcessor
from model_handler import get_model

In [2]:
def main(model_name, model_size, model_path, device_map, data_base_path, output_base_path):
    # Initialize model handler, data handler, and image processor
    model = get_model(model_name, model_size, model_path, device_map)
    data_handler = DataHandler(data_base_path)
    image_processor = ImageProcessor()

    # Evaluation settings and data types
    # settings = ["default", "student-forcing", "teacher-forcing", "single"]
    settings = ["student-forcing"]
    data_types = ["train", "validation"]

    # Iterate through each setting and data type
    for setting in tqdm(settings, desc="Settings"):
        for data_type in tqdm(data_types, desc=f"Data Types for {setting}", leave=False):
            # Load the data for the current data type
            data = data_handler.load_data(data_type)
            
            # Get the evaluation class for the current setting
            evaluation_class = get_evaluation(setting)
            
            # Define the output path
            output_path = os.path.join(output_base_path, model_name+model_size, setting, data_type)
            
            # Process each data file
            for file_name, entries in tqdm(data.items(), desc=f"Files for {data_type}", leave=False):
                processed_data = []
                photo2answer = {}
                acc_list = [0, 0, 0, 0, 0]
                evaluator = evaluation_class(model, image_processor, data_base_path, data_type)

                try:
                    entries = entries[:5]
                    for entry in tqdm(entries, desc=f"Entries for {file_name}", leave=False):
                        evaluator.process_entry(entry, acc_list, photo2answer)
                        processed_data.append(entry)

                        # Save partial results after processing each entry
                        metrics = evaluator.calculate_metrics(acc_list)
                        data_handler.save_partial_results(processed_data, metrics, photo2answer, output_path, file_name)
                    # data_handler.save_final_results(processed_data, metrics, photo2answer, output_path, file_name)
                except Exception as e:
                    print(f"Error processing file {file_name} in setting {setting}, data type {data_type}: {e}")

In [3]:
# Define parameters
model_name = "operallava"
model_size = "7b"
model_path = "/scratch/rqa8sm/ROPE/llava-v1.6-34b-hf"
device_map = "cuda"
data_base_path = "/scratch/rqa8sm/rebuttal/data/ROPE"
output_base_path = "/scratch/rqa8sm/ROPE/output-experiments-rebuttal"

# Run the main function
main(model_name, model_size, model_path, device_map, data_base_path, output_base_path)

Initializing Model


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  return torch.load(checkpoint_file, map_location="cpu")
Some weights of the model checkpoint at openai/clip-vit-large-patch14-336 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.final_layer_norm.bias', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'tex

CLIPImageProcessor {
  "crop_size": {
    "height": 336,
    "width": 336
  },
  "do_center_crop": true,
  "do_convert_rgb": true,
  "do_normalize": false,
  "do_rescale": true,
  "do_resize": true,
  "feature_extractor_type": "CLIPFeatureExtractor",
  "image_mean": [
    0.48145466,
    0.4578275,
    0.40821073
  ],
  "image_processor_type": "CLIPImageProcessor",
  "image_std": [
    0.26862954,
    0.26130258,
    0.27577711
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 336
  }
}

Done!


Settings:   0%|          | 0/1 [00:00<?, ?it/s]
[A

[A[A

prompt:  USER: <ImageHere> Given the classes: arm, armchair, bed, book, bottle, box, building, cabinet, car, ceiling, chair, column, curtain, cushion, door, drawer, fence, floors, flower, glass, grass, handle, head, lamp, leg, light, light source, mirror, mountain, pane, person, picture, pillow, plant, plate, pole, pot, road, rock, seat, shelf, sign, sofa, spotlight, streetlight, table, tree, vase, wheel, window. There are five red bounding boxes in this image. For each object within the red bounding boxes, identify its class from the list. Provide the class names in the format: 'obj1: <class1>, obj2: <class2>, obj3: <class3>, obj4: <class4>, obj5: <class5>', with no additional words or punctuation. For example: obj1: class, obj2: class, obj3: class, obj4: class, obj5: class. Replace class with the actual names of the classes from your class list. Ensure that no placeholders or brackets are used around the class names and that no additional words or punctuation are added to the respons

  return torch.cuda.amp.autocast(dtype=dtype)


OPERA's output:
book, obj2: book, obj3: book, obj4: book, obj5: book
predicted_class_str:  book, obj2: book, obj3: book, obj4: book, obj5: book
predicted_class:  book
prompt:  USER: <ImageHere> Given the classes: arm, armchair, bed, book, bottle, box, building, cabinet, car, ceiling, chair, column, curtain, cushion, door, drawer, fence, floors, flower, glass, grass, handle, head, lamp, leg, light, light source, mirror, mountain, pane, person, picture, pillow, plant, plate, pole, pot, road, rock, seat, shelf, sign, sofa, spotlight, streetlight, table, tree, vase, wheel, window. There are five red bounding boxes in this image. For each object within the red bounding boxes, identify its class from the list. Provide the class names in the format: 'obj1: <class1>, obj2: <class2>, obj3: <class3>, obj4: <class4>, obj5: <class5>', with no additional words or punctuation. For example: obj1: class, obj2: class, obj3: class, obj4: class, obj5: class. Replace class with the actual names of the cla