# Captioning Model

This notebook will contain the code to train the captioning models (decoders). We load the encoders that were pre-trained on the 18-attribute data. We compare a LSTM decoder with a (refining) transformer decoder architecture.

In [None]:
import utils.load_funcs
import json
import torch,torchvision
from torch import nn
from torchsummary import summary
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

print(device)

In [None]:
# Load Data
train_loader, val_loader = utils.load_funcs.get_data_loaders()
images, labels, captions = next(iter(train_loader))
print(images.shape)
print(labels.shape)
print(captions, captions.shape)

In [None]:
train_loader.dataset.captions.head()

In [None]:
# Get the tokenizer and vocab dictionary
tokenizer = train_loader.dataset.tokenizer
vocab = json.loads(tokenizer.get_config()['index_word'])
vocab = {v: int(k)-1 for k, v in vocab.items()}
print(vocab)
print('Vocab Length: ', len(vocab))

In [None]:
# Define Classes for Encoder (Classifier)/Decoder
class AttributeClassifier(torch.nn.Module):
    def __init__(self, in_features) -> None:
        super().__init__()
        self.forks = torch.nn.ModuleList()
        for class_count in attribute_classes:
            fork = torch.nn.Linear(in_features=in_features, out_features=class_count)
            self.forks.append(fork)
    
    def forward(self, x):
        out = []
        for index,fork in enumerate(self.forks):
            out_fork = fork(x) #Classification
            out.append(out_fork)
        return out

class ClassifierModel(torch.nn.Module):
    def __init__(self, backbone, backbone_out_features) -> None:
        super().__init__()
        self.backbone = backbone
        self.classifier = AttributeClassifier(backbone_out_features)
    
    def forward(self, x):
        out = self.backbone(x)
        out = self.classifier(out)
        return out

# Define LSTM Decoder
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, feature_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        
        # define the properties
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        
        # embedding layer
        self.embed = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embed_size)
        
        # lstm cell
        self.lstm_cell = nn.LSTMCell(input_size=embed_size+feature_size, hidden_size=hidden_size)
    
        # output fully connected layer
        self.fc_out = nn.Linear(in_features=self.hidden_size, out_features=self.vocab_size)
    
    def forward(self, features, captions, mode='train'):
        # batch size
        batch_size = features.size(0)
        features = torch.unsqueeze(features, dim=1)
        # init the hidden and cell states to zeros
        hidden_state = torch.zeros((batch_size, self.hidden_size)).to(device, non_blocking=True)
        cell_state = torch.zeros((batch_size, self.hidden_size)).to(device, non_blocking=True)
        max_caption_length = 109
        
        # define the output tensor placeholder
        outputs = torch.zeros((batch_size, max_caption_length - 1, self.vocab_size)).to(device, non_blocking=True)
        # Embedding the captions
        embeddings = self.embed(captions.int())
        # Concat Embeddings with features
        embeddings = torch.cat((features.expand((-1, embeddings.shape[1], -1)), embeddings), dim = -1) #shape = (batch_size, 95, 768+512=1280)
        # Pass the caption word by word in train mode
        if mode == 'train':
            #embeddings = torch.roll(embeddings, shifts=-1, dims=-1)
            for t in range(outputs.size(1)):
                hidden_state, cell_state = self.lstm_cell(embeddings[:, t, :], (hidden_state, cell_state))
                out = self.fc_out(hidden_state)
                # build the output tensor
                outputs[:, t, :] = out
        # In test mode, we generate until length = max_caption_length
        else:
            t = 0
            while t < max_caption_length:
                # First time step - feed <sos> token
                if t == 0:
                    hidden_state, cell_state = self.lstm_cell(embeddings[:, 0, :], (hidden_state, cell_state))
                else:
                    prev_output = outputs[:, t-1, :]
                    prev_output = torch.argmax(prev_output, dim=-1)
                    prev_output = self.embed(prev_output.int())
                    prev_output = torch.cat((features, prev_output), dim=-1)
                    hidden_state, cell_state = self.lstm_cell(prev_output, (hidden_state, cell_state))
                out = self.fc_out(hidden_state)
                outputs[:, t, :] = out
        return outputs

# Define Full Captioning Model Class which has a encoder+decoder
class CaptionModel(nn.Module):
    def __init__(self, encoder, decoder, vocab) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.vocab = vocab
    
    def forward(self, images, captions, mode='train'):
        features = self.encoder(images)
        if mode != 'train':
            start_token = self.vocab['sos']
            captions = torch.full((images.shape[0],1), fill_value=start_token)
        out = self.decoder(features, captions, mode)
        return out

In [None]:
# Initialize decoder
LSTM_decoder = DecoderRNN(embed_size=512, feature_size=768, hidden_size=512, vocab_size=len(vocab), num_layers=3)

In [None]:
# Check number of parameters
pytorch_total_params = sum(p.numel() for p in LSTM_decoder.parameters() if p.requires_grad)
print(pytorch_total_params)

In [None]:
# Load trained encoder(s)
attribute_classes = [
    6, 5, 4, 3, 5, 3, 3, 3, 5, 8, 3, 3, #Shape Attributes
    8, 8, 8, #Fabric Attributes
    8, 8, 8 #Color Attributes
]

backbone = torchvision.models.swin_t()
backbone.head = torch.nn.Identity()
transformer_encoder = ClassifierModel(backbone, 768)
# We load the transformer attribute prediction model which had ~0.9 accuracy
transformer_encoder.load_state_dict(
    torch.load('./models/transformer_unfreeze_attribute_model.pth')['model_state_dict']
)

In [None]:
# Drop Classifier Head and just keep feature extractor (backbone)
transformer_encoder = transformer_encoder.backbone
# Freeze params
for param in transformer_encoder.parameters():
    param.requires_grad = False
transformer_caption_model = CaptionModel(transformer_encoder, LSTM_decoder, vocab)
print(transformer_caption_model)

In [None]:
# Training the model
from utils.train_funcs import fit

epochs = 3
learning_rate = 1e-3
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(transformer_caption_model.decoder.parameters(), lr=learning_rate)

fit(
    transformer_caption_model,
    train_loader,
    val_loader,
    vocab,
    optimizer,
    loss_func,
    epochs,
    device,
    name='rnn_decoder'
)