In [1]:
import clip
import numpy as np
import torch
from torchmultimodal.modules.losses.contrastive_loss_with_temperature import ContrastiveLossWithTemperature
from torchmultimodal.transforms.clip_transform import CLIPTransform, CLIPImageTransform, CLIPTextTransform
from torchvision import transforms
from tqdm import tqdm

from modules.refcocog import RefCOCOg
from modules.refcocog import RefCOCOgSample
from modules.utilities import get_best_device


In [2]:
data_path = "dataset/refcocog"

dataset = RefCOCOg(ds_path=data_path, transform_img=CLIPImageTransform(), transform_txt=CLIPTextTransform())

# train_ds = RefCOCOg(ds_path=data_path, split='train')
# val_ds = RefCOCOg(ds_path=data_path, split='val')
# test_ds = RefCOCOg(ds_path=data_path, split='test')

print(f"[INFO] Dataset Size: {len(dataset)}")
# print(f"[INFO] train split:  {len(train_ds)}")
# print(f"[INFO] val split:    {len(val_ds)}")
# print(f"[INFO] test split:   {len(test_ds)}")


[INFO] Dataset Size: 49822


In [3]:
#@title Collate function to yield images and sentences from the DataLoader

def collate_fn(batch_):
    batch_ = [RefCOCOgSample(**sample) for sample in batch_]

    images, texts = list(), list()

    for sample in batch_:
        for sentence in sample.sentences:
            images.append(sample.img)
            texts.append(sentence)

    return torch.stack(images), torch.stack(texts)


In [4]:
#@title Main training cell

# hyperparameters
epochs = 10
batch_size = 64

# get best device
device = get_best_device()

# instantiate clip transform
clip_transform = CLIPTransform()

# instantiate the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# instantiate model. Here we use clip with vit-L as the image encoder
model, _ = clip.load("ViT-L/14", device=device)

model = model.float()

# define loss and other things needed for training
contrastive_loss = ContrastiveLossWithTemperature()
optim = torch.optim.AdamW(model.parameters(), lr=1e-5)

# define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

# write your train loop

for n in range(epochs):
    print(f"[INFO] Epoch #{n}")
    epoch_losses = list()

    pbar = tqdm(dataloader, desc="[INFO] Loss ?????", leave=True)

    for batch in pbar:
        image, text = batch
        # image, text = clip_transform(image, text)

        image_embeddings, text_embeddings = model(image.to(device), text.to(device))

        loss = contrastive_loss(image_embeddings, text_embeddings)
        epoch_losses.append(loss.item())

        avg_loss = np.mean(epoch_losses)
        pbar.set_description("[INFO] Loss %.4f" % avg_loss)

        loss.backward()
        optimizer.step()


[INFO] Using MPS.
[INFO] Epoch #0


[INFO] Loss ?????:   0%|          | 0/1557 [00:06<?, ?it/s]


RuntimeError: MPS backend out of memory (MPS allocated: 16.20 GB, other allocations: 1.81 GB, max allowed: 18.13 GB). Tried to allocate 183.71 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).