In [None]:
import os
os.environ["HLS_MODULE_ID"] = "0"

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torchvision.io import ImageReadMode, read_image
from transformers.models.bridgetower.modeling_bridgetower import BridgeTowerForContrastiveLearning
from transformers import Trainer, TrainingArguments
from transformers import AutoImageProcessor
import torch
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode, to_grayscale, to_tensor

from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments

# Set up the data

In [None]:
# Load dataset
dset = load_dataset("jmhessel/newyorker_caption_contest", "matching")

In [None]:
# Get the training dataset
train_dataset = dset["train"]

In [None]:
dset["train"]

In [None]:
# k = 213
k = 199
print(dset["train"][k]['image_description'])
dset["train"][k]['image']

# Load pre-trained model and tokenizer

In [None]:
# Load pre-trained model and tokenizer
model_name_or_path = "BridgeTower/bridgetower-large-itm-mlm-itc"
model = BridgeTowerForContrastiveLearning.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

# Tokenize the train dataset captions

In [None]:
# Tokenize captions using the tokenizer
def tokenize_captions(examples):
        captions = list(examples["image_description"])
        text_inputs = tokenizer(captions, max_length=128, padding="max_length", truncation=True)
        examples["input_ids"] = text_inputs.input_ids
        examples["attention_mask"] = text_inputs.attention_mask
        return examples

In [None]:
# Tokenize captions of all the train dataset
train_dataset = train_dataset.map(
            function=tokenize_captions,
            batched=True,
            remove_columns=[col for col in dset["train"].column_names if col != "image"],
            num_proc=None,
            desc="Running tokenizer on train dataset",
        )

In [None]:
train_dataset

# Image preprocessing

In [None]:
# Convert image to grayscale and tensor
def get_image(image_or_path):
    image_or_path = to_grayscale(image_or_path, num_output_channels=3)
    return to_tensor(image_or_path)

In [None]:
# Preprocess of the image: Resize, CenterCrop, ConvertImageDtype and Normalize
class Transform(torch.nn.Module):
    def __init__(self, image_size, mean, std):
        super().__init__()
        self.transforms = torch.nn.Sequential(
            Resize([image_size], interpolation=InterpolationMode.BICUBIC),
            CenterCrop(image_size),
            ConvertImageDtype(torch.float),
            Normalize(mean, std),
        )

    def forward(self, x) -> torch.Tensor:
        """`x` should be an instance of `PIL.Image.Image`"""
        with torch.no_grad():
            x = self.transforms(x)
        return x

In [None]:
# Load image_processor, in this script we only use this to get the mean and std for normalization.
image_processor = AutoImageProcessor.from_pretrained(model_name_or_path)
image_size = model.config.vision_config.image_size

image_transformations = Transform(image_size, image_processor.image_mean, image_processor.image_std)

In [None]:
# Apply image transformations to the images in the examples
def transform_images(examples):
    images = [get_image(image_file) for image_file in examples["image"]]
    examples["pixel_values"] = [image_transformations(image) for image in images]
    return examples

In [None]:
# The transform (image processor) is applied on-the-fly on batches when __getitem__ is called
train_dataset.set_transform(transform_images)

# Training configuration

In [None]:
# The function to use to form a batch from a list of elements of train_dataset
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long)
    attention_mask = torch.tensor([example["attention_mask"] for example in examples], dtype=torch.long)
    return {
        "pixel_values": pixel_values,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "return_loss": True,
    }

In [None]:
# Define training arguments

num_train_epochs = 1
per_device_train_batch_size = 8
use_lazy_mode = True
learning_rate = 5e-05

# training_args = TrainingArguments(
#     output_dir="test_trainer",
#     remove_unused_columns=False,
#     num_train_epochs=1,
#     report_to=[],
#     logging_steps=50
# )

training_args = GaudiTrainingArguments(
    output_dir="test_trainer",
    remove_unused_columns=False,
    use_habana=True,
    use_lazy_mode=use_lazy_mode,
    use_hpu_graphs_for_inference=True,
    gaudi_config_name="Habana/clip",
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    learning_rate=learning_rate,
    report_to=[],
    logging_steps=50
)

In [None]:
training_args

In [None]:
gaudi_config = GaudiConfig.from_pretrained(training_args.gaudi_config_name)

In [None]:
# Initalize the trainer

# trainer = Trainer(
#         model=model,
#         #gaudi_config=gaudi_config,
#         #args=training_args,
#         args=training_args,
#         train_dataset=train_dataset,
#         data_collator=collate_fn,
#     )

trainer = GaudiTrainer(
    model=model,
    gaudi_config=gaudi_config,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=collate_fn,
)

# Train the model 😃

In [None]:
# Train the model
train_result = trainer.train(resume_from_checkpoint=None)

# Test the model

In [None]:
INDX_EXAMPLE_1 = 0
print(dset["train"][INDX_EXAMPLE_1]['image_description'])
dset["train"][INDX_EXAMPLE_1]['image']

In [None]:
INDX_EXAMPLE_2 = 5
print(dset["train"][INDX_EXAMPLE_2]['image_description'])
dset["train"][INDX_EXAMPLE_2]['image']

In [None]:
# Set the device to HPU
device = torch.device("hpu")

ex_1 = train_dataset[INDX_EXAMPLE_1]  # Get example 1 from the train dataset
ex_2 = train_dataset[INDX_EXAMPLE_2]  # Get example 2 from the train dataset


# Extract pixel values, input IDs, and attention mask from example 1
pixel_values_1 = torch.stack([ex_1['pixel_values']]).to(device)
input_ids_1 = torch.tensor([ex_1['input_ids']]).to(device)
attention_mask_1 = torch.tensor([ex_1['attention_mask']]).to(device)


# Extract pixel values, input IDs, and attention mask from example 2
pixel_values_2 = torch.stack([ex_2['pixel_values']]).to(device)
input_ids_2 = torch.tensor([ex_2['input_ids']]).to(device)
attention_mask_2 = torch.tensor([ex_2['attention_mask']]).to(device)



# Create an encoding dictionary for example 1
encoding_1 = {
    "pixel_values": pixel_values_2,
    "input_ids": input_ids_1,
    "attention_mask": attention_mask_1,
}

# Pass the example to the fine-thuned model
outputs_1 = trainer.model(**encoding_1)
logits_text_to_image_1 = torch.matmul(outputs_1['text_embeds'], outputs_1['image_embeds'].t()).to('cpu')
print(f"logits_text_to_image_1: {logits_text_to_image_1}")

In [None]:
# Pass the example to the baesline model

outputs_1 = model(**encoding_1)
logits_text_to_image_1 = torch.matmul(outputs_1['text_embeds'], outputs_1['image_embeds'].t()) 
print(f"logits_text_to_image_1: {logits_text_to_image_1.to('cpu')}")