In [15]:
import io
import spacy
from collections import Counter
from collections import defaultdict 
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision
import torch
from  torchtext.vocab import vocab
from torchtext.data.utils import get_tokenizer
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from PIL import Image


In [2]:
IMAGES_PATH = './flickr8/Images'
CAPTIONS_PATH = './flickr8/captions.txt'

In [None]:
!python -m spacy download en_core_web_sm

In [4]:
class FlickrDataset(Dataset):
    def __init__(self, captions_path, images_path, transform=transforms.ToTensor()):
        self.tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
        self.img_caption_dict = self._load_img_caption_dict(captions_path)
        self.caption_vocab = self._get_vocab(self.img_caption_dict)
        self.image_paths = [f"{images_path}/{img_name}" for img_name in self.img_caption_dict.keys()]
        self.image_captions = list(self.img_caption_dict.values())
        self.transform = transform
        
    def _get_vocab(self, img_caption_dict):
        counter = Counter()
        for img_key in img_caption_dict:
            for caption in img_caption_dict[img_key]:
                counter.update(self.tokenizer(caption))
        caption_vocab = vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
        caption_vocab.lookup_token(100)
        caption_vocab.set_default_index(caption_vocab["<unk>"])
        return caption_vocab
    
    def _load_img_caption_dict(self, captions_path):
        img_capt_dict = defaultdict(list)
        with open(captions_path, 'r') as captions_file:
            for line in captions_file.readlines():
                if line.startswith("image"):
                    # header
                    continue
                
                else:
                    current_line = line.split(',')
                    img = current_line[0]
                    capt = current_line[1]
                    img_capt_dict[img].append(capt)
                    
        return img_capt_dict        

    
    def __len__(self):
        return len(self.img_caption_dict)
    
    def __getitem__(self, index):
        image = Image.open(self.image_paths[index]).convert("RGB")
        image_tensor = self.transform(image)
        
        data = []
        for caption in self.image_captions[index]:
            tokens = self.tokenizer(caption)
            tensor = torch.cat([
                torch.tensor([self.caption_vocab['<bos>']]),
                torch.tensor([self.caption_vocab[token] for token in tokens]),
                torch.tensor([self.caption_vocab['<eos>']])
            ])
            data.append(tensor)
                
        
        return image_tensor, pad_sequence(data, padding_value=self.caption_vocab['<pad>'])
        
    

In [5]:
dataset = FlickrDataset(CAPTIONS_PATH, IMAGES_PATH, transform=transforms.Compose(
        [
            transforms.Resize((356, 356)),
            transforms.RandomCrop((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    ))


In [38]:
dataset.__getitem__(9)[1].size()

torch.Size([19, 5])

In [63]:
class Collate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = torch.stack([item[0] for item in batch], dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)

        return imgs, targets


In [64]:
loader = DataLoader(
    dataset=dataset,
    batch_size=10,
    shuffle=True,
    pin_memory=True,
    collate_fn=Collate(pad_idx=dataset.caption_vocab['<pad>']),
)

Reference: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/image_captioning/model.py

In [83]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train=False):
        super(EncoderCNN, self).__init__()
        self.use_pretrained = not train
        self.inception = models.inception_v3(pretrained=self.use_pretrained, aux_logits=True) # TODO - aux_logits=False giving issues
        self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
        # make sure output of cnn model is embed size
        
        for name, param in self.inception.named_parameters():
            if "fc.weight" in name or "fc.bias" in name:
                param.requires_grad = True
            else:
                param.requires_grad = train
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
                
        
        # self.main = nn.Sequential(
        #     inception,
        #     nn.ReLU(),
        #     nn.Dropout(0.5)
        # )

    def forward(self, images):
        inception_output = self.inception(images)
        return self.dropout(self.relu(inception_output[1]))


In [84]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.dropout = nn.Dropout(0.5)
        
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.dropout(self.embed(captions))
        print(embeddings.size())
        print(features.unsqueeze(0).size())
        embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs

In [85]:
class CNNtoRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        self.encoder = EncoderCNN(embed_size)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

    def forward(self, images, captions):
        features = self.encoder(images)
        print(features.size())
        print(captions.size())
        outputs = self.decoder(features, captions)
        return outputs

    def caption_image(self, image, vocabulary, max_length=50):
        result_caption = []

        with torch.no_grad():
            encoded_image = self.encoder(image).unsqueeze(0)
            states = None

            for _ in range(max_length):
                hiddens, states = self.decoder.lstm(encoded_image, states)
                output = self.decoder.linear(hiddens.squeeze(0))
                predicted = output.argmax(1)
                predicted_word = vocabulary[predicted.item()]
                result_caption.append(predicted_word)
                encoded_image = self.decoder.embed(predicted).unsqueeze(0)

                if predicted_word == "<eos>":
                    break

        return result_caption

In [60]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

In [86]:
model = CNNtoRNN(embed_size=256, hidden_size=256, vocab_size=len(dataset.caption_vocab), num_layers=1).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.caption_vocab['<pad>'])
optimizer = optim.Adam(model.parameters(), lr=3e-4)

In [None]:
model.train()
for epoch in range(2):
    for idx, (imgs, captions) in enumerate(loader):
        imgs = imgs.to(device)
        captions = captions.to(device)
        outputs = model(imgs, captions[:-1])
        print(f"output size: {outputs.size()}")
        loss = criterion(
            outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
        )
        print(f"loss: {loss}")
        optimizer.zero_grad()
        loss.backward(loss)
        optimizer.step()
        