#### Imports

In [1]:
import torch
import torch.nn as nn 
import torch.optim as optim
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from torchsummary import summary

#### Model Class:

In [None]:
use_cuda = torch.cuda.is_available()
print('Use GPU?', use_cuda)

num_train = 50000


class TreeLayer(nn.Module):
  """Tree sampling layer"""
  def __init__(self, sizes,activation):
    super().__init__()

    self.sizes = sizes
    self.activation = activation

    weights = torch.empty(self.sizes) 
    bias = torch.empty(1,self.sizes[1],self.sizes[3],self.sizes[4])
    self.weights = nn.Parameter(weights)  
    self.bias = nn.Parameter(bias)  
    
    if self.activation == "relu":
      nn.init.kaiming_normal_(self.weights)
    elif self.activation == "sigmoid":
      nn.init.normal_(self.weights,mean=0.0,std=1.0)
    nn.init.kaiming_normal_(self.bias)

  def forward(self, x):
    w_times_x = torch.mul(x, self.weights)
    w_times_x = torch.sum(w_times_x,dim=[2,5,6,7])
    w_times_x = torch.add(w_times_x, self.bias)
    return w_times_x

class model(nn.Module):
    def __init__(self):
      super().__init__()

      self.num_groups = 3
      self.num_filters_conv1 = 15
      self.num_filters_conv2 = 16
      self.activation = "sigmoid"
  

      self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.num_filters_conv1*self.num_groups, kernel_size=5, groups=self.num_groups)
      self.tree1 = TreeLayer((1,self.num_filters_conv2,self.num_filters_conv1,self.num_groups,7,7,2,2),self.activation) 
      self.fc1  =  nn.Linear(7*self.num_filters_conv2*3, 10 )    

      # Initializations
      if self.activation == "relu":
        nn.init.kaiming_normal_(self.conv1.weight)
      elif self.activation == "sigmoid":
        nn.init.normal_(self.conv1.weight,mean=0.0,std=1.0) 
      
      if self.activation == "relu":
        nn.init.kaiming_normal_(self.fc1.weight)
      elif self.activation == "sigmoid":
        nn.init.normal_(self.fc1.weight,mean=0.0,std=1.0)
     


    def forward(self, x):

      if self.activation == "relu":
        out = F.relu(self.conv1(x))
      elif self.activation == "sigmoid":
        out = torch.sigmoid(self.conv1(x))
      
      out = F.max_pool2d(out, 2)
      out = F.unfold(out,(2,2),stride=2)
      
    
      out = out.reshape(-1,1*self.num_filters_conv1*self.num_groups,2*2,7*7).transpose(2,3)
      out = out.reshape(-1,self.num_groups,1,self.num_filters_conv1,7,7,2,2).transpose(1,2).transpose(2,3)
      if self.activation == "relu":
        out = F.relu(self.tree1(out))
      elif self.activation == "sigmoid":
        out = torch.sigmoid(self.tree1(out))
      out = out.reshape(out.size(0),7*self.num_filters_conv2*3)
      out =self.fc1(out)

      return out

my_model = model()

if use_cuda:
  my_model = my_model.cuda()  # transfer model to GPU


summary(my_model,(3,32,32))
num_epochs = 200
minibatch_size = 100 
criterion = nn.CrossEntropyLoss() 



#### Data preprocessing:

In [None]:
# CIFAR10 dataset

#Augmentaion
CIFAR10_transform = torchvision.transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),
                                                    transforms.RandomAffine(degrees=0, translate=(0.125, 0.125),fill=0),
                                                    transforms.ToTensor()])

train_dataset = torchvision.datasets.CIFAR10(root='../../data', 
                                           train=True, 
                                           transform=CIFAR10_transform,
                                           download=True)

test_dataset = torchvision.datasets.CIFAR10(root='../../data', 
                                          train=False,
                                          transform=transforms.ToTensor())
 
train_dataset, valid_dataset = random_split(train_dataset,[40000,10000])

# Data loader
trainloader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                        batch_size=minibatch_size,
                                        shuffle = True)

valloader = torch.utils.data.DataLoader(dataset=valid_dataset, 
                                        batch_size=minibatch_size,
                                        shuffle = False)

testloader = torch.utils.data.DataLoader(dataset=test_dataset,
                                        batch_size=minibatch_size, 
                                        shuffle=False)

### main

In [None]:
my_model = model()
if use_cuda:
  my_model = my_model.cuda()  

optimizer = optim.SGD(my_model.parameters(), lr=0.075, momentum=0.965, weight_decay = 0.00005, nesterov=True)
val_accuracy = torch.zeros(num_epochs)

for epoch in range(num_epochs): 
  if epoch%20 == 0:
    optimizer.param_groups[0]['lr'] *= 0.8 
  for i, (images, labels) in enumerate(trainloader):

    if use_cuda:
      images = images.cuda()
      labels = labels.cuda()
    
    #feedforward
    output = my_model(images) 
    optimizer.zero_grad()
    loss= criterion(output, labels)
   
    #backpropragation
    loss.backward()  
    #update the weights/parameters
    optimizer.step()

   
  # Train accuracy
  total = 0
  correct = 0

  for _, (images, labels) in enumerate(valloader):
    if use_cuda:
      images = images.cuda()
      labels = labels.cuda() 

    with torch.no_grad():         
      outputs = my_model(images)

    _, predicted = torch.max(outputs, 1) 
    correct += (predicted == labels).sum()
    total += labels.size(0)
  val_accuracy[epoch] = float(correct)/total
  print('Epoch {}: Val accuracy {:.4f}' .format(epoch,val_accuracy[epoch]))

   
  

#### Test Accuracy

In [None]:
# Test accuracy
total = 0
correct = 0

for i, (images, labels) in enumerate(testloader):
  if use_cuda:
    images = images.cuda()
    labels = labels.cuda() 

  with torch.no_grad():         
    outputs = my_model(images)

  _, predicted = torch.max(outputs, 1) 
  correct += (predicted == labels).sum()
  total += labels.size(0)
test_accuracy = float(correct)/total
print('Test accuracy {}' .format(test_accuracy))
