<a href="https://colab.research.google.com/github/zuhaib786/Disconinuity-Identification-in-Numerical-solutions-of-DEs/blob/main/HED_Detector.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import torch
from torch import nn
from torch.nn import Linear, Conv2d, LeakyReLU, Sigmoid


In [37]:
class HED(nn.Module):
  def __init__(self, input_dim, channels = [16, 16,32,32,64,64]):
    super(HED, self).__init__()
    h,w,n = input_dim
    
    self.upperLane = nn.ModuleList(
        [
            nn.Conv2d(
                in_channels = a,
                out_channels = b,
                kernel_size = 3,
                padding = 'same'
            ) 
            for (a,b) in zip([n] + channels[:-1], channels)
        ]
    )
    self.middleLane = nn.ModuleList(
        [
            nn.Conv2d(
                in_channels = a,
                out_channels = 1,
                kernel_size = 3,
                padding = 'same'
            )
            for a in channels
        ]
    )
    self.conv = nn.Conv2d(
        in_channels = len(channels),
        out_channels = 1,
        kernel_size= 3,
        padding = 'same'
    )
  def forward(self, batch):
    '''
    batch.shape = (None, input_dim)
    '''
    upperLaneOutput = []
    for idx, layer in enumerate(self.upperLane):
      if idx == 0:
        upperLaneOutput = [LeakyReLU(negative_slope = 0.2)(layer(batch))]
      else:
        upperLaneOutput.append(layer(upperLaneOutput[-1]))
    middleLaneOutput = [layer(upperLaneOutput[i]) for i, layer in enumerate(self.middleLane)]
    lowerLaneOutput = [Sigmoid()(i) for i in middleLaneOutput]
    concatenated = torch.cat(middleLaneOutput, axis = 1)
    lowerLaneOutput.append(Sigmoid()(self.conv(concatenated)))
    lowerLaneOutput = torch.cat(lowerLaneOutput, axis = 1)
    return torch.mean(lowerLaneOutput,1)



In [38]:
model = HED((11, 11, 3))

In [39]:
print(model)

HED(
  (upperLane): ModuleList(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
  )
  (middleLane): ModuleList(
    (0): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (2): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (3): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (4): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (5): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=same)
  )
  (conv): Conv2d(6, 1, kernel_size=(3, 3), stride=(1, 1), paddi

In [41]:
from torchsummary import summary
summary(model, input_size = (3,11,11))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 12, 12]             448
            Conv2d-2           [-1, 16, 12, 12]           2,320
            Conv2d-3           [-1, 32, 12, 12]           4,640
            Conv2d-4           [-1, 32, 12, 12]           9,248
            Conv2d-5           [-1, 64, 12, 12]          18,496
            Conv2d-6           [-1, 64, 12, 12]          36,928
            Conv2d-7            [-1, 1, 12, 12]             145
            Conv2d-8            [-1, 1, 12, 12]             145
            Conv2d-9            [-1, 1, 12, 12]             289
           Conv2d-10            [-1, 1, 12, 12]             289
           Conv2d-11            [-1, 1, 12, 12]             577
           Conv2d-12            [-1, 1, 12, 12]             577
           Conv2d-13            [-1, 1, 12, 12]              55
Total params: 74,157
Trainable params: 