# Attribute Prediction Model Training

This notebook will contain the code to initialize and train the attribute prediction models, which will be used as a backbone for the captioning. We train two models for comparison - a Swin Transformer (swin_t, our proposed model) and a ResNet-34

In [1]:
import os
import numpy as np
import cv2
import torch, torchvision
import copy
from torchsummary import summary
from utils.train_funcs import fit_classifier


device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

print(device)

cuda


In [None]:
# Loading Data
import tensorflow as tf
from utils.load_funcs import get_data_loaders

# Create DataLoader instances
train_loader, val_loader = get_data_loaders()
images, labels, _ = next(iter(train_loader))
print(images.shape)
print(labels.shape)

In [2]:
# Initialize swin transformer backbone with ImageNet weights
backbone = torchvision.models.swin_t(weights='IMAGENET1K_V1')
# Remove classifier head
backbone.head = torch.nn.Identity()

In [3]:
# Defining the classifier (Swin Transfomer Model)
attribute_classes = [
    6, 5, 4, 3, 5, 3, 3, 3, 5, 8, 3, 3, #Shape Attributes
    8, 8, 8, #Fabric Attributes
    8, 8, 8 #Color Attributes
]

# Classifier with 18 forks (For each of the 18 attribute categories)
class AttributeClassifier(torch.nn.Module):
    def __init__(self, in_features) -> None:
        super().__init__()
        self.forks = torch.nn.ModuleList()
        for class_count in attribute_classes:
            fork = torch.nn.Linear(in_features=in_features, out_features=class_count)
            self.forks.append(fork)
    
    def forward(self, x):
        out = []
        for index,fork in enumerate(self.forks):
            out_fork = fork(x) #Classification
            out.append(out_fork)
        return out

# Model definition
class ClassifierModel(torch.nn.Module):
    def __init__(self, backbone, backbone_out_features) -> None:
        super().__init__()
        self.backbone = backbone
        self.classifier = AttributeClassifier(backbone_out_features)
    
    def forward(self, x):
        out = self.backbone(x)
        out = self.classifier(out)
        return out

transformer_attribute_model = ClassifierModel(backbone, 768)
transformer_attribute_model.to(device)

#Freeze all params in backbone
for param in transformer_attribute_model.backbone.parameters():
    param.requires_grad = False

summary(transformer_attribute_model, (3, 329, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 82, 56]           4,704
           Permute-2           [-1, 82, 56, 96]               0
         LayerNorm-3           [-1, 82, 56, 96]             192
         LayerNorm-4           [-1, 82, 56, 96]             192
ShiftedWindowAttention-5           [-1, 82, 56, 96]               0
   StochasticDepth-6           [-1, 82, 56, 96]               0
         LayerNorm-7           [-1, 82, 56, 96]             192
            Linear-8          [-1, 82, 56, 384]          37,248
              GELU-9          [-1, 82, 56, 384]               0
          Dropout-10          [-1, 82, 56, 384]               0
           Linear-11           [-1, 82, 56, 96]          36,960
          Dropout-12           [-1, 82, 56, 96]               0
  StochasticDepth-13           [-1, 82, 56, 96]               0
SwinTransformerBlock-14           [

In [None]:
# Training transformer model (with all layers frozen)

epochs = 5
learning_rate = 1e-3
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(transformer_attribute_model.classifier.parameters(), lr=learning_rate)

train_loss_history, train_acc_history, val_loss_history, val_acc_history = fit_classifier(
    transformer_attribute_model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    attributes=attribute_classes,
    optimizer=optimizer, 
    loss_func=loss_func, 
    epochs=epochs, 
    device=device,
    name='transformer_frozen'
)


In [None]:
# Defining the baseline model (ResNet-34 CNN)
resnet_backbone = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)

# Remove Classifier Head
resnet_backbone.fc = torch.nn.Identity()

# Initialize Baseline Classifier and freeze weights
resnet_attribute_model = ClassifierModel(resnet_backbone, 512)

for param in resnet_attribute_model.backbone.parameters():
    param.requires_grad = False
resnet_attribute_model.to(device)

summary(resnet_attribute_model, (3, 329, 224))


In [None]:
# Training baseline model (with all layers frozen)

epochs = 5
learning_rate = 1e-3
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(resnet_attribute_model.classifier.parameters(), lr=learning_rate)

train_loss_history_baseline, train_acc_history_baseline, val_loss_history_baseline, val_acc_history_baseline = fit_classifier(
    resnet_attribute_model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    attributes=attribute_classes,
    optimizer=optimizer, 
    loss_func=loss_func, 
    epochs=epochs, 
    device=device,
    name='resnet_frozen'
)

In [None]:
#load the previous best saved models
transformer_attribute_model.load_state_dict(
    torch.load('./models/transformer_frozen_attribute_model.pth')['model_state_dict']
)

resnet_attribute_model.load_state_dict(
    torch.load('./models/resnet_frozen_attribute_model.pth')['model_state_dict']
)

In [None]:
# Fine tune transformer model

#Freeze all parameters
for param in transformer_attribute_model.parameters():
    param.requires_grad = False
    
#Unfreeze classifier and some swin transformer blocks
for param in transformer_attribute_model.backbone.features[6:].parameters():
    param.requires_grad = True

for param in transformer_attribute_model.backbone.norm.parameters():
    param.requires_grad = True
    
for param in transformer_attribute_model.backbone.avgpool.parameters():
    param.requires_grad = True

for param in transformer_attribute_model.classifier.parameters():
    param.requires_grad = True

summary(transformer_attribute_model, (3, 329, 224))

In [None]:
epochs = 5
learning_rate = 1e-4
loss_func = torch.nn.CrossEntropyLoss()
custom_params_list = []

for param in transformer_attribute_model.parameters():
    if param.requires_grad:
        custom_params_list.append(param)
        
optimizer = torch.optim.AdamW(custom_params_list, lr=learning_rate)

train_loss_history, train_acc_history, val_loss_history, val_acc_history = fit_classifier(
    transformer_attribute_model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    attributes=attribute_classes,
    optimizer=optimizer, 
    loss_func=loss_func, 
    epochs=epochs, 
    initial_epoch = 5,
    device=device,
    name='transformer_semi_frozen'
)


In [None]:
# Fine tune resnet model by unfreezing layer4 and classifier

# Freeze all parameters
for param in resnet_attribute_model.parameters():
    param.requires_grad = False
    
# Unfreeze classifier and layer4
for name, param in resnet_attribute_model.named_parameters():
    if name.find("classifier") != -1 or name.find("layer4") != -1:
        param.requires_grad = True
    
summary(resnet_attribute_model, (3, 329, 244))

In [None]:
epochs = 5
learning_rate = 1e-4
loss_func = torch.nn.CrossEntropyLoss()
custom_params_list = []

for param in resnet_attribute_model.parameters():
    if param.requires_grad:
        custom_params_list.append(param)

optimizer = torch.optim.AdamW(custom_params_list, lr=learning_rate)

train_loss_history_baseline, train_acc_history_baseline, val_loss_history_baseline, val_acc_history_baseline = fit_classifier(
    resnet_attribute_model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    attributes=attribute_classes,
    optimizer=optimizer, 
    loss_func=loss_func, 
    epochs=epochs, 
    initial_epoch = 5,
    device=device,
    name='resnet_semi_frozen'
)

In [None]:
# Unfreeze all layers of transformer attribute model

for param in transformer_attribute_model.parameters():
    param.requires_grad = True

In [None]:
# Fine tune transformer model on all parameters

epochs = 5
learning_rate = 1e-4
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(transformer_attribute_model.parameters(), lr=learning_rate)

train_loss_history, train_acc_history, val_loss_history, val_acc_history = fit_classifier(
    transformer_attribute_model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    attributes=attribute_classes,
    optimizer=optimizer, 
    loss_func=loss_func, 
    epochs=epochs,
    initial_epoch = 10,
    device=device,
    name='transformer_unfreeze'
)

In [None]:
print(resnet_attribute_model)

In [None]:
# Unfreeze all layers of resnet attribute model

for param in resnet_attribute_model.parameters():
    param.requires_grad = True
    
#summary(resnet_attribute_model, (3, 329, 244))

In [None]:
# Fine tune resnet model on all parameters

epochs = 5
learning_rate = 1e-4
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(resnet_attribute_model.parameters(), lr=learning_rate)

train_loss_history_baseline, train_acc_history_baseline, val_loss_history_baseline, val_acc_history_baseline = fit_classifier(
    resnet_attribute_model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    attributes=attribute_classes,
    optimizer=optimizer, 
    loss_func=loss_func, 
    epochs=epochs,
    initial_epoch = 10,
    device=device,
    name='resnet_unfreeze'
)