In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable
from torchvision import transforms
from PIL import Image

# Load pre-trained ResNet model
resnet = models.resnet50(pretrained=True)
modules = list(resnet.children())[:-1]
resnet = nn.Sequential(*modules)
for param in resnet.parameters():
    param.requires_grad = False
resnet.eval()

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# LSTM-based captioning model
class CaptionGenerator(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size):
        super(CaptionGenerator, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        lstm_out, _ = self.lstm(embeddings)
        outputs = self.linear(lstm_out)
        return outputs

# Load a sample image
image_path = 'path/to/your/image.jpg'
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0)

# Extract features using ResNet
with torch.no_grad():
    image_features = resnet(image)
image_features = image_features.view(image_features.size(0), -1)

# Load your vocabulary
# Note: In a real scenario, you would need to have a vocabulary mapping words to indices
vocab_size = 10000  # Adjust based on your vocabulary size

# Create the captioning model
embed_size = 256
hidden_size = 512
caption_model = CaptionGenerator(embed_size, hidden_size, vocab_size)

# Generate a caption
max_caption_length = 20
start_token = torch.tensor([[1]])  # Start token for the caption

with torch.no_grad():
    for _ in range(max_caption_length):
        caption_output = caption_model(image_features, start_token)
        _, predicted_index = caption_output.max(2)
        predicted_token = predicted_index[0][-1].item()

        if predicted_token == 2:  # End token
            break

        start_token = torch.cat((start_token, torch.tensor([[predicted_token]])), 1)

# Convert the predicted indices to words (replace with your vocabulary)
predicted_caption = [your_vocabulary_index_to_word_mapping[token] for token in start_token[0].tolist()]
predicted_caption = ' '.join(predicted_caption[1:])  # Exclude the start token

print("Predicted Caption:", predicted_caption)
