diff --git a/README.md b/README.md index 218897d..9c79b20 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Following [[C. Trabelsi et al., International Conference on Learning Representat ## Synthax and usage The synthax is supposed to copy the one of the standard real functions and modules from PyTorch. -The names are the same as in `nn.modules` and `nn.functional` except that they start with `Complex`, e.g. `ComplexRelu`, `ComplexMaxPool2D`... +The names are the same as in `nn.modules` and `nn.functional` except that they start with `Complex` for Modules, e.g. `ComplexRelu`, `ComplexMaxPool2D` or `complex_` for functions, e.g. `complex_relu`, `complex_max_pool2d`. The only usage difference is that the forward fuction takes two tensors, corresponding to real and imaginary parts, and returns two ones too. ## BatchNorm @@ -43,7 +43,8 @@ import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms -from complexLayers import ComplexBatchNorm2D +from complexLayers import ComplexBatchNorm2d, ComplexConv2d, ComplexLinear +from complexFunctions import complex_relu, complex_max_pool2d batch_size = 64 trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) @@ -57,21 +58,32 @@ class ComplexNet(nn.Module): def __init__(self): super(ComplexNet, self).__init__() - self.conv1 = nn.Conv2d(1, 20, 5, 1) - self.bn = ComplexBatchNorm2D(20) - self.conv2 = nn.Conv2d(20, 50, 5, 1) - self.fc1 = nn.Linear(4*4*50, 500) - self.fc2 = nn.Linear(500, 10) + self.conv1 = ComplexConv2d(1, 20, 5, 1) + self.bn = ComplexBatchNorm2d(20) + self.conv2 = ComplexConv2d(20, 50, 5, 1) + self.fc1 = ComplexLinear(4*4*50, 500) + self.fc2 = ComplexLinear(500, 10) def forward(self,x): - x = F.relu(self.conv1(x)) - x = F.max_pool2d(x, 2, 2) - x,_ = self.bn(x,x) - x = F.relu(self.conv2(x)) - x = F.max_pool2d(x, 2, 2) - x = x.view(-1, 4*4*50) - x = F.relu(self.fc1(x)) - x = self.fc2(x) + xr = x + # imaginary part to zero + xi = torch.zeros(xr.shape, dtype = xr.dtype, device = xr.device) + xr,xi = self.conv1(xr,xi) + xr,xi = complex_relu(xr,xi) + xr,xi = complex_max_pool2d(xr,xi, 2, 2) + + + xr,xi = self.bn(xr,xi) + xr,xi = self.conv2(xr,xi) + xr,xi = complex_relu(xr,xi) + xr,xi = complex_max_pool2d(xr,xi, 2, 2) + + xr = xr.view(-1, 4*4*50) + xi = xi.view(-1, 4*4*50) + xr,xi = self.fc1(xr,xi) + xr,xi = complex_relu(xr,xi) + xr,xi = self.fc2(xr,xi) + x = torch.sqrt(torch.pow(xr,2)+torch.pow(xi,2)) return F.log_softmax(x, dim=1) device = torch.device("cuda:3" ) @@ -87,6 +99,18 @@ def train(model, device, train_loader, optimizer, epoch): loss = F.nll_loss(output, target) loss.backward() optimizer.step() + if batch_idx % 1000 == 0: + print('Train Epoch: {:3} [{:6}/{:6} ({:3.0f}%)]\tLoss: {:.6f}'.format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100. * batch_idx / len(train_loader), + loss.item()) + ) + +# Run trainong on 50 epochs +for epoch in range(50): + train(model, device, train_loader, optimizer, epoch) ``` ## Todo