In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#====================================================================
# calling data 
#====================================================================
class sep_conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, bias=False):
        super(sep_conv2d, self).__init__()
        
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels, bias=bias)
        self.bat1 = nn.BatchNorm2d(in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias)
        self.bat2 = nn.BatchNorm2d(out_channels)
        
    def forward(self,x):
        out = self.depthwise(x)
        out = self.bat1(out)
        out = self.pointwise(out)
        out = self.bat2(out)
        return out

class EntryFlow(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3,32,3,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32,64,3,stride=1,padding=0,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        self.conv2_resi = nn.Sequential(
            sep_conv2d(64,128),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            sep_conv2d(128,128),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(3,stride=2,padding=1)
        )
        
        self.conv2_shortcut = nn.Sequential(
            nn.Conv2d(64,128,1,stride=2,padding=0),
            nn.BatchNorm2d(128)
        )
        
        self.conv3_resi = nn.Sequential(
            nn.ReLU(),
            sep_conv2d(128,256),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            sep_conv2d(256,256),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(3,stride=2,padding=1)
        )
        
        self.conv3_shortcut = nn.Sequential(
            nn.Conv2d(128,256,1,stride=2,padding=0),
            nn.BatchNorm2d(256)
        )
        
        self.conv4_resi = nn.Sequential(
            nn.ReLU(),
            sep_conv2d(256,728),
            nn.BatchNorm2d(728),
            nn.ReLU(),
            sep_conv2d(728,728),
            nn.BatchNorm2d(728),
            nn.MaxPool2d(3,stride=2,padding=1)
        )
        
        self.conv4_shortcut = nn.Sequential(
            nn.Conv2d(256,728,1,stride=2,padding=0),
            nn.BatchNorm2d(728)
        )
        
    def forward(self,x):
        out = self.conv1(x)
        out = self.conv2_resi(out)+self.conv2_shortcut(out)
        out = self.conv3_resi(out)+self.conv3_shortcut(out)
        out = self.conv4_resi(out)+self.conv4_shortcut(out)
        return out
    
class MiddleFlow(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv_resi = nn.Sequential(
            nn.ReLU(),
            sep_conv2d(728,728),
            nn.BatchNorm2d(728),
            nn.ReLU(),
            sep_conv2d(728,728),
            nn.BatchNorm2d(728),
            nn.ReLU(),
            sep_conv2d(728,728),
            nn.BatchNorm2d(728)
        )
        
        self.conv_shortcut = nn.Sequential()
    
    def forward(self,x):
        return self.conv_resi(x) + self.conv_shortcut(x)
    
class ExitFlow(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1_resi = nn.Sequential(
            nn.ReLU(),
            sep_conv2d(728,1024),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            sep_conv2d(1024,1024),
            nn.BatchNorm2d(1024),
            nn.MaxPool2d(3,stride=2,padding=1)
        )
        
        self.conv1_shortcut = nn.Sequential(
            nn.Conv2d(728,1024,1,stride=2,padding=0),
            nn.BatchNorm2d(1024)
        )
        
        self.conv2 = nn.Sequential(
            sep_conv2d(1024,1536),
            nn.BatchNorm2d(1536),
            nn.ReLU(),
            sep_conv2d(1536,2048),
            nn.BatchNorm2d(2048),
            nn.ReLU()
        )
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
    
    def forward(self,x):
        out = self.conv1_resi(x) + self.conv1_shortcut(x)
        out = self.conv2(out)
        out = self.avg_pool(out)
        return out

class Xception(nn.Module):
    def __init__(self):
        super().__init__()

        self.a = EntryFlow()
        self.b = self._make_middle_flow()
        self.c = ExitFlow()
        self.linear = nn.Linear(2048,1000)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)        

                
    def forward(self,x):
        out = self.a(x)
        out = self.b(out)
        out = self.c(out)
        out = out.view(out.size(0),-1)
        out = self.linear(out)
        return out
        
#         def middle_flow(self):
#             k = nn.Sequential()
#             for i in range(8):
#                 k.add_module('middle_block_{i}'.format(i), MiddleFlow())
#             return k
    
    def _make_middle_flow(self):
        middle = nn.Sequential()
        for i in range(8):
            middle.add_module('middle_block_{}'.format(i), MiddleFlow())
        return middle
model = Xception().to(device)
summary(model, (3,299,299), device = device.type)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 150, 150]             864
       BatchNorm2d-2         [-1, 32, 150, 150]              64
              ReLU-3         [-1, 32, 150, 150]               0
            Conv2d-4         [-1, 64, 148, 148]          18,432
       BatchNorm2d-5         [-1, 64, 148, 148]             128
              ReLU-6         [-1, 64, 148, 148]               0
            Conv2d-7         [-1, 64, 148, 148]             576
       BatchNorm2d-8         [-1, 64, 148, 148]             128
            Conv2d-9        [-1, 128, 148, 148]           8,192
      BatchNorm2d-10        [-1, 128, 148, 148]             256
       sep_conv2d-11        [-1, 128, 148, 148]               0
      BatchNorm2d-12        [-1, 128, 148, 148]             256
             ReLU-13        [-1, 128, 148, 148]               0
           Conv2d-14        [-1, 128, 1