## 출처: https://github.com/xmu-xiaoma666/External-Attention-pytorch/blob/master/attention/CoTAttention.py

## download DALI

In [1]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Wed_Jul_22_19:09:09_PDT_2020
Cuda compilation tools, release 11.0, V11.0.221
Build cuda_11.0_bu.TC445_37.28845127_0


In [2]:
!pip install torchvision
!pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda110

Looking in indexes: https://pypi.org/simple, https://developer.download.nvidia.com/compute/redist
Collecting nvidia-dali-cuda110
  Downloading https://developer.download.nvidia.com/compute/redist/nvidia-dali-cuda110/nvidia_dali_cuda110-1.4.0-2575285-py3-none-manylinux2014_x86_64.whl (789.4 MB)
[K     |████████████████████████████████| 789.4 MB 8.3 kB/s 
[?25hInstalling collected packages: nvidia-dali-cuda110
Successfully installed nvidia-dali-cuda110-1.4.0


## library load

In [3]:
import numpy as np
import torch
from torch import flatten, nn
from torch.nn import init
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn import functional as F
from torchsummary import summary
import math
from torch import Tensor
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from string import Template
import cupy
from collections import namedtuple
from torchvision import transforms
import torch.optim as optim
import time

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## download data

In [5]:
!wget -cq https://s3.amazonaws.com/content.udacity-data.com/courses/nd188/flower_data.zip \
  && unzip -qq flower_data.zip

In [6]:
import glob
!ls ./flower_data/

train  valid


In [7]:
from nvidia.dali.pipeline import Pipeline
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy

## CoT attention implementation
## 출처:https://github.com/xmu-xiaoma666/External-Attention-pytorch

In [8]:
class CoTAttention(nn.Module):

    def __init__(self, dim=512,kernel_size=3):
        super().__init__()
        self.dim=dim
        self.kernel_size=kernel_size

        self.key_embed=nn.Sequential(
            nn.Conv2d(dim,dim,kernel_size=kernel_size,padding=kernel_size//2,groups=4,bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU()
        )
        self.value_embed=nn.Sequential(
            nn.Conv2d(dim,dim,1,bias=False),
            nn.BatchNorm2d(dim)
        )

        factor=4
        self.attention_embed=nn.Sequential(
            nn.Conv2d(2*dim,2*dim//factor,1,bias=False),
            nn.BatchNorm2d(2*dim//factor),
            nn.ReLU(),
            nn.Conv2d(2*dim//factor,kernel_size*kernel_size*dim,1)
        )


    def forward(self, x):
        bs,c,h,w=x.shape
        k1=self.key_embed(x) #bs,c,h,w
        v=self.value_embed(x).view(bs,c,-1) #bs,c,h,w

        y=torch.cat([k1,x],dim=1) #bs,2c,h,w
        att=self.attention_embed(y) #bs,c*k*k,h,w
        att=att.reshape(bs,c,self.kernel_size*self.kernel_size,h,w)
        att=att.mean(2,keepdim=False).view(bs,c,-1) #bs,c,h*w
        k2=F.softmax(att,dim=-1)*v
        k2=k2.view(bs,c,h,w)


        return k1+k2

In [None]:
if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    cot = CoTAttention(dim=512,kernel_size=3)
    output=cot(input)
    print(output.shape)

torch.Size([50, 512, 7, 7])


## ResNet50 implementation

## 출처:https://github.com/weiaicunzai/pytorch-cifar100/blob/master/models/resnet.py

## 출처:https://deep-learning-study.tistory.com/534

In [9]:
class BottleNeck(nn.Module):
    """Residual block for resnet over 50 layers
    """
    expansion = 4
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels * BottleNeck.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

In [10]:
class Backbone(nn.Module):

    def __init__(self, block, num_block, num_classes=102):
        super().__init__()

        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        
        #we use a different inputsize than the original paper
        #so conv2_x's stride is 1
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        """make resnet layers(by layer i didnt mean this 'layer' was the
        same as a neuron netowork layer, ex. conv layer), one layer may
        contain more than one residual block
        Args:
            block: block type, basic block or bottle neck block
            out_channels: output depth channel number of this layer
            num_blocks: how many blocks per layer
            stride: the stride of the first block of this layer
        Return:
            return a resnet layer
        """

        # we have num_block blocks per layer, the first block
        # could be 1 or 2, other blocks would always be 1
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        output = self.conv1(x)
        output = self.conv2_x(output)
        output = self.conv3_x(output)
        output = self.conv4_x(output)
        output = self.conv5_x(output)
        output = self.avg_pool(output)
        output = output.view(output.size(0), -1)
        output = self.fc(output)

        return output

In [9]:
def resnet50():
    """ return a ResNet 50 object
    """
    return Backbone(BottleNeck, [3, 4, 6, 3])

In [None]:
# check input
model1 = resnet50().to(device)
x = torch.randn(3, 3, 224, 224).to(device)
output = model1(x)
print(output.size())

torch.Size([3, 10])


In [None]:
summary(model1, (3, 224, 224), device=device.type)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,

## replace 3*3convolution with CoT attention

In [11]:
class CoT_BottleNeck(nn.Module):

    expansion = 4
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            CoTAttention(out_channels,kernel_size=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * CoT_BottleNeck.expansion, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels * CoT_BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        if in_channels != out_channels * CoT_BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * CoT_BottleNeck.expansion, stride=1, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels * CoT_BottleNeck.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

In [12]:
def cotnet50():
    return Backbone(CoT_BottleNeck, [3, 4, 6, 3])

In [None]:
#check input
model2 = cotnet50().to(device)
x = torch.randn(3, 3, 224, 224).to(device)
output = model2(x)
print(output.size())

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


torch.Size([3, 10])


In [None]:
summary(model2, (3, 224, 224), device=device.type)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]           9,216
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11           [-1, 64, 56, 56]           4,096
      BatchNorm2d-12           [-1, 64, 56, 56]             128
           Conv2d-13           [-1, 32, 56, 56]           4,096
      BatchNorm2d-14           [-1, 32,

## official implementation
## 출처:https://github.com/JDAI-CV/CoTNet/blob/master/models/cotnet.py

## prepare requirements

In [14]:
CUDA_NUM_THREADS = 1024

kernel_loop = '''
#define CUDA_KERNEL_LOOP(i, n)                        \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
      i < (n);                                       \
      i += blockDim.x * gridDim.x)
'''

def GET_BLOCKS(N):
    return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS

_aggregation_zeropad_forward_kernel = kernel_loop + '''
extern "C"
__global__ void aggregation_zeropad_forward_kernel(
const ${Dtype}* bottom_data, const ${Dtype}* weight_data, ${Dtype}* top_data) {
  CUDA_KERNEL_LOOP(index, ${nthreads}) {
    const int n = index / ${weight_heads} / ${input_channels} / ${top_height} / ${top_width};
    const int head = (index / ${top_width} / ${top_height} / ${input_channels}) % ${weight_heads};
    const int c = (index / ${top_width} / ${top_height}) % ${input_channels};
    const int h = (index / ${top_width}) % ${top_height};
    const int w = index % ${top_width};
    ${Dtype} value = 0;
    for (int kh = 0; kh < ${kernel_h}; ++kh) {
      for (int kw = 0; kw < ${kernel_w}; ++kw) {
        const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h};
        const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w};
        if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) {
          const int offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in;
          const int offset_weight = (((n * ${weight_heads} + head) * ${weight_channels} + c % ${weight_channels}) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w;
          value += weight_data[offset_weight] * bottom_data[offset_bottom];
        }
      }
    }
    top_data[index] = value;
  }
}
'''

_aggregation_zeropad_input_backward_kernel = kernel_loop + '''
extern "C"
__global__ void aggregation_zeropad_input_backward_kernel(
    const ${Dtype}* const top_diff, const ${Dtype}* const weight_data, ${Dtype}* bottom_diff) {
  CUDA_KERNEL_LOOP(index, ${nthreads}) {
    const int n = index / ${input_channels} / ${bottom_height} / ${bottom_width};
    const int c = (index / ${bottom_height} / ${bottom_width}) % ${input_channels};
    const int h = (index / ${bottom_width}) % ${bottom_height};
    const int w = index % ${bottom_width};
    ${Dtype} value = 0;
    for (int head = 0; head < ${weight_heads}; ++head) {
        for (int kh = 0; kh < ${kernel_h}; ++kh) {
          for (int kw = 0; kw < ${kernel_w}; ++kw) {
            const int h_out_s = h + ${pad_h} - kh * ${dilation_h};
            const int w_out_s = w + ${pad_w} - kw * ${dilation_w};
            if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) {
              const int h_out = h_out_s / ${stride_h};
              const int w_out = w_out_s / ${stride_w};
              if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) {
                const int offset_top = (((n * ${weight_heads} + head) * ${input_channels} + c) * ${top_height} + h_out) * ${top_width} + w_out;
                const int offset_weight = (((n * ${weight_heads} + head) * ${weight_channels} + c % ${weight_channels}) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out;
                value += weight_data[offset_weight] * top_diff[offset_top];
              }
            }
          }
        }
    }
    bottom_diff[index] = value;
  }
}
'''

_aggregation_zeropad_weight_backward_kernel = kernel_loop + '''
extern "C"
__global__ void aggregation_zeropad_weight_backward_kernel(
    const ${Dtype}* const top_diff, const ${Dtype}* const bottom_data, ${Dtype}* weight_diff) {
  CUDA_KERNEL_LOOP(index, ${nthreads}) {
    const int n = index / ${weight_heads} / ${weight_channels} / ${top_height} / ${top_width};
    const int head = (index / ${top_width} / ${top_height} / ${weight_channels}) % ${weight_heads};
    const int c = (index / ${top_width} / ${top_height}) % ${weight_channels};
    const int h = (index / ${top_width}) % ${top_height};
    const int w = index % ${top_width};
    for (int kh = 0; kh < ${kernel_h}; ++kh) {
      for (int kw = 0; kw < ${kernel_w}; ++kw) {
        const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h};
        const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w};
        const int offset_weight = (((n * ${weight_heads} + head) * ${weight_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w;
        ${Dtype} value = 0;
        if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) {
          for (int cc = c; cc < ${input_channels}; cc += ${weight_channels}) {
            const int offset_bottom = ((n * ${input_channels} + cc) * ${bottom_height} + h_in) * ${bottom_width} + w_in;
            const int offset_top = (((n * ${weight_heads} + head) * ${input_channels} + cc) * ${top_height} + h) * ${top_width} + w;
            value += bottom_data[offset_bottom] * top_diff[offset_top];
          }
        }
        weight_diff[offset_weight] = value;
      }
    }
  }
}
'''

In [15]:
Stream = namedtuple('Stream', ['ptr'])

def Dtype(t):
    if isinstance(t, torch.cuda.FloatTensor):
        return 'float'
    elif isinstance(t, torch.cuda.DoubleTensor):
        return 'double'

In [16]:
@cupy.memoize(for_each_device=True)
def load_kernel(kernel_name, code, **kwargs):
    code = Template(code).substitute(**kwargs)
    kernel_code = cupy.cuda.compile_with_cache(code)
    return kernel_code.get_function(kernel_name)

In [17]:
class AggregationZeropad(Function):
    @staticmethod
    def forward(ctx, input, weight, kernel_size, stride, padding, dilation):
        kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation)
        ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation
        assert input.dim() == 4 and input.is_cuda and weight.is_cuda
        batch_size, input_channels, input_height, input_width = input.size()
        _, weight_heads, weight_channels, weight_kernels, weight_height, weight_width = weight.size()
        output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1)
        output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1)
        assert output_height * output_width == weight_height * weight_width
        output = input.new(batch_size, weight_heads * input_channels, output_height, output_width)
        n = output.numel()
        if not input.is_contiguous():
            input = input.detach().clone()
        if not weight.is_contiguous():
            weight = weight.detach().clone()

        with torch.cuda.device_of(input):
            f = load_kernel('aggregation_zeropad_forward_kernel', _aggregation_zeropad_forward_kernel, Dtype=Dtype(input), nthreads=n,
                            num=batch_size, input_channels=input_channels, 
                            weight_heads=weight_heads, weight_channels=weight_channels,
                            bottom_height=input_height, bottom_width=input_width,
                            top_height=output_height, top_width=output_width,
                            kernel_h=kernel_size[0], kernel_w=kernel_size[1],
                            stride_h=stride[0], stride_w=stride[1],
                            dilation_h=dilation[0], dilation_w=dilation[1],
                            pad_h=padding[0], pad_w=padding[1])
            f(block=(CUDA_NUM_THREADS, 1, 1),
              grid=(GET_BLOCKS(n), 1, 1),
              args=[input.data_ptr(), weight.data_ptr(), output.data_ptr()],
              stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
        ctx.save_for_backward(input, weight)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation
        input, weight = ctx.saved_tensors
        assert grad_output.is_cuda
        if not grad_output.is_contiguous():
            grad_output = grad_output.contiguous()
        batch_size, input_channels, input_height, input_width = input.size()
        _, weight_heads, weight_channels, weight_kernels, weight_height, weight_width = weight.size()
        output_height, output_width = grad_output.size()[2:]
        grad_input, grad_weight = None, None
        opt = dict(Dtype=Dtype(grad_output),
                   num=batch_size, input_channels=input_channels, 
                   weight_heads=weight_heads, weight_channels=weight_channels,
                   bottom_height=input_height, bottom_width=input_width,
                   top_height=output_height, top_width=output_width,
                   kernel_h=kernel_size[0], kernel_w=kernel_size[1],
                   stride_h=stride[0], stride_w=stride[1],
                   dilation_h=dilation[0], dilation_w=dilation[1],
                   pad_h=padding[0], pad_w=padding[1])
        with torch.cuda.device_of(input):
            if ctx.needs_input_grad[0]:
                grad_input = input.new(input.size())
                n = grad_input.numel()
                opt['nthreads'] = n
                f = load_kernel('aggregation_zeropad_input_backward_kernel', _aggregation_zeropad_input_backward_kernel, **opt)
                f(block=(CUDA_NUM_THREADS, 1, 1),
                  grid=(GET_BLOCKS(n), 1, 1),
                  args=[grad_output.data_ptr(), weight.data_ptr(), grad_input.data_ptr()],
                  stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
            if ctx.needs_input_grad[1]:
                grad_weight = weight.new(weight.size())
                n = grad_weight.numel() // weight.shape[3]
                opt['nthreads'] = n
                f = load_kernel('aggregation_zeropad_weight_backward_kernel', _aggregation_zeropad_weight_backward_kernel, **opt)
                f(block=(CUDA_NUM_THREADS, 1, 1),
                  grid=(GET_BLOCKS(n), 1, 1),
                  args=[grad_output.data_ptr(), input.data_ptr(), grad_weight.data_ptr()],
                  stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
        return grad_input, grad_weight, None, None, None, None

def aggregation_zeropad(input, weight, kernel_size=3, stride=1, padding=0, dilation=1):
    assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[2] == 0)
    if input.is_cuda:
        out = AggregationZeropad.apply(input, weight, kernel_size, stride, padding, dilation)
    else:
        #raise NotImplementedError
        out = AggregationZeropad.apply(input.cuda(), weight.cuda(), kernel_size, stride, padding, dilation)
        torch.cuda.synchronize()
        out = out.cpu()
    return out

In [18]:
class LocalConvolution(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        pad_mode: int = 0,
    ):
        super(LocalConvolution, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.pad_mode = pad_mode

    def forward(self, input: Tensor, weight: Tensor):
        #if self.pad_mode == 0:
        out = aggregation_zeropad(
            input, 
            weight, 
            kernel_size=self.kernel_size, 
            stride=self.stride, 
            padding=self.padding, 
            dilation=self.dilation)
        #else:
        #  out = aggregation_refpad(
        #    input, 
        #    weight, 
        #    kernel_size=self.kernel_size, 
        #    stride=self.stride, 
        #    padding=self.padding, 
        #    dilation=self.dilation)  
        return out

In [19]:
def swish(x, inplace: bool = False):
    """Swish - Described in: https://arxiv.org/abs/1710.05941
    """
    return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())


class Swish(nn.Module):
    def __init__(self, inplace: bool = False):
        super(Swish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return swish(x, self.inplace)

## define CoT Layer

In [20]:
class CotLayer(nn.Module):
    def __init__(self, dim, kernel_size):
        super(CotLayer, self).__init__()

        self.dim = dim
        self.kernel_size = kernel_size

        self.key_embed = nn.Sequential(
            nn.Conv2d(dim, dim, self.kernel_size, stride=1, padding=self.kernel_size//2, groups=4, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU(inplace=True)
        )

        share_planes = 8
        factor = 2
        self.embed = nn.Sequential(
            nn.Conv2d(2*dim, dim//factor, 1, bias=False),
            nn.BatchNorm2d(dim//factor),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim//factor, pow(kernel_size, 2) * dim // share_planes, kernel_size=1),
            nn.GroupNorm(num_groups=dim // share_planes, num_channels=pow(kernel_size, 2) * dim // share_planes)
        )

        self.conv1x1 = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, dilation=1, bias=False),
            nn.BatchNorm2d(dim)
        )

        self.local_conv = LocalConvolution(dim, dim, kernel_size=self.kernel_size, stride=1, padding=(self.kernel_size - 1) // 2, dilation=1)
        self.bn = nn.BatchNorm2d(dim)
        self.act = Swish(inplace=True)

        reduction_factor = 4
        self.radix = 2
        attn_chs = max(dim * self.radix // reduction_factor, 32)
        self.se = nn.Sequential(
            nn.Conv2d(dim, attn_chs, 1),
            nn.BatchNorm2d(attn_chs),
            nn.ReLU(inplace=True),
            nn.Conv2d(attn_chs, self.radix*dim, 1)
        )

    def forward(self, x):
        k = self.key_embed(x)
        qk = torch.cat([x, k], dim=1)
        b, c, qk_hh, qk_ww = qk.size()

        w = self.embed(qk)
        w = w.view(b, 1, -1, self.kernel_size*self.kernel_size, qk_hh, qk_ww)
        
        x = self.conv1x1(x)
        x = self.local_conv(x, w)
        x = self.bn(x)
        x = self.act(x)

        B, C, H, W = x.shape
        x = x.view(B, C, 1, H, W)
        k = k.view(B, C, 1, H, W)
        x = torch.cat([x, k], dim=2)

        x_gap = x.sum(dim=2)
        x_gap = x_gap.mean((2, 3), keepdim=True)
        x_attn = self.se(x_gap)
        x_attn = x_attn.view(B, C, self.radix)
        x_attn = F.softmax(x_attn, dim=2)
        out = (x * x_attn.reshape((B, C, self.radix, 1, 1))).sum(dim=2)
        
        return out.contiguous()

## define bottleneck

In [21]:
#using ResNet bottleneck

class CoT_Bottleneck(nn.Module):

    expansion = 4
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            CotLayer(out_channels,kernel_size=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * CoT_Bottleneck.expansion, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels * CoT_Bottleneck.expansion),
        )

        self.shortcut = nn.Sequential()

        if in_channels != out_channels * CoT_Bottleneck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * CoT_Bottleneck.expansion, stride=1, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels * CoT_Bottleneck.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

In [17]:
# official bottleneck

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, cardinality=1, base_width=64,
                 reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
                 attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
        super(Bottleneck, self).__init__()

        width = int(math.floor(planes * (base_width / 64)) * cardinality)
        first_planes = width // reduce_first
        outplanes = planes * self.expansion
        first_dilation = first_dilation or dilation
        use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)

        self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
        self.bn1 = norm_layer(first_planes)
        self.act1 = act_layer(inplace=True)

        if stride > 1:
            self.avd = nn.AvgPool2d(3, 2, padding=1)
        else:
            self.avd = None
        
        self.conv2 = CotLayer(width, kernel_size=3) if cardinality == 1 else CoXtLayer(width, kernel_size=3)

        #self.conv2 = nn.Conv2d(
        #    first_planes, width, kernel_size=3, stride=1 if use_aa else stride,
        #    padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
        #self.bn2 = norm_layer(width)
        #self.act2 = act_layer(inplace=True)
        #self.aa = aa_layer(channels=width, stride=stride) if use_aa else None

        self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
        self.bn3 = norm_layer(outplanes)

        #self.se = create_attn(attn_layer, outplanes)
        self.se = None

        self.act3 = act_layer(inplace=True)
        self.shortcut = nn.Sequential()

        if stride != 1 or inplanes != outplanes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inplanes, outplanes, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(outplanes)
            )

        self.stride = stride
        self.dilation = dilation
        self.drop_block = drop_block
        self.drop_path = drop_path

    def zero_init_last_bn(self):
        nn.init.zeros_(self.bn3.weight)

    def forward(self, x):
        residual = x

        x = self.conv1(x)
        x = self.bn1(x)
        if self.drop_block is not None:
            x = self.drop_block(x)
        x = self.act1(x)

        if self.avd is not None:
            x = self.avd(x)

        x = self.conv2(x)
        #x = self.bn2(x)
        #if self.drop_block is not None:
        #    x = self.drop_block(x)
        #x = self.act2(x)
        #if self.aa is not None:
        #    x = self.aa(x)

        x = self.conv3(x)
        x = self.bn3(x)
        if self.drop_block is not None:
            x = self.drop_block(x)

        if self.se is not None:
            x = self.se(x)

        if self.drop_path is not None:
            x = self.drop_path(x)

        residual = self.shortcut(residual)

        x += residual
        x = self.act3(x)

        return x

In [22]:
def cotnet50v2(): #Resnet bottleneck + official CoTlayer
    return Backbone(CoT_Bottleneck, [3, 4, 6, 3])

In [18]:
def cotnet50v3(): #official CoTNet50
    return Backbone(Bottleneck, [3, 4, 6, 3])

In [None]:
# check input

model3 = cotnet50v2().to(device)
x = torch.randn(3, 3, 224, 224).to(device)
output = model3(x)
print(output.size())

torch.Size([3, 10])


In [None]:
# check summary
summary(model3, (3, 224, 224), device=device.type)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]           9,216
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11           [-1, 32, 56, 56]           4,096
      BatchNorm2d-12           [-1, 32, 56, 56]              64
             ReLU-13           [-1, 32, 56, 56]               0
           Conv2d-14           [-1, 72,

In [None]:
# check input

model4 = cotnet50v3().to(device)
x = torch.randn(3, 3, 224, 224).to(device)
output = model4(x)
print(output.size())

torch.Size([3, 10])


In [None]:
# check summary
summary(model4, (3, 224, 224), device=device.type)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]           9,216
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11           [-1, 32, 56, 56]           4,096
      BatchNorm2d-12           [-1, 32, 56, 56]              64
             ReLU-13           [-1, 32, 56, 56]               0
           Conv2d-14           [-1, 72,

## train using DALI

In [31]:
# define model

model1 = resnet50().to(device) #ResNet50
model2 = cotnet50v3().to(device) #official CoTNet50

In [14]:
# define DALI pipeline

@pipeline_def
def dali_mixed_pipeline():
    image_dir = "./flower_data/train"
    jpegs, labels = fn.readers.file(name="Reader", file_root=image_dir, random_shuffle=True)
    images = fn.decoders.image(jpegs, device="mixed")
    images= fn.normalize(images.gpu(), mean=0, stddev=255)  # same as divide by 255
    images = fn.resize(images.gpu(), resize_x=224, resize_y=224)
    return images, labels.gpu()

torch_transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Resize((224, 224)),
])

In [15]:
#build pipeline

batch_size=4
num_threads=2
pipe = dali_mixed_pipeline(batch_size=4,num_threads=2,device_id=0)
pipe.build()

In [16]:
#DALI dataloader creation

dali_dl = DALIGenericIterator(
  pipelines=pipe, 
  output_map=['image', 'label'], 
  size=pipe.epoch_size("Reader"), 
  last_batch_padded=True,
  last_batch_policy=LastBatchPolicy.PARTIAL, 
  auto_reset=True
)



In [17]:
#define basic train function

def train(dl, model, num_epochs=5):
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters())

  start = time.time()
  for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(dl):
      # DALI
      if type(dl) == DALIGenericIterator:
        inputs = data[0]["image"].permute(0, 3, 1, 2).to(device)
        labels = data[0]["label"].long().view(-1).to(device)
      # Torch
      else:
        inputs, labels = data[0].to(device), data[1].to(device)
      optimizer.zero_grad()
      outputs = model(inputs)
      
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      running_loss += loss.item()
      
    print('epoch : %d loss: %.3f' %(epoch + 1, running_loss / (i+1)))
  print('Finished Training')
  print(np.round(time.time() - start, 2), '(seconds)')

#define basic test function

@torch.no_grad()
def test(dl, model):
  start = time.time()
  correct = 0
  total = 0
  for i, data in enumerate(dl):
    # DALI
    if type(dl) == DALIGenericIterator:
      inputs = data[0]["image"].permute(0, 3, 1, 2).to(device)
      labels = data[0]["label"].long().view(-1).to(device)
    # Torch
    else:
      inputs, labels = data[0].to(device), data[1].to(device)
    outputs = model(inputs)
    _, preds = torch.max(outputs, 1)
    correct += (preds == labels).sum()
    total += labels.shape[0]
  print(f"Acc: {correct/total*100.0:.2f}%")
  print(f"Elapsed time: {(time.time() - start):.2f} sec")

In [36]:
#train
print("ResNet50")
train(dali_dl, model1)
print("CoTNet50")
train(dali_dl, model2)

ResNet50


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


epoch : 1 loss: 4.179
epoch : 2 loss: 3.520
epoch : 3 loss: 3.190
epoch : 4 loss: 2.911
epoch : 5 loss: 2.640
Finished Training
1868.76 (seconds)
CoTNet50
epoch : 1 loss: 4.317
epoch : 2 loss: 3.555
epoch : 3 loss: 3.057
epoch : 4 loss: 2.686
epoch : 5 loss: 2.276
Finished Training
1108.34 (seconds)


In [37]:
# Test
print("ResNet50")
test(dali_dl, model1)
print("CoTNet50")
test(dali_dl, model2)

ResNet50
Acc: 32.74%
Elapsed time: 149.35 sec
CoTNet50
Acc: 44.83%
Elapsed time: 68.27 sec
