In [1]:
# module imports
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.model_selection import KFold
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import Subset
from dataset import BigGenomeDataset
from dataloaders import BatchedGenomeDataLoader
from models import EdeepVPP

In [13]:
# hyper parameters
max_epoch = 1000
batch_size = 100
vocab_size = 5 # A, C, G, T, N
embed_dim = 5
out_size = 2

In [3]:
# train data
train_path = "../data/undersampled_train.csv"
train_set = BigGenomeDataset(train_path)
train_dataloader = BatchedGenomeDataLoader(train_set, batch_size=batch_size)

In [4]:
# criterion
# weight = torch.tensor([100.0, 1.0]).cuda()
criterion = nn.BCEWithLogitsLoss()

In [5]:
# model & optimizer
model = EdeepVPP(vocab_size, embed_dim, out_size=out_size).cuda()
optimizer = optim.Adam(model.parameters())

In [None]:
# training
model.train()
for epoch in range(max_epoch):
    total_loss, count = 0, 0

    for x, t in train_dataloader:
        optimizer.zero_grad()
        y = model(x)
        loss = criterion(y, t)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.data
        count += 1

    avg_loss = total_loss/count

    if epoch%10 == 9:
        print("| epoch %d | loss %f" % (epoch + 1, avg_loss))

| epoch 10 | loss 0.409069
| epoch 20 | loss 0.414495
| epoch 30 | loss 0.403536
| epoch 40 | loss 0.403734
| epoch 50 | loss 0.383433
| epoch 60 | loss 0.382669
| epoch 70 | loss 0.377067
| epoch 80 | loss 0.372685
| epoch 90 | loss 0.363985
| epoch 100 | loss 0.376610
| epoch 110 | loss 0.365012
| epoch 120 | loss 0.361953
| epoch 130 | loss 0.360827
| epoch 140 | loss 0.352873
| epoch 150 | loss 0.353417
| epoch 160 | loss 0.338176
| epoch 170 | loss 0.342708
| epoch 180 | loss 0.336188
| epoch 190 | loss 0.335437
| epoch 200 | loss 0.338080
| epoch 210 | loss 0.345089
| epoch 220 | loss 0.339793
| epoch 230 | loss 0.339074
| epoch 240 | loss 0.329357
| epoch 250 | loss 0.335011
| epoch 260 | loss 0.332027
| epoch 270 | loss 0.322769
| epoch 280 | loss 0.326410
| epoch 290 | loss 0.320955
| epoch 300 | loss 0.327184
| epoch 310 | loss 0.327551
| epoch 320 | loss 0.321032
| epoch 330 | loss 0.335036
| epoch 340 | loss 0.323620
| epoch 350 | loss 0.323205
| epoch 360 | loss 0.324974
|

In [None]:
# test data
test_path = "../data/undersampled_train.csv"
test_set = BigGenomeDataset(test_path)
test_dataloader = BatchedGenomeDataLoader(test_set, batch_size=batch_size, shuffle=True)

In [None]:
# test
batch_count = 0

model.eval()
with torch.no_grad():
    ts = torch.Tensor().cuda()
    ps = torch.Tensor().cuda()

    for x, t in test_dataloader:
        y = model(x)
        p = torch.sigmoid(y)

        t_tmp = t
        p_tmp = p

        if test_dataloader.iteration == 0:
            ts = t_tmp
            ps = p_tmp
        else:
            ts = torch.cat((ts, t_tmp), dim=0)
            ps = torch.cat((ps, p_tmp), dim=0)

    label = ts[:, 0].cpu()
    probability = ps[:, 0].cpu()
    y_pred = ps.argmin(dim=1).cpu()
    
    # confusion matrix
    conf_matrix = metrics.confusion_matrix(label, y_pred)
    print(conf_matrix)
    """
                        Predicted
                    Negative  Positive
    Actual Negative     TN        FP
           Positive     FN        TP
    
    """
    print("正解率: ", metrics.accuracy_score(label, y_pred))
    print("適合率: ", metrics.precision_score(label, y_pred))
    print("再現率: ", metrics.recall_score(label, y_pred))
    print("F1値: ", metrics.f1_score(label, y_pred))

    # ROC-AUC
    fpr, tpr, thresholds = metrics.roc_curve(label, y_pred)
    plt.plot(fpr, tpr, label="ROC curve")
    plt.title("ROC curve")
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.grid(True)
    plt.show()
    print("AUC-ROC: ", metrics.roc_auc_score(label, y_pred))

    # PR-AUC
    precision, recall, thresholds = metrics.precision_recall_curve(label, y_pred)
    auc = metrics.auc(recall, precision)
    plt.plot(recall, precision, label="PR curve")
    plt.title("PR curve")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.grid(True)
    plt.show()
    print("AUC-PR: ", auc)