In [2]:
from OFA_mbv3_extended.utils.flops_counter import profile
from torch import nn
from torch.nn import functional as F
from ofa.utils.layers import ResidualBlock, IdentityLayer

In [3]:
input_size = (2,5,5)

In [4]:
class MyNet(nn.Module):

    def __init__(self):

        super(MyNet,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=2,out_channels=3,kernel_size=3,stride=1,padding="same")
        self.conv2 = nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding="same")
        self.conv3 = nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding="same")

    def forward(self,x):

        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        return x

In [5]:
class MyNetDense(nn.Module):

    def __init__(self):

        super(MyNetDense,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=2,out_channels=3,kernel_size=3,stride=1,padding="same")
        self.conv2 = nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding="same")
        self.conv3 = nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding="same")



    def forward(self,x):

        x = F.relu(self.conv1(x))
        out1=x
        x = F.relu(self.conv2(x)+out1)
        out2=x
        x = F.relu(self.conv3(x)+out1+out2)
        return x


In [6]:
base_net = MyNet()

In [7]:
print(profile(base_net,input_size))

(1800.0, 81.0)


In [8]:
dense_net = MyNetDense()

In [9]:
print(profile(dense_net,input_size))

(1800.0, 81.0)


In [10]:
class SimilOFA(nn.Module):

    def __init__(self):

        super(SimilOFA,self).__init__()
        conv1 = nn.Conv2d(in_channels=2,out_channels=3,kernel_size=3,stride=1,padding="same")
        conv2 = nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding="same")
        conv3 = nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding="same")
        i2=IdentityLayer(conv2.in_channels,conv2.out_channels)
        i3=IdentityLayer(conv3.in_channels,conv3.out_channels)
        self.r1=ResidualBlock(conv1,None)
        self.r2=ResidualBlock(conv2,i2)
        self.r3=ResidualBlock(conv3,i3)


    def forward(self,x):

        x = F.relu(self.r1(x))
        x = F.relu(self.r2(x))
        x = F.relu(self.r3(x))
        return x

In [11]:
simil=SimilOFA()

In [12]:
print(profile(simil,input_size))

(1800.0, 81.0)


In [13]:
class MyNetSeq(nn.Module):

    def __init__(self):

        super(MyNetSeq,self).__init__()
        self.s1 = nn.Sequential(
            nn.Conv2d(in_channels=2,out_channels=3,kernel_size=3,stride=1,padding="same"),
            nn.ReLU(inplace=True))
        self.s2 = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding="same"),
            nn.ReLU(inplace=True))
        self.s3 = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding="same"),
            nn.ReLU(inplace=True))

    def forward(self,x):

        x = self.s1(x)
        x = self.s2(x)
        x = self.s3(x)
        return x

In [14]:
net_seq= MyNetSeq()

In [15]:
print(profile(simil,input_size))

(1800.0, 81.0)


In [17]:
for m in net_seq.modules():
    print(m)

MyNetSeq(
  (s1): Sequential(
    (0): Conv2d(2, 3, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): ReLU(inplace=True)
  )
  (s2): Sequential(
    (0): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1), padding=same)
    (1): ReLU(inplace=True)
  )
  (s3): Sequential(
    (0): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1), padding=same)
    (1): ReLU(inplace=True)
  )
)
Sequential(
  (0): Conv2d(2, 3, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (1): ReLU(inplace=True)
)
Conv2d(2, 3, kernel_size=(3, 3), stride=(1, 1), padding=same)
ReLU(inplace=True)
Sequential(
  (0): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1), padding=same)
  (1): ReLU(inplace=True)
)
Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1), padding=same)
ReLU(inplace=True)
Sequential(
  (0): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1), padding=same)
  (1): ReLU(inplace=True)
)
Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1), padding=same)
ReLU(inplace=True)
