In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg16

In [None]:
def create_vgg_base():
    base_model = vgg16(pretrained=True)
    base_layers = list(base_model.features)[:30]  # Up to conv5_3 layer

    # Modify the architecture
    base_layers[16] = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)  # FC6 with atrous
    base_layers[23] = nn.Conv2d(1024, 1024, kernel_size=1)  # FC7 as conv

    # Replace FC6 and FC7 in VGG16 with convolutions
    base_layers.append(nn.MaxPool2d(kernel_size=3, stride=1, padding=1))  # Add modified pool5
    base_layers = nn.ModuleList(base_layers)
    
    return base_layers

In [None]:
def add_extras():
    # Configuration for extra layers
    # (256, 'S') means 256 channels and stride 2. 'S' indicates a stride of 2.
    extras_cfg = [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256]
    layers = []
    in_channels = 1024  # Output channels from the last layer of VGG base
    
    for k, v in enumerate(extras_cfg):
        if v == 'S':
            layers += [nn.Conv2d(in_channels, extras_cfg[k+1], kernel_size=(1, 3)[k % 2], stride=2, padding=1)]
        else:
            layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[k % 2])]
        in_channels = v

    return nn.ModuleList(layers)

In [None]:
def multibox(vgg_base, extra_layers, num_classes):
    loc_layers = []
    conf_layers = []
    vgg_source = [21, -2]  # conv4_3 and conv7 (FC7) layer indices in the base

    for k, v in enumerate(vgg_source):
        loc_layers += [nn.Conv2d(vgg_base[v].out_channels, 4 * 4, kernel_size=3, padding=1)]
        conf_layers += [nn.Conv2d(vgg_base[v].out_channels, 4 * num_classes, kernel_size=3, padding=1)]

    for k, v in enumerate(extra_layers[1::2], 2):  # From conv8_2 onwards, at every alternate layer
        loc_layers += [nn.Conv2d(v.out_channels, 4 * 4, kernel_size=3, padding=1)]
        conf_layers += [nn.Conv2d(v.out_channels, 4 * num_classes, kernel_size=3, padding=1)]

    return nn.ModuleList(loc_layers), nn.ModuleList(conf_layers)

In [None]:
class SSD(nn.Module):
    def __init__(self, num_classes):
        super(SSD, self).__init__()
        self.num_classes = num_classes
        self.vgg_base = create_vgg_base()
        self.extras = add_extras()
        self.loc, self.conf = multibox(self.vgg_base, self.extras, num_classes)

        self.init_weights()

    def forward(self, x):
        sources = []  # To hold selected feature maps
        loc = []  # To hold location predictions
        conf = []  # To hold confidence predictions
        
        # Pass through the VGG base layers
        for i in range(23):  # Up to conv4_3
            x = self.vgg_base[i](x)
        sources.append(x)  # Add conv4_3 feature map
        
        for i in range(23, len(self.vgg_base)):  # From conv4_3 to end of vgg_base
            x = self.vgg_base[i](x)
        sources.append(x)  # Add final feature map from VGG base (conv7)
        
        # Pass through extra layers and add selected feature maps
        for i, layer in enumerate(self.extras):
            x = F.relu(layer(x), inplace=True)
            if i % 2 == 1:  # Add feature maps from layers 1, 3, 5, ...
                sources.append(x)
        
        # Apply multibox loc and conf to each source
        for (x, l, c) in zip(sources, self.loc, self.conf):
            loc.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf.append(c(x).permute(0, 2, 3, 1).contiguous())
        
        # Concatenate all predictions
        loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
        conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
        
        return loc.view(loc.size(0), -1, 4), conf.view(conf.size(0), -1, self.num_classes)


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
