Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #122 from mrT23/master
TResNet models
- Loading branch information
Showing
5 changed files
with
409 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import torch | ||
import torch.nn.parallel | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class AntiAliasDownsampleLayer(nn.Module): | ||
def __init__(self, remove_aa_jit: bool = False, filt_size: int = 3, stride: int = 2, | ||
channels: int = 0): | ||
super(AntiAliasDownsampleLayer, self).__init__() | ||
if not remove_aa_jit: | ||
self.op = DownsampleJIT(filt_size, stride, channels) | ||
else: | ||
self.op = Downsample(filt_size, stride, channels) | ||
|
||
def forward(self, x): | ||
return self.op(x) | ||
|
||
|
||
@torch.jit.script | ||
class DownsampleJIT(object): | ||
def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0): | ||
self.stride = stride | ||
self.filt_size = filt_size | ||
self.channels = channels | ||
|
||
assert self.filt_size == 3 | ||
assert stride == 2 | ||
a = torch.tensor([1., 2., 1.]) | ||
|
||
filt = (a[:, None] * a[None, :]).clone().detach() | ||
filt = filt / torch.sum(filt) | ||
self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda().half() | ||
|
||
def __call__(self, input: torch.Tensor): | ||
if input.dtype != self.filt.dtype: | ||
self.filt = self.filt.float() | ||
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') | ||
return F.conv2d(input_pad, self.filt, stride=2, padding=0, groups=input.shape[1]) | ||
|
||
|
||
class Downsample(nn.Module): | ||
def __init__(self, filt_size=3, stride=2, channels=None): | ||
super(Downsample, self).__init__() | ||
self.filt_size = filt_size | ||
self.stride = stride | ||
self.channels = channels | ||
|
||
|
||
assert self.filt_size == 3 | ||
a = torch.tensor([1., 2., 1.]) | ||
|
||
filt = (a[:, None] * a[None, :]) | ||
filt = filt / torch.sum(filt) | ||
|
||
# self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) | ||
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) | ||
|
||
def forward(self, input): | ||
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') | ||
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class SpaceToDepth(nn.Module): | ||
def __init__(self, block_size=4): | ||
super().__init__() | ||
assert block_size == 4 | ||
self.bs = block_size | ||
|
||
def forward(self, x): | ||
N, C, H, W = x.size() | ||
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) | ||
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) | ||
x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) | ||
return x | ||
|
||
|
||
@torch.jit.script | ||
class SpaceToDepthJit(object): | ||
def __call__(self, x: torch.Tensor): | ||
# assuming hard-coded that block_size==4 for acceleration | ||
N, C, H, W = x.size() | ||
x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) | ||
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) | ||
x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) | ||
return x | ||
|
||
|
||
class SpaceToDepthModule(nn.Module): | ||
def __init__(self, remove_model_jit=False): | ||
super().__init__() | ||
if not remove_model_jit: | ||
self.op = SpaceToDepthJit() | ||
else: | ||
self.op = SpaceToDepth() | ||
|
||
def forward(self, x): | ||
return self.op(x) | ||
|
||
|
||
class DepthToSpace(nn.Module): | ||
|
||
def __init__(self, block_size): | ||
super().__init__() | ||
self.bs = block_size | ||
|
||
def forward(self, x): | ||
N, C, H, W = x.size() | ||
x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) | ||
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) | ||
x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) | ||
return x |
Oops, something went wrong.