In [1]:
import numpy as np
import pandas as pd
import torch
from fourier_conv import ConvFourier
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from utils import get_generator
from matplotlib import pyplot as plt

In [2]:
class Model(nn.Module):
    def __init__(self, fourierPrecision=6):
        super(Model, self).__init__()
#         self.layer1 = nn.Sequential(
#                                     nn.Conv2d(1, 20, kernel_size=5, stride=1, padding=4),
#                                     nn.Sigmoid(),
#                                     nn.MaxPool2d(kernel_size=3, stride=1)
#                                     )
        self.l1 = nn.Sequential(
                    ConvFourier(in_channels=1, out_channels=20, kernel_size=5,stride=1,padding=4,sequence_length=fourierPrecision),
                                    nn.Sigmoid(),
                                    nn.MaxPool2d(kernel_size=3, stride=1)
                )
#         self.layer2 = nn.Sequential(
#                                     nn.Conv2d(20, 15, kernel_size=3, stride=1, padding=1),
#                                     nn.Sigmoid(),
#                                     nn.MaxPool2d(kernel_size=4, stride=2)
#                                     )
        self.l2 = nn.Sequential(
                        ConvFourier(in_channels=20, out_channels=15, kernel_size=3, stride=1,padding=1,sequence_length=fourierPrecision),
                        nn.Sigmoid(),
                        nn.MaxPool2d(kernel_size=4, stride=2)
                )
#         self.layer3 = nn.Sequential(
#                                     nn.Conv2d(15,10,kernel_size=3,stride=1,padding=1),
#                                     nn.Sigmoid(),
#                                     nn.MaxPool2d(kernel_size=4,stride=2)
#                                     )
        self.l3 = nn.Sequential(
                        ConvFourier(in_channels=15, out_channels=10, kernel_size=3, stride=1, padding=1,sequence_length=fourierPrecision),
                        nn.Sigmoid(),
                        nn.MaxPool2d(kernel_size=4, stride=2)
                )
#         self.layer4 = nn.Sequential(
#                                     nn.Conv2d(10,5,kernel_size=3,stride=1,padding=1),
#                                     nn.Sigmoid(),
#                                     nn.MaxPool2d(kernel_size=2,stride=2)
#                                     )
        self.l4 = nn.Sequential(
                        ConvFourier(in_channels=10, out_channels=5, kernel_size=3, stride=1, padding=1,sequence_length=fourierPrecision),
                        nn.Sigmoid(),
                        nn.MaxPool2d(kernel_size=2, stride=2)
                )
        self.fc1 = nn.Linear(3*3*5,10)
        self.fc5 = nn.LogSigmoid()
    
    def forward(self, img):
        img=self.l1(img)
        img=self.l2(img)
        img=self.l3(img)
        img=self.l4(img)
        img = img.reshape(img.size(0), -1)
        img = self.fc1(img)
        img = self.fc5(img)
        return img


class ConvNet1(nn.Module):
    def __init__(self):
        super(ConvNet1, self).__init__()
        self.layer1 = nn.Sequential(
                                    nn.Conv2d(1, 20, kernel_size=5, stride=1, padding=4),
                                    nn.Sigmoid(),
                                    nn.MaxPool2d(kernel_size=3, stride=1)
                                    )
        self.layer2 = nn.Sequential(
                                    nn.Conv2d(20, 15, kernel_size=3, stride=1, padding=1),
                                    nn.Sigmoid(),
                                    nn.MaxPool2d(kernel_size=4, stride=2)
                                    )
        self.layer3 = nn.Sequential(
                                    nn.Conv2d(15,10,kernel_size=3,stride=1,padding=1),
                                    nn.Sigmoid(),
                                    nn.MaxPool2d(kernel_size=4,stride=2)
                                    )
        self.layer4 = nn.Sequential(
                                    nn.Conv2d(10,5,kernel_size=3,stride=1,padding=1),
                                    nn.Sigmoid(),
                                    nn.MaxPool2d(kernel_size=2,stride=2)
                                    )
        self.fc1 = nn.Linear(3*3*5,10)
#         self.fc2 = nn.Linear(20,10)
        self.fc5 = nn.LogSigmoid()
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc1(out)
        out = self.fc5(out)
        return out
    

In [3]:
## Hyper-Parameters
batch_size=2
workers=8
train_size=10000
test_size=50
lr = 0.003
wd = 0.000001
epochs=100

## Data Generators
train_gen, test_gen = get_generator(batch_size=batch_size, num_workers=workers, train_size=train_size, test_size=test_size)

## Fourier Conv Model
model=Model()
criterion = nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(), lr=lr)

## Conv Model
# convModel=ConvNet1()
# convCriterion = nn.CrossEntropyLoss()
# convOptimizer=torch.optim.Adam(convModel.parameters(), lr=lr)


In [4]:
fixImg, fixLabel=None, None
for i,l in train_gen:
    fixImg=torch.unsqueeze(i.float(), 1)
    fixLabel=l
    break

for epoch in range(0, epochs):
    
    for i, data in enumerate(train_gen):
        ## Getting the data and converting to required format of model
        img, labels=data
        img=torch.unsqueeze(img.float(), 1)

        # Training for Fourier Conv Model
        preds=model(img)        
        loss=criterion(preds, labels)
        if(i%3==0):
            print(epoch," "+str(i)+" ", loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

#         ## Traiing for Conv Model
#         convPreds=convModel(img)
#         convLoss=convCriterion(convPreds, labels)
#         if(i%100==0):
#            print(epoch ," "+ str(i)+" ",convLoss.item())
#         convOptimizer.zero_grad()
#         convLoss.backward()
#         convOptimizer.step()

  out_tensor = torch.tensor(torch.zeros((batch_size,out_channels,out_height, out_width), dtype=torch.float32))
  out_tensor = torch.tensor(torch.zeros(input.shape), dtype= torch.float32)


KeyboardInterrupt: 