<a href="https://colab.research.google.com/github/tohkunhao/Paper-Implementation/blob/main/EfficientNET_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import math
import numpy as np
import torch
import torch.nn as nn

class EfficientNET(nn.Module):
    '''
    Based on the paper by （以下の論文の実装です。）
    Mingxing Tan, Quoc V. Le
    EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks.
    ICML'19, https://arxiv.org/abs/1905.11946
    
    Resolution for each of the scaled model is as follows.
    B0～B7の解像度は次です。
    B0:224, B1:240, B2:260, B3:300, B4:380, B5:456, B6:528, B7:600
    
    Takes in arguments input channels, num_classes, optional (B0~B7 variant), optional batchnorm momentum, optional batchnorm epsilon
    入力は入力画像のチャンネル数、分類クラス数、（任意）B0~B7の種類、（任意）バッチ正規化のモメンタム、（任意）バッチ正規化の0除算防ぎ小数
    
    '''
    def __init__(self,in_channels,num_classes,architecture='B0',BN_momentum = 0.99,BN_eps =1e-3):
        super(EfficientNET,self).__init__()
        self.in_channels = in_channels
        self.channelscale = {'B0':1,'B1':1,'B2':1.1,'B3':1.2,'B4':1.4,'B5':1.6,'B6':1.8,'B7':2}
        self.depthscale = {'B0':1,'B1':1.1,'B2':1.2,'B3':1.4,'B4':1.8,'B5':2.2,'B6':2.6,'B7':3.1}
        self.BN_momentum = BN_momentum
        self.BN_eps = BN_eps
        self.survival_prob = 0.8
        self.dropoutparams = [0.2,0.2,0.3,0.3,0.4,0.4,0.5,0.5] 
        self.dropout = self.dropoutparams[int(architecture[1])]
        
        #architecture follows the format: conv type, num filters(out channels), num layers, kernel size, stride, padding
        self.architecture = [['C',32,1,3,2,1],
                            ['MB1',self.channelscaling(self.channelscale[architecture],16),self.depthscaling(self.depthscale[architecture],1),3,1,1],
                            ['MB6',self.channelscaling(self.channelscale[architecture],24),self.depthscaling(self.depthscale[architecture],2),3,2,1],
                            ['MB6',self.channelscaling(self.channelscale[architecture],40),self.depthscaling(self.depthscale[architecture],2),5,2,2],
                            ['MB6',self.channelscaling(self.channelscale[architecture],80),self.depthscaling(self.depthscale[architecture],3),3,2,1],
                            ['MB6',self.channelscaling(self.channelscale[architecture],112),self.depthscaling(self.depthscale[architecture],3),5,1,2],
                            ['MB6',self.channelscaling(self.channelscale[architecture],192),self.depthscaling(self.depthscale[architecture],4),5,2,2],
                            ['MB6',self.channelscaling(self.channelscale[architecture],320),self.depthscaling(self.depthscale[architecture],1),3,1,1],
                            ['C',1280,1,1,1,0]]
        
        self.EffNet = nn.Sequential(self.create_layers(),nn.AdaptiveAvgPool2d(1),nn.Flatten(),nn.Dropout(self.dropout),nn.Linear(1280,num_classes)).apply(self.init_weights)
    
    def create_layers(self): #return list of layers (nn要素のリストを返す)
        in_channels=self.in_channels
        layers=[]
        for stage in self.architecture:
            if stage[0]=='C': #regular convolution (普通の畳み込み)
                layers.append(nn.Conv2d(in_channels,stage[1],stage[3],stride=stage[4],padding=stage[5]))
                in_channels = stage[1] #in channels = out channels
                layers.append(nn.BatchNorm2d(in_channels,momentum = self.BN_momentum,eps = self.BN_eps)) #momentum value as stated in paper
                layers.append(nn.SiLU())
            
            else:#execute MBConv (MobileNetV2の畳み込み)
                for layer_num in range(stage[2]):
                    layers.append(MBConv(in_channels,stage,layer_num,self.survival_prob,self.BN_momentum,self.BN_eps))
                    in_channels = stage[1]
        
        return nn.Sequential(*layers)
    
    def forward(self,x):
        return self.EffNet(x)
    
    #function to scale channels for B0~B7
    def channelscaling(self,scale,channels):
        channels = channels*scale
        new_channels = max(8,int(channels+4)//8*8) #8の倍率にするため
        if new_channels <0.9 * channels:
            new_channels += 8
        return int(new_channels)
    
    #function to scale layers for B0~B7
    def depthscaling(self,scale,layers):
        return int(math.ceil(scale*layers))
    
    #initialize weights
    def init_weights(self,m):
        if type(m) == nn.Linear:
            nn.init.kaiming_uniform_(m.weight,a=np.sqrt(5),mode='fan_out',nonlinearity='leaky_relu')
        elif type(m) == nn.Conv2d:
            nn.init.kaiming_normal_(m.weight,mode = 'fan_out',nonlinearity = 'relu')
    
class MBConv(nn.Module):#mobile inverted bottleneck
    '''
    以下の論文を参照しました。
    Mingxing Tan, Bo Chen, Ruoming Pang, Vijay Vasudevan, Mark Sandler, Andrew Howard, Quoc V. Le
    MnasNet: Platform-Aware Neural Architecture Search for Mobile
    arXiv:1807.11626v3
    
    Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen
    MobileNetV2: Inverted Residuals and Linear Bottlenecks
    arXiv:1801.04381v4
    '''
    def __init__(self,in_channels,stage_archi_list,layer_num,survival_probability,BN_momentum, BN_eps):
        super(MBConv,self).__init__()
        if stage_archi_list[0] == 'MB1':#拡大係数は1
            self.expansion = 1
        else:#拡大係数は6
            self.expansion = 6
        self.in_channels = in_channels
        self.out_channels = stage_archi_list[1]
        self.layers = layer_num
        self.kernel_size = stage_archi_list[3]
        self.layer1stride = stage_archi_list[4]
        self.padding = stage_archi_list[5]
        self.BN_momentum = BN_momentum
        self.BNeps = BN_eps
        self.survival_prob = survival_probability
        self.mbconv = self.MBConvblock(self.in_channels,self.layers)
    
    def conv1x1block(self,in_channels,activation='SiLU',order = 'first'):
        layers=[]
        if order == 'first':
            out_channels = in_channels * self.expansion
        elif order == 'last':
            out_channels = self.out_channels
        layers.append(nn.Conv2d(in_channels,out_channels,kernel_size=1))
        layers.append(nn.BatchNorm2d(out_channels,momentum = self.BN_momentum,eps = self.BNeps))
        if activation == 'SiLU':
            layers.append(nn.SiLU())
        return nn.Sequential(*layers)
    
    def depthwiseblock(self,in_channels,layer_num):
        layers=[]
        if layer_num == 0:
            stride = self.layer1stride
        else:
            stride = 1
        layers.append(nn.Conv2d(in_channels,in_channels,kernel_size = self.kernel_size,stride = stride,padding = self.padding,groups = in_channels)) #group makes it depthwise
        layers.append(nn.BatchNorm2d(in_channels,momentum = self.BN_momentum,eps = self.BNeps))
        layers.append(nn.SiLU())
        return nn.Sequential(*layers)
    
    def MBConvblock(self,in_channels,layer_num):
        return nn.Sequential(self.conv1x1block(in_channels),self.depthwiseblock(in_channels*self.expansion,layer_num),SEBlock(in_channels*self.expansion),self.conv1x1block(in_channels*self.expansion,activation = None,order = 'last'))
      
    def StochasticDepth(self,x):
        if not self.training:
            return x
        prob = torch.rand((x.shape[0],1,1,1)).to('cuda' if torch.cuda.is_available else 'cpu')
        mask = prob < self.survival_prob
        return torch.mul(torch.div(x,self.survival_prob),mask)
      
    def forward(self,x):
        if self.layers == 0: #first layer in the layers
            return self.mbconv(x)
        else:
            return torch.add(self.StochasticDepth(self.mbconv(x)),x)#画像解像度を半分にする層以外Stochastic Depthを実施する

class SEBlock(nn.Module):
    '''
    以下の論文を参照しました。
    Jie Hu, Li Shen, Samuel Albanie, Gang Sun, Enhua Wu
    Squeeze-and-Excitation Networks
    arXiv:1709.01507v4
    '''
    def __init__(self,in_channels):
        super(SEBlock,self).__init__()
        self.SEratio = 0.25
        self.in_channels = in_channels
        self.squeezedchannels = max(1,int(0.25*in_channels)) #チャンネル数は4分の1まで圧縮する
        self.squeezeexcite = self.SE()
    
    def SE(self):#squeeze and excitation
        return nn.Sequential(nn.AdaptiveAvgPool2d(1),
                             nn.Conv2d(self.in_channels,self.squeezedchannels,kernel_size=1,bias=False),
                            nn.SiLU(),
                            nn.Conv2d(self.squeezedchannels,self.in_channels,kernel_size=1,bias=False),
                            nn.Sigmoid())
    
    def forward(self,x):
        return torch.mul(self.squeezeexcite(x),x)