Skip to content

Commit

Permalink
complex example
Browse files Browse the repository at this point in the history
  • Loading branch information
wavefrontshaping committed May 19, 2019
1 parent f0099fc commit b439d57
Showing 1 changed file with 39 additions and 15 deletions.
54 changes: 39 additions & 15 deletions README.md
Expand Up @@ -21,7 +21,7 @@ Following [[C. Trabelsi et al., International Conference on Learning Representat
## Synthax and usage

The synthax is supposed to copy the one of the standard real functions and modules from PyTorch.
The names are the same as in `nn.modules` and `nn.functional` except that they start with `Complex`, e.g. `ComplexRelu`, `ComplexMaxPool2D`...
The names are the same as in `nn.modules` and `nn.functional` except that they start with `Complex` for Modules, e.g. `ComplexRelu`, `ComplexMaxPool2D` or `complex_` for functions, e.g. `complex_relu`, `complex_max_pool2d`.
The only usage difference is that the forward fuction takes two tensors, corresponding to real and imaginary parts, and returns two ones too.

## BatchNorm
Expand All @@ -43,7 +43,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from complexLayers import ComplexBatchNorm2D
from complexLayers import ComplexBatchNorm2d, ComplexConv2d, ComplexLinear
from complexFunctions import complex_relu, complex_max_pool2d

batch_size = 64
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
Expand All @@ -57,21 +58,32 @@ class ComplexNet(nn.Module):

def __init__(self):
super(ComplexNet, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.bn = ComplexBatchNorm2D(20)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
self.conv1 = ComplexConv2d(1, 20, 5, 1)
self.bn = ComplexBatchNorm2d(20)
self.conv2 = ComplexConv2d(20, 50, 5, 1)
self.fc1 = ComplexLinear(4*4*50, 500)
self.fc2 = ComplexLinear(500, 10)

def forward(self,x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x,_ = self.bn(x,x)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
xr = x
# imaginary part to zero
xi = torch.zeros(xr.shape, dtype = xr.dtype, device = xr.device)
xr,xi = self.conv1(xr,xi)
xr,xi = complex_relu(xr,xi)
xr,xi = complex_max_pool2d(xr,xi, 2, 2)


xr,xi = self.bn(xr,xi)
xr,xi = self.conv2(xr,xi)
xr,xi = complex_relu(xr,xi)
xr,xi = complex_max_pool2d(xr,xi, 2, 2)

xr = xr.view(-1, 4*4*50)
xi = xi.view(-1, 4*4*50)
xr,xi = self.fc1(xr,xi)
xr,xi = complex_relu(xr,xi)
xr,xi = self.fc2(xr,xi)
x = torch.sqrt(torch.pow(xr,2)+torch.pow(xi,2))
return F.log_softmax(x, dim=1)

device = torch.device("cuda:3" )
Expand All @@ -87,6 +99,18 @@ def train(model, device, train_loader, optimizer, epoch):
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 1000 == 0:
print('Train Epoch: {:3} [{:6}/{:6} ({:3.0f}%)]\tLoss: {:.6f}'.format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item())
)

# Run trainong on 50 epochs
for epoch in range(50):
train(model, device, train_loader, optimizer, epoch)
```
## Todo
Expand Down

0 comments on commit b439d57

Please sign in to comment.