# Basic CNN

Usually, using basic CNN to process simple images is sufficient to retrieve a good result. However, because the data we are dealing with are vastly invariant, i.e., only the small cylinder's position changes while the rest of the image remains the same, it is almost impossible for the model to learn any features from the data input. This is also verified by our effort of Autoencoder. 

Therefore, we use transfer learning to avoid this issue. The CNN layers are first trained on [cifar-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). Then, the model is trained on actual dataset. 

Author: [Xiyan Su](mailto:tim.su@tum.de)

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
from src.network import Model, PretrainedModel

%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [None]:
logs_base_dir = "./runs"
%tensorboard --logdir {logs_base_dir} --port 6006

In [None]:
#Defining hyperparameters
params = {
    "hidden_channel": 20, #Number of hidden channels in CNN layer
    "hidden_layer": 100,  #Number of hidden layers in FC layer
    "lr": 1e-4,           #Initial learning rate
    "batch_size": 8       #Batch size
}

In [None]:
#Using cuda if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
#Loading cifer-10 dataset for pretraining CNN layers
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)

cifer_training_set = torchvision.datasets.CIFAR10(root="../data",
                                                  train=True,
                                                  download = True, 
                                                  transform=transform
                                                 )
cifer_trainloader = torch.utils.data.DataLoader(cifer_training_set, batch_size=params["batch_size"],
                                          shuffle=True, num_workers=0, drop_last=True)

cifer_test_set = torchvision.datasets.CIFAR10(root='../data', train=False,
                                       download=True, transform=transform)
cifer_testloader = torch.utils.data.DataLoader(cifer_test_set, batch_size=params["batch_size"],
                                         shuffle=False, num_workers=0, drop_last=True)

In [None]:
#Initializing the model instance
# model = Model(params)
model = PretrainedModel(params)

In [None]:
#Initializing Tensorboard for pretraining
writer = SummaryWriter('logs/pretrained_model')

In [None]:
#Loading pretrained model if already trained
# pretrained_params = torch.load('./saved_models/pretrained_model_params', map_location=torch.device(device))
# model.load_state_dict(pretrained_params)

In [None]:
#Pre-training the model using cifar-10 dataset
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
loss = nn.CrossEntropyLoss()

for epoch in range(1, 100):
    model.train()
    epoch_loss = 0
    for idx, data in enumerate(cifer_trainloader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        step_loss = loss(outputs, labels)
        epoch_loss += step_loss
        step_loss.backward()
        optimizer.step()
    writer.add_scalar('Pretrained_training_loss/Epoch', epoch_loss, epoch)
    
    with torch.no_grad():
        model.eval()
        val_loss = 0
        for idx, data in enumerate(cifer_testloader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            step_loss = loss(outputs, labels)
            val_loss += step_loss
        writer.add_scalar('Pretrained_validation_loss/Epoch', val_loss, epoch)
    
    if epoch==1:
        best_loss = val_loss
        best_epoch = epoch
    elif val_loss<best_loss:
            best_loss = val_loss
            best_epoch = epoch
            torch.save(model.state_dict(), "./saved_models/new/pretrained_model_params")
    else:
        pass
    
    #Early stop while overfitting
    if epoch>(best_epoch+5):
        print("Reached post-best limit, Training ended.")
        print(f'Best loss is {best_loss}')
        break
    print(f'[Epoch {epoch:3d}, training loss: {epoch_loss:.10f}, validation loss: {val_loss:.10f}]')

In [None]:
#Matching state dict
pretrained_model = model
pretrained_model_params = torch.load("./saved_models/pretrained_model_params", map_location=torch.device(device))
model = Model(params)
pretrained_model_params["model.7.weight"] = model.state_dict()["model.7.weight"]
pretrained_model_params["model.7.bias"] = model.state_dict()["model.7.bias"]
pretrained_model_params["model.9.weight"] = model.state_dict()["model.9.weight"]
pretrained_model_params["model.9.bias"] = model.state_dict()["model.9.bias"]
model.load_state_dict(pretrained_model_params)

In [None]:
#Loading the dataset from .pt files and dividing into training and validation sets
training_dataset_normalized = torch.load('../data/training_data.pt')[0:1750]
val_dataset_normalized = torch.load('../data/training_data.pt')[1750:2000]

#Normalizing the datasets manually
for idx, [img, _] in enumerate(training_dataset_normalized):
    training_dataset_normalized[idx][0] = (F.normalize(img.type(torch.float), dim=2) - 0.5) * 2
for idx, [img, _] in enumerate(val_dataset_normalized):
    val_dataset_normalized[idx][0] = (F.normalize(img.type(torch.float), dim=2) - 0.5) * 2

#Load data into pytorch DataLoader
training_dataloader = DataLoader(training_dataset_normalized, batch_size=params["batch_size"], drop_last=True)
val_dataloader = DataLoader(val_dataset_normalized, batch_size=params["batch_size"], drop_last=True)

In [None]:
#Initializing Tensorboard for acutal training
writer = SummaryWriter('logs/cnn_model')

In [None]:
#Training
#Loading weights from the saved model
# model.load_state_dict(torch.load("saved_models/cnn_model"), map_location=torch.device(device))
model.to(device)

#Defining the optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
loss = torch.nn.MSELoss()


for epoch in range(1, 3000):
    model.train()
    epoch_loss = 0
    
    for batch_idx, data in enumerate(training_dataloader):
        img = data[0].reshape((8, -1 ,120, 120)).type(torch.cuda.FloatTensor)
        pos = 10*torch.stack(data[1][0:2], dim=1).type(torch.cuda.FloatTensor)
        img, pos = img.to(device), pos.to(device)
        optimizer.zero_grad()
        output = model(img)
        output = output.to(device)
        step_loss = loss(pos, output)
        epoch_loss += step_loss.item()
        step_loss.backward()
        optimizer.step()
        
    with torch.no_grad():
        model.eval()
        val_loss = 0
        for batch_idx, data in enumerate(val_dataloader):
            img = data[0].reshape((8, -1 ,120, 120)).type(torch.cuda.FloatTensor)
            pos = 10*torch.stack(data[1][0:2], dim=1).type(torch.cuda.FloatTensor)
            img, pos = img.to(device), pos.to(device)
            optimizer.zero_grad()
            output = model(img)
            output = output.to(device)
            step_loss = loss(pos, output)
            val_loss += step_loss.item()

        writer.add_scalar("Epoch_validation_loss/train", val_loss, epoch)
        
    #Save the best model
    if epoch==1:
        best_loss = val_loss
        best_epoch = epoch
    elif val_loss<best_loss:
            best_loss = val_loss
            best_epoch = epoch
            torch.save(model.state_dict(), "saved_models/new/cnn_model")
    else:
        pass
    
    #Early stop while overfitting
    if epoch>(best_epoch+20):
        print("Reached post-best limit, Training ended.")
        print(f'Best loss is {best_loss}')
        break
    
    print(f'[Epoch {epoch:4d}, training loss: {epoch_loss/7:.10f}, validation loss: {val_loss:.10f}]')      

In [None]:
#if trained, re-load the model here.
model = Model(params)
model.to(device)
model.load_state_dict(torch.load("saved_models/cnn_model", map_location=torch.device(device)))

#datasets used for visualization
training_dataset = torch.load('../data/training_data.pt', map_location=torch.device(device))[0:1750]
val_dataset = torch.load('../data/training_data.pt', map_location=torch.device(device))[1750:2000]

In [None]:
#Helper functions to retrieve pixel-wise coordinates
def x_pixel(cylinder_y):
    return ((0.48 - cylinder_y) * 120 / (0.48+0.48))

def y_pixel(cylinder_x):
    return ((0.48 - cylinder_x) * 120 / (0.48+0.48))

#Visualizing the results
def plot_cylinder(model):
    imgs = []
    poses = []
    outputs = []
    fig = plt.figure(figsize=(80, 40))
    for idx, data in enumerate(val_dataset):
        if idx==8:
            break
        elif idx>=0:
            img = data[0]
            img_np = img.numpy()
            img = (F.normalize(img.type(torch.float), dim=2) - 0.5) * 2
            img.to(device)
            pos = data[1]
            pixel = (int(np.around(x_pixel(pos[1]))), int(np.around(y_pixel(pos[0]))))
            output = model(img.reshape((1, -1 ,120, 120)).type(torch.float)).cpu() / 10
            output_np = output.detach().numpy()
            pixel = (int(np.around(x_pixel(output_np[0][1]))), int(np.around(y_pixel(output_np[0][0]))))
            fig.add_subplot(2, 4, idx+1)
            plt.imshow(img_np)
            plt.scatter(pixel[0], pixel[1], linewidths=10, s=4000, facecolors='none', edgecolors='red')
    plt.show()

In [None]:
plot_cylinder(model)