In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import nltk
from nltk.tokenize import word_tokenize
from collections import Counter
import os
from torch.nn.utils.rnn import pad_sequence

In [2]:
# Download NLTK data for tokenization
nltk.download('punkt')
if torch.cuda.is_available():
    device=torch.device(type="cuda",index=0)
else:
    device=torch.device(type="cpu",index=0)

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
# Parameters
image_size = 224
embedding_dim = 256
hidden_dim = 512
freq_threshold = 5
batch_size = 32
num_epochs = 10
learning_rate = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# Image transformations
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [5]:
# Vocabulary class to build word-to-index and index-to-word mappings
class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {v: k for k, v in self.itos.items()}

    def __len__(self):
        return len(self.itos)

    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4  # Starting index for new words

        for sentence in sentence_list:
            tokens = word_tokenize(sentence.lower())
            frequencies.update(tokens)

            for token, freq in frequencies.items():
                if freq >= self.freq_threshold and token not in self.stoi:
                    self.stoi[token] = idx
                    self.itos[idx] = token
                    idx += 1

    def numericalize(self, text):
        tokens = word_tokenize(text.lower())
        return [self.stoi.get(token, self.stoi["<UNK>"]) for token in tokens]


In [6]:
# Building the vocabulary
captions_file = '/kaggle/input/flickr-8k-images/Flickr8k/Flickr8k_text/Flickr8k.token.txt'
img_folder = '/kaggle/input/flickr-8k-images/Flickr8k/Flickr8k_Dataset/Flicker8k_Dataset'
captions_list = []

# Read all captions and build vocabulary
with open(captions_file, 'r') as file:
    lines = file.readlines()
    for line in lines:
        arr = line.strip().split('\t')
        if len(arr) > 1:
            captions_list.append(arr[1])

vocab = Vocabulary(freq_threshold)
vocab.build_vocabulary(captions_list)


In [7]:
%%writefile token.txt
""

Overwriting token.txt


In [8]:
import os

def preprocess_images(captions_file, img_folder):
    """
    Preprocess the captions and image paths.
    Checks if each image in the captions file exists in the img_folder.
    If an image does not exist, it is removed from the captions file.
    
    Args:
        captions_file (str): Path to the file containing captions.
        img_folder (str): Path to the folder containing images.
        
    Returns:
        None: Updates the captions file in place with only existing images.
    """
    filtered_captions = []

    # Load captions and check for image existence
    with open(captions_file, 'r') as file:
        lines = file.readlines()
        
        for line in lines:
            img_id, caption = line.strip().split('\t')
            img_name = img_id.split('#')[0]  # Strip out caption ID if present
            img_path = os.path.join(img_folder, img_name)
            
            # Check if image exists in the folder
            if os.path.exists(img_path):
                filtered_captions.append(f"{img_id}\t{caption}")
            else:
                print(f"Image not found: {img_name}. Skipping this entry.")
    
    # Write the filtered captions back to the original file
    with open('/kaggle/working/token.txt', 'w') as file:
        file.write('\n'.join(filtered_captions))
    print("Updated captions file saved.")

# Paths
# captions_file = '/content/drive/MyDrive/Image_caption/Flickr8k_text/Flickr8k_text/Flickr8k.token.txt'
# img_folder = '/content/drive/MyDrive/Image_caption/Flicker8k_Dataset/Flicker8k_Dataset'

# Run the preprocessing function to update the captions file
preprocess_images(captions_file, img_folder)


Image not found: 2258277193_586949ec62.jpg.1. Skipping this entry.
Image not found: 2258277193_586949ec62.jpg.1. Skipping this entry.
Image not found: 2258277193_586949ec62.jpg.1. Skipping this entry.
Image not found: 2258277193_586949ec62.jpg.1. Skipping this entry.
Image not found: 2258277193_586949ec62.jpg.1. Skipping this entry.
Updated captions file saved.


In [10]:
captions_file='/kaggle/working/token.txt'

In [11]:
print(vocab.__len__())
print(captions_list[0])

3005
A child in a pink dress is climbing up a set of stairs in an entry way .


In [12]:
class Flickr8kDataset(Dataset):
    def __init__(self, img_folder, captions_file, transform=None, vocab=None):
        self.img_folder = img_folder
        self.transform = transform
        self.vocab = vocab
        self.captions = self.load_captions(captions_file)

    def load_captions(self, captions_file):
        with open(captions_file, 'r') as file:
            lines = file.readlines()
        captions = {}
        for line in lines:
            img, caption = line.strip().split('\t')
            img_id = img.split('#')[0]
            if img_id not in captions:
                captions[img_id] = []
            captions[img_id].append(caption)
        return captions

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        img_id = list(self.captions.keys())[idx]
        img_path = os.path.join(self.img_folder, img_id)
        
        # Verify if the image exists
        if not os.path.exists(img_path):
            print(f"Warning: File {img_path} does not exist.")
            raise IndexError  # Skip this item in the DataLoader

        # Load image and apply transformations
        image = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)

        # Process caption
        caption = self.captions[img_id][0]
        caption = [self.vocab.stoi["<SOS>"]] + self.vocab.numericalize(caption) + [self.vocab.stoi["<EOS>"]]
        
        return image, caption  # Return caption as a list, not a tensor

# Updated DataLoader with error handling
def collate_fn(batch):
    batch = [b for b in batch if b is not None]  # Remove None entries if any images were skipped
    images = []
    captions = []

    for img, caption in batch:
        images.append(img)
        captions.append(torch.tensor(caption, dtype=torch.long))

    # Stack images and pad captions
    images = torch.stack(images)
    captions = pad_sequence(captions, batch_first=True, padding_value=vocab.stoi["<PAD>"])
    return images, captions



In [13]:
# Encoder model
class Encoder(nn.Module):
    def __init__(self, embed_size):
        super(Encoder, self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.fc = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.bn(self.fc(features))
        return features

# Decoder model
class Decoder(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs

# Initialize the dataset and dataloader
dataset = Flickr8kDataset(img_folder, captions_file, transform=transform, vocab=vocab)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)



# Initialize models, loss, and optimizer
encoder = Encoder(embed_size=embedding_dim).to(device)
decoder = Decoder(embed_size=embedding_dim, hidden_size=hidden_dim, vocab_size=len(vocab)).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<PAD>"])
params = list(decoder.parameters()) + list(encoder.fc.parameters()) + list(encoder.bn.parameters())
optimizer = optim.Adam(params, lr=learning_rate)



In [14]:
for epoch in range(num_epochs):
    for i, (images, captions) in enumerate(data_loader):
        images, captions = images.to(device), captions.to(device)

        # Forward pass through encoder
        features = encoder(images)

        # Pass all tokens except the last one to the decoder
        outputs = decoder(features, captions[:, :-1])  # Predict the next token for each token in captions


        outputs = outputs.view(-1, outputs.shape[2])  # Flatten to (batch_size * (seq_len - 1), vocab_size)
        targets = captions[:, :].contiguous().view(-1)  # Flatten to (batch_size * (seq_len - 1))

        # Debug: Print shapes to verify alignment before loss calculation
        # print(f"Outputs shape: {outputs.shape}")  # Should be (batch_size * (seq_len - 1), vocab_size)
        # print(f"Targets shape: {targets.shape}")  # Should be (batch_size * (seq_len - 1))

        # Calculate loss, ignoring <PAD> tokens
        loss = criterion(outputs, targets)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Step [{i}/{len(data_loader)}], Loss: {loss.item():.4f}")

print("Training completed!")


Epoch [0/10], Step [0/253], Loss: 8.0075
Epoch [0/10], Step [100/253], Loss: 3.5962
Epoch [0/10], Step [200/253], Loss: 3.1139
Epoch [1/10], Step [0/253], Loss: 3.1001
Epoch [1/10], Step [100/253], Loss: 2.9630
Epoch [1/10], Step [200/253], Loss: 2.8545
Epoch [2/10], Step [0/253], Loss: 2.3913
Epoch [2/10], Step [100/253], Loss: 2.7825
Epoch [2/10], Step [200/253], Loss: 2.4289
Epoch [3/10], Step [0/253], Loss: 2.3332
Epoch [3/10], Step [100/253], Loss: 2.1646
Epoch [3/10], Step [200/253], Loss: 2.8367
Epoch [4/10], Step [0/253], Loss: 2.1732
Epoch [4/10], Step [100/253], Loss: 2.2393
Epoch [4/10], Step [200/253], Loss: 2.2535
Epoch [5/10], Step [0/253], Loss: 1.9444
Epoch [5/10], Step [100/253], Loss: 2.0729
Epoch [5/10], Step [200/253], Loss: 1.8761
Epoch [6/10], Step [0/253], Loss: 1.7802
Epoch [6/10], Step [100/253], Loss: 1.6981
Epoch [6/10], Step [200/253], Loss: 1.8455
Epoch [7/10], Step [0/253], Loss: 1.5268
Epoch [7/10], Step [100/253], Loss: 1.6378
Epoch [7/10], Step [200/253

In [None]:
import torch
from PIL import Image
import torchvision.transforms as transforms

def load_image(image_path, transform):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image.to(device)

def generate_caption(encoder, decoder, image_tensor, vocab, max_length=20):
    # Get image features from encoder
    with torch.no_grad():
        features = encoder(image_tensor).unsqueeze(1)  # Shape: (1, 1, embed_size)
    
    # Initialize the caption with the start token <SOS>
    caption = [vocab.stoi["<SOS>"]]
    
    for _ in range(max_length):
        # Convert current caption sequence to tensor and pass through decoder
        caption_tensor = torch.tensor(caption).unsqueeze(0).to(device)  # Shape: (1, seq_len)
        
        # Predict next word
        with torch.no_grad():
            outputs = decoder(features, caption_tensor)
        
        # Get the most likely word index from the decoder's output
        predicted_idx = outputs.argmax(2)[:, -1].item()
        caption.append(predicted_idx)
        
        # If the predicted word is <EOS>, stop generating
        if predicted_idx == vocab.stoi["<EOS>"]:
            break

    # Convert word indices to words
    caption_words = [vocab.itos[idx] for idx in caption[1:]]  # Exclude <SOS>
    return " ".join(caption_words)

# Define transformation for the input image
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# Paths
image_path = '/kaggle/input/flickr-8k-images/Flickr8k/Flickr8k_Dataset/Flicker8k_Dataset/1001773457_577c3a7d70.jpg'  # Path to the image you want to caption

# Load and preprocess the image
image_tensor = load_image(image_path, transform)

# Generate the caption
caption = generate_caption(encoder, decoder, image_tensor, vocab)
print("Generated Caption:", caption)
