In [2]:
!pip install torch



In [5]:
import torch
import torch.nn as nn

In [6]:
class SimpleMultiTaskCNN(nn.Module):
    def __init__(self, num_classes_list):
        super(SimpleMultiTaskCNN, self).__init__()

        # feature extraction block
        self.conv_block = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # fully connected layers
        self.fc_block = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 14 * 14, 512),
            nn.ReLU()
        )

        # classifier layers for each category level
        self.fc1 = nn.Linear(512, num_classes_list[0])
        self.fc2 = nn.Linear(512, num_classes_list[1])
        self.fc3 = nn.Linear(512, num_classes_list[2])
        self.fc4 = nn.Linear(512, num_classes_list[3])
        self.fc5 = nn.Linear(512, num_classes_list[4])
        self.fc6 = nn.Linear(512, num_classes_list[5])

    def forward(self, x):
        # Shared features
        x = self.conv_block(x)
        x = self.fc_block(x)

        # separate classifier layers for 
        output1 = self.fc1(x)
        output2 = self.fc2(x)
        output3 = self.fc3(x)
        output4 = self.fc4(x)
        output5 = self.fc5(x)
        output6 = self.fc6(x)

        return output1, output2, output3, output4, output5, output6

In [7]:
# added batch norm in feature extraction block
# added dropout in fc block

class BatchNormCNN(nn.Module):
    def __init__(self, num_classes_list):
        super(BatchNormCNN, self).__init__()

        self.conv_block = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc_block = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 14 * 14, 512),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

        self.fc1 = nn.Linear(512, num_classes_list[0])
        self.fc2 = nn.Linear(512, num_classes_list[1])
        self.fc3 = nn.Linear(512, num_classes_list[2])
        self.fc4 = nn.Linear(512, num_classes_list[3])
        self.fc5 = nn.Linear(512, num_classes_list[4])
        self.fc6 = nn.Linear(512, num_classes_list[5])

    def forward(self, x):
        x = self.conv_block(x)
        x = self.fc_block(x)
 
        output1 = self.fc1(x)
        output2 = self.fc2(x)
        output3 = self.fc3(x)
        output4 = self.fc4(x)
        output5 = self.fc5(x)
        output6 = self.fc6(x)

        return output1, output2, output3, output4, output5, output6
