<a href="https://colab.research.google.com/github/veerendra12/CS598-DL4H-Project/blob/main/notebooks/ModelFactory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import import_ipynb

from Utils import get_device, load_checkpoint
from Configuration import CONFIG

import torch.nn as nn
from torchvision import models

import torch
import torch.nn as nn

import torchvision
from torchvision import models
import torchvision.transforms.functional as TF
import torch.nn.functional as F

In [None]:
def get_densenet121_model():
    DEVICE = get_device()
    num_of_classes = CONFIG['NUM_CLASSES']
    model = models.densenet121(pretrained=True)
    model.classifier = nn.Sequential(nn.Linear(1024, num_of_classes), nn.Sigmoid())
    model.to(DEVICE)
    return model

In [None]:
def get_fem1_model():
    fem1_model = get_densenet121_model()
    load_checkpoint(fem1_model, CONFIG['FEM1_BEST_MODEL'])
    fem1_model.eval()

    fem1_model.classifier = nn.Identity(1024)

    for param in fem1_model.parameters():
      param.requires_grad = False     

    return fem1_model

In [None]:
def get_fem2_model():
    fem2_model = get_densenet121_model()
    load_checkpoint(fem2_model, CONFIG['FEM2_BEST_MODEL'])
    fem2_model.eval()

    fem2_model.classifier = nn.Identity(1024)

    for param in fem2_model.parameters():
      param.requires_grad = False     

    return fem2_model    

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

In [None]:
class SDFN(nn.Module):
  def __init__(self):
    super(SDFN, self).__init__()
    self.fem1_model = get_fem1_model()
    self.fem2_model = get_fem2_model()

    self.fc = nn.Linear(2 * 1024, 14)

  def forward(self, x):
    x1 = self.fem1_model(x)
    x2 = self.fem2_model(x)
    xnew = torch.cat((x1, x2), dim=1)
    return F.softmax(self.fc(xnew))

In [None]:
def get_sdfn_model():
  DEVICE = get_device()
  sdfn_model = SDFN()
  sdfn_model.to(DEVICE)
  return sdfn_model

In [None]:
def load_segmentation_model():
  DEVICE = get_device()
  model = UNET(in_channels=3, out_channels=1).to(DEVICE)
  load_checkpoint(model, CONFIG['SEGMENTATION_BEST_MODEL'])
  model.eval()
  return model