In [1]:
# Imports
import os
import re
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from collections import defaultdict

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms

from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
# Config
IMAGE_DIR = "C:/Users/Idris/OneDrive/Desktop/ML PROJECT 3/Flickr8k_Dataset/Images"
CAPTION_FILE = "C:/Users/Idris/OneDrive/Desktop/ML PROJECT 3/Flickr8k_Dataset/captions.txt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
EMBEDDING_DIM = 256
HIDDEN_DIM = 512
NUM_EPOCHS = 5
MAX_LEN = 35


In [3]:
# 1. Charger les captions
def load_captions(filepath):
    captions = defaultdict(list)
    with open(filepath, 'r', encoding='utf-8') as f:
        lines = f.readlines()[1:]
        for line in lines:
            image_id, caption = line.strip().split(',', 1)
            caption = re.sub(r'[^a-zA-Z0-9 ]+', '', caption.lower())
            captions[image_id].append(f"<start> {caption} <end>")
    return captions

captions_dict = load_captions(CAPTION_FILE)

In [4]:
# 2. Build vocabulary
class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.stoi = {v: k for k, v in self.itos.items()}

    def __len__(self):
        return len(self.itos)

    def tokenizer(self, text):
        return text.lower().split()

    def build_vocab(self, sentence_list):
        frequencies = defaultdict(int)
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer(sentence):
                frequencies[word] += 1
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized = self.tokenizer(text)
        return [
            self.stoi.get(token, self.stoi["<unk>"])
            for token in tokenized
        ]

In [5]:
# 3. Custom Dataset
class FlickrDataset(Dataset):
    def __init__(self, captions_dict, image_dir, transform=None, freq_threshold=5):
        self.image_dir = image_dir
        self.transform = transform
        self.image_ids = list(captions_dict.keys())
        self.captions = []
        for k in self.image_ids:
            for cap in captions_dict[k]:
                self.captions.append((k, cap))

        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab([cap for _, cap in self.captions])

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, index):
        image_id, caption = self.captions[index]
        img_path = os.path.join(self.image_dir, image_id)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        numericalized_caption = [self.vocab.stoi["<start>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<end>"])

        return image, torch.tensor(numericalized_caption)

In [6]:
# 4. Collate Function
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        captions = [item[1] for item in batch]
        captions = nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=self.pad_idx)
        return imgs, captions

# 5. Transforms and Loader
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = FlickrDataset(captions_dict, IMAGE_DIR, transform=transform)
pad_idx = dataset.vocab.stoi["<pad>"]
vocab = dataset.vocab

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=MyCollate(pad_idx))

In [7]:
import pickle

VOCAB_PATH = "C:/Users/Idris/OneDrive/Desktop/ML PROJECT 3/vocab.pkl"

# Sauvegarde du vocabulaire
with open(VOCAB_PATH, 'wb') as f:
    pickle.dump(vocab, f)

print(f"✅ Vocabulaire sauvegardé dans : {VOCAB_PATH}")


✅ Vocabulaire sauvegardé dans : C:/Users/Idris/OneDrive/Desktop/ML PROJECT 3/vocab.pkl


In [8]:
# 6. Encoder CNN + Decoder RNN
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.cnn = models.resnet18(pretrained=True)
        for param in self.cnn.parameters():
            param.requires_grad = False
        self.cnn.fc = nn.Linear(self.cnn.fc.in_features, embed_size)

    def forward(self, images):
        return self.cnn(images)

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions[:, :-1])
        inputs = torch.cat((features.unsqueeze(1), embeddings), 1)
        hiddens, _ = self.lstm(inputs)
        outputs = self.linear(hiddens)
        return outputs


In [9]:
# 7. Full Model
class ImageCaptioningModel(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size):
        super().__init__()
        self.encoder = EncoderCNN(embed_size)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

    def forward(self, images, captions):
        features = self.encoder(images)
        return self.decoder(features, captions)

In [10]:
# 8. Entraînement

NUM_EPOCHS = 15  # ⬅️ Spécifie ici le nombre d'époques souhaité

model = ImageCaptioningModel(EMBEDDING_DIM, HIDDEN_DIM, len(dataset.vocab)).to(DEVICE)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

model.train()
for epoch in range(NUM_EPOCHS):
    total_loss = 0
    for idx, (imgs, caps) in enumerate(tqdm(loader)):
        imgs, caps = imgs.to(DEVICE), caps.to(DEVICE)

        outputs = model(imgs, caps)  # (batch, seq_len, vocab_size)

        # Corriger la forme pour que outputs et targets soient alignés
        outputs = outputs[:, :-1, :]  # ignore prédiction après le dernier token
        targets = caps[:, 1:]         # ignore <SOS>

        # Aplatir pour CrossEntropyLoss
        outputs = outputs.reshape(-1, outputs.shape[2])  # (batch * seq_len, vocab_size)
        targets = targets.reshape(-1)                    # (batch * seq_len)

        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {total_loss/len(loader):.4f}")



100%|██████████| 1265/1265 [18:48<00:00,  1.12it/s]


Epoch [1/15], Loss: 3.8946


100%|██████████| 1265/1265 [18:30<00:00,  1.14it/s]


Epoch [2/15], Loss: 3.4339


100%|██████████| 1265/1265 [18:38<00:00,  1.13it/s]


Epoch [3/15], Loss: 3.2161


100%|██████████| 1265/1265 [18:29<00:00,  1.14it/s]


Epoch [4/15], Loss: 3.0725


100%|██████████| 1265/1265 [18:30<00:00,  1.14it/s]


Epoch [5/15], Loss: 2.9618


100%|██████████| 1265/1265 [18:29<00:00,  1.14it/s]


Epoch [6/15], Loss: 2.8659


100%|██████████| 1265/1265 [17:31<00:00,  1.20it/s]


Epoch [7/15], Loss: 2.7804


100%|██████████| 1265/1265 [16:41<00:00,  1.26it/s]


Epoch [8/15], Loss: 2.7013


100%|██████████| 1265/1265 [16:29<00:00,  1.28it/s]


Epoch [9/15], Loss: 2.6245


100%|██████████| 1265/1265 [16:28<00:00,  1.28it/s]


Epoch [10/15], Loss: 2.5509


100%|██████████| 1265/1265 [16:23<00:00,  1.29it/s]


Epoch [11/15], Loss: 2.4820


100%|██████████| 1265/1265 [16:22<00:00,  1.29it/s]


Epoch [12/15], Loss: 2.4127


100%|██████████| 1265/1265 [16:12<00:00,  1.30it/s]


Epoch [13/15], Loss: 2.3455


100%|██████████| 1265/1265 [16:01<00:00,  1.32it/s]


Epoch [14/15], Loss: 2.2788


100%|██████████| 1265/1265 [16:03<00:00,  1.31it/s]

Epoch [15/15], Loss: 2.2178





In [11]:
# 9. Sauvegarder le modèle
torch.save(model.state_dict(), "model.pth")
print("✅ Modèle sauvegardé sous model.pth")

✅ Modèle sauvegardé sous model.pth
