In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from resnet_utils import BasicBlock, Bottleneck, _resnet

In [None]:
"""def resnet18(pretrained=True, progress=True, **kwargs):
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)


def resnet34(pretrained=True, progress=True, **kwargs):
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)


def resnet50(pretrained=True, progress=True, **kwargs):
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)


def resnet101(pretrained=True, progress=True, **kwargs):
    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
                   **kwargs)


def resnet152(pretrained=True, progress=True, **kwargs):
    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
                   **kwargs)


def resnet_feature_model():
    model = resnet18()
    model = nn.Sequential(*list(model.classifier.children())[:-3])
    return model"""

In [None]:
# creating the model architecture for the VGG model
class VGGish(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer_norm = nn.LayerNorm([3, 242, 224])
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3,out_channels=16, kernel_size=(7,7), stride=(1,1), padding=3),
                                   nn.BatchNorm2d(16),
                                   nn.Tanh(),
                                   nn.MaxPool2d(2))
        self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(5,5), stride=(1,1), padding=2),
                                   nn.BatchNorm2d(32),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2))
        self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3,3), stride=(1,1), padding=1),
                                   nn.BatchNorm2d(32),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2))
        self.conv4 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3), stride=(1,1), padding=1),
                                   nn.BatchNorm2d(64),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2))
        self.conv5 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), stride=(1,1), padding=1),
                                   nn.BatchNorm2d(64),
                                   nn.ReLU())

        self.classifier = nn.Sequential(nn.Flatten(),
                                        nn.Dropout(p=0.65),
                                        nn.Linear(64*15*14, 64),
                                        nn.ReLU(),
                                        nn.Linear(64, 7))


    def forward(self, input_data):
        x = self.layer_norm(input_data)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        predictions = self.classifier(x)
        #predictions = F.softmax(logits,dim=1)
        return predictions

In [None]:
def select_model(model_name, pretrained):

    if (model_name == "VGGish"):
        print("VGG !!!")
        model = VGGish()

    elif (model_name == "resnet18"):

        print("RESNET18 !!!")
        #model = resnet18()
        model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=pretrained)
        model.fc = nn.Linear(2048, 7)

    elif (model_name == "resnet50"):

        print("RESNET50 !!!")
        #model = resnet50()
        model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=pretrained)
        model.fc = nn.Linear(2048, 7)

    elif (model_name == "resnet101"):

        print("RESNET101 !!!")
        #model = resnet101()
        model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=pretrained)
        model.fc = nn.Linear(2048, 7)

    else:
        print("Not a valid model")
        print("Default VGG model used")
        model = VGGish()
    print(model)


    return model