In [1]:
import torch
import numpy as np
import sys
sys.path.insert(0, '../lib')
from models import BasicCNN, TaxonomyCNN, evaluate_model
from dataset_utils import train_test_split, split_dataset
from datasets import Genes, Mutations

In [4]:
args_dict = [{'batch_size': 5, 'epochs': 10, 'lr': 0.01}]

for i in range(len(args_dict)):
    dataset = Genes('../data', k=500)
    dataset, test_set = split_dataset(dataset, 0.1, True)

    train_labels = torch.from_numpy(np.array(dataset.y)).cuda()
    test_labels = torch.from_numpy(np.array(test_set.y))

    dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=5)
    labels_loader = torch.utils.data.DataLoader(train_labels, batch_size=5)

    model = TaxonomyCNN().train().cuda()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args_dict[i]['lr'])

    losses = []
    loss = None
    for epoch in range(args_dict[i]['epochs']):
        print('Epoch: ' + str(epoch+1) + '/' + str(args_dict[i]['epochs']))
        model = model.train().cuda()

        for x, y in zip(dataset_loader, labels_loader):
            optimizer.zero_grad()

            prediction = model(x.cuda())
            loss = criterion(prediction, y)

            loss.backward()
            optimizer.step()

        losses.append(loss.item())
        print('Loss: ' + str(losses[-1]))
        test_acc = evaluate_model(model.cpu(), test_set, test_labels)
        print('Test accuracy: ' + str(test_acc) + '\n')

    torch.save({'state_dict': model.state_dict, 'loss': losses, 'args_dict': args_dict[i]}, 'taxonomy_cnn.pt')



Epoch: 1/10
Loss: 1.7923204898834229
Test accuracy: 0.17466666666666666

Epoch: 2/10
Loss: 1.7759990692138672
Test accuracy: 0.183

Epoch: 3/10
Loss: 1.8190256357192993
Test accuracy: 0.19866666666666666

Epoch: 4/10
Loss: 1.7375704050064087
Test accuracy: 0.2653333333333333

Epoch: 5/10
Loss: 1.682586908340454
Test accuracy: 0.2763333333333333

Epoch: 6/10
Loss: 1.7589813470840454
Test accuracy: 0.2856666666666667

Epoch: 7/10
Loss: 1.8001174926757812
Test accuracy: 0.25266666666666665

Epoch: 8/10
Loss: 1.7624294757843018
Test accuracy: 0.27266666666666667

Epoch: 9/10
Loss: 1.821192979812622
Test accuracy: 0.244

Epoch: 10/10
Loss: 1.7226572036743164
Test accuracy: 0.26266666666666666



In [None]:
    train_acc = evaluate_model(model.cpu(), dataset, train_labels.cpu())
    print('Final train accuracy: ' + str(train_acc))
    
    test_labels = torch.from_numpy(np.asarray(test_labels))
    test_acc = evaluate_model(model.cpu(), test_set, test_labels)
    print('Final test accuracy: ' + str(test_acc))