In [1]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from torch.optim import Adam
import glob

In [2]:
def data_processing():
    transformar = transforms.Compose([
        transforms.Resize((150, 150)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5]
        ),
    ])

    train_loader = DataLoader(
        torchvision.datasets.ImageFolder(train_path, transform = transformar),
        batch_size = 256,
        shuffle = True
    )
    test_loader = DataLoader(
        torchvision.datasets.ImageFolder(test_path, transform = transformar),
        batch_size = 256,
        shuffle = True
    )
    return train_loader, test_loader


In [3]:
class ConvNet(nn.Module):
    def __init__(self, num_class = 6):
        super(ConvNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels = 3, out_channels=12, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features = 12)
        self.relu1 = nn.ReLU()

        self.pool = nn.MaxPool2d(kernel_size=2)

        self.conv2 = nn.Conv2d(in_channels = 12, out_channels=20, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()

        self.conv3 = nn.Conv2d(in_channels = 20, out_channels = 32, kernel_size=3, stride=1, padding=1)
        self.batch_norm3 = nn.BatchNorm2d(num_features = 32)
        self.relu3 = nn.ReLU()

        self.fc = nn.Linear(in_features=32*75*75, out_features=num_class)

    
    def forward(self, input):
        output = self.conv1(input)
        output = self.bn1(output)
        output = self.relu1(output)

        output = self.pool(output)

        output = self.conv2(output)
        output = self.relu2(output)

        output = self.conv3(output)
        output = self.batch_norm3(output)
        output = self.relu3(output)

        output = output.view(-1, 32*75*75)
        return self.fc(output)


In [101]:
def train():
    train_accuracy = 0.0
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs  = model(images)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()

        # train_loss += loss.cpu().data * images.size(0)
        _, prediction = torch.max(outputs.data, 1)
        train_accuracy += int(torch.sum(prediction == labels.data))

    return train_accuracy


In [102]:
def test():
    test_accuracy = 0.0
    for i, (images, labels) in enumerate(test_loader):
        outputs = model(images)
        _, prediction = torch.max(outputs.data, 1) # get catagory id
        test_accuracy += int(torch.sum(prediction == labels.data))

    return test_accuracy


In [103]:
total_epoch = 15
best_accuracy = 0.0
train_path = "./data/seg_train"
test_path = "./data/seg_test"
train_count = len(glob.glob(train_path + '/**/*.jpg'))
test_count = len(glob.glob(test_path + '/**/*.jpg'))
model = ConvNet(num_class = 6)
train_loader, test_loader = data_processing()
loss_function = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

for i in range(total_epoch):
    train_accuracy = train()
    model.eval()
    test_accuracy = test()
    # train_accuracy = train_accuracy / train_count
    test_accuracy = test_accuracy / test_count

    print('num of epoch = ', i, ' test_accuracy = ', test_accuracy)

    if test_accuracy > best_accuracy:
        torch.save(model.state_dict(), 'best_checkpoint.model')
        best_accuracy = test_accuracy


num of epoch =  0  test_accuracy =  0.6766666666666666
num of epoch =  1  test_accuracy =  0.7416666666666667
num of epoch =  2  test_accuracy =  0.777
num of epoch =  3  test_accuracy =  0.769
num of epoch =  4  test_accuracy =  0.758
num of epoch =  5  test_accuracy =  0.7516666666666667
num of epoch =  6  test_accuracy =  0.759
num of epoch =  7  test_accuracy =  0.7663333333333333
num of epoch =  8  test_accuracy =  0.772
num of epoch =  9  test_accuracy =  0.771
num of epoch =  10  test_accuracy =  0.7656666666666667
num of epoch =  11  test_accuracy =  0.757
num of epoch =  12  test_accuracy =  0.757
num of epoch =  13  test_accuracy =  0.7596666666666667
num of epoch =  14  test_accuracy =  0.765


In [None]:
# pred
