In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
from capsnet import CapsNet
from data_loader import Dataset
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
trainDir = "/home/trojan/Desktop/dimentia/dataset/data_2categ/data_PGGAN/train"
valDir = "/home/trojan/Desktop/dimentia/dataset/data_2categ/data_PGGAN/validation"

USE_CUDA = True if torch.cuda.is_available() else False
BATCH_SIZE = 16
N_EPOCHS = 30
LEARNING_RATE = 0.01
MOMENTUM = 0.9

In [3]:
class Config:
    def __init__(self, dataset='dementia'):
        if dataset == 'mnist':
            # CNN (cnn)
            self.cnn_in_channels = 1
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 9

            # Primary Capsule (pc)
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 9
            self.pc_num_routes = 32 * 6 * 6

            # Digit Capsule (dc)
            self.dc_num_capsules = 10
            self.dc_num_routes = 32 * 6 * 6
            self.dc_in_channels = 8
            self.dc_out_channels = 16

            # Decoder
            self.input_width = 28
            self.input_height = 28

        elif dataset == 'cifar10':
            # CNN (cnn)
            self.cnn_in_channels = 3
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 9

            # Primary Capsule (pc)
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 9
            self.pc_num_routes = 32 * 8 * 8

            # Digit Capsule (dc)
            self.dc_num_capsules = 10
            self.dc_num_routes = 32 * 8 * 8
            self.dc_in_channels = 8
            self.dc_out_channels = 16

            # Decoder
            self.input_width = 32
            self.input_height = 32

        elif dataset == 'dementia':
            # CNN (cnn)
            self.cnn_in_channels = 3
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 9

            # Primary Capsule (pc)
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 9
            self.pc_num_routes = 32 * 24 * 24

            # Digit Capsule (dc)
            self.dc_num_capsules = 2
            self.dc_num_routes = 32 * 24 * 24
            self.dc_in_channels = 8
            self.dc_out_channels = 16

            # Decoder
            self.input_width = 64
            self.input_height = 64

In [4]:
def train(model, optimizer, train_loader, epoch):
    capsule_net = model
    capsule_net.train()
    n_batch = len(list(enumerate(train_loader)))
    total_loss = 0
    TRAIN_LOSS = []
    for batch_id, (data, target) in enumerate(tqdm(train_loader)):

        target = torch.sparse.torch.eye(2).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)

        if USE_CUDA:
            data, target = data.cuda(), target.cuda()

        optimizer.zero_grad()
        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)
        loss.backward()
        optimizer.step()
        correct = sum(np.argmax(masked.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1))
        train_loss = loss.item()
        total_loss += train_loss
        if batch_id % 100 == 0:
            tqdm.write("Epoch: [{}/{}], Batch: [{}/{}], train accuracy: {:.6f}, loss: {:.6f}".format(
                epoch,
                N_EPOCHS,
                batch_id + 1,
                n_batch,
                correct / float(BATCH_SIZE),
                train_loss / float(BATCH_SIZE)
                ))
            
    
    tqdm.write('Epoch: [{}/{}], train loss: {:.6f}'.format(epoch,N_EPOCHS,total_loss / len(train_loader.dataset)))
    TRAIN_LOSS.append(total_loss / len(train_loader.dataset))
    
def validation(capsule_net, val_loader, epoch):
    capsule_net.eval()
    val_loss = 0
    correct = 0
    best_acc = 0
    VAL_LOSS = []
    for batch_id, (data, target) in enumerate(val_loader):

        target = torch.sparse.torch.eye(2).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)

        if USE_CUDA:
            data, target = data.cuda(), target.cuda()

        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)

        val_loss += loss.item()
        correct += sum(np.argmax(masked.data.cpu().numpy(), 1) ==
                       np.argmax(target.data.cpu().numpy(), 1))

    val_acc = correct / len(val_loader.dataset)
    tqdm.write(
        "Epoch: [{}/{}], validation accuracy: {:.6f}, loss: {:.6f}".format(epoch, N_EPOCHS, val_acc,
                                                                  val_loss / len(val_loader)))
    VAL_LOSS.append(val_loss / len(val_loader))
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(capsule_net.state_dict(), './models/model.pth')

In [5]:
if __name__ == '__main__':
    torch.manual_seed(1)
    dataset = 'dementia'
    # dataset = 'mnist'
    config = Config(dataset)
    dementia = Dataset(dataset, BATCH_SIZE, trainDir, valDir)

    capsule_net = CapsNet(config)
    capsule_net = torch.nn.DataParallel(capsule_net)
    if USE_CUDA:
        capsule_net = capsule_net.cuda()
    capsule_net = capsule_net.module

    optimizer = torch.optim.Adam(capsule_net.parameters())

    for e in range(1, N_EPOCHS + 1):
        train(capsule_net, optimizer, dementia.train_loader, e)
        validation(capsule_net, dementia.val_loader, e)

  0%|          | 1/609 [00:01<13:05,  1.29s/it]

Epoch: [1/30], Batch: [1/609], train accuracy: 0.687500, loss: 0.056298


 17%|█▋        | 101/609 [01:02<05:09,  1.64it/s]

Epoch: [1/30], Batch: [101/609], train accuracy: 0.500000, loss: 0.034158


 33%|███▎      | 201/609 [02:03<04:09,  1.64it/s]

Epoch: [1/30], Batch: [201/609], train accuracy: 0.562500, loss: 0.025094


 49%|████▉     | 301/609 [03:04<03:07,  1.64it/s]

Epoch: [1/30], Batch: [301/609], train accuracy: 0.500000, loss: 0.024990


 66%|██████▌   | 401/609 [04:06<02:07,  1.63it/s]

Epoch: [1/30], Batch: [401/609], train accuracy: 0.687500, loss: 0.025027


 82%|████████▏ | 501/609 [05:07<01:06,  1.63it/s]

Epoch: [1/30], Batch: [501/609], train accuracy: 0.687500, loss: 0.025020


 99%|█████████▊| 601/609 [06:09<00:04,  1.64it/s]

Epoch: [1/30], Batch: [601/609], train accuracy: 0.687500, loss: 0.025201


100%|██████████| 609/609 [06:14<00:00,  1.63it/s]


Epoch: [1/30], train loss: 0.029304
Epoch: [1/30], validation accuracy: 0.530435, loss: 0.405311


  0%|          | 1/609 [00:00<06:25,  1.58it/s]

Epoch: [2/30], Batch: [1/609], train accuracy: 0.562500, loss: 0.025063


  2%|▏         | 13/609 [00:08<06:14,  1.59it/s]


KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(10,5))
plt.title("Training and Validation Loss")
plt.plot(TRAIN_LOSS,label="train")
plt.plot(VAL_LOSS,label="validation")
plt.xlabel("epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()