<a href="https://colab.research.google.com/github/saqlainkazi690/Projects/blob/main/EfficientCapsNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


In [None]:
batch_size = 17

In [None]:
class ConvLayer(nn.Module):
  def __init__(self):
    super(ConvLayer, self).__init__() # input is MURA (batch, 3, 224, 224), MNIST (batch, 3, 28, 28)
    self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 7, stride = 2, padding = 0) # (batch, 32, 100, 109), (batch, 32, 24, 24)
    self.conv2 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 5, stride = 2, padding = 0) #(batch, 64, 52, 52), (batch, 64, 22, 22)
    self.conv3 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 5, stride = 2, padding = 0) # (batch, 64, 24, 24), (batch, 64, 20, 20)
    self.conv4 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 5, stride = 2, padding = 0) #(batch, 128, 10, 10), (batch, 128, 9, 9)

    self.bn1 = nn.BatchNorm2d(num_features = 32)
    self.bn2 = nn.BatchNorm2d(num_features = 64)
    self.bn3 = nn.BatchNorm2d(num_features = 64)
    self.bn4 = nn.BatchNorm2d(num_features = 128)

    self.reset_parameters()

  def forward(self, x):
    x = F.relu(self.bn1(self.conv1(x)))
    x = F.relu(self.bn2(self.conv2(x)))
    x = F.relu(self.bn3(self.conv3(x)))
    x = F.relu(self.bn4(self.conv4(x)))

    return x

  def reset_parameters(self):
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, nonlinearity="relu")



class DepthWiseConv(nn.Module):
  def __init__(self,capsule_size = (16, 8), kernel_size = 11):
    super(DepthWiseConv, self).__init__()
    self.num_capsules, self.dim_capsules = capsule_size
    self.dwc = nn.Conv2d(in_channels = 128, out_channels = self.num_capsules * self.dim_capsules, kernel_size = kernel_size, groups = 128 , padding = 0)
    # Output shape is (batch_size, 128, 1, 1)
  def forward(self, x):
    x = self.dwc(x)
    x = x.view(-1, self.num_capsules, self.dim_capsules) #output shape is (batch_size, 16, 8)
    x = self.squash(x)
    return x

  def squash(self, x):
    eps = 1e-20
    norm = torch.linalg.norm(x, ord = 2, dim = -1, keepdim = True)
    coef = 1 - 1 / (torch.exp(norm) + eps)
    unit = x / (norm + eps)
    return coef * unit

class FCCaps(nn.Module):
  def __init__(self, num_capsules = 10, dim_capsules = 16, kernel_iniitalizer = 'he_normal', input_shape = (batch_size, 16, 8)):
    super(FCCaps, self).__init__()
    self.num_capsules = num_capsules
    self.dim_capsules = dim_capsules
    input_num_cap = input_shape[-2]
    input_dim_cap = input_shape[-1]
    self.W = nn.Parameter(torch.randn(self.num_capsules, input_num_cap, input_dim_cap, self.dim_capsules)) #self.N, input_N, input_D, self.D
    nn.init.kaiming_normal_(self.W, mode = 'fan_in', nonlinearity = 'relu')

    self.b = nn.Parameter(torch.zeros(self.num_capsules, input_num_cap, 1))

    #W = torch.cat([self.W] * batch_size, dim=0)

  def forward(self, x):
    u = torch.einsum('...ji,kjiz->...kjz', x, self.W)

    # c shape=(None, N, H*W*input_N, 1)
    c = torch.einsum('...ij,...kj->...i', u, u).unsqueeze(-1)
    c = c / torch.sqrt(torch.tensor(self.dim_capsules, dtype=torch.float32))
    c = F.softmax(c, dim=1)  # c shape=(None, N, H*W*input_N, 1)
    c = c + self.b

    # s shape=(None, N, D)
    s = torch.sum(u * c, dim=-2) #(batch_size, self.num_capsules, self.dim_capsules)

    # v shape=(None, N, D)
    v = self.squash(s)

    return v

  def squash(self, x):
    eps = 1e-20
    norm = torch.linalg.norm(x, ord = 2, dim = -1, keepdim = True)
    coef = 1 - 1 / (torch.exp(norm) + eps)
    unit = x / (norm + eps)
    return coef * unit


class CapsMask(nn.Module):
    def __init__(self):
        super(CapsMask, self).__init__()

    def forward(self, x, y_true=None):
        if y_true is not None:  # training mode
            mask = y_true
        else:  # testing mode
            # convert list of maximum value's indices to one-hot tensor
            temp = torch.sqrt(torch.sum(x**2, dim=-1))
            mask = F.one_hot(torch.argmax(temp, dim=1), num_classes=temp.shape[1])

        masked = x * mask.unsqueeze(-1)
        return masked.view(x.shape[0], -1)  # reshape

class EfficientCapsNet(nn.Module):
  def __init__(self, input_size = (3, 224, 224)):
    super(EfficientCapsNet, self).__init__()
    self.convlayer = ConvLayer()
    self.primarycaps = DepthWiseConv()
    self.routingcaps = FCCaps()
    #self.len_final_caps = CapsLen()

  def forward(self, x):
    x = self.convlayer(x)
    x = self.primarycaps(x)
    x = self.routingcaps(x)
    return x#, self.len_final_caps(x)

class ReconstructionNet(nn.Module):
  def __init__(self, input_size = (3, 224, 224), num_classes = 2, num_capsules = 16):
    super(ReconstructionNet, self).__init__()
    self.input_size = input_size
    self.decoder = nn.Sequential(
            nn.Linear(16 * 10, 512 * 7 * 7),  # Fully connected layer to reshape to a small feature map
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (512, 7, 7)),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()  # To ensure the output pixel values are between 0 and 1
        )
    self.reset_parameters()

  def forward(self, x):
    x = self.decoder(x)
    return x

  def reset_parameters(self):
    for layer in self.decoder:
      if isinstance(layer, nn.Linear):
        nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
        if layer.bias is not None:
          nn.init.constant_(layer.bias, 0)
      elif isinstance(layer, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
        if layer.bias is not None:
          nn.init.constant_(layer.bias, 0)



class FinalCapsNet(nn.Module):
    def __init__(self):
        super(FinalCapsNet, self).__init__()
        self.efficient_capsnet = EfficientCapsNet()
        self.mask = CapsMask()
        self.generator = ReconstructionNet()

    def forward(self, x, y_true=None, mode='train'):
        x = self.efficient_capsnet(x)
        if mode == "train":
            masked = self.mask(x, y_true)
        elif mode == "eval":
            masked = self.mask(x)
        x = self.generator(masked)
        return x

In [None]:
class MarginLoss(nn.Module):
    def __init__(self, m_pos=0.9, m_neg=0.1, lambda_=0.5):
        super(MarginLoss, self).__init__()
        self.m_pos = m_pos
        self.m_neg = m_neg
        self.lambda_ = lambda_

    def forward(self, targets, digit_probs):
        assert targets.shape is not digit_probs.shape
        present_losses = (
            targets * torch.clamp_min(self.m_pos - digit_probs, min=0.0) ** 2
        )
        absent_losses = (1 - targets) * torch.clamp_min(
            digit_probs - self.m_neg, min=0.0
        ) ** 2
        losses = present_losses + self.lambda_ * absent_losses
        return torch.mean(torch.sum(losses, dim=1))


class ReconstructionLoss(nn.Module):
    def forward(self, reconstructions, input_images):
        return torch.nn.MSELoss(reduction="mean")(reconstructions, input_images)


class TotalLoss(nn.Module):
    def __init__(self, m_pos=0.9, m_neg=0.1, lambda_=0.5, recon_factor=0.0005):
        super(TotalLoss, self).__init__()
        self.margin_loss = MarginLoss(m_pos, m_neg, lambda_)
        self.recon_loss = ReconstructionLoss()
        self.recon_factor = recon_factor

    def forward(self, input_images, targets, reconstructions, digit_probs):
        margin = self.margin_loss(targets, digit_probs)
        recon = self.recon_loss(reconstructions, input_images)
        return margin + self.recon_factor * recon


In [None]:
ECN = FinalCapsNet()

In [None]:
!pip install torchInfo
from torchinfo import summary

Collecting torchInfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchInfo
Successfully installed torchInfo-1.8.0


In [None]:
summary(
    ECN,
    input_size = (17,3,224,224),
    verbose=0,
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"]
)

Layer (type (var_name))                       Input Shape          Output Shape         Param #              Trainable
FinalCapsNet (FinalCapsNet)                   [17, 3, 224, 224]    [17, 3, 112, 112]    --                   True
├─EfficientCapsNet (efficient_capsnet)        [17, 3, 224, 224]    [17, 10, 16]         --                   True
│    └─ConvLayer (convlayer)                  [17, 3, 224, 224]    [17, 128, 11, 11]    --                   True
│    │    └─Conv2d (conv1)                    [17, 3, 224, 224]    [17, 32, 109, 109]   4,736                True
│    │    └─BatchNorm2d (bn1)                 [17, 32, 109, 109]   [17, 32, 109, 109]   64                   True
│    │    └─Conv2d (conv2)                    [17, 32, 109, 109]   [17, 64, 53, 53]     51,264               True
│    │    └─BatchNorm2d (bn2)                 [17, 64, 53, 53]     [17, 64, 53, 53]     128                  True
│    │    └─Conv2d (conv3)                    [17, 64, 53, 53]     [17, 64, 25, 25]