In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
from capsnet import CapsNet
from data_loader import Dataset
from tqdm import tqdm


In [2]:
trainDir = "/home/trojan/Desktop/dimentia/dataset/data_2categ/data_PGGAN/train"
valDir = "/home/trojan/Desktop/dimentia/dataset/data_2categ/data_PGGAN/validation"

USE_CUDA = True if torch.cuda.is_available() else False
BATCH_SIZE = 8
N_EPOCHS = 30
LEARNING_RATE = 0.01
MOMENTUM = 0.9

In [3]:
class Config:
    def __init__(self, dataset='dementia'):
        if dataset == 'mnist':
            # CNN (cnn)
            self.cnn_in_channels = 1
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 9

            # Primary Capsule (pc)
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 9
            self.pc_num_routes = 32 * 6 * 6

            # Digit Capsule (dc)
            self.dc_num_capsules = 10
            self.dc_num_routes = 32 * 6 * 6
            self.dc_in_channels = 8
            self.dc_out_channels = 16

            # Decoder
            self.input_width = 28
            self.input_height = 28

        elif dataset == 'cifar10':
            # CNN (cnn)
            self.cnn_in_channels = 3
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 9

            # Primary Capsule (pc)
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 9
            self.pc_num_routes = 32 * 8 * 8

            # Digit Capsule (dc)
            self.dc_num_capsules = 10
            self.dc_num_routes = 32 * 8 * 8
            self.dc_in_channels = 8
            self.dc_out_channels = 16

            # Decoder
            self.input_width = 32
            self.input_height = 32

        elif dataset == 'dementia':
            # CNN (cnn)
            self.cnn_in_channels = 3
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 9

            # Primary Capsule (pc)
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 9
            self.pc_num_routes = 32 * 6 * 6

            # Digit Capsule (dc)
            self.dc_num_capsules = 2
            self.dc_num_routes = 32 * 6 * 6
            self.dc_in_channels = 8
            self.dc_out_channels = 16

            # Decoder
            self.input_width = 28
            self.input_height = 28

In [4]:
def train(model, optimizer, train_loader, epoch):
    capsule_net = model
    capsule_net.train()
    n_batch = len(list(enumerate(train_loader)))
    total_loss = 0
    for batch_id, (data, target) in enumerate(tqdm(train_loader)):

        target = torch.sparse.torch.eye(2).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)

        if USE_CUDA:
            data, target = data.cuda(), target.cuda()

        optimizer.zero_grad()
        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)
        loss.backward()
        optimizer.step()
        correct = sum(np.argmax(masked.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1))
        train_loss = loss.item()
        total_loss += train_loss
        if batch_id % 100 == 0:
            tqdm.write("Epoch: [{}/{}], Batch: [{}/{}], train accuracy: {:.6f}, loss: {:.6f}".format(
                epoch,
                N_EPOCHS,
                batch_id + 1,
                n_batch,
                correct / float(BATCH_SIZE),
                train_loss / float(BATCH_SIZE)
                ))
    tqdm.write('Epoch: [{}/{}], train loss: {:.6f}'.format(epoch,N_EPOCHS,total_loss / len(train_loader.dataset)))


def validation(capsule_net, val_loader, epoch):
    capsule_net.eval()
    val_loss = 0
    correct = 0
    for batch_id, (data, target) in enumerate(val_loader):

        target = torch.sparse.torch.eye(2).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)

        if USE_CUDA:
            data, target = data.cuda(), target.cuda()

        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)

        val_loss += loss.item()
        correct += sum(np.argmax(masked.data.cpu().numpy(), 1) ==
                       np.argmax(target.data.cpu().numpy(), 1))

    tqdm.write(
        "Epoch: [{}/{}], validation accuracy: {:.6f}, loss: {:.6f}".format(epoch, N_EPOCHS, correct / len(val_loader.dataset),
                                                                  val_loss / len(val_loader)))

In [5]:
if __name__ == '__main__':
    torch.manual_seed(1)
    dataset = 'dementia'
    # dataset = 'mnist'
    config = Config(dataset)
    dementia = Dataset(dataset, BATCH_SIZE, trainDir, valDir)

    capsule_net = CapsNet(config)
    capsule_net = torch.nn.DataParallel(capsule_net)
    if USE_CUDA:
        capsule_net = capsule_net.cuda()
    capsule_net = capsule_net.module

    optimizer = torch.optim.Adam(capsule_net.parameters())

    for e in range(1, N_EPOCHS + 1):
        train(capsule_net, optimizer, dementia.train_loader, e)
        validation(capsule_net, dementia.val_loader, e)

  0%|          | 3/1218 [00:00<02:14,  9.02it/s]

Epoch: [1/30], Batch: [1/1218], train accuracy: 0.500000, loss: 0.112584


  8%|▊         | 103/1218 [00:06<01:14, 14.92it/s]

Epoch: [1/30], Batch: [101/1218], train accuracy: 0.625000, loss: 0.056708


 17%|█▋        | 203/1218 [00:13<01:06, 15.33it/s]

Epoch: [1/30], Batch: [201/1218], train accuracy: 0.750000, loss: 0.051568


 25%|██▍       | 303/1218 [00:19<01:00, 15.24it/s]

Epoch: [1/30], Batch: [301/1218], train accuracy: 0.750000, loss: 0.050634


 33%|███▎      | 403/1218 [00:26<00:54, 15.00it/s]

Epoch: [1/30], Batch: [401/1218], train accuracy: 0.875000, loss: 0.053285


 41%|████▏     | 503/1218 [00:33<00:46, 15.32it/s]

Epoch: [1/30], Batch: [501/1218], train accuracy: 0.500000, loss: 0.052383


 50%|████▉     | 603/1218 [00:39<00:40, 15.30it/s]

Epoch: [1/30], Batch: [601/1218], train accuracy: 0.625000, loss: 0.049497


 58%|█████▊    | 703/1218 [00:46<00:34, 15.13it/s]

Epoch: [1/30], Batch: [701/1218], train accuracy: 0.875000, loss: 0.050362


 66%|██████▌   | 803/1218 [00:53<00:27, 14.95it/s]

Epoch: [1/30], Batch: [801/1218], train accuracy: 1.000000, loss: 0.050029


 74%|███████▍  | 903/1218 [00:59<00:20, 15.23it/s]

Epoch: [1/30], Batch: [901/1218], train accuracy: 0.750000, loss: 0.049622


 82%|████████▏ | 1003/1218 [01:06<00:14, 15.21it/s]

Epoch: [1/30], Batch: [1001/1218], train accuracy: 0.625000, loss: 0.049456


 91%|█████████ | 1103/1218 [01:12<00:07, 15.16it/s]

Epoch: [1/30], Batch: [1101/1218], train accuracy: 0.875000, loss: 0.048763


 99%|█████████▉| 1203/1218 [01:19<00:00, 15.28it/s]

Epoch: [1/30], Batch: [1201/1218], train accuracy: 0.875000, loss: 0.050418


100%|██████████| 1218/1218 [01:20<00:00, 15.19it/s]


Epoch: [1/30], train loss: 0.052306


AttributeError: 'Dataset' object has no attribute 'test_loader'