In [1]:
# params
test_valid_percentage = 20 # (test - 10, valid - 10)

train_data_percentage = 20
valid_data_percentage = 40
test_data_percentage = 40

In [2]:
import os
import re
import json
import torch
import numpy as np
import pandas as pd

In [3]:
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor,AutoTokenizer
os.environ["WANDB_DISABLED"] = "true"

In [4]:
import nltk
try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    nltk.download("punkt", quiet=True)

## Initialize VisionEncoderDecoderModelPermalink

In [5]:
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor

image_encoder_model = "google/vit-base-patch16-224-in21k"
text_decode_model = "gpt2"

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(image_encoder_model, text_decode_model)

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.9.crossattention.q_attn.weight', 'h.9.crossattention.c_attn.weight', 'h.11.crossattention.q_attn.bias', 'h.4.crossattention.q_attn.bias', 'h.7.ln_cross_attn.bias', 'h.5.crossattention.q_attn.bias', 'h.5.ln_cross_attn.weight', 'h.3.ln_cross_attn.weight', 'h.7.crossattention.c_attn.bias', 'h.2.crossattention.c_proj.weight', 'h.6.crossattention.c_proj.weight', 'h.0.ln_cross_attn.weight', 'h.3.crossattention.c_attn.bias', 'h.1.ln_cross_attn.bias', 'h.8.crossattention.c_proj.weight', 'h.10.crossattention.q_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.11.crossattention.c_proj.weight', 'h.6.ln_cross_attn.bias', 'h.3.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.5.crossattention.c_attn.bias', 'h.2.crossattention.q_attn.weight', 'h.0.crossattention.q_attn.weight', 'h.7.crossattention.q_attn.weight', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_at

In [6]:
# image feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model)
# text tokenizer
tokenizer = AutoTokenizer.from_pretrained(text_decode_model)



In [7]:
# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
tokenizer.pad_token = tokenizer.eos_token

# update the model config
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

In [8]:
output_dir = "vit-gpt-model"
model.save_pretrained(output_dir)
feature_extractor.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

('vit-gpt-model/tokenizer_config.json',
 'vit-gpt-model/special_tokens_map.json',
 'vit-gpt-model/vocab.json',
 'vit-gpt-model/merges.txt',
 'vit-gpt-model/added_tokens.json',
 'vit-gpt-model/tokenizer.json')

## Data Loading and Preparation

In [9]:
# file paths data
image_dir = './all_data/train/radiology/images/'
data_file = './all_data/train/radiology/traindata.csv'

In [10]:
data = pd.read_csv(data_file)
data

Unnamed: 0,id,name,caption
0,ROCO_00002,PMC4083729_AMHSR-4-14-g002.jpg,Computed tomography scan in axial view showin...
1,ROCO_00003,PMC2837471_IJD2009-150251.001.jpg,Bacterial contamination occurred after comple...
2,ROCO_00004,PMC2505281_11999_2007_30_Fig6_HTML.jpg,The patient had residual paralysis of the han...
3,ROCO_00005,PMC3745845_IJD2013-683423.005.jpg,Panoramic radiograph after immediate loading.\n
4,ROCO_00007,PMC4917066_amjcaserep-17-301-g001.jpg,Plain abdomen x-ray: Multiple air levels at t...
...,...,...,...
65445,ROCO_81819,PMC3517833_CRIM.HEMATOLOGY2012-490438.001.jpg,Initial CT abdomen with contrast showing a di...
65446,ROCO_81820,PMC5487234_rb-50-03-0190-g13.jpg,44-year-old male patient after surgical amput...
65447,ROCO_81821,PMC2974222_kjr-11-612-g001.jpg,Primary pulmonary tuberculosis in 18-year-old...
65448,ROCO_81822,PMC3532764_AJNS-7-151-g002.jpg,"MRI brain with gadolinium, coronal view, show..."


In [11]:
# Replace column name 'name' with 'image_path'
data['image_path'] = data.pop('name')

# Prepend 'image_dir' to all entries in 'image_path' column
data['image_path'] = image_dir + data['image_path']

data

Unnamed: 0,id,caption,image_path
0,ROCO_00002,Computed tomography scan in axial view showin...,./all_data/train/radiology/images/PMC4083729_A...
1,ROCO_00003,Bacterial contamination occurred after comple...,./all_data/train/radiology/images/PMC2837471_I...
2,ROCO_00004,The patient had residual paralysis of the han...,./all_data/train/radiology/images/PMC2505281_1...
3,ROCO_00005,Panoramic radiograph after immediate loading.\n,./all_data/train/radiology/images/PMC3745845_I...
4,ROCO_00007,Plain abdomen x-ray: Multiple air levels at t...,./all_data/train/radiology/images/PMC4917066_a...
...,...,...,...
65445,ROCO_81819,Initial CT abdomen with contrast showing a di...,./all_data/train/radiology/images/PMC3517833_C...
65446,ROCO_81820,44-year-old male patient after surgical amput...,./all_data/train/radiology/images/PMC5487234_r...
65447,ROCO_81821,Primary pulmonary tuberculosis in 18-year-old...,./all_data/train/radiology/images/PMC2974222_k...
65448,ROCO_81822,"MRI brain with gadolinium, coronal view, show...",./all_data/train/radiology/images/PMC3532764_A...


In [12]:
# LLM output
f = open('./all_data/llm_result.txt', "r")
contents = f.read()
contents = contents.replace("\n", "")
json_data = json.loads(contents)

llm_df = pd.DataFrame(json_data)

llm_df = llm_df.drop('index', axis=1)

llm_df = llm_df[llm_df['relationship'].apply(lambda x: re.search(r'\w', str(x)) is not None)]
llm_df = llm_df.reset_index(drop=True)

llm_df

Unnamed: 0,id,relationship,summary
0,ROCO_00002,with a mass of homogeneous attenuation (necro...,The patient has undergone a CT scan which sho...
1,ROCO_00003,The patient developed pain and discomfort on ...,The patient developed pain and discomfort on ...
2,ROCO_00007,The given UMLS semantic types are 1. Radiogra...,There is no single UMLS semantic type that re...
3,ROCO_00008,or edema. There are no other abnormalities.\n...,"as follows: 1. Intellectual Product , 2. Phar..."
4,ROCO_00009,The patient was treated with K-wire and cannu...,\nThe UMLS concept DIAGNOSIS has many children...
...,...,...,...
796,ROCO_01327,Pseudoaneurysm is a Pathologic Function with ...,Pseudoaneurysm is a Pathologic Function with ...
797,ROCO_01329,The answer is Causative Agent: Bacterium.,\nCausative Agent: Bacterium
798,ROCO_01330,and ascites with dilated veins at the periphe...,No new information.
799,ROCO_01332,".\nThe UMLS semantic types are 1. Body Part, O...",The given UMLS semantic types are 1. Diagnost...


In [13]:
# Filter ones that have been training by the LLM
data = data.merge(llm_df, on='id')
data

Unnamed: 0,id,caption,image_path,relationship,summary
0,ROCO_00002,Computed tomography scan in axial view showin...,./all_data/train/radiology/images/PMC4083729_A...,with a mass of homogeneous attenuation (necro...,The patient has undergone a CT scan which sho...
1,ROCO_00003,Bacterial contamination occurred after comple...,./all_data/train/radiology/images/PMC2837471_I...,The patient developed pain and discomfort on ...,The patient developed pain and discomfort on ...
2,ROCO_00007,Plain abdomen x-ray: Multiple air levels at t...,./all_data/train/radiology/images/PMC4917066_a...,The given UMLS semantic types are 1. Radiogra...,There is no single UMLS semantic type that re...
3,ROCO_00008,A 3-year-old child with visual difficulties. ...,./all_data/train/radiology/images/PMC4805615_1...,or edema. There are no other abnormalities.\n...,"as follows: 1. Intellectual Product , 2. Phar..."
4,ROCO_00009,Showing the subtrochanteric fracture in the p...,./all_data/train/radiology/images/PMC2584650_1...,The patient was treated with K-wire and cannu...,\nThe UMLS concept DIAGNOSIS has many children...
...,...,...,...,...,...
796,ROCO_01327,Control angiography showed total exclusion of...,./all_data/train/radiology/images/PMC4396546_C...,Pseudoaneurysm is a Pathologic Function with ...,Pseudoaneurysm is a Pathologic Function with ...
797,ROCO_01329,Abdominal CT finding. Enterocutaneous fistula...,./all_data/train/radiology/images/PMC4316223_i...,The answer is Causative Agent: Bacterium.,\nCausative Agent: Bacterium
798,ROCO_01330,(Case 2) Post operative CT scan showing persi...,./all_data/train/radiology/images/PMC3589860_J...,and ascites with dilated veins at the periphe...,No new information.
799,ROCO_01332,Post-operative chest X-ray image of the same ...,./all_data/train/radiology/images/PMC4262879_W...,".\nThe UMLS semantic types are 1. Body Part, O...",The given UMLS semantic types are 1. Diagnost...


In [14]:
from sklearn.model_selection import train_test_split

# Split data into train, test, and valid datasets
train_data, valid_test_data = train_test_split(data, test_size=test_valid_percentage/100, random_state=42)
valid_data, test_data = train_test_split(valid_test_data, test_size=0.5, random_state=42)

# Reset index
train_data = train_data.reset_index(drop=True)
valid_data = valid_data.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)

In [15]:
print("Train data shape: ", train_data.shape)
print("Valid data shape: ", valid_data.shape)
print("Test data shape: ", test_data.shape)

Train data shape:  (640, 5)
Valid data shape:  (80, 5)
Test data shape:  (81, 5)


In [16]:
# Select n% of data
train_data = train_data.sample(frac=train_data_percentage/100, random_state=42)
valid_data = valid_data.sample(frac=valid_data_percentage/100, random_state=42)
test_data = test_data.sample(frac=test_data_percentage/100, random_state=42)

In [19]:
# Reset index
train_data = train_data.reset_index(drop=True)
valid_data = valid_data.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)

print("Train data shape: ", train_data.shape)
print("Valid data shape: ", valid_data.shape)
print("Test data shape: ", test_data.shape)

Train data shape:  (128, 5)
Valid data shape:  (32, 5)
Test data shape:  (32, 5)


In [17]:
from datasets import Dataset, DatasetDict

# Convert DataFrame to Hugging Face dataset dictionary format
train_data_dict = Dataset.from_pandas(train_data)
valid_data_dict = Dataset.from_pandas(valid_data)
test_data_dict = Dataset.from_pandas(test_data)

dataset_dict = DatasetDict({
    'train': train_data_dict,
    'validation': valid_data_dict,
    'test': test_data_dict
})

print(dataset_dict)

DatasetDict({
    train: Dataset({
        features: ['id', 'caption', 'image_path', 'relationship', 'summary', '__index_level_0__'],
        num_rows: 128
    })
    validation: Dataset({
        features: ['id', 'caption', 'image_path', 'relationship', 'summary', '__index_level_0__'],
        num_rows: 32
    })
    test: Dataset({
        features: ['id', 'caption', 'image_path', 'relationship', 'summary', '__index_level_0__'],
        num_rows: 32
    })
})


In [None]:
# from PIL import Image

# # text preprocessing step
# def tokenization_fn(captions, max_target_length):
#     """Run tokenization on captions."""
#     labels = tokenizer(captions, 
#                       padding="max_length", 
#                       max_length=max_target_length).input_ids

#     return labels

# # image preprocessing step
# def preprocess_images(image_paths):
#     processed_images = []
#     for image_path in image_paths:
#         image = Image.open(image_path)
#         if image.mode != "RGB":
#             image = image.convert("RGB")
#         processed_images.append(image)
#     return processed_images

# def feature_extraction_fn(image_paths, check_image=True):
#     if check_image:
#         images = preprocess_images(image_paths)
#     else:
#         images = [Image.open(image_file) for image_file in image_paths]

#     encoder_inputs = feature_extractor(images=images, return_tensors="np")

#     return encoder_inputs.pixel_values

# # def feature_extraction_fn(image_paths, check_image=True):
# #     """
# #     Run feature extraction on images
# #     If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded.
# #     Otherwise, an exception will be thrown.
# #     """
# #     model_inputs = {}

# #     if check_image:
# #         images = []
# #         to_keep = []
# #         for image_file in image_paths:
# #             try:
# #                 img = Image.open(image_file)
# #                 images.append(img)
# #                 to_keep.append(True)
# #             except Exception:
# #                 to_keep.append(False)
# #     else:
# #         images = [Image.open(image_file) for image_file in image_paths]

# #     encoder_inputs = feature_extractor(images=images, return_tensors="np")

# #     return encoder_inputs.pixel_values

# def preprocess_fn(examples, max_target_length, check_image = True):
#     """Run tokenization + image feature extraction"""
#     image_paths = examples['image_path']
#     captions = examples['caption']    
    
#     model_inputs = {}
#     # This contains image path column
#     model_inputs['labels'] = tokenization_fn(captions, max_target_length)
#     model_inputs['pixel_values'] = feature_extraction_fn(image_paths, check_image=check_image)

#     return model_inputs

In [None]:
from PIL import Image

class ImageCapatioingDataset(torch.utils.data.Dataset):
    def __init__(self, ds, ds_type, max_target_length):
        self.ds = ds
        self.max_target_length = max_target_length
        self.ds_type = ds_type

    def __getitem__(self, idx):
        image_path = self.ds[self.ds_type]['image_path'][idx]
        caption = self.ds[self.ds_type]['caption'][idx]
        model_inputs = dict()
        model_inputs['labels'] = self.tokenization_fn(caption, self.max_target_length)
        model_inputs['pixel_values'] = self.feature_extraction_fn(image_path)
        return model_inputs

    def __len__(self):
        return len(self.ds[self.ds_type])
    
    # text preprocessing step
    def tokenization_fn(self, caption, max_target_length):
        """Run tokenization on caption."""
        labels = tokenizer(caption, 
                          padding="max_length", 
                          max_length=max_target_length).input_ids

        return labels

    # image preprocessing step
    def feature_extraction_fn(self, image_path):
        image = Image.open(image_path)
        if image.mode != "RGB":
            image = image.convert("RGB")

        encoder_inputs = feature_extractor(images=image, return_tensors="np")

        return encoder_inputs.pixel_values[0]

In [None]:
# processed_dataset = dataset_dict.map(
#     function=preprocess_fn,
#     batched=True,
#     fn_kwargs={"max_target_length": 128},
#     remove_columns=dataset_dict['train'].column_names
# )

In [None]:
# processed_dataset

In [None]:
train_ds = ImageCapatioingDataset(dataset_dict, 'train', 64)
eval_ds = ImageCapatioingDataset(dataset_dict, 'validation', 64)

## Define seq2seq training argumentsPermalink

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    output_dir="./image-captioning-output",
)

## Define metric

In [None]:
import evaluate
metric = evaluate.load("rouge")

In [None]:
import numpy as np

ignore_pad_token_for_loss = True

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds,
                                                     decoded_labels)

    result = metric.compute(predictions=decoded_preds,
                            references=decoded_labels,
                            use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    return result

## Training

In [None]:
from transformers import default_data_collator

# # instantiate trainer
# trainer = Seq2SeqTrainer(
#     model=model,
#     tokenizer=feature_extractor,
#     args=training_args,
#     compute_metrics=compute_metrics,
#     train_dataset=processed_dataset['train'],
#     eval_dataset=processed_dataset['validation'],
#     data_collator=default_data_collator,
# )

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_ds, 
    eval_dataset=eval_ds, 
    data_collator=default_data_collator,
)


In [None]:
trainer.train()

In [None]:
trainer.save_model("./image-captioning-output")
tokenizer.save_pretrained("./image-captioning-output")

## Inference

In [None]:
from transformers import pipeline
image_captioner = pipeline("image-to-text", model="./image-captioning-output")

In [None]:
for index, row in test_data.iterrows():
    caption = row['caption']
    image_path = row['image_path']
    print(caption)
    print()
    image_captioner("sample_image.png")
    print("\n----------\n")

In [None]:

# # Define your ImageCaptionDataset class for loading and preprocessing the data
# class ImageCaptionDataset(Dataset):
#     def __init__(self, image_paths, captions, transform):
#         self.image_paths = image_paths
#         self.captions = captions
#         self.transform = transform

#     def __len__(self):
#         return len(self.image_paths)

#     def __getitem__(self, index):
#         image_path = self.image_paths[index]
#         caption = self.captions[index]
#         image = self.transform(image_path)
#         return image, caption

# # Set your hyperparameters and configurations
# image_dir = "path/to/images"
# caption_file = "path/to/captions.txt"
# batch_size = 32
# embedding_dim = 256
# hidden_dim = 512
# learning_rate = 0.001
# num_epochs = 10

# # Load and preprocess your dataset
# def preprocess_image(image_path):
#     image = Image.open(image_path).convert("RGB")
#     image = image.resize((224, 224))
#     image = transforms.ToTensor()(image)
#     image = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(image)
#     return image

# image_paths = [...]  # List of image file paths
# captions = [...]  # List of corresponding captions

# transform = preprocess_image
# dataset = ImageCaptionDataset(image_paths, captions, transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# # Initialize the ViT feature extractor and tokenizer
# feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
# tokenizer = AutoTokenizer.from_pretrained("google/vit-base-patch16-224")

# # Initialize the ViT-based image captioning model
# model = ViTForImageCaptioning.from_pretrained("google/vit-base-patch16-224")

# # Define your loss function and optimizer
# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)