In [9]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")

import torch
from transformers import DistilBertModel, DistilBertTokenizer
from src.dataset import UnsplashDataset, FlickrDataset
from torch.utils.data import DataLoader
from torch import nn

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
# dataset = UnsplashDataset("../../data/unsplash/photos.tsv*")
dataset = FlickrDataset(image_folder_path = "../../data/flickr-dataset/Images/", caption_path = "../../data/flickr-dataset/captions.txt")
train_loader = DataLoader(dataset, batch_size = 5, shuffle=True)    
img, label = next(iter(train_loader))

In [11]:
class TextEncoder(nn.Module):
    def __init__(self, model_name="distilbert-base-uncased", pretrained=True, trainable=True):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
            
        for p in self.model.parameters():
            p.requires_grad = trainable

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]

In [12]:
text_model = TextEncoder()

In [13]:
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [14]:
data = tokenizer(label, padding=True, truncation=True, max_length = 76, return_tensors="pt")

In [18]:
data

{'input_ids': tensor([[  101,  1037,  2879,  5629,  1010,  6729,  2058,  1037,  9540,  3561,
          2007,  2048,  2312, 25730,  2015,  1012,   102],
        [  101,  2048,  2273,  1999, 17072,  4133,  2012,  1037,  2795,  5948,
          5404,  1012,   102,     0,     0,     0,     0],
        [  101,  2048,  2308,  5102,  2007, 11228,  6961,  2058,  2037,  4641,
          2298, 14136,  2012,  1996,  8088,  1012,   102],
        [  101,  2274,  3057,  2006,  1037, 11485,  4536,  1012,   102,     0,
             0,     0,     0,     0,     0,     0,     0],
        [  101,  1037,  3598,  2447,  4324,  2010,  7151,  1998,  4152,  3201,
          2005,  1996,  2208,  1012,   102,     0,     0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 

In [16]:
text_model(data['input_ids'], data['attention_mask'])

tensor([[-0.0835, -0.1285, -0.2291,  ..., -0.0572,  0.3771, -0.0428],
        [ 0.2269,  0.0358, -0.0390,  ..., -0.1364,  0.4991, -0.0495],
        [ 0.2484, -0.0023, -0.0976,  ..., -0.1935,  0.4512,  0.1045],
        [-0.2170, -0.5539, -0.1784,  ..., -0.1827,  0.4049,  0.0035],
        [-0.3889, -0.2942, -0.0121,  ..., -0.0645,  0.2578,  0.2159]],
       grad_fn=<SliceBackward0>)