In [2]:
import torch
import matplotlib.pyplot as plt
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F

In [5]:
cap = dset.CocoCaptions(root = '/workspaces/datasets/MS_COCO/train2017',
                        annFile = '/workspaces/datasets/MS_COCO/annotations_trainval2017/annotations/captions_train2017.json',
                        transform=transforms.PILToTensor())

loading annotations into memory...
Done (t=0.99s)
creating index...
index created!


In [6]:
print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample

print("Image Size: ", img.size())
print(target)

Number of samples:  118287
Image Size:  torch.Size([3, 425, 640])
['A zebra grazing on lush green grass in a field.', 'Zebra reaching its head down to ground where grass is. ', 'The zebra is eating grass in the sun.', 'A lone zebra grazing in some green grass.', 'a Zebra grazing on grass in a green open field.']


In [7]:
from typing import Any, Callable, List, Optional, Tuple
import torchvision

class CustomCocoCaptions(torchvision.datasets.CocoDetection):
    """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.

    It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.

    Args:
        root (string): Root directory where images are downloaded to.
        annFile (string): Path to json annotation file.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.PILToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.

    Example:

        .. code:: python

            import torchvision.datasets as dset
            import torchvision.transforms as transforms
            cap = dset.CocoCaptions(root = 'dir where images are',
                                    annFile = 'json annotation file',
                                    transform=transforms.PILToTensor())

            print('Number of samples: ', len(cap))
            img, target = cap[3] # load 4th sample

            print("Image Size: ", img.size())
            print(target)

        Output: ::

            Number of samples: 82783
            Image Size: (3L, 427L, 640L)
            [u'A plane emitting smoke stream flying over a mountain.',
            u'A plane darts across a bright blue sky behind a mountain covered in snow',
            u'A plane leaves a contrail above the snowy mountain top.',
            u'A mountain that has a plane flying overheard in the distance.',
            u'A mountain view with a plume of smoke in the background']

    """
    
    def __init__(
        self,
        root: str,
        annFile: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
    ) -> None:
        super().__init__(root, annFile, transform, target_transform, transforms)
    
    def _load_target(self, id: int) -> List[str]:
        
        ## an image might have more than one description
        captions = [ann["caption"] for ann in super()._load_target(id)]
        
        ## randomly pick one description, transform it to the right representations and return
        cap_idx = torch.randint(low=0, high=len(captions), size=(1,))
        
        return captions[cap_idx]

In [9]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.ToTensor(),
])

cap = CustomCocoCaptions(root = '/workspaces/datasets/MS_COCO/train2017',
                        annFile = '/workspaces/datasets/MS_COCO/annotations_trainval2017/annotations/captions_train2017.json',
                        transform=transforms)

loading annotations into memory...
Done (t=0.84s)
creating index...
index created!


In [10]:
model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

In [7]:
def collate_fn(batch):
    images = []
    text = []
    for i in range(len(batch)):
        images.append(batch[i][0])
        text.append(batch[i][1])
    
    tokens = tokenizer(text, return_tensors='pt', padding=True, truncation=True).to(device)
    return images, tokens

In [8]:
batch_size = 32
temperature = 0.2
epochs = 4

In [9]:
loader = torch.utils.data.DataLoader(cap, batch_size=batch_size, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=collate_fn,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

In [10]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained(model_ckpt).to(device)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
vision_encoder = torchvision.models.convnext_small(num_classes=768)
vision_encoder.classifier[2] = torch.nn.Sequential(
    torch.nn.Linear(in_features=768, out_features=2048),
    torch.nn.GELU(),
    torch.nn.Linear(in_features=2048, out_features=768)
)
vision_encoder = vision_encoder.to(device)

In [12]:
def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

In [13]:
cross_entropy_fn = torch.nn.CrossEntropyLoss() 

In [14]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)

In [None]:
class NTXentLoss(nn.Module):

    def __init__(
        self,
        batch_size,
        world_size,
        rank,
        temperature,
        device
    ):
        """
        Make NTXent loss with normalized embeddings and a temperature parameter
        NOTE: Assumes data is loaded with data-loaders constrcuted from 'init_data'
              method in data_manager.py
        :param batch_size: num. local original images per batch
        :param world_size: total number of workers in network
        :param rank: rank in network
        :param temperature: temp. param
        :param device: device to map tensors onto
        :param gather_tensors: whether to all-gather tensors across workers
        """
        self.temperature = temperature
        total_images = 2*batch_size*world_size
        self.pos_mask = torch.zeros(2*batch_size, total_images).to(device)
        self.diag_mask = torch.ones(2*batch_size, total_images).to(device)
        offset = rank*2*batch_size
        self.batch_size = batch_size
        
        for i in range(batch_size):
            self.pos_mask[i, offset + batch_size + i] = 1.
            self.pos_mask[batch_size + i, offset + i] = 1.
            self.diag_mask[i, offset + i] = 0.
            self.diag_mask[batch_size + i, offset + batch_size + i] = 0.

    def forward(self, z):
        # Step 1: normalize embeddings
        z = torch.nn.functional.normalize(z)

        # Step 2: gather embeddings from all workers
        z_buffer = AllGather.apply(z.detach())

        # Step 3: compute similarity between local embeddings and all others
        exp_cs = torch.exp(z @ z_buffer.T / self.temperature) * self.diag_mask

        # Step 4: separate positive sample from negatives and compute loss
        pos = torch.sum(exp_cs * self.pos_mask, dim=1)
        diag = torch.sum(exp_cs, dim=1)
        loss = - torch.sum(torch.log(pos.div(diag))) / (2.* self.batch_size)

        return loss.squeeze()

    return contrastive_loss

In [16]:
labels = torch.zeros(batch_size, device=device, dtype=torch.long)  # positives are the 0-th

for epoch in range(epochs):

    for images, tokens in loader:
        images = torch.stack(images, dim=0).to(device)
        output = model(**tokens)
        text_embedding = output.last_hidden_state[:,0] # use the hidden state from the CLS token
        visual_embedding = vision_encoder(images)

#         print("text_embedding:", text_embedding.shape)
#         print("visual_embedding:", visual_embedding.shape)

        text_embedding = F.normalize(text_embedding, dim=1)
        visual_embedding = F.normalize(visual_embedding, dim=1)

        similarity_matrix = torch.matmul(text_embedding, visual_embedding.T)

        positives = torch.diagonal(similarity_matrix).view(batch_size, 1)

        negatives = off_diagonal(similarity_matrix).view(batch_size, batch_size-1)

        logits = torch.cat([positives, negatives], dim=1)
        loss = cross_entropy_fn(logits / temperature, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Loss:", loss)

KeyboardInterrupt: 

In [17]:
loss

tensor(3.1795, device='cuda:0', grad_fn=<NllLossBackward0>)

In [None]:
inputs = tokenizer("I am a god!", return_tensors='pt')
inputs

In [None]:
inputs = {k: v.to(device) for k, v in inputs.items()}
inputs

In [None]:
output = model(**inputs)
output.last_hidden_state.shape

In [None]:
inputs

In [12]:
logits = torch.rand(64,64)
labels = torch.arange(64)

In [13]:
import torch.nn.functional as F

In [15]:
F.cross_entropy(logits, labels)

tensor(4.2488)