<a href="https://colab.research.google.com/github/ajaykurani/DataAugmentation/blob/master/Feature_Space_Aug.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torchvision
from torch import nn, optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from scipy import ndimage
import seaborn as sns
import cv2
import math
import numbers
import random
from torchvision.utils import save_image
from torchsummary import summary
from google.colab import files

In [None]:
n_epochs = 20
batch_size_train = 64
batch_size_test = 1000
learning_rate = 1e-4
log_interval = 50

In [None]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.5,), (0.5,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.5,), (0.5,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

test_data = torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.5,), (0.5,))
                             ]))

inp, _ = test_data[0]
inp = inp.unsqueeze(0)
inp = inp.cpu().detach().numpy()[0][0]

In [None]:
'''
Information on interpolation here: https://medium.com/@wenrudong/what-is-opencvs-inter-area-actually-doing-282a626a09b3
'''

def dot(inp, x):
  image = cv2.resize(inp, (x[0][0].shape[1], x[0][0].shape[0]), 
                     interpolation = cv2.INTER_AREA)
  sum = np.zeros(x[0][0].shape)
  for b in range(x.shape[0]):
    for f in range(x.shape[1]):
      sum += image.dot(x[b][f])
  return sum

class Net(nn.Module):
    '''
    Adapted from https://github.com/sksq96/pytorch-summary.
    More information on layers available here: https://towardsdatascience.com/pytorch-layer-dimensions-what-sizes-should-they-be-and-why-4265a41e01fd
    '''
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=1)
        self.fc1 = nn.Linear(4608, 100)
        self.fc2 = nn.Linear(100, 10)

    def forward(self, x):
        '''
        For the first layer (add more layers based on results), 
        (1) rotate by 15 or 30 deg increments through 90 to get a dictionary of feature spaces
        (2) for each of those rotated spaces, take a dot product of the input image and the current layer 
            to figure out which gives the maximal activation or highest correlation
        (3) pass that maximally correlated feature space to the next layer
        '''
        c1 = []
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        xn = x.cpu().detach().numpy()
        for i in range(1,7):
          deg = 15 * i
          xp = x.detach().clone().cpu().numpy()
          for b in range(xp.shape[0]):
            for f in range(xp.shape[1]):
              xp[b][f] = ndimage.rotate(xp[b][f], deg, reshape=False)
          c1.append(xp)
        for xp in c1:
          arr = np.greater_equal(dot(inp, xp), dot(inp, xn))
          count = 0
          tot = 0
          for r in range(arr.shape[0]):
            for c in range(arr.shape[1]):
              tot += 1
              if arr[r][c]:
                count += 1
          if count > tot // 2:
            x = torch.from_numpy(xp).cuda()
            xn = x.cpu().detach().numpy()
       
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    def loss_function(self, out, target):
        return F.cross_entropy(out, target)

In [None]:
def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

network = Net()
network.apply(init_weights)
network.cuda()

optimizer = optim.Adam(network.parameters(), lr=1e-4)

summary(network, (1, 28, 28))

In [None]:
def train(epoch):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data = data.cuda()
    target = target.cuda()
    optimizer.zero_grad()
    output = network(data)
    loss = network.loss_function(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

In [None]:
def test():
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data = data.cuda()
      target = target.cuda()
      target = target.view(batch_size_test)
      output = network(data)
      test_loss += network.loss_function(output, target).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [None]:
test()
for epoch in range(1, n_epochs + 1):
  train(epoch)
  test()

In [None]:
torch.save(network.state_dict(), 'network_fsaug.pth')

In [None]:
for data, target in test_loader:
  im = data[0]
  im = torch.squeeze(im)
  plt.imshow(im.numpy())
  plt.show()
  data = data.cuda()
  target = target.cuda()
  output = network(data)
  pred = output.data.max(1, keepdim=True)[1]
  print("Prediction: " + str(pred[0].item()))

In [None]:
'''
Adapted from https://towardsdatascience.com/visualizing-convolution-neural-networks-using-pytorch-3dfa8443e74e.
More information on feature visualization: https://distill.pub/2017/feature-visualization/
'''

def imshow(img, title):
  std_correction = np.asarray([0.229, 0.224, 0.225]).reshape(3, 1, 1)
  mean_correction = np.asarray([0.485, 0.456, 0.406]).reshape(3, 1, 1)
  npimg = np.multiply(img.numpy(), std_correction) + mean_correction
  plt.figure(figsize = (batch_size * 4, 4))
  plt.axis("off")
  plt.imshow(np.transpose(npimg, (1, 2, 0)))
  plt.title(title)
  plt.show()

def plot_filters_single_channel_big(t):
    
    #setting the rows and columns
    nrows = t.shape[0]*t.shape[2]
    ncols = t.shape[1]*t.shape[3]
    
    
    npimg = np.array(t.numpy(), np.float32)
    npimg = npimg.transpose((0, 2, 1, 3))
    npimg = npimg.ravel().reshape(nrows, ncols)
    
    npimg = npimg.T
    
    fig, ax = plt.subplots(figsize=(ncols/10, nrows/200))    
    imgplot = sns.heatmap(npimg, xticklabels=False, yticklabels=False, cmap='gray', ax=ax, cbar=False)
  
def plot_filters_single_channel(t):
    
    #kernels depth * number of kernels
    nplots = t.shape[0]*t.shape[1]
    ncols = 12
    
    nrows = 1 + nplots//ncols
    #convert tensor to numpy image
    npimg = np.array(t.numpy(), np.float32)
    
    count = 0
    fig = plt.figure(figsize=(ncols, nrows))
    
    #looping through all the kernels in each channel
    for i in range(t.shape[0]):
        for j in range(t.shape[1]):
            count += 1
            ax1 = fig.add_subplot(nrows, ncols, count)
            npimg = np.array(t[i, j].numpy(), np.float32)
            npimg = (npimg - np.mean(npimg)) / np.std(npimg)
            npimg = np.minimum(1, np.maximum(0, (npimg + 0.5)))
            ax1.imshow(npimg)
            ax1.set_title(str(i))
            ax1.axis('off')
            ax1.set_xticklabels([])
            ax1.set_yticklabels([])
   
    plt.tight_layout()
    plt.show()

def plot_filters_multi_channel(t):
    
    #get the number of kernals
    num_kernels = t.shape[0]    
    
    #define number of columns for subplots
    num_cols = 12
    #rows = num of kernels
    num_rows = num_kernels
    
    #set the figure size
    fig = plt.figure(figsize=(num_cols,num_rows))
    
    #looping through all the kernels
    for i in range(t.shape[0]):
        ax1 = fig.add_subplot(num_rows,num_cols,i+1)
        
        #for each kernel, we convert the tensor to numpy 
        npimg = np.array(t[i].numpy(), np.float32)
        #standardize the numpy image
        npimg = (npimg - np.mean(npimg)) / np.std(npimg)
        npimg = np.minimum(1, np.maximum(0, (npimg + 0.5)))
        npimg = npimg.transpose((1, 2, 0))
        ax1.imshow(npimg)
        ax1.axis('off')
        ax1.set_title(str(i))
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        
    plt.savefig('myimage.png', dpi=100)    
    plt.tight_layout()
    plt.show()

def plot_weights(model, layer_num, single_channel = True, collated = False):
  #extracting the model features at the particular layer number
  if layer_num == 1:
    layer = model.conv1
  elif layer_num == 2:
    layer = model.conv2
  else:
    print("Convolutional layer not found")
    return
  
  #checking whether the layer is convolution layer or not 
  if isinstance(layer, nn.Conv2d):
    #getting the weight tensor data
    weight_tensor = layer.weight.data
    
    if single_channel:
      if collated:
        plot_filters_single_channel_big(weight_tensor)
      else:
        plot_filters_single_channel(weight_tensor)
        
    else:
      if weight_tensor.shape[1] == 3:
        plot_filters_multi_channel(weight_tensor)
      else:
        print("Can only plot weights with three channels with single channel = False")
        
  else:
    print("Can only visualize layers which are convolutional")

### **Convolutional Filter Visualizations**

In [None]:
plot_weights(network.cpu(), 1, single_channel = True)

### **Feature Maps**

In [None]:
'''
Adapted from https://debuggercafe.com/visualizing-filters-and-feature-maps-in-convolutional-neural-networks-using-pytorch/.
'''

img = cv2.imread("/content/drive/My Drive/six.png", 0)
img = cv2.bitwise_not(img)

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
])
img = np.array(img)

img = transform(img)

plt.imshow(img.squeeze(0))
plt.show()

img = img.unsqueeze(0)

results = [F.relu(F.max_pool2d(network.conv1(img), 2))]
results.append(F.relu(F.max_pool2d(network.conv2(results[0]), 2)))

outputs = results

In [None]:
num_layer = 0
plt.figure(figsize=(10, 10))
layer_viz = outputs[num_layer][0, :, :, :]
layer_viz = layer_viz.data

filters = []

n = 64
sq = int(math.sqrt(n))

for i, filter in enumerate(layer_viz):
    plt.subplot(sq, sq, i + 1)
    filters.append(filter)
    plt.imshow(filter, cmap='gray')
    plt.axis("off")

plt.show()

In [None]:
num_layer = 1
plt.figure(figsize=(10, 10))
layer_viz = outputs[num_layer][0, :, :, :]
layer_viz = layer_viz.data

n = 256
sq = int(math.sqrt(n))

for i, filter in enumerate(layer_viz):
    plt.subplot(8, sq, i + 1)
    plt.imshow(filter, cmap='gray')
    plt.axis("off")

plt.show()

### **Maximum Activation Visualizations**


In [None]:
'''
Adapted from https://discuss.pytorch.org/t/visualize-feature-map/29597.
'''

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

dataset = torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.5,), (0.5,))
                             ]))

In [None]:
network.conv1.register_forward_hook(get_activation('conv1'))
data, _ = dataset[0]
data.unsqueeze_(0)
output = network(data)

n = 64
sq = int(math.sqrt(n))

act = activation['conv1'].squeeze()

fig, axarr = plt.subplots(sq, sq, figsize=(10,10))
for idx in range(act.size(0)):
    axarr[idx // sq][idx % sq].axis('off')
    axarr[idx // sq][idx % sq].imshow(act[idx], cmap='viridis')

In [None]:
network.conv2.register_forward_hook(get_activation('conv2'))
data, _ = dataset[0]

data.unsqueeze_(0)
output = network(data)

act = activation['conv2'].squeeze()

n = 128

fig, axarr = plt.subplots(8, 16, figsize=(10,10))
for idx in range(act.size(0)):
    axarr[idx // 16][idx % 16].axis('off')
    axarr[idx // 16][idx % 16].imshow(act[idx], cmap='viridis')