####File: TrainNetwork.ipynb
- This file will train a convolutional neural network according to the trainset
- The main function `trainNetwork()` is at the bottom. **Please ensure that all of code above `trainNetwork()` has been compiled and run before launching the main function.**
- Just compile and run sequentially.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as functional
import torchvision 
from torchvision.datasets import ImageFolder
import torch.optim as optim
import os
from xlwt import Workbook

In [None]:
class Network(nn.Module):
    def __init__(self):
        self.output_size = 12
        
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool = nn.MaxPool2d(2)
        self.fc = nn.Linear(400, 120)
        self.fc1 = nn.Linear(120, 84)
        self.fc2 = nn.Linear(84, self.output_size)
        
    def forward(self, x):
        x = self.pool(functional.relu(self.conv1(x)))
        x = self.pool(functional.relu(self.conv2(x)))
        x = x.view(-1, 400)
        x = functional.relu(self.fc(x))
        x = functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
def initExcelTable_train():
    wb = Workbook()
    sheet1 = wb.add_sheet('Sheet 1')
    sheet1.write(0, 0, 'epoch')
    sheet1.write(0, 1, 'loss')
    return wb, sheet1

def getTrainLoader(batchsize, trainsetdir):
    custom_transform = torchvision.transforms.Compose([
        # torchvision.transforms.RandomResizedCrop(size=32,scale=(0.5,1.0)), 
        # torchvision.transforms.RandomPerspective(distortion_scale=0.6,p=1.0),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5 ), (0.5))
        ])

    train = ImageFolder(root=trainsetdir, transform=custom_transform)
    trainloader = torch.utils.data.DataLoader(
        train,
        batch_size = batchsize,
        shuffle = True
    )
    # the second return value is number of images in trainset
    return trainloader, len(train)

In [None]:
#nn.Conv2d(ni,no,f,s) ni:number of input channels, no: num of output channels
#f: convolutional kernel size, usually 5. s:stride (default=1), if it's not 1, it will not
#use the whole data
#F.cross_entropy(logits, y) This function will use softmax and calculate entropy
#note: no need to do softmax, because softmax is done in function "cross_entropy"

def trainNetwork(networkname, batchsize, epochs, learning_rate = 1e-2, lossfnresult = True, consoledebug = False, 
                trainsetdir = './trainset'):
    wb, sheet1 = initExcelTable_train()
    net = Network()
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr = learning_rate)
    running_loss = 0.0
    trainloader, _ = getTrainLoader(batchsize=batchsize, trainsetdir=trainsetdir)
    for epoch in range(epochs):
        running_loss = 0.0
        loss1 = 0
        for i, mydata in enumerate(trainloader, 0):
            inputs, labels = mydata
            optimizer.zero_grad()
            predicted_labels = net(inputs)
            loss = loss_function(predicted_labels, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            loss1 += loss.item()
            if i % 20 == 19 and consoledebug == True:    # print every 20 mini-batches, for debug
                print('Epoch: %d, after mini-batch: %5d loss: %.6f' % (epoch + 1, i + 1, running_loss / 20))
                running_loss = 0.0
        if lossfnresult == True:
            sheet1.write(epoch + 1, 1, loss1)
            sheet1.write(epoch + 1, 0, epoch + 1)
        
    torch.save(net.state_dict(), './' + networkname)
    
    if lossfnresult == True:
        if os.path.isfile('train_result.xls'):
            os.remove('train_result.xls')
        wb.save('train_result.xls')
        
    print('Finished training.')

####Main function: trainNetwork()
- parameters:
  - networkname: the name of file that stores all parameters of network
  - batchsize(optional, default is 24)
  - learning rate(optional, default is 0.01)
  - lossfnresult(optional, default is True): if you want to check the loss function's value via an excel, set it True
  - consoledebug(optional, default is False): if you want to trace the current epoch, set it True, but it may print lot of lines
  - trainsetdir(optional, default is using the current trainset): you may set it to your own dataset directory if neccessary.

In [None]:
trainNetwork(networkname = "result.pth", batchsize=30, epochs=30, consoledebug=True)