In [4]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence


class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)  # Load the pre-trained weights
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.fc = nn.Linear(resnet.fc.in_features, embed_size)

    def forward(self, images):
        features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.fc(features)
        return features


class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seq_length = max_seq_length

    def forward(self, features, captions, lengths):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        
        # Sort the lengths array in decreasing order
        sorted_lengths, sorted_indices = torch.sort(lengths, descending=True)
        embeddings = embeddings[sorted_indices]
        
        packed = pack_padded_sequence(embeddings, sorted_lengths, batch_first=True)
        hiddens, _ = self.lstm(packed)
        outputs = self.linear(hiddens[0])
        
        return outputs

    def sample(self, features, states=None):
        sampled_ids = []
        inputs = features.unsqueeze(1)
        for i in range(self.max_seq_length):
            hiddens, states = self.lstm(inputs, states)
            outputs = self.linear(hiddens.squeeze(1))
            _, predicted = outputs.max(1)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted)
            inputs = inputs.unsqueeze(1)
        sampled_ids = torch.stack(sampled_ids, 1)
        return sampled_ids


# Instantiate the encoder and decoder models
embed_size = 256
hidden_size = 512
vocab_size = 10000
num_layers = 1
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

# # Generate some dummy data for testing
# batch_size = 3
# seq_length = 4
# image_features = torch.randn(batch_size, 3, 224, 224)  # 4D input tensor
# captions = torch.randint(0, vocab_size, (batch_size, seq_length))
# lengths = torch.randint(1, seq_length, (batch_size,))

# # Forward pass
# outputs = encoder(image_features)
# outputs = decoder(outputs, captions, lengths)

# Print the outputs shape
print(outputs.shape)
# sampled_ids = decoder.sample(features)

# print(sampled_ids.shape)  # Example sampled IDs shape

# Save the encoder and decoder models
torch.save(encoder.state_dict(), './model/encoder.pkl')
torch.save(decoder.state_dict(), './model/decoder.pkl')

torch.Size([9, 10000])
