## Step 1. Import Library

In [1]:
from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel
from PIL import Image
from datasets import load_dataset
import json

## Step 2. Make Ground Truth

In [4]:
# Generate data following the identified pattern

# Define categories and trials
categories = ["A-Single", "B-Single", "Comp", "Coop"]
classes = {"A-Single": "Single", "B-Single": "Single", "Comp": "Competition", "Coop": "Cooperation"}
trials = list(range(1, 41))  # Generating for trials 1 to 10 as an example
pairs = list(range(12, 41))  # Generating for trials 1 to 10 as an example

# Initialize an empty list to store the JSON data
data_generated = []

# Generate data following the pattern
for pair in pairs:
    for trial in trials:
        for category in categories:
            if category == "B-Single" or category == "A-Single":
                data_generated.append({
                    "image1": f"heatmapOn_trajOn_Pair-{pair}-{category}-EYE_trial{str(trial).zfill(2)}_player.png",
                    "image2": f"heatmapOn_trajOn_Pair-{pair}-{category}-EYE_trial{str(trial).zfill(2)}_observer.png",
                    "class": classes[category]
                })
            else:  # For "Comp" and "Coop" categories
                data_generated.append({
                    "image1": f"heatmapOn_trajOn_Pair-{pair}-{category}-EYE_trial{str(trial).zfill(2)}_playerA.png",
                    "image2": f"heatmapOn_trajOn_Pair-{pair}-{category}-EYE_trial{str(trial).zfill(2)}_playerB.png",
                    "class": classes[category]
                })

# Convert to JSON format for viewing
json_data_generated = json.dumps(data_generated, indent=4)

# Save to a JSON file
with open("./image_gt.json", "w") as json_file:
    json_file.write(json_data_generated)

data_generated  # Display the first 1000 characters for preview


[{'image1': 'heatmapOn_trajOn_Pair-12-A-Single-EYE_trial01_player.png',
  'image2': 'heatmapOn_trajOn_Pair-12-A-Single-EYE_trial01_observer.png',
  'class': 'Single'},
 {'image1': 'heatmapOn_trajOn_Pair-12-B-Single-EYE_trial01_player.png',
  'image2': 'heatmapOn_trajOn_Pair-12-B-Single-EYE_trial01_observer.png',
  'class': 'Single'},
 {'image1': 'heatmapOn_trajOn_Pair-12-Comp-EYE_trial01_playerA.png',
  'image2': 'heatmapOn_trajOn_Pair-12-Comp-EYE_trial01_playerB.png',
  'class': 'Competition'},
 {'image1': 'heatmapOn_trajOn_Pair-12-Coop-EYE_trial01_playerA.png',
  'image2': 'heatmapOn_trajOn_Pair-12-Coop-EYE_trial01_playerB.png',
  'class': 'Cooperation'},
 {'image1': 'heatmapOn_trajOn_Pair-12-A-Single-EYE_trial02_player.png',
  'image2': 'heatmapOn_trajOn_Pair-12-A-Single-EYE_trial02_observer.png',
  'class': 'Single'},
 {'image1': 'heatmapOn_trajOn_Pair-12-B-Single-EYE_trial02_player.png',
  'image2': 'heatmapOn_trajOn_Pair-12-B-Single-EYE_trial02_observer.png',
  'class': 'Single'}

## Step 3. Load Dataset

In [19]:
dataset = load_dataset("json", data_files="./image_gt.json", split="train")
datasets = dataset.train_test_split(test_size=0.1)
datasets

DatasetDict({
    train: Dataset({
        features: ['image1', 'image2', 'class'],
        num_rows: 4176
    })
    test: Dataset({
        features: ['image1', 'image2', 'class'],
        num_rows: 464
    })
})

## Step 4. Preprocess Data 

In [20]:
import torch
from transformers import ViTImageProcessor, BertTokenizer, VisionEncoderDecoderModel


#model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")

def process_function(example_batch):
    # 確保 batch 中的資料長度一致
    assert len(example_batch['image1']) == len(example_batch['class']), "Batch data length mismatch"

    # 批次處理圖片
    images = []
    valid_indices = []
    for idx, image_name in enumerate(example_batch['image1']):
        img_path = "./heatmapOn_trajOn/" + image_name
        try:
            image = Image.open(img_path)
            images.append(image_processor(image, return_tensors="pt").pixel_values.squeeze())
            valid_indices.append(idx)
        except Exception as e:
            print(f"Error processing image {image_name}: {e}")
            continue

    # 保留與成功處理的圖片對應的文本
    filtered_texts = [example_batch["class"][i] for i in valid_indices]
    labels = tokenizer(filtered_texts, return_tensors="pt", padding="max_length", max_length=8)
    #inputs = tokenizer(filtered_texts, return_tensors="pt", padding="max_length", truncation=True, max_length=8)

    # 檢查是否圖片與文本長度一致
    if len(images) != len(labels):
        print("Skipping batch due to length mismatch between images and input_ids.")
        return {}  # 返回空字典來跳過這個批次

    # 回傳處理結果
    example_batch["pixel_values"] = torch.stack(images)  # 堆疊成 tensor
    example_batch["input_ids"] = labels
    #example_batch["attention_mask"] = inputs.attention_mask"  # 添加attention_mask
    return example_batch

tokenized_datasets = datasets.map(process_function, batched=True, batch_size=8, remove_columns=datasets["train"].column_names)
tokenized_datasets

Map:   0%|                              | 8/4176 [00:00<01:57, 35.38 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   0%|                             | 16/4176 [00:00<02:00, 34.63 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   1%|▏                            | 24/4176 [00:00<02:00, 34.55 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   1%|▏                            | 32/4176 [00:00<01:59, 34.81 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   1%|▎                            | 40/4176 [00:01<01:58, 34.76 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   1%|▎                            | 48/4176 [00:01<01:58, 34.87 examples/s]

Skipping batch due to length mismatch between images and input_ids.


Map:   2%|▍                            | 64/4176 [00:01<01:53, 36.29 examples/s]

Skipping batch due to length mismatch between images and input_ids.
Error processing image heatmapOn_trajOn_Pair-18-B-Single-EYE_trial12_player.png: [Errno 2] No such file or directory: './heatmapOn_trajOn/heatmapOn_trajOn_Pair-18-B-Single-EYE_trial12_player.png'
Skipping batch due to length mismatch between images and input_ids.



Map:   2%|▌                            | 72/4176 [00:02<01:54, 35.92 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   2%|▌                            | 80/4176 [00:02<01:53, 35.97 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   2%|▌                            | 88/4176 [00:02<01:54, 35.59 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   2%|▋                            | 96/4176 [00:02<01:54, 35.54 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   2%|▋                           | 104/4176 [00:02<01:55, 35.32 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   3%|▊                           | 112/4176 [00:03<01:54, 35.39 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   3%|▊                           | 120/4176 [00:03<01:54, 35.57 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   3%|▊                           | 128/4176 [00:03<01:54, 35.44 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   3%|▉                           | 136/4176 [00:03<01:54, 35.35 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   3%|▉                           | 144/4176 [00:04<01:54, 35.32 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   4%|█                           | 152/4176 [00:04<01:54, 35.03 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   4%|█                           | 160/4176 [00:04<01:56, 34.55 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   4%|█▏                          | 168/4176 [00:04<01:55, 34.78 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   4%|█▏                          | 176/4176 [00:05<01:57, 34.16 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   4%|█▏                          | 184/4176 [00:05<01:56, 34.34 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   5%|█▎                          | 192/4176 [00:05<01:55, 34.47 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   5%|█▎                          | 200/4176 [00:05<01:57, 33.84 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   5%|█▍                          | 208/4176 [00:05<01:56, 34.07 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   5%|█▍                          | 216/4176 [00:06<01:54, 34.73 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   5%|█▌                          | 224/4176 [00:06<01:53, 34.86 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   6%|█▌                          | 232/4176 [00:06<01:53, 34.79 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   6%|█▌                          | 240/4176 [00:06<01:53, 34.66 examples/s]

Skipping batch due to length mismatch between images and input_ids.



Map:   6%|█▋                          | 248/4176 [00:07<01:55, 34.11 examples/s]

Skipping batch due to length mismatch between images and input_ids.


Map:   6%|█▋                          | 256/4176 [00:07<01:55, 33.96 examples/s]

Skipping batch due to length mismatch between images and input_ids.





KeyboardInterrupt: 

## Step5. Load Model

In [11]:
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")


VisionEncoderDecoderModel has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


## Step 6. Create Training Arguments

In [15]:
from transformers import Trainer, TrainingArguments


# 載入模型、tokenizer、image processor
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
#tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
#image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

# 設置模型生成參數
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

# 設置訓練參數
training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    save_steps=10_000,
    save_total_limit=2,
    push_to_hub=False,  # 禁用自動登錄到 Hugging Face Hub
    report_to="none"  # 禁用所有外部服務的報告
    #logging_dir='./logs',
    #logging_steps=500,
)

# 初始化Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
)

# 開始訓練
trainer.train()




ValueError: You have to specify either input_ids or inputs_embeds

In [16]:
print(tokenized_datasets["train"].column_names)

['pixel_values', 'input_ids', 'attention_mask']
