# Attribute Prediction Model Training

This notebook will contain the code to initialize and train the attribute prediction models, which will be used as a backbone for the captioning. We train two models for comparison - a Swin Transformer (swin_t our proposed model) and a ResNet-34

In [1]:
import os
import numpy as np
import cv2
import torch, torchvision
import copy
from torchsummary import summary
from utils.train_funcs import fit_classifier

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

print(device)

cuda


In [2]:
#labelled_data = np.load('./labels/full_data.npy', allow_pickle=True)
train_data = np.load('./train_data.npy', allow_pickle=True)
val_data = np.load('./validation_data.npy', allow_pickle=True)
print(train_data.shape)
print(val_data.shape)

(38519, 19)
(4024, 19)


In [3]:
# Loading Data
import pandas as pd
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from utils.customDataset import FashionDataset

train_dataset = FashionDataset(train_data, '../data/images_224x329', mode='train')
val_dataset = FashionDataset(val_data, '../data/images_224x329', mode='val')

# Create DataLoader instances
train_loader = DataLoader(dataset = train_dataset, batch_size = 64, shuffle = True, num_workers=8, pin_memory=True)
val_loader = DataLoader(dataset = val_dataset, batch_size = 64, shuffle = False)
images, labels = next(iter(train_loader))
print(images.shape)
print(labels.shape)

torch.Size([64, 3, 329, 224])
torch.Size([64, 18])


In [4]:
# Initialize swin transformer backbone with ImageNet weights
backbone = torchvision.models.swin_t(weights='IMAGENET1K_V1')
# Remove classifier head
backbone.head = torch.nn.Identity()

In [5]:
# Defining the classifier (Swin Transfomer Model)
attribute_classes = [
    6, 5, 4, 3, 5, 3, 3, 3, 5, 8, 3, 3, #Shape Attributes
    8, 8, 8, #Fabric Attributes
    8, 8, 8 #Color Attributes
]

# Classifier with 18 forks (For each of the 18 attribute categories)
class AttributeClassifier(torch.nn.Module):
    def __init__(self, in_features) -> None:
        super().__init__()
        self.forks = torch.nn.ModuleList()
        for class_count in attribute_classes:
            fork = torch.nn.Linear(in_features=in_features, out_features=class_count)
            self.forks.append(fork)
    
    def forward(self, x):
        out = []
        for index,fork in enumerate(self.forks):
            out_fork = fork(x) #Classification
            out.append(out_fork)
        return out

# Model definition
class ClassifierModel(torch.nn.Module):
    def __init__(self, backbone, backbone_out_features) -> None:
        super().__init__()
        self.backbone = backbone
        self.classifier = AttributeClassifier(backbone_out_features)
    
    def forward(self, x):
        out = self.backbone(x)
        out = self.classifier(out)
        return out

transformer_attribute_model = ClassifierModel(backbone, 768)
transformer_attribute_model.to(device)

#Freeze all params in backbone
for param in transformer_attribute_model.backbone.parameters():
    param.requires_grad = False

summary(transformer_attribute_model, (3, 329, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 82, 56]           4,704
           Permute-2           [-1, 82, 56, 96]               0
         LayerNorm-3           [-1, 82, 56, 96]             192
         LayerNorm-4           [-1, 82, 56, 96]             192
ShiftedWindowAttention-5           [-1, 82, 56, 96]               0
   StochasticDepth-6           [-1, 82, 56, 96]               0
         LayerNorm-7           [-1, 82, 56, 96]             192
            Linear-8          [-1, 82, 56, 384]          37,248
              GELU-9          [-1, 82, 56, 384]               0
          Dropout-10          [-1, 82, 56, 384]               0
           Linear-11           [-1, 82, 56, 96]          36,960
          Dropout-12           [-1, 82, 56, 96]               0
  StochasticDepth-13           [-1, 82, 56, 96]               0
SwinTransformerBlock-14           [

In [None]:
# Training transformer model (with all layers frozen)

epochs = 5
learning_rate = 1e-3
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(transformer_attribute_model.classifier.parameters(), lr=learning_rate)

train_loss_history, train_acc_history, val_loss_history, val_acc_history = fit_classifier(
    transformer_attribute_model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    attributes=attribute_classes,
    optimizer=optimizer, 
    loss_func=loss_func, 
    epochs=epochs, 
    device=device,
    name='transformer_frozen'
)


Epoch 1 train: 100%|█| 602/602 [09:17<00:00,  1.08batch/s, accuracy=0.8, loss=9.


Epoch 1 train loss: 10.17539510071075 train accuracy: 0.7912400960922241


Epoch 1 val: 100%|█| 63/63 [01:06<00:00,  1.05s/batch, accuracy=0.655, loss=17.5


Epoch 1 val loss: 10.5379984497313 val accuracy: 0.7836039066314697
--------------------


Epoch 2 train:  27%|▎| 165/602 [02:34<06:45,  1.08batch/s, accuracy=0.806, loss=

In [6]:
# Defining the baseline model (ResNet-34 CNN)
resnet_backbone = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)

# Remove Classifier Head
resnet_backbone.fc = torch.nn.Identity()

# Initialize Baseline Classifier and freeze weights
resnet_attribute_model = ClassifierModel(resnet_backbone, 512)

for param in resnet_attribute_model.backbone.parameters():
    param.requires_grad = False
resnet_attribute_model.to(device)

summary(resnet_attribute_model, (3, 329, 224))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 165, 112]           9,408
       BatchNorm2d-2         [-1, 64, 165, 112]             128
              ReLU-3         [-1, 64, 165, 112]               0
         MaxPool2d-4           [-1, 64, 83, 56]               0
            Conv2d-5           [-1, 64, 83, 56]          36,864
       BatchNorm2d-6           [-1, 64, 83, 56]             128
              ReLU-7           [-1, 64, 83, 56]               0
            Conv2d-8           [-1, 64, 83, 56]          36,864
       BatchNorm2d-9           [-1, 64, 83, 56]             128
             ReLU-10           [-1, 64, 83, 56]               0
       BasicBlock-11           [-1, 64, 83, 56]               0
           Conv2d-12           [-1, 64, 83, 56]          36,864
      BatchNorm2d-13           [-1, 64, 83, 56]             128
             ReLU-14           [-1, 64,

In [None]:
# Training baseline model (with all layers frozen)

epochs = 5
learning_rate = 1e-3
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(resnet_attribute_model.classifier.parameters(), lr=learning_rate)

train_loss_history_baseline, train_acc_history_baseline, val_loss_history_baseline, val_acc_history_baseline = fit_classifier(
    resnet_attribute_model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    attributes=attribute_classes,
    optimizer=optimizer, 
    loss_func=loss_func, 
    epochs=epochs, 
    device=device,
    name='resnet_frozen'
)

In [11]:
#load the previous best saved models
transformer_attribute_model.load_state_dict(
    torch.load('./models/transformer_frozen_attribute_model.pth')['model_state_dict']
)

resnet_attribute_model.load_state_dict(
    torch.load('./models/resnet_frozen_attribute_model.pth')['model_state_dict']
)

<All keys matched successfully>

In [12]:
# Fine tune transformer model

#Freeze all parameters
for param in transformer_attribute_model.parameters():
    param.requires_grad = False
    
#Unfreeze classifier and some swin transformer blocks
for param in transformer_attribute_model.backbone.features[6:].parameters():
    param.requires_grad = True

for param in transformer_attribute_model.backbone.norm.parameters():
    param.requires_grad = True
    
for param in transformer_attribute_model.backbone.avgpool.parameters():
    param.requires_grad = True

for param in transformer_attribute_model.classifier.parameters():
    param.requires_grad = True

summary(transformer_attribute_model, (3, 329, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 82, 56]           4,704
           Permute-2           [-1, 82, 56, 96]               0
         LayerNorm-3           [-1, 82, 56, 96]             192
         LayerNorm-4           [-1, 82, 56, 96]             192
ShiftedWindowAttention-5           [-1, 82, 56, 96]               0
   StochasticDepth-6           [-1, 82, 56, 96]               0
         LayerNorm-7           [-1, 82, 56, 96]             192
            Linear-8          [-1, 82, 56, 384]          37,248
              GELU-9          [-1, 82, 56, 384]               0
          Dropout-10          [-1, 82, 56, 384]               0
           Linear-11           [-1, 82, 56, 96]          36,960
          Dropout-12           [-1, 82, 56, 96]               0
  StochasticDepth-13           [-1, 82, 56, 96]               0
SwinTransformerBlock-14           [

In [None]:
epochs = 5
learning_rate = 1e-4
loss_func = torch.nn.CrossEntropyLoss()
custom_params_list = []

for param in transformer_attribute_model.parameters():
    if param.requires_grad:
        custom_params_list.append(param)
        
optimizer = torch.optim.AdamW(custom_params_list, lr=learning_rate)

train_loss_history, train_acc_history, val_loss_history, val_acc_history = fit_classifier(
    transformer_attribute_model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    attributes=attribute_classes,
    optimizer=optimizer, 
    loss_func=loss_func, 
    epochs=epochs, 
    initial_epoch = 5,
    device=device,
    name='transformer_semi_frozen'
)


Epoch 6 train: 100%|█| 602/602 [11:05<00:00,  1.10s/batch, accuracy=0.849, loss=


Epoch 6 train loss: 6.91100790905505 train accuracy: 0.8566133379936218


Epoch 6 val: 100%|█| 63/63 [01:05<00:00,  1.05s/batch, accuracy=0.736, loss=12.4


Epoch 6 val loss: 7.409150043965334 val accuracy: 0.8483681082725525
--------------------


Epoch 7 train: 100%|█| 602/602 [11:05<00:00,  1.10s/batch, accuracy=0.879, loss=


Epoch 7 train loss: 5.993897678330976 train accuracy: 0.8767880797386169


Epoch 7 val:  35%|▎| 22/63 [00:23<00:43,  1.05s/batch, accuracy=0.837, loss=8.28

In [None]:
# Fine tune resnet model by unfreezing layer4 and classifier

# Freeze all parameters
for param in resnet_attribute_model.parameters():
    param.requires_grad = False
    
# Unfreeze classifier and layer4
for name, param in resnet_attribute_model.named_parameters():
    if name.find("classifier") != -1 or name.find("layer4") != -1:
        param.requires_grad = True
    
summary(resnet_attribute_model, (3, 329, 244))

In [None]:
epochs = 5
learning_rate = 1e-4
loss_func = torch.nn.CrossEntropyLoss()
custom_params_list = []

for param in resnet_attribute_model.parameters():
    if param.requires_grad:
        custom_params_list.append(param)

optimizer = torch.optim.AdamW(custom_params_list, lr=learning_rate)

train_loss_history_baseline, train_acc_history_baseline, val_loss_history_baseline, val_acc_history_baseline = fit_classifier(
    resnet_attribute_model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    attributes=attribute_classes,
    optimizer=optimizer, 
    loss_func=loss_func, 
    epochs=epochs, 
    initial_epoch = 5,
    device=device,
    name='resnet_semi_frozen'
)

In [None]:
# Unfreeze all layers of transformer attribute model

for param in transformer_attribute_model.parameters():
    param.requires_grad = True

In [None]:
# Fine tune transformer model on all parameters

epochs = 5
learning_rate = 1e-4
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(transformer_attribute_model.parameters(), lr=learning_rate)

train_loss_history, train_acc_history, val_loss_history, val_acc_history = fit_classifier(
    transformer_attribute_model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    attributes=attribute_classes,
    optimizer=optimizer, 
    loss_func=loss_func, 
    epochs=epochs,
    initial_epoch = 10,
    device=device,
    name='transformer_unfreeze'
)

In [None]:
# Unfreeze all layers of resnet attribute model

for param in resnet_attribute_model.parameters():
    param.requires_grad = True
    
#summary(resnet_attribute_model, (3, 329, 244))

In [None]:
# Fine tune resnet model on all parameters

epochs = 5
learning_rate = 1e-4
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(resnet_attribute_model.parameters(), lr=learning_rate)

train_loss_history_baseline, train_acc_history_baseline, val_loss_history_baseline, val_acc_history_baseline = fit_classifier(
    resnet_attribute_model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    attributes=attribute_classes,
    optimizer=optimizer, 
    loss_func=loss_func, 
    epochs=epochs,
    initial_epoch = 10,
    device=device,
    name='resnet_unfreeze'
)