Skip to content

Commit

Permalink
synthax modifications to better match the PyTorch one
Browse files Browse the repository at this point in the history
  • Loading branch information
wavefrontshaping committed May 19, 2019
1 parent b439d57 commit 2cee903
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
6 changes: 3 additions & 3 deletions complexFunctions.py
Expand Up @@ -17,9 +17,9 @@ def complex_relu(input_r,input_i):
# assert(input_r.size() == input_i.size())
return relu(input_r), relu(input_i)

def complex_max_pool(input_r,input_i,kernel_size, stride, padding,
dilation, ceil_mode, return_indices):
def complex_max_pool2d(input_r,input_i,kernel_size, stride=None, padding=0,
dilation=1, ceil_mode=False, return_indices=False):
return max_pool2d(input_r, kernel_size, stride, padding, dilation,
ceil_mode, return_indices), \
max_pool2d(input_i, kernel_size, stride, padding, dilation,
ceil_mode, return_indices)
ceil_mode, return_indices)
48 changes: 31 additions & 17 deletions complexLayers.py
Expand Up @@ -12,19 +12,19 @@
import torch
from torch.nn import Module, Parameter, init, Sequential
from torch.nn import Conv2d, Linear, BatchNorm1d, BatchNorm2d
from complexFunctions import complex_relu, complex_max_pool
from complexFunctions import complex_relu, complex_max_pool2d

class ComplexSequential(Sequential):
def forward(self, input_r, input_t):
for module in self._modules.values():
input_r, input_t = module(input_r, input_t)
return input_r, input_t

class ComplexMaxPool(Module):
class ComplexMaxPool2d(Module):

def __init__(self,kernel_size, stride= None, padding = 0,
dilation = 1, return_indices = False, ceil_mode = False):
super(ComplexMaxPool,self).__init__()
super(ComplexMaxPool2d,self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
Expand All @@ -33,7 +33,7 @@ def __init__(self,kernel_size, stride= None, padding = 0,
self.return_indices = return_indices

def forward(self,input_r,input_i):
return complex_max_pool(input_r,input_i,kernel_size = self.kernel_size,
return complex_max_pool2d(input_r,input_i,kernel_size = self.kernel_size,
stride = self.stride, padding = self.padding,
dilation = self.dilation, ceil_mode = self.ceil_mode,
return_indices = self.return_indices)
Expand All @@ -44,50 +44,64 @@ def forward(self,input_r,input_i):
return complex_relu(input_r,input_i)


class ComplexConv2D(Module):
class ComplexConv2d(Module):

def __init__(self,in_channels, out_channels, kernel_size=3, stride=1, padding = 1):
super(ComplexConv2D, self).__init__()
self.conv_r = Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.conv_i = Conv2d(in_channels, out_channels, kernel_size, stride, padding)
def __init__(self,in_channels, out_channels, kernel_size=3, stride=1, padding = 0,
dilation=1, groups=1, bias=True):
super(ComplexConv2d, self).__init__()
self.conv_r = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.conv_i = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)

def forward(self,input_r, input_i):
# assert(input_r.size() == input_i.size())
return self.conv_r(input_r)-self.conv_i(input_i), \
self.conv_r(input_i)+self.conv_i(input_r)


class ComplexFC(Module):
class ComplexLinear(Module):

def __init__(self, in_features, out_features):
super(ComplexFC, self).__init__()
super(ComplexLinear, self).__init__()
self.fc_r = Linear(in_features, out_features)
self.fc_i = Linear(in_features, out_features)

def forward(self,input_r, input_i):
return self.fc_r(input_r)-self.fc_i(input_i), \
self.fc_r(input_i)+self.fc_i(input_r)

class NaiveComplexBatchNorm1d(Module):
'''
Naive approach to complex batch norm, perform batch norm independently on real and imaginary part.
'''
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, \
track_running_stats=True):
super(NaiveComplexBatchNorm1d, self).__init__()
self.bn_r = BatchNorm1d(num_features, eps, momentum, affine, track_running_stats)
self.bn_i = BatchNorm1d(num_features, eps, momentum, affine, track_running_stats)

def forward(self,input_r, input_i):
return self.bn_r(input_r), self.bn_i(input_i)

class NaiveComplexBatchNorm2D(Module):
class NaiveComplexBatchNorm2d(Module):
'''
Naive approach to complex batch norm, perform batch norm independently on real and imaginary part.
'''
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, \
track_running_stats=True):
super(NaiveComplexBatchNorm2D, self).__init__()
super(NaiveComplexBatchNorm2d, self).__init__()
self.bn_r = BatchNorm2d(num_features, eps, momentum, affine, track_running_stats)
self.bn_i = BatchNorm2d(num_features, eps, momentum, affine, track_running_stats)

def forward(self,input_r, input_i):
return self.bn_r(input_r), self.bn_i(input_i)

class NaiveComplexBatchNorm1D(Module):
class NaiveComplexBatchNorm1d(Module):
'''
Naive approach to complex batch norm, perform batch norm independently on real and imaginary part.
'''
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, \
track_running_stats=True):
super(NaiveComplexBatchNorm1D, self).__init__()
super(NaiveComplexBatchNorm1d, self).__init__()
self.bn_r = BatchNorm1d(num_features, eps, momentum, affine, track_running_stats)
self.bn_i = BatchNorm1d(num_features, eps, momentum, affine, track_running_stats)

Expand Down Expand Up @@ -137,7 +151,7 @@ def reset_parameters(self):
init.zeros_(self.weight[:,2])
init.zeros_(self.bias)

class ComplexBatchNorm2D(_ComplexBatchNorm):
class ComplexBatchNorm2d(_ComplexBatchNorm):

def forward(self, input_r, input_i):
assert(input_r.size() == input_i.size())
Expand Down Expand Up @@ -215,7 +229,7 @@ def forward(self, input_r, input_i):
return input_r, input_i


class ComplexBatchNorm1D(_ComplexBatchNorm):
class ComplexBatchNorm1d(_ComplexBatchNorm):

def forward(self, input_r, input_i):
assert(input_r.size() == input_i.size())
Expand Down

0 comments on commit 2cee903

Please sign in to comment.