In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.transforms import v2 as transforms
#from src.data import MultiModalH5PyDataset
from torchmultimodal.modules.losses.contrastive_loss_with_temperature import ContrastiveLossWithTemperature

from omegaconf import OmegaConf
import lightning as pl
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.utilities import rank_zero_only


from transformers import (
    VisionTextDualEncoderModel,
    VisionTextDualEncoderProcessor,
    AutoImageProcessor,
    AutoTokenizer,
    BertConfig,
    ViTConfig,
    VisionTextDualEncoderConfig
)

from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [82]:
class CustomDataset(Dataset):
    def __init__(self, split_file, image_dir, processor=None, transform=None):
        self.split = load_json(split_file)
        self.image_dir = image_dir
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.split)
    
    def __getitem__(self, idx):
        sample = self.split[idx]
        #img_filename = sample['filename']
        img_filenames = sample['filename']
        img_filename = img_filenames[0] if isinstance(img_filenames, list) else img_filenames
        img_path = os.path.join(self.image_dir, img_filename)       
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        # Apply transformations if provided
        if self.processor:
            img = torch.squeeze(self.processor(images=image, return_tensors="pt").pixel_values)
        else:
            img = image

        if self.transform:
            img = self.transform(img)
        
        #captions = sample['sentences']
        #caption = captions[0]
        caption = sample['sentences 0']
        
        return img, caption

In [83]:
def get_datasets(dataset, data_dir, processor):

    image_transforms = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=.5, contrast=.2, saturation=.3, hue=.2),
            transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.))
        ]
    )
    
    train_dataset = CustomDataset(
        'coco_karpathy_train.json', 
        os.path.join(data_dir, "train2014"), 
        processor,
        # image_transforms
    )

    val_dataset = CustomDataset(
        'coco_karpathy_val.json', 
        os.path.join(data_dir, "val2014"), 
        processor,
        # image_transforms
    )

    return train_dataset, val_dataset

In [84]:
def get_dataloaders(train_dataset, val_dataset, **kwargs):

    # def collate_fn(samples):
    #     pixel_values = torch.stack([img for (img, text) in samples])
    #     text = [text for (img, text) in samples]
    #     return {"pixel_values": pixel_values, "text": text}
    
    train_dataloader = DataLoader(train_dataset, **kwargs)
    val_dataloader = DataLoader(val_dataset, **kwargs)

    return train_dataloader, val_dataloader

In [5]:
def get_models(image_encoder_name, text_encoder_name):
    config_vision = ViTConfig()
    config_text = BertConfig()
    config_model = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config=config_vision, text_config=config_text)
    
    model = VisionTextDualEncoderModel(config_model)

    image_processor = AutoImageProcessor.from_pretrained(image_encoder_name)
    tokenizer = AutoTokenizer.from_pretrained(text_encoder_name, use_fast=False)
    processor = VisionTextDualEncoderProcessor(image_processor=image_processor, tokenizer=tokenizer)

    return model, processor

In [6]:
class LitMML(pl.LightningModule):
    # def __init__(self, image_encoder, text_encoder,tokenizer,temperature,learning_rate):
    def __init__(
        self,
        model,
        processor,
        loss_cfg,
        optimizer_cfg,
        scheduler_cfg,
    ):
        super().__init__()

        # self.image_encoder = image_encoder
        # self.text_encoder = text_encoder
        self.model = model
        self.processor = processor

        self.loss_cfg = loss_cfg
        self.optimizer_cfg = optimizer_cfg
        self.scheduler_cfg = scheduler_cfg

        self.contrastive_loss = ContrastiveLossWithTemperature()

        # self.temperature = temperature
        # self.learning_rate = learning_rate

        self.save_hyperparameters(ignore=["model", "image_encoder", "text_encoder", "tokenizer"])
    
    # def encode_image(self, imgs):
    #     image_out = self.image_encoder(imgs)
    #     return image_out

    # def encode_text(self, txt):
    #     tok_y = self.tokenizer(txt, padding=True, return_tensors="pt").to(self.device)

    #     text_out = self.text_encoder(**tok_y)
    #     return text_out

    def tokenize(self, text):
        tokens = self.processor(text=text, return_tensors="pt", padding=True).to(self.device)
        return tokens
    
    def forward(self, images, text):
        tokens = self.tokenize(text)
        return self.model(**tokens, pixel_values=images)
    
    def training_step(self, batch, batch_idx):
        images, text = batch

        # image_out = self.encode_image(x)["image_embeds"]
        # text_out = self.encode_text(y)["text_embeds"]

        tokens = self.tokenize(text)
        
        outputs = self.model(**tokens, pixel_values=images)
        # outputs = self.model(
        #     input_ids=inputs.input_ids,
        #     attention_mask=inputs.attention_mask,
        #     pixel_values=inputs.pixel_values,
        #     return_loss=True
        # )

        image_out = torch.nn.functional.normalize(outputs.image_embeds, dim=-1)
        text_out = torch.nn.functional.normalize(outputs.text_embeds, dim=-1)

        # loss, logits_per_image = outputs.loss, outputs.logits_per_image
        # self.log("loss/train", loss.mean(), sync_dist=True)
        # return loss.mean()

        #if self.loss_cfg.name == "contrastive_like_clip":
        #loss = clip_contrastive_loss(image_out, text_out, self.loss_cfg.temperature)
        loss = self.contrastive_loss(image_out, text_out).mean()
        self.log("loss/train", loss, sync_dist=True)

        return loss

    def validation_step(self, batch, batch_idx):
        images, text = batch

        tokens = self.tokenize(text)
        
        outputs = self.model(**tokens, pixel_values=images)
        # outputs = self.model(
        #     input_ids=inputs.input_ids,
        #     attention_mask=inputs.attention_mask,
        #     pixel_values=inputs.pixel_values,
        #     return_loss=True
        # )

        image_out = torch.nn.functional.normalize(outputs.image_embeds, dim=-1)
        text_out = torch.nn.functional.normalize(outputs.text_embeds, dim=-1)

        # loss, logits_per_image = outputs.loss, outputs.logits_per_image
        # self.log("loss/train", loss.mean(), sync_dist=True)
        # return loss.mean()
    
        #loss = clip_contrastive_loss(image_out, text_out, self.loss_cfg.temperature)
        loss = self.contrastive_loss(image_out, text_out).mean()
        self.log("loss/val", loss, sync_dist=True)
        return loss

    def configure_optimizers(self):

        return torch.optim.Adam(self.parameters())
        
        if self.optimizer_cfg.name == "Adam":
            optimizer = torch.optim.Adam(self.parameters(), **self.optimizer_cfg.kwargs)
        elif self.optimizer_cfg.name == "SGD":
            optimizer = torch.optim.SGD(self.parameters(), **self.optimizer_cfg.kwargs)
        elif self.optimizer_cfg.name == "AdamW":
            optimizer = torch.optim.AdamW(self.parameters(), **self.optimizer_cfg.kwargs)
        elif self.optimizer_cfg.name == "Adagrad":
            optimizer = torch.optim.Adagrad(self.parameters(), **self.optimizer_cfg.kwargs)
        else:
            raise ValueError(
                f"Wrong optimizer name. Provided {self.optimizer_cfg.name} which doesn't exist"
            )

        # check if scheduler is to be used
        if self.scheduler_cfg.name is None:
            print("No scheduler provided, using only optimizer")
            return optimizer

        elif self.scheduler_cfg.name == "CosineAnnealingLR":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, **self.scheduler_cfg.kwargs
            )
            return {
                "optimizer": optimizer,
                "lr_scheduler": scheduler,
            }

        elif self.scheduler_cfg.name == "CosineAnnealingLRWarmRestarts":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLRWarmRestarts(
                optimizer, **self.scheduler_cfg.kwargs
            )
            return {
                "optimizer": optimizer,
                "lr_scheduler": scheduler,
            }

        elif self.scheduler_cfg.name == "ReduceROnPlateau":
            monitor_metric = self.scheduler_cfg.kwargs.pop("monitor")
            scheduler = torch.optim.lr_scheduler.ReduceROnPlateau(
                optimizer, **self.scheduler_cfg.kwargs
            )

            return {
                "optimizer": optimizer,
                "lr_scheduler": {"scheduler": scheduler, "monitor": monitor_metric},
            }
        else:
            raise ValueError(
                f"Wrong scheduler name. Provided {self.scheduler_cfg.name} which doesn't exist"
            )

In [7]:
model, processor = get_models('google/vit-base-patch16-224', 'google-bert/bert-base-uncased')

In [85]:
dataset = load_dataset("yerevann/coco-karpathy")
#train_dataset, val_dataset = get_datasets(dataset, '/home/data/COCOcaptions', processor)
train_dataset, val_dataset = get_datasets(dataset, '/home/data/COCOcaptions', processor)

In [18]:
train_dataloader, val_dataloader = get_dataloaders(train_dataset, val_dataset, batch_size=2)

In [10]:
net = LitMML(
        model,
        processor,
        loss_cfg='contrastive_like_clip',
        optimizer_cfg="AdamW",
        scheduler_cfg=None,
    )

In [91]:
train_dataset

<__main__.CustomDataset at 0x7f161c66d4f0>

In [90]:
batch = next(iter(train_dataloader))
batch

{'pixel_values': tensor([[[[-0.3569, -0.3725, -0.3176,  ...,  0.0196, -0.0039,  0.4588],
           [-0.3569, -0.3804, -0.3020,  ...,  0.0039,  0.0196,  0.4824],
           [-0.3882, -0.3961, -0.3255,  ...,  0.0667, -0.0196,  0.0902],
           ...,
           [-0.6392, -0.6784, -0.6863,  ...,  0.2392,  0.2314,  0.2000],
           [-0.6000, -0.6235, -0.6471,  ...,  0.2157,  0.2078,  0.2078],
           [-0.5059, -0.5137, -0.5529,  ...,  0.2157,  0.1765,  0.1922]],
 
          [[-0.6471, -0.6471, -0.6235,  ..., -0.2627, -0.2235,  0.1059],
           [-0.6941, -0.6863, -0.6863,  ..., -0.2471, -0.1922,  0.1608],
           [-0.6941, -0.6941, -0.7020,  ..., -0.1686, -0.2078, -0.1059],
           ...,
           [-0.8588, -0.8745, -0.8667,  ...,  0.3255,  0.2941,  0.2784],
           [-0.8431, -0.8353, -0.8196,  ...,  0.2863,  0.2784,  0.2706],
           [-0.7255, -0.7333, -0.7647,  ...,  0.2706,  0.2549,  0.2392]],
 
          [[-0.7882, -0.7804, -0.7647,  ..., -0.2549, -0.2235,  0.0196

In [31]:
train_dataset[300][1]

'some cows walking across the grass and beside some trees '

In [32]:
train_loader = DataLoader(train_dataset, batch_size=12)

In [41]:
len(train_loader.dataset)

82783

In [45]:
dataset['train'].features

{'filepath': Value(dtype='string', id=None),
 'sentids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'filename': Value(dtype='string', id=None),
 'imgid': Value(dtype='int64', id=None),
 'split': Value(dtype='string', id=None),
 'sentences': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
 'cocoid': Value(dtype='int64', id=None),
 'url': Value(dtype='string', id=None)}

In [46]:
flat_dataset = dataset['train'].flatten()

In [73]:
def split_sentences(sample):
    for i in range(len(sample["sentences"])):
        sample[f"sentence {i+1}"] = sample["sentences"][i]
    return sample

In [74]:
new_dataset = dataset['train'].map(split_sentences, remove_columns=["sentences"])

In [76]:
new_dataset.save_to_disk('coco_karpathy')

Saving the dataset (0/1 shards):   0%|          | 0/82783 [00:00<?, ? examples/s]

Saving the dataset (1/1 shards): 100%|██████████| 82783/82783 [00:01<00:00, 72267.23 examples/s]


In [68]:
new_dataset.to_json('./coco_karpathy.json')

Creating json from Arrow format: 100%|██████████| 83/83 [00:02<00:00, 32.39ba/s]


48763993

In [77]:
from datasets import load_from_disk

my_dataset = load_from_disk('coco_karpathy')

In [78]:
my_dataset

Dataset({
    features: ['filepath', 'sentids', 'filename', 'imgid', 'split', 'cocoid', 'url', 'sentence 1', 'sentence 2', 'sentence 3', 'sentence 4', 'sentence 5'],
    num_rows: 82783
})

In [62]:
def add_images(sample):
    sample['image'] = Image.open(os.path.join('/home/data/COCOcaptions/train2014', sample['filename']))
    return sample

In [79]:
my_dataset.map(add_images)

Map:  16%|█▌        | 12959/82783 [01:44<09:21, 124.30 examples/s]


KeyboardInterrupt: 

In [67]:
small_dataset

{'filepath': ['train2014', 'train2014', 'train2014', 'train2014', 'train2014'],
 'sentids': [[787980, 789366, 789888, 791316, 794853],
  [118034, 157682, 159179, 175826, 185540],
  [813979, 814735, 816446, 816950, 817379],
  [636271, 638152, 638287, 638446, 643555],
  [72465, 74946, 76104, 81696, 84192]],
 'filename': ['COCO_train2014_000000057870.jpg',
  'COCO_train2014_000000384029.jpg',
  'COCO_train2014_000000222016.jpg',
  'COCO_train2014_000000520950.jpg',
  'COCO_train2014_000000069675.jpg'],
 'imgid': [40504, 40505, 40506, 40507, 40508],
 'split': ['train', 'train', 'train', 'train', 'train'],
 'cocoid': [57870, 384029, 222016, 520950, 69675],
 'url': ['http://images.cocodataset.org/train2014/COCO_train2014_000000057870.jpg',
  'http://images.cocodataset.org/train2014/COCO_train2014_000000384029.jpg',
  'http://images.cocodataset.org/train2014/COCO_train2014_000000222016.jpg',
  'http://images.cocodataset.org/train2014/COCO_train2014_000000520950.jpg',
  'http://images.cocodata

In [80]:
import json

with open('coco_karpathy.json', 'r') as file:
    for line in file:
        data = json.loads(line)
        print(data)
        break

{'filepath': 'train2014', 'sentids': [787980, 789366, 789888, 791316, 794853], 'filename': 'COCO_train2014_000000057870.jpg', 'imgid': 40504, 'split': 'train', 'cocoid': 57870, 'url': 'http://images.cocodataset.org/train2014/COCO_train2014_000000057870.jpg', 'sentence 1': 'A restaurant has modern wooden tables and chairs.', 'sentence 2': 'A long restaurant table with rattan rounded back chairs.', 'sentence 3': 'a long table with a plant on top of it surrounded with wooden chairs ', 'sentence 4': 'A long table with a flower arrangement in the middle for meetings', 'sentence 5': 'A table is adorned with wooden chairs with blue accents.'}


In [55]:
net.training_step()

tensor(3.6100, grad_fn=<MeanBackward0>)

In [81]:
def load_json(filepath):
    json_objects = []
    with open(filepath, 'r') as file:
        for line in file:
            json_object = json.loads(line.strip())
            json_objects.append(json_object)
    return json_objects

In [8]:
def main(config):
    seed_everything(config.lightning.seed, workers=True)

    dataset = load_dataset("yerevann/coco-karpathy")
    
    model, processor = get_models(config)
    train_dataset, val_dataset = get_datasets(config, dataset, processor)
    train_dataloader, val_dataloader = get_dataloaders(config, train_dataset, val_dataset)

    net = LitMML(
        model,
        processor,
        loss_cfg=config.loss,
        optimizer_cfg=config.optimizer,
        scheduler_cfg=config.scheduler,
    )

    wandb_logger = WandbLogger(**config.wandb)

    # log the config on the master node
    if rank_zero_only.rank == 0:
        cfg_dict = OmegaConf.to_container(config, resolve=True)
        wandb_logger.experiment.config.update(cfg_dict)

    ckpt_callback = ModelCheckpoint(
        every_n_epochs=2,
        dirpath=f"{config.save_dir}/ckpts/{wandb_logger.experiment.id}",
        filename="ckpt-{epoch:02d}-{val_loss:.3f}",
    )
    lr_callback = LearningRateMonitor(logging_interval="step")

    trainer = pl.Trainer(
        **config.lightning.trainer,
        logger=wandb_logger,
        callbacks=[ckpt_callback, lr_callback]
    )

    wandb_logger.watch(net)

    trainer.fit(net, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
    #trainer.fit(net, train_dataloaders=train_dataloader)