In [2]:
import os 
import string
import torch 
# import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torchvision import transforms # wud just be using the pytorch library directly
from nltk.tokenize import word_tokenize
from torch.utils.data import Dataset, DataLoader
import numpy as np 
from collections import Counter
"""
The flickr8k Dataset consists of 8k images and the captions given to them
The problem is essentialy mapping the Image to text space 
    So we need to find the semantics of the caption and map the image(which is transformed into a vectorised form) 
"""
# preprocessing of the dataset here
# There are two types of files 
# One is the image file and the other is the caption file 
# In captions file we have (image_name,txt)
# We have to separate the image name and the caption differently
def load_dataset(filename):
    file = open(filename, 'r')
    text = file.read()
    file.close()
    
    dataset = list()
    for line in text.split('\n'):
        if len(line)<1:
            continue
        identifier = line.split(',')[0]
        dataset.append(identifier)
    return set(dataset)


dataset = load_dataset('archive/captions.txt')
print(len(dataset))

8092


In [3]:
def load_descriptions(filename, dataset):
    with open(filename, 'r') as file:
        text = file.read()
    descriptions = {}
    for line in text.split('\n'):
        tokens = line.split(',',1)
        if len(line) < 1:
            continue
        image_id, image_desc = tokens[0], tokens[1:]
        if image_id in dataset:
            if image_id not in descriptions:
                descriptions[image_id] = []
            desc = 'startseq ' + ' '.join(image_desc) + ' endseq'
            descriptions[image_id] = desc
    return descriptions



descriptions = load_descriptions('archive/captions.txt', dataset)


In [4]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to C:\Users\Sambhram
[nltk_data]     Shetty\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [36]:
def build_vocab(descriptions):
    all_desc = [desc for desc in descriptions.values()]
    counter = Counter()
    for desc in all_desc:
        tokens = word_tokenize(desc)
        counter.update(tokens)
    words = [word for word , count in counter.items() if count>=1]
    word_to_index = {word: idx+1 for idx, word in enumerate(words)}
    word_to_index['<pad>'] = 0
    index_to_word = {idx: word for word, idx in word_to_index.items()}
    return word_to_index, index_to_word
word_to_index, index_to_word = build_vocab(descriptions)
vocab_size = len(word_to_index)
print(word_to_index, vocab_size)
max_length = max(len(word_tokenize(desc)) for desc in descriptions.values())
print(max_length)

42


In [276]:
import torchvision.models as models
from PIL import Image
class FlickrDataset(Dataset):
    def __init__(self, image_dir, descriptions, word_to_index, transform):
        self.image_dir = image_dir
        self.descriptions = descriptions
        self.word_to_index = word_to_index
        self.transform = transform
        self.image_ids = list(descriptions.keys())
        self.vgg16 = models.vgg16( weights='DEFAULT').to('cuda:0')
        self.vgg16_features = self.vgg16.features
        self.vgg16.eval()
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        description = self.descriptions[image_id]
        image_path = os.path.join(self.image_dir , image_id)
        if not os.path.exists(image_path):
            print(f'File not found: {image_path}')
            return None
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
            image = image.to('cuda:0')
        with torch.no_grad():
            image_features = self.vgg16_features(image.unsqueeze(0))
        
        tokens = word_tokenize(description)
        target = [self.word_to_index[word] for word in tokens]
        target = torch.tensor(target)
        return image_features, target
    
transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
# Have to use custom collate function because of the padding 

def custom_collate_fn(batch):
    image_features = [item[0] for item in batch if item !=None]
    targets = [item[1] for item in batch if item!=None]
    image_features = torch.stack(image_features, dim=0)
    targets = pad_sequence(targets, batch_first=True, padding_value=0)
    return image_features, targets
image_dir = 'archive/Images'
dataset = FlickrDataset(image_dir, descriptions, word_to_index, transform)
dataloader = DataLoader(dataset, batch_size = 32, shuffle=True, collate_fn=custom_collate_fn)



In [277]:
# Now the Glove embeddings 

def load_glove(filename):
    embeddings_index = {}
    with open(filename, 'r', encoding='utf-8') as file:
        for line in file:
            values = line.split()
            word = values[0]
            coefs = np.array(values[1:], dtype='float32')
            embeddings_index[word] = coefs
    return embeddings_index

embeddings_index = load_glove('glove.6B.200d.txt/glove.6B.200d.txt')

embedding_dim = 200
embedding_matrix = np.zeros((vocab_size, embedding_dim))
for word, i in word_to_index.items():
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        embedding_matrix[i] = embedding_vector
    
embedding_matrix = torch.tensor(embedding_matrix).to('cuda:0')

In [278]:
#Making the model
import torch.nn as nn
import torch.optim as optim
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(512*7*7, 256) # B,256
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x  = self.dropout(self.fc(x))
        x = self.relu(x)
        return x

class Decoder(nn.Module):
    def __init__(self, embedding_matrix, hidden_size, vocab_size, max_length):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=False)
        self.lstm = nn.LSTM(embedding_dim + 256, hidden_size, batch_first = True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)
        self.max_length = max_length
    
    def forward(self, features, captions):
        embeddings = self.embedding(captions)

        features = features.unsqueeze(1) * torch.ones((embeddings.shape[1],1)).cuda()
        embeddings = torch.cat((features, embeddings), 2)
        embeddings = embeddings.to(torch.float32)
        hiddens , _ = self.lstm(embeddings)
        outputs = self.fc(hiddens)
        return outputs
    
encoder = Encoder().to('cuda:0')
decoder = Decoder(embedding_matrix, 256, vocab_size, max_length).to('cuda:0')

criteria = nn.CrossEntropyLoss(ignore_index=0).to('cuda:0')
optimizer =optim.Adam(list(encoder.parameters()) + list(decoder.parameters()))

In [279]:
def train_model(encoder, decoder, dataloader, criteria, optimizer, num_epochs = 2):
    
    encoder.train()
    decoder.train()
    batch_shape = list()
    i = 0
    for epoch in range(num_epochs):
        for batch in iter(dataloader):
            i = i+1
            image_features, captions = batch
            captions = captions.to('cuda:0')
            images = image_features.unsqueeze(1).to('cuda:0')
            targets = captions[:, 1:]
            captions = captions[:,:-1]
            features = encoder(images).to('cuda:0')
            outputs = decoder(features, captions)
            outputs = outputs.view(-1, vocab_size)
            targets = targets.contiguous().view(-1)
            loss = criteria(outputs, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(i)
            torch.save(encoder.state_dict(), 'encoder.pth')
            torch.save(decoder.state_dict(), 'decoder.pth')
        print(f'Epoch, {epoch+1}/{num_epochs}, Loss:{loss.item():.4f}')

train_model(encoder, decoder, dataloader, criteria, optimizer = optimizer)
torch.save(encoder.state_dict(), 'encoder.pth')
torch.save(decoder.state_dict(), 'decoder.pth')


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
