From 2cee9039a87cc72408824e925a4f557747db8b2a Mon Sep 17 00:00:00 2001 From: "sebastien.popoff@espci.fr" Date: Sun, 19 May 2019 13:29:43 +0200 Subject: [PATCH] synthax modifications to better match the PyTorch one --- complexFunctions.py | 6 +++--- complexLayers.py | 48 +++++++++++++++++++++++++++++---------------- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/complexFunctions.py b/complexFunctions.py index 394e974..2392ce2 100644 --- a/complexFunctions.py +++ b/complexFunctions.py @@ -17,9 +17,9 @@ def complex_relu(input_r,input_i): # assert(input_r.size() == input_i.size()) return relu(input_r), relu(input_i) -def complex_max_pool(input_r,input_i,kernel_size, stride, padding, - dilation, ceil_mode, return_indices): +def complex_max_pool2d(input_r,input_i,kernel_size, stride=None, padding=0, + dilation=1, ceil_mode=False, return_indices=False): return max_pool2d(input_r, kernel_size, stride, padding, dilation, ceil_mode, return_indices), \ max_pool2d(input_i, kernel_size, stride, padding, dilation, - ceil_mode, return_indices) \ No newline at end of file + ceil_mode, return_indices) diff --git a/complexLayers.py b/complexLayers.py index 73885a4..61c3877 100644 --- a/complexLayers.py +++ b/complexLayers.py @@ -12,7 +12,7 @@ import torch from torch.nn import Module, Parameter, init, Sequential from torch.nn import Conv2d, Linear, BatchNorm1d, BatchNorm2d -from complexFunctions import complex_relu, complex_max_pool +from complexFunctions import complex_relu, complex_max_pool2d class ComplexSequential(Sequential): def forward(self, input_r, input_t): @@ -20,11 +20,11 @@ def forward(self, input_r, input_t): input_r, input_t = module(input_r, input_t) return input_r, input_t -class ComplexMaxPool(Module): +class ComplexMaxPool2d(Module): def __init__(self,kernel_size, stride= None, padding = 0, dilation = 1, return_indices = False, ceil_mode = False): - super(ComplexMaxPool,self).__init__() + super(ComplexMaxPool2d,self).__init__() self.kernel_size = kernel_size self.stride = stride self.padding = padding @@ -33,7 +33,7 @@ def __init__(self,kernel_size, stride= None, padding = 0, self.return_indices = return_indices def forward(self,input_r,input_i): - return complex_max_pool(input_r,input_i,kernel_size = self.kernel_size, + return complex_max_pool2d(input_r,input_i,kernel_size = self.kernel_size, stride = self.stride, padding = self.padding, dilation = self.dilation, ceil_mode = self.ceil_mode, return_indices = self.return_indices) @@ -44,12 +44,13 @@ def forward(self,input_r,input_i): return complex_relu(input_r,input_i) -class ComplexConv2D(Module): +class ComplexConv2d(Module): - def __init__(self,in_channels, out_channels, kernel_size=3, stride=1, padding = 1): - super(ComplexConv2D, self).__init__() - self.conv_r = Conv2d(in_channels, out_channels, kernel_size, stride, padding) - self.conv_i = Conv2d(in_channels, out_channels, kernel_size, stride, padding) + def __init__(self,in_channels, out_channels, kernel_size=3, stride=1, padding = 0, + dilation=1, groups=1, bias=True): + super(ComplexConv2d, self).__init__() + self.conv_r = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) + self.conv_i = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) def forward(self,input_r, input_i): # assert(input_r.size() == input_i.size()) @@ -57,37 +58,50 @@ def forward(self,input_r, input_i): self.conv_r(input_i)+self.conv_i(input_r) -class ComplexFC(Module): +class ComplexLinear(Module): def __init__(self, in_features, out_features): - super(ComplexFC, self).__init__() + super(ComplexLinear, self).__init__() self.fc_r = Linear(in_features, out_features) self.fc_i = Linear(in_features, out_features) def forward(self,input_r, input_i): return self.fc_r(input_r)-self.fc_i(input_i), \ self.fc_r(input_i)+self.fc_i(input_r) + +class NaiveComplexBatchNorm1d(Module): + ''' + Naive approach to complex batch norm, perform batch norm independently on real and imaginary part. + ''' + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, \ + track_running_stats=True): + super(NaiveComplexBatchNorm1d, self).__init__() + self.bn_r = BatchNorm1d(num_features, eps, momentum, affine, track_running_stats) + self.bn_i = BatchNorm1d(num_features, eps, momentum, affine, track_running_stats) + + def forward(self,input_r, input_i): + return self.bn_r(input_r), self.bn_i(input_i) -class NaiveComplexBatchNorm2D(Module): +class NaiveComplexBatchNorm2d(Module): ''' Naive approach to complex batch norm, perform batch norm independently on real and imaginary part. ''' def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, \ track_running_stats=True): - super(NaiveComplexBatchNorm2D, self).__init__() + super(NaiveComplexBatchNorm2d, self).__init__() self.bn_r = BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) self.bn_i = BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) def forward(self,input_r, input_i): return self.bn_r(input_r), self.bn_i(input_i) -class NaiveComplexBatchNorm1D(Module): +class NaiveComplexBatchNorm1d(Module): ''' Naive approach to complex batch norm, perform batch norm independently on real and imaginary part. ''' def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, \ track_running_stats=True): - super(NaiveComplexBatchNorm1D, self).__init__() + super(NaiveComplexBatchNorm1d, self).__init__() self.bn_r = BatchNorm1d(num_features, eps, momentum, affine, track_running_stats) self.bn_i = BatchNorm1d(num_features, eps, momentum, affine, track_running_stats) @@ -137,7 +151,7 @@ def reset_parameters(self): init.zeros_(self.weight[:,2]) init.zeros_(self.bias) -class ComplexBatchNorm2D(_ComplexBatchNorm): +class ComplexBatchNorm2d(_ComplexBatchNorm): def forward(self, input_r, input_i): assert(input_r.size() == input_i.size()) @@ -215,7 +229,7 @@ def forward(self, input_r, input_i): return input_r, input_i -class ComplexBatchNorm1D(_ComplexBatchNorm): +class ComplexBatchNorm1d(_ComplexBatchNorm): def forward(self, input_r, input_i): assert(input_r.size() == input_i.size())