In [1]:
import os
import cv2
import torch, torchvision
import copy
from torchsummary import summary

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [2]:
# Loading Data
import pandas as pd
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from utils.customDataset import FashionDataset

shape_file = './labels/shape/shape_anno_all.txt'
fabric_file = './labels/texture/fabric_ann.txt'
pattern_file = './labels/texture/pattern_ann.txt'

dataset = FashionDataset(shape_file, fabric_file, pattern_file, '../data/images_224x329', mode='train')

# Split into train, val sets
train, val = torch.utils.data.random_split(dataset, [40543, 2000])
train_loader = DataLoader(dataset = train, batch_size = 64, shuffle = True, num_workers=8, pin_memory=True)
val_loader = DataLoader(dataset = val, batch_size = 64, shuffle = False)
image, label = next(iter(train_loader))
print(image.shape)
print(label.shape)

torch.Size([64, 3, 329, 224])
torch.Size([64, 18])


In [3]:
# Initialize swin transformer backbone with ImageNet weights
backbone = torchvision.models.swin_t(weights='IMAGENET1K_V1')
# Remove classifier head
backbone.head = torch.nn.Identity()

In [4]:
attribute_classes = [
    6, 5, 4, 3, 5, 3, 3, 3, 5, 8, 3, 3, #Shape Attributes
    8, 8, 8, #Fabric Attributes
    8, 8, 8 #Color Attributes
]

# Classifier with 18 forks (For each of the 18 attribute categories)
class AttributeClassifier(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.forks = torch.nn.ModuleList()
        for class_count in attribute_classes:
            fork = torch.nn.Linear(in_features=768, out_features=class_count)
            self.forks.append(fork)
    
    def forward(self, x):
        out = []
        for index,fork in enumerate(self.forks):
            out_fork = fork(x) #Classification
            out.append(out_fork)
        return out

# Model definition
class ClassifierModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.backbone = backbone
        self.classifier = AttributeClassifier()
    
    def forward(self, x):
        out = self.backbone(x)
        out = self.classifier(out)
        return out

model = ClassifierModel()
model.to(device)
summary(model, (3, 329, 224))

# Freeze weights
for param in model.parameters():
    param.requires_grad = False

# Unfreeze classifier weights
for param in model.classifier.parameters():
    param.requires_grad = True

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 82, 56]           4,704
           Permute-2           [-1, 82, 56, 96]               0
         LayerNorm-3           [-1, 82, 56, 96]             192
         LayerNorm-4           [-1, 82, 56, 96]             192
ShiftedWindowAttention-5           [-1, 82, 56, 96]               0
   StochasticDepth-6           [-1, 82, 56, 96]               0
         LayerNorm-7           [-1, 82, 56, 96]             192
            Linear-8          [-1, 82, 56, 384]          37,248
              GELU-9          [-1, 82, 56, 384]               0
          Dropout-10          [-1, 82, 56, 384]               0
           Linear-11           [-1, 82, 56, 96]          36,960
          Dropout-12           [-1, 82, 56, 96]               0
  StochasticDepth-13           [-1, 82, 56, 96]               0
SwinTransformerBlock-14           [

In [7]:
# Training the model
from utils.train_funcs import fit_classifier

epochs = 15
learning_rate = 1e-4
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=learning_rate)

fit_classifier(
    model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    attributes=attribute_classes,
    optimizer=optimizer, 
    loss_func=loss_func, 
    epochs=epochs, 
    device=device
)


Epoch 1 train:   7%|██████▌                                                                                             | 42/634 [00:40<09:33,  1.03batch/s, accuracy=14, loss=11.6]


KeyboardInterrupt: 