## Import Libraries and Instiate Deconv Module

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.fft import fft2, ifft2
import torchvision
import torchvision.transforms as T
from torchvision import io
from torch.utils.data import Dataset, DataLoader
from torchvision import models

import os
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [12]:
class Deconv2D(nn.Module):
    def __init__(self, shape=(2, 4)):
        super(Deconv2D, self).__init__()
        self.w_flat = nn.Parameter(data=torch.zeros(shape[0]*shape[1]-1),
                                   requires_grad=True)
        self.h_shape = shape

    def forward(self, x):
        w = nn.functional.pad(self.w_flat, (1, 0), value=1)

        w = torch.reshape(w, self.h_shape)
        hm1 = nn.functional.pad(w, (0, x.size(-1)-w.size(-1), 0, x.size(-2)-w.size(-2)))

        gm1f = 1/fft2(hm1)

        gm2f = torch.flip(gm1f, (0,))
        gm2f = torch.roll(gm1f, shifts=1, dims=0)

        gm3f = torch.flip(gm1f, (1,))
        gm3f = torch.roll(gm1f, shifts=1, dims=1)

        gm4f = torch.flip(gm1f, (0, 1))
        gm4f = torch.roll(gm1f, shifts=(1, 1), dims=(0,1))

        gmf = gm1f*gm2f*gm3f*gm4f

        ymf = gmf*fft2(x)

        y = ifft2(ymf).real

        return y

## Load CIFAR Dataset

In [13]:
from torchvision import datasets

transform = T.Compose([T.ToTensor(),
                       T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

cifar_data = datasets.CIFAR100("/content/", download=True, transform=transform)
data_len = len(cifar_data)

Files already downloaded and verified


In [14]:
dataloader = DataLoader(cifar_data, batch_size=32, shuffle=True)

## Create training function

In [15]:
def train_model(model, criterion, optimizer, dataloader, data_len, num_epochs=3):
    history = {'loss':[], 'accuracy':[]}
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1:2d}/{num_epochs}')

        model.train()
        running_loss = 0.0
        running_correct = 0

        for X, labels in tqdm(dataloader):

            X = X.to(device)
            labels = labels.to(device)

            with torch.set_grad_enabled(True):
                outputs = model(X)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            running_loss += loss
            running_correct += torch.sum(preds==labels.data)
        
        epoch_loss = running_loss
        epoch_acc = running_correct/data_len

        history['loss'].append(epoch_loss.item())
        history['accuracy'].append(epoch_acc.item())

        print('Loss: {:.4f}, Acc: {:.3f}'.format(epoch_loss, epoch_acc))
    return history

In [16]:
class MyResNet18(nn.Module):
    """
    This module serves as a wrapper for ResNet18 PyTorch implementation.
    This allows you to create a ResNet18 model with the correct number of 
    output features quickly!
    """
    def __init__(self,  first_layer="deconv", num_ftrs=100, channels=3):
        super(MyResNet18, self).__init__()
        if first_layer == "deconv":
            self.deconv = Deconv2D((4, 4))
        else:
            self.deconv = nn.Identity()

        model = models.resnet18(weights='DEFAULT')
        # set fully-connected layer to correct number of features
        res_ftrs = model.fc.in_features
        model.fc = nn.Linear(res_ftrs, num_ftrs)   

        # if channels is not equal to 3 (ie RGB) then change input layer   
        if channels != 3:
            model.conv1 = torch.nn.Conv2d(channels, 64, kernel_size=(7, 7),
                                          stride=(2, 2), padding=(3, 3),
                                          bias=False)
        self.resnet18 = model

    def forward(self, x):
        x = self.deconv(x)
        x = self.resnet18(x)
        return x

## Train model with no deconv layer

In [17]:
model = MyResNet18("deconv")

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

history = train_model(model, criterion, optimizer, dataloader, data_len, num_epochs=10)
print(model.deconv.w_flat)

Epoch  1/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 4416.0483, Acc: 0.308
Epoch  2/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 2998.2625, Acc: 0.483
Epoch  3/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 2422.5239, Acc: 0.569
Epoch  4/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 2019.8202, Acc: 0.633
Epoch  5/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 1672.0231, Acc: 0.687
Epoch  6/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 1399.4591, Acc: 0.736
Epoch  7/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 1157.9497, Acc: 0.778
Epoch  8/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 966.1988, Acc: 0.812
Epoch  9/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 812.7243, Acc: 0.838
Epoch 10/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 679.0031, Acc: 0.863
Parameter containing:
tensor([ 0.0543,  0.0032,  0.0039,  0.0482,  0.0101, -0.0025,  0.0019,  0.0040,
        -0.0079, -0.0019,  0.0020,  0.0028,  0.0035, -0.0014,  0.0043],
       device='cuda:0', requires_grad=True)


## Train model with deconv layer

In [18]:
model = MyResNet18("conv")

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

history_replace = train_model(model, criterion, optimizer, dataloader, data_len, num_epochs=10)

Epoch  1/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 4403.5703, Acc: 0.309
Epoch  2/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 3011.6411, Acc: 0.481
Epoch  3/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 2435.4004, Acc: 0.567
Epoch  4/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 2029.2737, Acc: 0.632
Epoch  5/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 1687.5574, Acc: 0.687
Epoch  6/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 1401.1614, Acc: 0.736
Epoch  7/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 1156.7188, Acc: 0.775
Epoch  8/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 962.8721, Acc: 0.812
Epoch  9/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 816.2945, Acc: 0.837
Epoch 10/10


  0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 689.6357, Acc: 0.862
