Skip to content

Commit

Permalink
Merge pull request #122 from mrT23/master
Browse files Browse the repository at this point in the history
TResNet models
  • Loading branch information
rwightman committed Apr 13, 2020
2 parents e15f979 + 8a63c1a commit ebf82b8
Show file tree
Hide file tree
Showing 5 changed files with 409 additions and 0 deletions.
1 change: 1 addition & 0 deletions timm/models/__init__.py
Expand Up @@ -17,6 +17,7 @@
from .dla import *
from .hrnet import *
from .sknet import *
from .tresnet import *

from .registry import *
from .factory import create_model
Expand Down
2 changes: 2 additions & 0 deletions timm/models/layers/__init__.py
Expand Up @@ -16,3 +16,5 @@
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .anti_aliasing import AntiAliasDownsampleLayer
from .space_to_depth import SpaceToDepthModule
61 changes: 61 additions & 0 deletions timm/models/layers/anti_aliasing.py
@@ -0,0 +1,61 @@
import torch
import torch.nn.parallel
import torch.nn as nn
import torch.nn.functional as F


class AntiAliasDownsampleLayer(nn.Module):
def __init__(self, remove_aa_jit: bool = False, filt_size: int = 3, stride: int = 2,
channels: int = 0):
super(AntiAliasDownsampleLayer, self).__init__()
if not remove_aa_jit:
self.op = DownsampleJIT(filt_size, stride, channels)
else:
self.op = Downsample(filt_size, stride, channels)

def forward(self, x):
return self.op(x)


@torch.jit.script
class DownsampleJIT(object):
def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0):
self.stride = stride
self.filt_size = filt_size
self.channels = channels

assert self.filt_size == 3
assert stride == 2
a = torch.tensor([1., 2., 1.])

filt = (a[:, None] * a[None, :]).clone().detach()
filt = filt / torch.sum(filt)
self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda().half()

def __call__(self, input: torch.Tensor):
if input.dtype != self.filt.dtype:
self.filt = self.filt.float()
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
return F.conv2d(input_pad, self.filt, stride=2, padding=0, groups=input.shape[1])


class Downsample(nn.Module):
def __init__(self, filt_size=3, stride=2, channels=None):
super(Downsample, self).__init__()
self.filt_size = filt_size
self.stride = stride
self.channels = channels


assert self.filt_size == 3
a = torch.tensor([1., 2., 1.])

filt = (a[:, None] * a[None, :])
filt = filt / torch.sum(filt)

# self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))

def forward(self, input):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])
53 changes: 53 additions & 0 deletions timm/models/layers/space_to_depth.py
@@ -0,0 +1,53 @@
import torch
import torch.nn as nn


class SpaceToDepth(nn.Module):
def __init__(self, block_size=4):
super().__init__()
assert block_size == 4
self.bs = block_size

def forward(self, x):
N, C, H, W = x.size()
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
return x


@torch.jit.script
class SpaceToDepthJit(object):
def __call__(self, x: torch.Tensor):
# assuming hard-coded that block_size==4 for acceleration
N, C, H, W = x.size()
x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
return x


class SpaceToDepthModule(nn.Module):
def __init__(self, remove_model_jit=False):
super().__init__()
if not remove_model_jit:
self.op = SpaceToDepthJit()
else:
self.op = SpaceToDepth()

def forward(self, x):
return self.op(x)


class DepthToSpace(nn.Module):

def __init__(self, block_size):
super().__init__()
self.bs = block_size

def forward(self, x):
N, C, H, W = x.size()
x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
return x

0 comments on commit ebf82b8

Please sign in to comment.