In [71]:
import pandas as pd
import os
import spacy
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from PIL import Image
from pathlib import Path

In [72]:
##if you don't have the flickr8k images and captions(need 1GB space)
# import requests
# import zipfile
# dataset_source = "https://www.kaggle.com/datasets/e1cd22253a9b23b073794872bf565648ddbe4f17e7fa9e74766ad3707141adeb/download?datasetVersionNumber=1"

# print("downloading started")
# with open("archive.zip", "wb") as f:
#     response = requests.get(dataset_source)
#     f.write(response.content)
    
# with zipfile.ZipFile("archive.zip", "r") as zip_ref:
#     zip_ref.extractall(os.getcwd())
# print("file extracted")

# question = input("do you want to delete zip file? y/n:  ")    
# try:
#     if question == "y":
#         os.remove("archive.zip")
# except FileNotFoundError:
#     print("The system cannot find the file specified: 'archive.zip")    

In [73]:
images = os.listdir("flickr8k/images/")
images

['1000268201_693b08cb0e.jpg',
 '1001773457_577c3a7d70.jpg',
 '1002674143_1b742ab4b8.jpg',
 '1003163366_44323f5815.jpg',
 '1007129816_e794419615.jpg',
 '1007320043_627395c3d8.jpg',
 '1009434119_febe49276a.jpg',
 '1012212859_01547e3f17.jpg',
 '1015118661_980735411b.jpg',
 '1015584366_dfcec3c85a.jpg',
 '101654506_8eb26cfb60.jpg',
 '101669240_b2d3e7f17b.jpg',
 '1016887272_03199f49c4.jpg',
 '1019077836_6fc9b15408.jpg',
 '1019604187_d087bf9a5f.jpg',
 '1020651753_06077ec457.jpg',
 '1022454332_6af2c1449a.jpg',
 '1022454428_b6b660a67b.jpg',
 '1022975728_75515238d8.jpg',
 '102351840_323e3de834.jpg',
 '1024138940_f1fefbdce1.jpg',
 '102455176_5f8ead62d5.jpg',
 '1026685415_0431cbf574.jpg',
 '1028205764_7e8df9a2ea.jpg',
 '1030985833_b0902ea560.jpg',
 '103106960_e8a41d64f8.jpg',
 '103195344_5d2dc613a3.jpg',
 '103205630_682ca7285b.jpg',
 '1032122270_ea6f0beedb.jpg',
 '1032460886_4a598ed535.jpg',
 '1034276567_49bb87c51c.jpg',
 '104136873_5b5d41be75.jpg',
 '1042020065_fb3d3ba5ba.jpg',
 '1042590306_95dea

In [74]:
df = pd.DataFrame(images,columns=["image"])
df.set_index(df["image"])
df;

In [75]:
# lst = pd.read_csv("flickr8k/captions.txt")["caption"].to_list()
# for i in lst:
#     print(type(i))

In [76]:
list(Path("flickr8k/images/").glob("*"));

In [77]:
#download with !python -m spacy download en_core_web_sm
nlp = spacy.load('en_core_web_sm')
# doc = spacy.lang.en.examples.sentences
# nlp(doc[0])

In [78]:
# for token in nlp(doc[0]):
#     print(token.text)

In [79]:
# [token.text.lower() for token in nlp.tokenizer(doc[0])]

In [80]:
from collections import defaultdict
class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>" : 0, "<SOS>" : 1, "<EOS>" : 2, "<UNK>" : 3}
        self.freq_threshold = freq_threshold
    
    @staticmethod
    def tokenizer(text):
        return [token.text.lower() for token in nlp.tokenizer(text)]
    
    def build_vocabulary(self, sentence_list):
        frequency = defaultdict(lambda:1)
        idx = 4
        
        for sentence in sentence_list:
            for word in self.tokenizer(sentence):
                frequency[word] += 1
                
                if frequency[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    
                    idx +=1
                    
    
    def numericalize(self, text):
        tokenized_text = self.tokenizer(text)

#         return [
#             self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
#             for token in tokenized_text
#         ]
        tokenized_text = self.tokenizer(text)
        
        # [self.stoi[token] if token in self.stoi[token] else self.stoi["<UNK>"] for token in tokenized_text]
        numericalized_list = []
        for token in tokenized_text:
            if token in self.stoi:
                token = self.stoi[token]
                numericalized_list.append(token)
            else:
                token = self.stoi["<UNK>"]
                numericalized_list.append(token)
        
        return numericalized_list
        

In [81]:
class FlickrDataset(Dataset):
    def __init__(self,img_root, caption_file, transform=None, freq_threshold=5):
        self.root = img_root
        self.df = pd.read_csv(caption_file)
        self.transform = transform
        
        self.img = self.df["image"]
        self.caption = self.df["caption"]
        
        #initialize vocabulary and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab = self.vocab.build_vocabulary(self.caption)  #self.caption.to_list() not needed
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image_id = self.df.iloc[idx, 0]
        caption_id = self.df.iloc[idx, 1]
        
#         img = torchvision.io.read_image(str(Path(self.root)/image_id)) / 255
        
        img = Image.open(str(Path(self.root)/image_id))
        if self.transform:
            img = transform(img)
        
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption_id)
        numericalized_caption.append(self.vocab.stoi["<SOS>"])

        return img, torch.tensor(numericalized_caption)
        
        
        

In [82]:
# import cv2 as cv
# image = cv.imread(str(Path("flickr8k/images/1000268201_693b08cb0e.jpg")))
# # image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
# cv.imshow("image",image)
# cv.waitKey(0)
# cv.destroyAllWindows()

In [83]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
    
    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
        
        return imgs, targets

In [84]:
def get_loader(
    img_dir,
    caption_file,
    transform,
    batch_size=32,
    shuffle=True,
    pin_memory=True
    
    
):
    dataset = FlickrDataset(img_dir, caption_file, transform=transform)
    
    pad_idx = dataset.vocab.stoi["<PAD>"]
    loader = DataLoader(
        dataset,
        batch_size,
        shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx)
        
    )
    
    return dataset, loader

In [85]:
if __name__ == "__main__":
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224,244)),
        torchvision.transforms.ToTensor()
    ])
    
    dataset, loader = get_loader(
        "flickr8k/images/",
        "flickr8k/captions.txt",
        transform
        )    

In [86]:
imgs, captions = next(iter(loader))

print(f"img shape: {imgs.shape} | captions: {captions.shape}")
    

img shape: torch.Size([32, 3, 224, 244]) | captions: torch.Size([25, 32])


In [87]:
len(pd.read_csv("flickr8k/captions.txt")["image"]) //32

1264

In [88]:
pd.read_csv("flickr8k/captions.txt").iloc[:, 1].tolist()

['A child in a pink dress is climbing up a set of stairs in an entry way .',
 'A girl going into a wooden building .',
 'A little girl climbing into a wooden playhouse .',
 'A little girl climbing the stairs to her playhouse .',
 'A little girl in a pink dress going into a wooden cabin .',
 'A black dog and a spotted dog are fighting',
 'A black dog and a tri-colored dog playing with each other on the road .',
 'A black dog and a white dog with brown spots are staring at each other in the street .',
 'Two dogs of different breeds looking at each other on the road .',
 'Two dogs on pavement moving toward each other .',
 'A little girl covered in paint sits in front of a painted rainbow with her hands in a bowl .',
 'A little girl is sitting in front of a large painted rainbow .',
 'A small girl in the grass plays with fingerpaints in front of a white canvas with a rainbow on it .',
 'There is a girl with pigtails sitting in front of a rainbow painting .',
 'Young girl with pigtails pain

In [89]:
torchvision.io.read_image(str(Path("flickr8k/images/667626_18933d713e.jpg")))/255

tensor([[[0.9529, 0.9490, 0.9020,  ..., 0.4941, 0.4667, 0.5922],
         [0.8588, 0.8000, 0.9569,  ..., 0.6627, 0.6078, 0.4275],
         [0.9843, 0.8275, 0.7961,  ..., 0.5725, 0.5922, 0.5529],
         ...,
         [0.9059, 0.9176, 0.9412,  ..., 0.4980, 0.6980, 0.6235],
         [0.9216, 0.8902, 0.9373,  ..., 0.8078, 0.8549, 0.7686],
         [0.9765, 0.9373, 0.9294,  ..., 0.9216, 0.8275, 0.9961]],

        [[0.9569, 0.9569, 0.9176,  ..., 0.6314, 0.6000, 0.7216],
         [0.8667, 0.8196, 0.9804,  ..., 0.8000, 0.7412, 0.5569],
         [1.0000, 0.8627, 0.8353,  ..., 0.7176, 0.7333, 0.6941],
         ...,
         [0.9137, 0.9020, 0.8941,  ..., 0.5569, 0.7569, 0.6824],
         [0.8863, 0.8549, 0.8941,  ..., 0.8431, 0.8902, 0.8039],
         [0.8941, 0.8706, 0.8941,  ..., 0.9255, 0.8314, 1.0000]],

        [[1.0000, 1.0000, 0.9529,  ..., 0.6039, 0.6039, 0.7490],
         [0.9137, 0.8353, 0.9804,  ..., 0.7725, 0.7373, 0.5765],
         [1.0000, 0.8510, 0.8000,  ..., 0.6902, 0.7176, 0.