# Captioning Model

This notebook will contain the code to train the LSTM captioning models (decoder). We load the swin-t encoder that was pre-trained on the 18-attribute data from the previous notebook. 

In [1]:
import utils.load_funcs
import json
import torch,torchvision
from torch import nn
from torchsummary import summary
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

print(device)

cuda


In [2]:
# Load Data
train_loader, val_loader = utils.load_funcs.get_data_loaders()
images, labels, captions = next(iter(train_loader))
print(images.shape)
print(labels.shape)
print(captions, captions.shape)

torch.Size([64, 3, 329, 224])
torch.Size([64, 18])
tensor([[15,  2, 49,  ...,  0,  0,  0],
        [15, 53, 49,  ...,  0,  0,  0],
        [15,  2, 82,  ...,  0,  0,  0],
        ...,
        [15,  2, 33,  ...,  0,  0,  0],
        [15,  2, 64,  ...,  0,  0,  0],
        [15,  2, 33,  ...,  0,  0,  0]]) torch.Size([64, 109])


In [3]:
train_loader.dataset.captions.head()

Unnamed: 0,0,sequence
MEN-Denim-id_00000080-01_7_additional.jpg,The lower clothing is of long length. The fabr...,"[15, 2, 78, 44, 3, 34, 35, 54, 1, 2, 5, 3, 10,..."
MEN-Denim-id_00000089-01_7_additional.jpg,"His tank top has sleeves cut off, cotton fabri...","[15, 53, 20, 45, 12, 30, 67, 68, 25, 10, 5, 8,..."
MEN-Denim-id_00000089-02_7_additional.jpg,"His sweater has long sleeves, cotton fabric an...","[15, 53, 49, 12, 35, 30, 25, 10, 5, 8, 84, 7, ..."
MEN-Denim-id_00000089-03_7_additional.jpg,"His shirt has short sleeves, cotton fabric and...","[15, 53, 17, 12, 51, 30, 25, 10, 5, 8, 28, 14,..."
MEN-Denim-id_00000089-04_7_additional.jpg,"The sweater the person wears has long sleeves,...","[15, 2, 49, 2, 26, 9, 12, 35, 30, 25, 47, 5, 3..."


In [4]:
# Get the tokenizer and vocab dictionary
tokenizer = train_loader.dataset.tokenizer
vocab = json.loads(tokenizer.get_config()['index_word'])
vocab = {v: int(k)-1 for k, v in vocab.items()}
print(vocab)
print('Vocab Length: ', len(vocab))

{'unk': 0, '.': 1, 'the': 2, 'is': 3, 'a': 4, 'fabric': 5, 'with': 6, 'patterns': 7, 'and': 8, 'wears': 9, 'cotton': 10, 'her': 11, 'has': 12, 'this': 13, 'color': 14, 'sos': 15, 'eos': 16, 'shirt': 17, 'on': 18, 'there': 19, 'tank': 20, 'it': 21, 'neckline': 22, 'ring': 23, 'an': 24, ',': 25, 'person': 26, 'accessory': 27, 'pure': 28, 'solid': 29, 'sleeves': 30, 'wearing': 31, 'lady': 32, 'female': 33, 'of': 34, 'long': 35, 'pants': 36, 'are': 37, 'three': 38, 'finger': 39, 'wrist': 40, 'graphic': 41, 'point': 42, 'shorts': 43, 'clothing': 44, 'top': 45, 'sleeve': 46, 'its': 47, 'woman': 48, 'sweater': 49, 'denim': 50, 'short': 51, 'in': 52, 'his': 53, 'length': 54, 'crew': 55, 'round': 56, 'chiffon': 57, 't': 58, 'neck': 59, 'neckwear': 60, 'trousers': 61, 'outer': 62, 'hat': 63, 'upper': 64, 'no': 65, 'sleeveless': 66, 'cut': 67, 'off': 68, 'knitting': 69, 'suspenders': 70, 'pattern': 71, 'lapel': 72, 'floral': 73, 'medium': 74, 'v': 75, 'shape': 76, 'head': 77, 'lower': 78, 'other'

We define our LSTM Decoder and Caption Model classes. The LSTM Decoder we choose has an embedding layer of size 512 with 3 LSTM layers, and each LSTM layer has a hidden dimension of 512. Our feature extraction backbone (encoder) has the classifier head removed, which results in an output feature map of size (1x768).

In [5]:
# Define Classes for Encoder (Classifier)/Decoder
class AttributeClassifier(torch.nn.Module):
    def __init__(self, in_features) -> None:
        super().__init__()
        self.forks = torch.nn.ModuleList()
        for class_count in attribute_classes:
            fork = torch.nn.Linear(in_features=in_features, out_features=class_count)
            self.forks.append(fork)
    
    def forward(self, x):
        out = []
        for index,fork in enumerate(self.forks):
            out_fork = fork(x) #Classification
            out.append(out_fork)
        return out

class ClassifierModel(torch.nn.Module):
    def __init__(self, backbone, backbone_out_features) -> None:
        super().__init__()
        self.backbone = backbone
        self.classifier = AttributeClassifier(backbone_out_features)
    
    def forward(self, x):
        out = self.backbone(x)
        out = self.classifier(out)
        return out

# Define LSTM Decoder
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, feature_size, hidden_size, vocab_size):
        super(DecoderRNN, self).__init__()
        
        # define the properties
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        
        # embedding layer
        self.embed = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embed_size)
        
        # lstm cells
        self.lstm_cell_1 = nn.LSTMCell(input_size=embed_size+feature_size, hidden_size=hidden_size)
        self.lstm_cell_2 = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
        self.lstm_cell_3 = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
    
        # output fully connected layer
        self.fc_out = nn.Linear(in_features=self.hidden_size, out_features=self.vocab_size)
        
        # dropout layer
        self.dropout = nn.Dropout(0.4)
    
    def forward(self, features, captions, mode='train'):
        # batch size
        batch_size = features.size(0)
        features = torch.unsqueeze(features, dim=1)
        # init the hidden and cell states to zeros
        hidden_state_1 = torch.zeros((batch_size, self.hidden_size)).to(device, non_blocking=True)
        cell_state_1 = torch.zeros((batch_size, self.hidden_size)).to(device, non_blocking=True)
        hidden_state_2 = torch.zeros((batch_size, self.hidden_size)).to(device, non_blocking=True)
        cell_state_2 = torch.zeros((batch_size, self.hidden_size)).to(device, non_blocking=True)
        hidden_state_3 = torch.zeros((batch_size, self.hidden_size)).to(device, non_blocking=True)
        cell_state_3 = torch.zeros((batch_size, self.hidden_size)).to(device, non_blocking=True)
        max_caption_length = 109
        
        # define the output tensor placeholder
        outputs = torch.zeros((batch_size, max_caption_length - 1, self.vocab_size)).to(device, non_blocking=True)
        # Embedding the captions
        embeddings = self.embed(captions.int())
        # Concat Embeddings with features
        embeddings = torch.cat((features.expand((-1, embeddings.shape[1], -1)), embeddings), dim = -1) #shape = (batch_size, 95, 768+512=1280)
        # Pass the caption word by word in train mode
        if mode == 'train':
            #embeddings = torch.roll(embeddings, shifts=-1, dims=-1)
            for t in range(outputs.size(1)):
                hidden_state_1, cell_state_1 = self.lstm_cell_1(embeddings[:, t, :], (hidden_state_1, cell_state_1))
                hidden_state_1 = self.dropout(hidden_state_1)
                hidden_state_2, cell_state_2 = self.lstm_cell_2(hidden_state_1, (hidden_state_2, cell_state_2))
                hidden_state_2 = self.dropout(hidden_state_2)
                hidden_state_3, cell_state_3 = self.lstm_cell_3(hidden_state_2, (hidden_state_3, cell_state_3))
                hidden_state_3 = self.dropout(hidden_state_3)
                out = self.fc_out(hidden_state_3)
                # build the output tensor
                outputs[:, t, :] = out
        # In test mode, we generate until length = max_caption_length
        else:
            t = 0
            while t < max_caption_length - 1:
                # First time step - feed <sos> token
                if t == 0:
                    hidden_state_1, cell_state_1 = self.lstm_cell_1(embeddings[:, 0, :], (hidden_state_1, cell_state_1))
                    hidden_state_1 = self.dropout(hidden_state_1)
                    hidden_state_2, cell_state_2 = self.lstm_cell_2(hidden_state_1, (hidden_state_2, cell_state_2))
                    hidden_state_2 = self.dropout(hidden_state_2)
                    hidden_state_3, cell_state_3 = self.lstm_cell_3(hidden_state_2, (hidden_state_3, cell_state_3))
                    hidden_state_3 = self.dropout(hidden_state_3)
                else:
                    prev_output = outputs[:, t-1, :]
                    prev_output = torch.argmax(prev_output, dim=-1)
                    prev_output = self.embed(prev_output.int())
                    prev_output = torch.cat((features.squeeze(dim=1), prev_output), dim=-1)                    
                    hidden_state_1, cell_state_1 = self.lstm_cell_1(prev_output, (hidden_state_1, cell_state_1))
                    hidden_state_1 = self.dropout(hidden_state_1)
                    hidden_state_2, cell_state_2 = self.lstm_cell_2(hidden_state_1, (hidden_state_2, cell_state_2))
                    hidden_state_2 = self.dropout(hidden_state_2)
                    hidden_state_3, cell_state_3 = self.lstm_cell_3(hidden_state_2, (hidden_state_3, cell_state_3))
                    hidden_state_3 = self.dropout(hidden_state_3)
                out = self.fc_out(hidden_state_3)
                outputs[:, t, :] = out
                t += 1
        return outputs

# Define Full Captioning Model Class which has a encoder+decoder
class CaptionModel(nn.Module):
    def __init__(self, encoder, decoder, vocab) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.vocab = vocab
    
    def forward(self, images, captions, mode='train'):
        features = self.encoder(images)
        out = self.decoder(features, captions, mode)
        return out

In [6]:
# Initialize LSTM decoder
LSTM_decoder = DecoderRNN(embed_size=512, feature_size=768, hidden_size=512, vocab_size=len(vocab))

In [7]:
# Check number of parameters
pytorch_total_params = sum(p.numel() for p in LSTM_decoder.parameters() if p.requires_grad)
print(pytorch_total_params)

7988333


In [8]:
# Load trained encoder(s)
attribute_classes = [
    6, 5, 4, 3, 5, 3, 3, 3, 5, 8, 3, 3, #Shape Attributes
    8, 8, 8, #Fabric Attributes
    8, 8, 8 #Color Attributes
]

backbone = torchvision.models.swin_t()
backbone.head = torch.nn.Identity()
transformer_encoder = ClassifierModel(backbone, 768)
# We load the transformer attribute prediction model which had ~0.9 accuracy
transformer_encoder.load_state_dict(
    torch.load('./models/transformer_unfreeze_attribute_model.pth')['model_state_dict']
)

<All keys matched successfully>

In [9]:
# Drop Classifier Head and just keep feature extractor (backbone)
transformer_encoder = transformer_encoder.backbone
# Freeze params
for param in transformer_encoder.parameters():
    param.requires_grad = False
transformer_caption_model = CaptionModel(transformer_encoder, LSTM_decoder, vocab)
print(transformer_caption_model)

CaptionModel(
  (encoder): SwinTransformer(
    (features): Sequential(
      (0): Sequential(
        (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
        (1): Permute()
        (2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      )
      (1): Sequential(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): ShiftedWindowAttention(
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (proj): Linear(in_features=96, out_features=96, bias=True)
          )
          (stochastic_depth): StochasticDepth(p=0.0, mode=row)
          (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (0): Linear(in_features=96, out_features=384, bias=True)
            (1): GELU(approximate=none)
            (2): Dropout(p=0.0, inplace=False)
            (3): Linear(in_features=384, out_features=96, bias=True)
            (4): Dropout(p=0.0, inplac

In [None]:
# Training the LSTM model
from utils.train_funcs import fit

epochs = 5
learning_rate = 1e-3
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(transformer_caption_model.decoder.parameters(), lr=learning_rate)

fit(
    transformer_caption_model,
    train_loader,
    val_loader,
    vocab,
    optimizer,
    loss_func,
    epochs,
    device,
    name='rnn_decoder'
)

Epoch 1 train: 100%|█| 602/602 [31:49<00:00,  3.17s/batch, BLEU=0.2


Epoch 1 train loss: 36.58778645268772 train BLEU: 0.22891239821910858


Epoch 1 val:  22%|▏| 14/63 [00:30<01:43,  2.11s/batch, BLEU=0.199, 

## Results

Our Swin-LSTM captioning model achieves a BLEU4 of 0.3028 and 0.2020 on the train and validation sets respectively. We find that these are considerably high scores considering that we use the smallest version of Swin Transformer and a 8M parameter LSTM. We print out some of the predictions on samples from the validation set, and find that the model is able to produce high-quality captions from the image features. The model sometimes even produces sentences for fashion attributes which are not in the ground truth caption, but are actually present in the image.

We first display screenshots from tensorboard for BLEU4 scores of the model:

In [46]:
from IPython.display import Image

Image(url="./tensorboard_screens/LSTM_train.png", width=600, height=600)

In [47]:
Image(url="./tensorboard_screens/LSTM_val.png", width=600, height=600)

In [11]:
# Perform some example captioning
images, labels, captions = next(iter(val_loader))

In [26]:
import numpy as np
from utils.train_funcs import get_predictions,seq2text

images = images.to(device)
captions = captions.to(device)
start_token = vocab['sos']
captions_sos = torch.full((images.shape[0],1), fill_value=start_token).to(device, non_blocking=True)
transformer_caption_model.eval()
transformer_caption_model.to(device)
outputs = transformer_caption_model(images, captions_sos, 'test')
preds = get_predictions(outputs, shape=(images.shape[0],outputs.shape[1]), device=device)

hypothesis, reference = seq2text(preds[0], captions[0], vocab)
val_data = np.load('labels/validation_data.npy', allow_pickle=True)

print('Predicted caption: \n', ' '.join(hypothesis).split('eos')[0])
print('*'*30)
print('Actual caption: \n', ' '.join(reference))
Image(url="../data/images_224x329/"+val_data[0,0], width=300, height=300)

Predicted caption: 
 the tank top this woman wears has no sleeves and it is with cotton fabric and graphic patterns . the neckline of the tank top is round . this woman wears a three point shorts , with denim fabric and pure color patterns . there is an accessory on her wrist . 
******************************
Actual caption: 
 the tank shirt this female wears has sleeves cut off , its fabric is cotton , and it has solid color patterns . the tank shirt has a suspenders neckline . this female wears a three point shorts , with denim fabric and pure color patterns . this person has neckwear . there is an accessory on her wrist . eos


In [28]:
hypothesis, reference = seq2text(preds[10], captions[10], vocab)
print('Predicted caption: \n', ' '.join(hypothesis).split('eos')[0])
print('*'*30)
print('Actual caption: \n', ' '.join(reference))
Image(url="../data/images_224x329/"+val_data[10,0], width=300, height=300)

Predicted caption: 
 the t shirt this woman wears has short sleeves and it is with cotton fabric and pure color patterns . the neckline of the t shirt is v shape . 
******************************
Actual caption: 
 her t shirt has short sleeves , cotton fabric and solid color patterns . it has a v shape neckline . eos


In [39]:
hypothesis, reference = seq2text(preds[24], captions[24], vocab)
print('Predicted caption: \n', ' '.join(hypothesis).split('eos')[0])
print('*'*30)
print('Actual caption: \n', ' '.join(reference))
Image(url="../data/images_224x329/"+val_data[24,0], width=300, height=300)

Predicted caption: 
 the upper clothing has short sleeves , cotton fabric and pure color patterns . the neckline of it is v shape . the lower clothing is of long length . the fabric is denim and it has pure color patterns . 
******************************
Actual caption: 
 the guy wears a short sleeve shirt with pure color patterns . the shirt is with cotton fabric . it has a v shape neckline . the trousers the guy wears is of long length . the trousers are with denim fabric and pure color patterns . eos


### Training Model with ImageNet Weights Encoder

To show that our encoder performs well when trained to predict fashion attributes, we compare it with an encoder loaded with pretrained ImageNet-1K weights. All other training steps are identical:

In [42]:
# Load encoder with ImageNet weights
transformer_encoder_imgnet = torchvision.models.swin_t(weights='IMAGENET1K_V1')
transformer_encoder_imgnet.head = torch.nn.Identity()
# Initialize another LSTM decoder
LSTM_decoder_imgnet = DecoderRNN(embed_size=512, feature_size=768, hidden_size=512, vocab_size=len(vocab))
# Freeze params in encoder
for param in transformer_encoder_imgnet.parameters():
    param.requires_grad = False
transformer_imgnet_model = CaptionModel(transformer_encoder_imgnet, LSTM_decoder_imgnet, vocab)

In [None]:
# Train the model with ImageNet Swin-T backbone
epochs = 5
learning_rate = 1e-3
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(transformer_imgnet_model.decoder.parameters(), lr=learning_rate)

fit(
    transformer_imgnet_model,
    train_loader,
    val_loader,
    vocab,
    optimizer,
    loss_func,
    epochs,
    device,
    name='rnn_decoder_imgnet_encoder'
)

Epoch 1 train:  94%|▉| 564/602 [29:50<02:00,  3.16s/batch, BLEU=0.2

In [48]:
# Performing captioning using the model with ResNet weights in the backbone
transformer_imgnet_model.eval()
transformer_imgnet_model.to(device)
outputs = transformer_imgnet_model(images, captions_sos, 'test')
preds = get_predictions(outputs, shape=(images.shape[0],outputs.shape[1]), device=device)

In [49]:
hypothesis, reference = seq2text(preds[0], captions[0], vocab)

print('Predicted caption: \n', ' '.join(hypothesis).split('eos')[0])
print('*'*30)
print('Actual caption: \n', ' '.join(reference))
Image(url="../data/images_224x329/"+val_data[0,0], width=300, height=300)

Predicted caption: 
 the tank shirt this female wears has no sleeves and its fabric is cotton . the pattern of it is graphic . it has a suspenders neckline . this woman wears a three point pants , with denim fabric and solid color patterns . there is an accessory on her wrist . there is an accessory in his her neck . this woman is wearing a ring on her finger . 
******************************
Actual caption: 
 the tank shirt this female wears has sleeves cut off , its fabric is cotton , and it has solid color patterns . the tank shirt has a suspenders neckline . this female wears a three point shorts , with denim fabric and pure color patterns . this person has neckwear . there is an accessory on her wrist . eos


In [50]:
hypothesis, reference = seq2text(preds[10], captions[10], vocab)
print('Predicted caption: \n', ' '.join(hypothesis).split('eos')[0])
print('*'*30)
print('Actual caption: \n', ' '.join(reference))
Image(url="../data/images_224x329/"+val_data[10,0], width=300, height=300)

Predicted caption: 
 the upper clothing has short sleeves , cotton fabric and solid color patterns . it has a crew neckline . 
******************************
Actual caption: 
 her t shirt has short sleeves , cotton fabric and solid color patterns . it has a v shape neckline . eos


In [51]:
hypothesis, reference = seq2text(preds[24], captions[24], vocab)
print('Predicted caption: \n', ' '.join(hypothesis).split('eos')[0])
print('*'*30)
print('Actual caption: \n', ' '.join(reference))
Image(url="../data/images_224x329/"+val_data[24,0], width=300, height=300)

Predicted caption: 
 the upper clothing has short sleeves , cotton fabric and pure color patterns . it has a round neckline . the lower clothing is of long length . the fabric is cotton and it has pure color patterns . 
******************************
Actual caption: 
 the guy wears a short sleeve shirt with pure color patterns . the shirt is with cotton fabric . it has a v shape neckline . the trousers the guy wears is of long length . the trousers are with denim fabric and pure color patterns . eos


The model with the resnet-weights encoder achieves a BLEU-4 of 0.18 on the validation set. It is evident that the quality of the captions is not as good as the model with the encoder trained to predict attributes as seen from the above examples.