In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 5GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch 
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import matplotlib.pylab as plt
import numpy as np

In [None]:
def show_data(data_sample):
    plt.imshow(data_sample[0].numpy().reshape(IMAGE_SIZE, IMAGE_SIZE), cmap='gray')
    plt.title('y = '+ str(data_sample[1].item()))
    

In [None]:
def plot_channels(W):
    n_out = W.shape[0]
    n_in = W.shape[1]
    w_min = W.min().item()
    w_max = W.max().item()
    fig, axes = plt.subplots(n_out, n_in)
    fig.subplots_adjust(hspace=0.1)
    out_index = 0
    in_index = 0
    
    #plot outputs as rows inputs as columns 
    for ax in axes.flat:
        if in_index > n_in-1:
            out_index = out_index + 1
            in_index = 0
        ax.imshow(W[out_index, in_index, :, :], vmin=w_min, vmax=w_max, cmap='seismic')
        ax.set_yticklabels([])
        ax.set_xticklabels([])
        in_index = in_index + 1

    plt.show()

In [None]:
def plot_parameters(W, number_rows=1, name="", i=0):
    W = W.data[:, i, :, :]
    n_filters = W.shape[0]
    w_min = W.min().item()
    w_max = W.max().item()
    fig, axes = plt.subplots(number_rows, n_filters // number_rows)
    fig.subplots_adjust(hspace=0.4)

    for i, ax in enumerate(axes.flat):
        if i < n_filters:
            # Set the label for the sub-plot.
            ax.set_xlabel("kernel:{0}".format(i + 1))

            # Plot the image.
            ax.imshow(W[i, :], vmin=w_min, vmax=w_max, cmap='seismic')
            ax.set_xticks([])
            ax.set_yticks([])
    plt.suptitle(name, fontsize=10)    
    plt.show()

In [None]:

IMAGE_SIZE = 16


composed = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor()])

In [None]:
train_dataset = dsets.FashionMNIST(root='./projectdata', train=True, transform=composed, download=True)

In [None]:
validation_dataset = dsets.FashionMNIST(root='./projectdata', train=False, download=True, transform=composed)

In [None]:
show_data(train_dataset[0])

In [None]:
show_data(train_dataset[1])

In [None]:
class CNN(nn.Module):
    
    def __init__(self, out_1=16, out_2=32, p=0.25):
        super(CNN, self).__init__()
        self.drop1 = nn.Dropout(p=p)
        self.drop2 = nn.Dropout(p=p)

        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=out_1, kernel_size=5, padding=2)
        self.conv1_bn = nn.BatchNorm2d(out_1)

        self.maxpool1=nn.MaxPool2d(kernel_size=2)
        
        self.cnn2 = nn.Conv2d(in_channels=out_1, out_channels=out_2, kernel_size=5, stride=1, padding=2)
        self.conv2_bn = nn.BatchNorm2d(out_2)

        self.maxpool2=nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(out_2 * 4 * 4, 10)
        self.fc_bn = nn.BatchNorm1d(10)
        
    def forward(self, x):
        x = self.cnn1(x)
        x = self.conv1_bn(x)
        x = self.maxpool1(x)
        x = self.cnn2(x)
        x = self.conv2_bn(x)
        x = self.maxpool2(x)
        x = self.drop2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc_bn(x)
        return x

In [None]:
def train_model(model,train_loader,validation_loader,optimizer,n_epochs=4):
    
    #global variable 
    N_test=len(validation_dataset)
    accuracy_list=[]
    loss_list=[]
    
    for epoch in range(n_epochs):
        temp_loss=0
        for x, y in train_loader:
            model.train()
            optimizer.zero_grad()
            z = model(x)
            loss = criterion(z, y)
            loss.backward()
            optimizer.step()
            temp_loss+=loss.item()
        loss_list.append(temp_loss)
        correct=0
        #perform a prediction on the validation  data  
        for x_test, y_test in validation_loader:
            model.eval()
            z = model(x_test)
            _, yhat = torch.max(z.data, 1)
            correct += (yhat == y_test).sum().item()
        accuracy = correct / N_test
        accuracy_list.append(accuracy)
        
        print(
                'Train Epoch: ', epoch, ' \n loss= ', temp_loss, ' accuracy= ', accuracy, '\n')
     
    return accuracy_list, loss_list

In [None]:
#model = CNN(out_1=16, out_2=32)
#absolute_accuracy_list=[]
#absolute_loss_list=[]

In [None]:
criterion = nn.CrossEntropyLoss()
learning_rate = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100)
validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=5000)

In [None]:
accuracy_list_normal, loss_list_normal = train_model(model=model,n_epochs=40,train_loader=train_loader,validation_loader=validation_loader,optimizer=optimizer)

In [None]:
xlim=len(absolute_loss_list)
fig, ax1 = plt.subplots()
axes = plt.gca()
axes.set_xlim([0,xlim])
axes.set_ylim([0,400])
color = 'tab:red'
ax1.plot(loss_list_normal, color=color)
ax1.set_xlabel('epoch', color=color)
ax1.set_ylabel('Cost', color=color)
ax1.tick_params(axis='y', color=color)

ax2 = ax1.twinx()  
color = 'tab:blue'
ax2.set_ylabel('accuracy', color=color) 
ax2.set_xlabel('epoch', color=color)
ax2.plot( accuracy_list_normal, color=color)
ax2.tick_params(axis='y', color=color)

In [None]:
torch.save(model.state_dict(), '/kaggle/working/model.h5')
absolute_accuracy_list=absolute_accuracy_list+accuracy_list_normal
absolute_loss_list=absolute_loss_list+loss_list_normal

In [None]:
loss_list_normal

In [None]:
len(absolute_accuracy_list)

In [None]:
len(absolute_loss_list)

In [None]:
xlim=len(absolute_loss_list)
fig, ax1 = plt.subplots()
axes = plt.gca()
axes.set_xlim([0,xlim])
axes.set_ylim([0,400])
color = 'tab:red'
ax1.plot(absolute_loss_list, color=color)
ax1.set_xlabel('epoch', color=color)
ax1.set_ylabel('Cost', color=color)
ax1.tick_params(axis='y', color=color)

ax2 = ax1.twinx()  
color = 'tab:blue'
ax2.set_ylabel('accuracy', color=color) 
ax2.set_xlabel('epoch', color=color)
ax2.plot( absolute_accuracy_list, color=color)
ax2.tick_params(axis='y', color=color)

