Skip to content

Commit

Permalink
Add multi-task classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Apr 12, 2023
1 parent d736a94 commit 9d1d9da
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions atomai/nets/reg_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,46 @@ def forward(self, x: torch.Tensor):
return x


class MultiTaskClassifierNet(nn.Module):
"""
Multi-task classifier with a custom backbone.
Args:
input_channels (int): The number of input channels.
num_classes (List[int]): A list containing the number of output classes for each task.
backbone_type (str, optional): The type of backbone architecture. Choose from "resnet", "vgg", or "mobilenet". Default is "resnet".
"""
def __init__(self, input_channels: int, num_tasks: int, num_classes: list[int], backbone_type: str = "resnet"):
super(MultiTaskClassifierNet, self).__init__()

# Create the backbone with adaptive pooling
self.backbone = CustomBackbone(input_channels, backbone_type)
# Create the output layers for each task
self.output_layers = nn.ModuleList([
nn.Sequential(
nn.Linear(self.backbone.in_features, n_classes),
nn.LogSoftmax(dim=1)
) for n_classes in num_classes
])
# Flatten layer
self.flatten = nn.Flatten()

def forward(self, x: torch.Tensor):
"""
Forward pass of the MultiTaskClassifierNet.
Args:
x (torch.Tensor): Input tensor with shape (batch_size, input_channels, height, width).
Returns:
List[torch.Tensor]: List of output tensors for each task with shape (batch_size, num_classes[i]).
"""
x = self.backbone(x)
x = self.flatten(x)
outputs = [output_layer(x) for output_layer in self.output_layers]
return outputs


def init_reg_model(out_dim, backbone_type, input_channels=1, **kwargs):
"""Initializes a regression model with a specified backbone type"""
net = RegressorNet(input_channels, out_dim, backbone_type)
Expand All @@ -107,3 +147,15 @@ def init_cls_model(num_classes, backbone_type, input_channels=1, **kwargs):
"nb_classes": num_classes
}
return net, meta_state_dict


def init_mtask_cls_model(num_classes, backbone_type, input_channels=1, **kwargs):
"""Initializes a regression model with a specified backbone type"""
net = ClassifierNet(input_channels, num_classes, backbone_type)
meta_state_dict = {
"model_type": "cls",
"backbone": backbone_type,
"in_channels": input_channels,
"nb_classes": num_classes # num_classes is a list with integers
}
return net, meta_state_dict

0 comments on commit 9d1d9da

Please sign in to comment.