In [15]:
import torch 
import torch.nn as nn
import math

In [2]:
def conv3_3(inplanes,outplanes,stride=1):
    
    return nn.Conv2d(in_channels=inplanes,out_channels=outplanes,stride=stride)


In [3]:
class BasicBlock(nn.Module):
    
    def __init__(self,inplanes,planes,stride=1,dilation=1):
        
        super(BasicBlock,self).__init__()
        
        self.conv1=nn.Conv2d(inplanes,planes,kernel_size=3,stride=stride,padding=dilation,bias=False,dilation=dilation)
        
        self.bn1=nn.BatchNorm2d(planes)
        
        self.relu=nn.ReLU(inplace=True)
        
        self.conv2=nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=dilation,bias=False,dilation=dilation)
        
        self.bn2=nn,BatchNorm2d(planes)
        
        self.stride=stride
        
    def forward(self,x,residual=False):
        
        if residual is None:
            residual=x
            
        out=self.conv1(x)
        out=self.bn1(out)
        out.self.relu(out)
        
        out=self.conv2(out)
        out=self.bn2(out)
        
        out=out+residual
        out=self.relu(out)
        
        return out

In [4]:
class Bottleneck(nn.Module):
    
    expansion=2
    
    def __init__(self,inplanes,planes,stride=1,dilation=1):
        
        
        super(Bottleneck,self).__init__()
        expansion=Bottleneck.expansion
        
        bottle_planes=planes//expansion
        
        self.conv1=nn.Conv2d(inplanes,planes,kernel_size=1,bias=False)
        self.bn1=nn.BatchNorm2d(bottle_planes)
        
        self.conv2=nn.Conv2d(bottle_planes,bottle_planes,kernel_size=3,stride=stride,padding=dilation,bias=False,dilation=dilation)
        self.bn2=nn.BatchNorm2d(bottle_planes)
        
        self.conv3=nn.Conv2d(bottle_planes,bottle_planes,kernel_size=1,bias=False)
        self.bn3=nn.BatchNorm2d(bottle_planes)
        
        self.relu=nn.ReLU(inplace=True)
        
        self.stride=stride
        
        
    def forward(self,x,residual=None):
        
        if residual is None:
            residual=x
        out=self.conv1(x)
        out=self.bn1(out)
        out=self.relu(out)
        
        out=self.conv2(out)
        out=self.bn2(out)
        out=self.relu(out)
        
        out=self.conv3(out)
        out=self.bn3(out)
        out=self.relu(out)
        
        out=out+residual
        
        out=self.relu(out)
        
        return out

In [5]:
class BottlenectX(nn.Module):
    
    expansion=2
    cardinality=32
    
    def __init__(self,inplanes,planes,stride=1,dilation=1):
        
        super(BottlenectX).__init__()
        cardinality=BottlenectX.cardinality
        
        bottle_planes=planes*cardinality//32
        
        self.conv1=nn.Conv2d(inplanes,bottle_planes,kernel_size=1,bias=False)
        self.bn1=nn.BatchNorm2d(bottle_planes)
        self.conv2=nn.Conv2d(bottle_planes,bottle_planes,kernel_size=3,padding=dilation,dilation=dilation,bias=False,groups=cardinality)
        self.bn2=nn.BatchNorm2d(bottle_planes)
        self.conv3=nn.Conv2d(bottle_planes,planes,kernel_size=1,bias=False)
        self.bn3=nn.BatchNorm2d(planes)
        self.relu=nn.ReLU(inplace=True)
        self.stride=stride
    def forward(self,x,residual=None):
        
        if residual is None:
            
            residual=x
            
        out=self.conv1(x)
        out=self.bn1(out)
        out=self.relu(out)
        
        out=self.conv2(out)
        out=self.bn2(out)
        out=self.relu(out)
        
        out=self.conv3(out)
        out=self.bn3(out)
        
        out=out+residual
        out=self.relu(out)
        
        return out

In [10]:
class Root(nn.Module):
    
    def __init__(self,in_channels,out_channels,kernel_size,residual):
        
        super(Root,self).__init__()
        self.conv=nn.Conv2d(in_channels,out_channels,1,stride=1,bias=False,padding=(kernel_size-1)//2)
        self.bn=nn.BatchNorm2d(out_channels)
        self.relu=nn.ReLU(inplace=True)
        self.residual=residual
        
        
    def forward(self,*x):
        children=x
        
        x=self.conv(torch.cat(x,1))
        x=self.bn(x)
        
        if self.residual:
            x=x+children[0]
            
        x=self.relu(x)
        
        return x

In [13]:
class Tree(nn.Module):
    def __init__(self,levels,block,in_channels,out_channels,stride=1,level_root=False,\
                 root_dim=0,root_kernel_size=1,dilation=1,root_residual=False):
        
        super(Tree,self).__init__()
        
        if root_dim==0:
            
            root_dim=2*out_channels
            
        if level_root:
            
            root_dim+=in_channels
            
        if levels ==1:
            self.tree1=block(in_channels,out_channels,stride,dilation=dilation)
            self.tree2=block(out_channels,out_channels,stride=1,dilation=dilation)
            
        else:
            
            self.tree1=Tree(levels-1,block,in_channels,out_channels,stride,root_dim=0,
                           root_kernel_size=root_kernel_size,dilation=dilation,root_residual=root_residual)
            self.tree2=Tree(levels-1,block,out_channels,out_channels,stride,root_dim=root_dim+out_channels,\
                           root_kernel_size=root_kernel_size,dilation=dilation,root_residual=root_kernel_size)
            
        if levels==1:
            
            self.root=Root(root_dim,out_channels,root_kernel_size,root_residual)
        self.root_level=root_leve
        self.root_dim=root_dim
        self.downsample=False
        self.project=False
        self.levels=levels
        if stride>1:
            self.downsample=nn.MaxPool2d(stride,stride=stride)
        if in_channels!=out_channels:
            
            self.project=nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,bias=False),
            nn.BatchNorm2d(out_channels))
    def forward(self,x,residual=False,children=None):
        
        children=[] if children is None else children
        
        bottom=self.downsample(x) if self.downsample else bottom
        
        residual=self.project(bottom) if self.project else bottom
        
        if self.level_root:
            children.append(bottom)
            
        x1=self.tree1(x,residual)
        
        if self.levels==1:
            x2=self.tree2(x1)
            x=self.root(x2,x1,*children)
            
        else:
            children.append(x1)
            
            x=self.tree2(x1,children=children)
            
        return x
        
        

In [18]:
class DLA(nn.Module):
    
    def __init__(self,levels,channels,num_classes=1000,block=BasicBlock,residual_root=False,return_levels=False,\
                pool_size=7,linear_root=False):
        
        
        super(DLA,self).__init__()
        self.channels=channels
        self.return_levels=return_levels
        self.num_class=num_class
        self.base_layer=nn.Sequential(
        nn.Conv2d(3,channels[0],kernel_size=7,stride=1,padding=3,bias=False),
        nn.BatchNorm2d(channels[0],
                      nn.ReLU(inplace=True)))
        
        self.level0=self._make_conv_level(channels[0],channels[0],levels[0])
    
        self.level1=self._make_conv_level(channels[0],channels[1],levels[1],stride=2)
        
        self.level2=Tree(levels[2],block,channels[1],channels[2],2,level_root=False,residual_root=residual_root)
        
        self.level3=Tree(levels[3],block,channels[2],channels[3],2,level_root=True,residual_root=residual_root)
        
        self.level4=Tree(levels[4],block,channels[3],channels[4],2,level_root=True,residual_root=residual_root)
        
        self.level5=Tree(levels[5],block,channels[4],channels[5],2,level_root=True,residual_root=residual_root)
        
        self.avgpool=nn.AvgPool2d(pool_size)
        
        self.fc=nn.Conv2d(channels[-1],num_classes,kernel_size=1,stride=1,padding=0,bias=True)
        
        
        for m in self.modules():
            
            if isinstance(m,nn.Conv2d):
                n=m.kernel_size[0]*m.kernel_size[1]*m.out_channels
                m.weight.data.normal_(0,math.sqrt(2/n))
                
            elif isinstance(m,nn.BatchNorm2d):
                
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        
    def _make_conv_level(self,inplanes,planes,convs,stride=1,dilation=1):
        modules=[]
        
        for i in range(convs):
            
            modules.extend([
                nn.Conv2d(inplanes,planes,kernel_size=3,stride=1 if i==0 else 1,padding=dilation,bias=False,dilation=dilation),
                nn.BatchNorm2d(planes),
                nn.ReLU(inplace==True)
            ])
            
            inplanes=planes
        return nn.Sequential(*modules)
    def forward(self,x):
        
        y=[]
        x=self.base_layer(x)
        for i in range(6):
            x=getattr(self,"level{}".format(i))(x)
            
            y.append(x)
            
        if self.return_levels:
            return y
        
        else:
            
            x=self.avgpool(x)
            x=self.fc(x)
            x=x.view(x.size(0),-1)
            
            return x
            
