In [8]:
import torch
import numpy as np
from matplotlib import pyplot as plt
import datetime
import json
import sys
sys.path.insert(0, '../lib')
import models
import datasets
import dataset_utils

In [9]:
with open('../reports/parameters_cnn.json', 'r') as file:
    args_dict = json.load(file)

for i in range(len(args_dict)):
    current_time = datetime.datetime.now().strftime('%d-%m-%Y_%H:%M:%S')
    dataset = datasets.Genes('../data', k=args_dict[i]['k'], genes_dict=args_dict[i]['genes_dict'])
    dataset, test_set = dataset_utils.split_dataset(dataset, test_size=0.1, shuffle=True)

    train_labels = torch.from_numpy(np.array(dataset.y)).cuda()
    test_labels = torch.from_numpy(np.array(test_set.y))

    dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=5)
    labels_loader = torch.utils.data.DataLoader(train_labels, batch_size=5)

    model = models.TaxonomyCNN(dataset, kernel_size=args_dict[i]['kernel_size']).train().cuda()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args_dict[i]['lr'])

    losses = []
    accuracies = []
    for epoch in range(args_dict[i]['epochs']):
        print('Epoch: ' + str(epoch + 1) + '/' + str(args_dict[i]['epochs']))
        model = model.train().cuda()

        for x, y in zip(dataset_loader, labels_loader):
            optimizer.zero_grad()

            prediction = model(x.cuda())
            loss = criterion(prediction, y)

            loss.backward()
            optimizer.step()

        losses.append(loss.item())
        print('Loss: ' + str(losses[-1]))

        test_acc, y_pred = models.evaluate_model(model.cpu(), test_set, test_labels)
        print('Test accuracy: ' + str(test_acc) + '\n')
        accuracies.append(test_acc)

    test_acc, y_pred = models.evaluate_model(model.cpu(), test_set, test_labels)
    print('Final test accuracy: ' + str(test_acc) + '\n')

    cm = models.construct_confusion_matrix(y_true=test_labels.numpy(), y_pred=y_pred)
    models.plot_confusion_matrix(cm=cm, classes=dataset.genes_dict)
    plt.savefig('../reports/cm/cnn_' + current_time + '.png')

    torch.save({'state_dict': model.state_dict(), 'loss': losses, 'accuracies': accuracies, 'args_dict': args_dict[i]},
               '../results/cnn_' + current_time + '.pt')

Epoch: 1/100
Loss: 0.6018974184989929
Test accuracy: 0.603

Epoch: 2/100


KeyboardInterrupt: 