In [1]:
import numpy as np
np.set_printoptions(precision=1)
import time
# import tensorflow as tf
import matplotlib.pylab as plt

from modules.utils import load_cifar10
# from modules.cnn_with_spectral_parameterization import CNN_Spectral_Param
# from modules.cnn_with_spectral_pooling import CNN_Spectral_Pool
from modules.image_generator import ImageGenerator

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import Module
import pytorch_fft.fft.autograd as fft

% matplotlib inline
% load_ext autoreload
% autoreload 2

In [2]:
# In the interest of training time, we only used 1 of 5 cifar10 batches
# The important part of the experiment is to compare the rates of convergence of training accuracy,
# so subsetting the training dataset for both spectral and spatial models shouldn't impact
# the relationship between their train accuracy convergences
xtrain, ytrain, xtest, ytest = load_cifar10(5, channels_last=False)

file already downloaded..
getting batch 1
getting batch 2
getting batch 3
getting batch 4
getting batch 5


In [3]:
xtrain.shape, ytrain.shape, xtest.shape, ytest.shape

((50000, 3, 32, 32), (50000,), (10000, 3, 32, 32), (10000,))

## 1. Spectral Pooling Layer

In [4]:
def _forward_spectral_pool(images, filter_size):
    assert (torch.ge(filter_size, 3)).all()
    assert images.size()[-1] == images.size()[-2] and images.size()[-1] >= 3
    
    if int(filter_size) % 2 == 1:
        n = int((filter_size - 1)/2)
        top_left = images[:, :, :n+1, :n+1]
        top_right = images[:, :, :n+1, -n:]
        bottom_left = images[:, :, -n:, :n+1]
        bottom_right = images[:, :, -n:, -n:]
        top_combined = torch.cat([top_left, top_right], dim=-1)
        bottom_combined = torch.cat([bottom_left, bottom_right], dim=-1)
        all_together = torch.cat([top_combined, bottom_combined], dim=-2)
    
    else:
        n = int(filter_size / 2)
        top_left = images[:, :, :n, :n]
        top_middle = torch.unsqueeze(0.5**0.5 * (images[:, :, :n, n] + images[:, :, :n, -n]), -1)
        top_right = images[:, :, :n, -(n-1):]
        middle_left = torch.unsqueeze(0.5**0.5 * (images[:, :, n, :n] + images[:, :, -n, :n]), -2)
        middle_middle = torch.unsqueeze(torch.unsqueeze(0.5 * 
                                    (images[:, :, n, n] + images[:, :, n, -n] + images[:, :, -n, n] + images[:, :, -n, -n]), 
                                    -1), -1)
        middle_right = torch.unsqueeze(0.5**0.5 * (images[:, :, n, -(n-1):] + images[:, :, -n, -(n-1):]), -2)
        bottom_left = images[:, :, -(n-1):, :n]
        bottom_middle = torch.unsqueeze(0.5 ** 0.5 * (images[:, :, -(n-1):, n] + images[:, :, -(n-1):, -n]), -1)
        bottom_right = images[:, :, -(n-1):, -(n-1):]
        top_combined = torch.cat([top_left, top_middle, top_right], dim=-1)
        middle_combined = torch.cat([middle_left, middle_middle, middle_right], dim=-1)
        bottom_combined = torch.cat([bottom_left, bottom_middle, bottom_right], dim=-1)
        all_together = torch.cat([top_combined, middle_combined, bottom_combined], dim=-2)
        
    return all_together
    

class SpectralPool(Module):
    def __init__(self, filter_size):
        super(SpectralPool, self).__init__()
        self.filter_size = torch.IntTensor(1).fill_(filter_size)
        self.fft = fft.Fft2d()
        self.ifft = fft.Ifft2d()
        
    def forward(self, input):
        in_re, in_im = self.fft(input, torch.zeros_like(input).cuda())
        trans_re = _forward_spectral_pool(in_re, self.filter_size)
        trans_im = _forward_spectral_pool(in_im, self.filter_size)
        out_re, out_im = self.ifft(trans_re, trans_im)
        
        return out_re

## 2. Convolutional Layer with Spectral Parameters

In [5]:
class SpectralParam(Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(SpectralParam, self).__init__()
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        
        self.ifft = fft.Ifft2d()
        
        weight = torch.Tensor(out_channels, in_channels, kernel_size, kernel_size).cuda()
        nn.init.xavier_uniform(weight)
        weight_re, weight_im = fft.fft2(weight, torch.zeros_like(weight).cuda())
        
        self.weight_re = nn.Parameter(weight_re, requires_grad=True)
        self.weight_im = nn.Parameter(weight_im, requires_grad=True)

        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels), requires_grad=True)
            nn.init.normal(self.bias)
        else:
            self.bias = None
            self.register_parameter('bias', None)
        
    def forward(self, input):
        weight, _ = self.ifft(self.weight_re, self.weight_im)
#         weight, _ = self.ifft(self.weight_re, torch.zeros_like(self.weight_re).cuda())
        result = F.conv2d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
        
        return result

## 3. Build Spectral CNN

In [6]:
class Net(nn.Module):
    def __init__(self, kernel_size):
        super(Net, self).__init__()
        self.conv1 = SpectralParam(3, 128, kernel_size, padding=(kernel_size-1)//2)
        self.pool1 = SpectralPool(filter_size=25)
        
        self.conv2 = SpectralParam(128, 160, kernel_size, padding=(kernel_size-1)//2)
        self.pool2 = SpectralPool(filter_size=19)
        
        self.conv3 = SpectralParam(160, 192, kernel_size, padding=(kernel_size-1)//2)
        self.pool3 = SpectralPool(filter_size=15)
        
        self.conv4 = SpectralParam(192, 224, kernel_size, padding=(kernel_size-1)//2)
        self.pool4 = SpectralPool(filter_size=11)
        
        self.conv5 = SpectralParam(224, 256, kernel_size, padding=(kernel_size-1)//2)
        self.pool5 = SpectralPool(filter_size=8)
        
        self.conv6 = SpectralParam(256, 288, kernel_size, padding=(kernel_size-1)//2)
        self.pool6 = SpectralPool(filter_size=4)
        
        self.conv7 = SpectralParam(288, 288, kernel_size=1, padding=0)
        self.conv8 = SpectralParam(288, 10, kernel_size=1, padding=0)
        
        self.avg = nn.AvgPool2d(4)
    
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        x = self.pool4(F.relu(self.conv4(x)))
        x = self.pool5(F.relu(self.conv5(x)))
        x = self.pool6(F.relu(self.conv6(x)))
        x = self.conv8(F.relu(self.conv7(x)))
        
        return torch.squeeze(self.avg(x))

## 4. Train and Test

In [7]:
kernel_size = 3
batch_size = 128
learning_rate = 1e-3
weight_decay = 1e-3
total_epoch = 100

if __name__ == '__main__':
    net = Net(kernel_size).cuda()
    
    best_val = 0
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.8)
    
    img_gen = ImageGenerator(xtrain[:-4096], ytrain[:-4096])
    val_gen = ImageGenerator(xtrain[-4096:], ytrain[-4096:])
    
    generator = img_gen.next_batch_gen(batch_size)
    val_generator = val_gen.next_batch_gen(batch_size)
    
    iters = int((xtrain.shape[0] - 4096) / batch_size)
    val_iters = int(4096 / batch_size)
    
    for epoch in range(total_epoch):
        scheduler.step()
        
        # train
        loss_iter = []
        acc_iter = []
        for itr in range(iters):
            
            X_batch, y_batch = next(generator)
            inputs = Variable(torch.Tensor(X_batch).cuda())
            labels = Variable(torch.LongTensor(y_batch).cuda())
            optimizer.zero_grad()
            
            outputs = net.forward(inputs)
            
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            _, predict = torch.max(outputs.data, 1)
            
            loss_iter.append(loss.data.cpu().numpy()[0])
            acc_iter.append(predict.eq(labels.data).cpu().sum())
        
        train_loss = np.mean(loss_iter)
        train_acc = np.sum(acc_iter) / (xtrain.shape[0] - 4096)
        
        # validation
        val_iter = []
        for itr in range(val_iters):
            X_batch, y_batch = next(val_generator)
            inputs = Variable(torch.Tensor(X_batch).cuda())
            labels = Variable(torch.LongTensor(y_batch).cuda())
            outputs = net.forward(inputs)
            
            _, predict = torch.max(outputs.data, 1)        

            val_iter.append(predict.eq(labels.data).cpu().sum())
        
        val_acc = np.sum(val_iter) / 4096
        
        if best_val < val_acc:
            best_val = val_acc
            torch.save(net.state_dict(), 'checkpoint.pth.tar')
        
        print('epoch: %d  train loss: %.3f  train acc: %.3f  val acc: %.3f  best val acc: %.3f' % 
              (epoch + 1, train_loss, train_acc, val_acc, best_val))
    
    # test the network
    testnet = Net(kernel_size).cuda()
    testnet.load_state_dict(torch.load('checkpoint.pth.tar'))
    
    test_gen = ImageGenerator(xtest, ytest)
    generator = test_gen.next_batch_gen(batch_size)
    iters = int(xtest.shape[0] / batch_size)
    test_iter = []
    for itr in range(iters):
        X_batch, y_batch = next(val_generator)
        inputs = Variable(torch.Tensor(X_batch).cuda())
        labels = Variable(torch.LongTensor(y_batch).cuda())
        outputs = testnet.forward(inputs)
            
        _, predict = torch.max(outputs.data, 1)        

        test_iter.append(predict.eq(labels.data).cpu().sum())
        
    test_acc = np.sum(test_iter) / xtest.shape[0]
        
    print('test acc: %.3f' % (test_acc))

epoch: 1  train loss: 57.371  train acc: 0.302  val acc: 0.400  best val acc: 0.400
epoch: 2  train loss: 1.438  train acc: 0.484  val acc: 0.520  best val acc: 0.520
epoch: 3  train loss: 1.215  train acc: 0.567  val acc: 0.615  best val acc: 0.615
epoch: 4  train loss: 1.057  train acc: 0.626  val acc: 0.624  best val acc: 0.624
epoch: 5  train loss: 0.945  train acc: 0.667  val acc: 0.643  best val acc: 0.643
epoch: 6  train loss: 0.844  train acc: 0.705  val acc: 0.664  best val acc: 0.664
epoch: 7  train loss: 0.763  train acc: 0.734  val acc: 0.689  best val acc: 0.689
epoch: 8  train loss: 0.712  train acc: 0.749  val acc: 0.695  best val acc: 0.695
epoch: 9  train loss: 0.657  train acc: 0.770  val acc: 0.721  best val acc: 0.721
epoch: 10  train loss: 0.604  train acc: 0.789  val acc: 0.730  best val acc: 0.730
epoch: 11  train loss: 0.564  train acc: 0.803  val acc: 0.715  best val acc: 0.730
epoch: 12  train loss: 0.533  train acc: 0.812  val acc: 0.693  best val acc: 0.730


MemoryError: 

it's too slow so I didn't continue