In [1]:
import os
import sys

sys.path.append("../")

import pandas as pd

import numpy as np
import torch
import torch.nn as nn

from src.Metamodel.models.depthwiseNet import DepthNet
from src.Metamodel.data.dataloader import create_train_val_loader

from sklearn.metrics import confusion_matrix, auc, roc_curve, roc_auc_score, f1_score


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = DepthNet(lengths=30, patch_size=1, in_chans=5, embed_dim=256, norm_layer=None, output_dim=3).to(device)

savemodel = torch.load("../src/Metamodel/log/results_2.pt")
model.load_state_dict(savemodel["model"])

DepthNet is used...


<All keys matched successfully>

In [3]:
savemodel["hyperparameter"].keys()

odict_keys(['l2.0', 'l2.1', 'l2.2', 'l2.3', 'l2.4', 'l2.5', 'l2.6', 'l2.7', 'l2.8', 'l2.9', 'l2.10', 'l2.11', 'l2.12', 'l2.13', 'l2.14', 'l2.15'])

In [4]:
batch_size = 512
database = "../data/padding"

train_loader, valid_loader = create_train_val_loader(database, batch_size, length=30, meta_train_client_idx_lst=[1])

#######################################
Train DataLoader
x_data shape:  torch.Size([945, 5, 30])
y_data shape:  torch.Size([945]) class_num 3
#######################################
Validation DataLoader
x_data shape:  torch.Size([1661, 5, 30])
y_data shape:  torch.Size([1661]) class_num 3
#######################################


In [5]:
criterion = nn.CrossEntropyLoss().to(device)

In [6]:
model.eval()

correct = 0
valid_loss = 0.0

model.eval()
val_pred = []
val_real = []

with torch.no_grad():
    for data in train_loader:
        x_data, stage = data[0].to(device), data[1].to(device)
        
        # pred_value, _ = model(x_data)
        pred_value, _ = model(x_data)
        pred_class = torch.argmax(pred_value, dim=1)
        loss = criterion(pred_value, stage)
        
        valid_loss += loss.item()
        correct += (pred_class == stage).sum().item()
        
        val_pred.extend(pred_class.detach().cpu().numpy())
        val_real.extend(stage.detach().cpu().numpy())
    
    acc = correct / len(valid_loader.dataset)
    f1 = f1_score(val_real, val_pred, average="macro")
    fpr, tpr, thresholds = roc_curve(val_pred, val_real, pos_label=2)
    
    # roc_score = roc_auc_score(val_real, val_pred, multi_class="ovr", average="macro")
    print(f"(Valid) Loss: {round(valid_loss, 3)}, Acc: {round(acc, 3)}, AUC: {round(auc(fpr, tpr), 3)}, F1: {round(f1, 3)}")
    # print(f"(Valid) Epoch: {epoch}, Loss: {round(valid_loss, 3)}, Accuracy: {round(acc, 3)},") # AUC: {round(roc_score, 3)}")
    
    print(confusion_matrix(val_real, val_pred))

(Valid) Loss: 2.87, Acc: 0.283, AUC: 0.461, F1: 0.278
[[ 13 159   7]
 [  5 450   2]
 [  1 301   7]]


In [7]:
epochs=40

In [8]:
lr = 0.0001
# optimizer_outer = utils.create_optimizer(optimizer_outer, model.parameters(), {"lr": lr})
optimizer_outer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.1)
criterion = nn.CrossEntropyLoss().to(device)

best_acc = 0
best_f1 = 0
best_auc = 0
for epoch in range(epochs):
    epoch += 1
    model.train()
    running_loss = 0.0
    
    for data in train_loader:            
        x_data, stage = data[0].to(device), data[1].to(device)
        
        optimizer_outer.zero_grad()
        # pred_value, _ = model(x_data)
        pred_value, _ = model(x_data)
        loss = criterion(pred_value, stage)
        loss.backward()
        optimizer_outer.step()
        
        running_loss += loss.item()
        
    running_loss /= len(train_loader)
    
    print(f"(Train) Epoch: {epoch}, Loss: {round(running_loss, 3)}")
    
    if epoch % 1 == 0:
        print("<< Validation >>")
        correct = 0
        valid_loss = 0.0
        
        model.eval()
        val_pred = []
        val_real = []
        with torch.no_grad():
            for data in train_loader:
                x_data, stage = data[0].to(device), data[1].to(device)
                
                # pred_value, _ = model(x_data)
                pred_value, _ = model(x_data)
                pred_class = torch.argmax(pred_value, dim=1)
                loss = criterion(pred_value, stage)
                
                valid_loss += loss.item()
                correct += (pred_class == stage).sum().item()
                
                val_pred.extend(pred_class.detach().cpu().numpy())
                val_real.extend(stage.detach().cpu().numpy())
            
            acc = correct / len(valid_loader.dataset)
            f1 = f1_score(val_real, val_pred, average="macro")
            fpr, tpr, thresholds = roc_curve(val_pred, val_real, pos_label=1)
            
            if best_acc < acc:
                best_acc = acc
                # torch.save(model.state_dict(), "./model/depthwiseNet.pth")
                
            if best_f1 < f1:
                best_f1 = f1
                
            if best_auc < auc(fpr, tpr):
                best_auc = auc(fpr, tpr)
                # torch.save(model.state_dict(), "./model/depthwiseNet.pth")
            
            # roc_score = roc_auc_score(val_real, val_pred, multi_class="ovr", average="macro")
            print(f"(Valid) Epoch: {epoch}, Loss: {round(valid_loss, 3)}, Acc: {round(acc, 3)}, AUC: {round(auc(fpr, tpr), 3)}, F1: {round(f1, 3)}")
            # print(f"(Valid) Epoch: {epoch}, Loss: {round(valid_loss, 3)}, Accuracy: {round(acc, 3)},") # AUC: {round(roc_score, 3)}")
            
            print(confusion_matrix(val_real, val_pred))
            print("##########################################################")
print("Best ACC: ", best_acc)
print("Best f1: ", best_f1)
print("Best AUC: ", best_auc)

(Train) Epoch: 1, Loss: 1.232
<< Validation >>
(Valid) Epoch: 1, Loss: 5.872, Acc: 0.192, AUC: nan, F1: 0.202
[[ 10   2 167]
 [  3   1 453]
 [  0   1 308]]
##########################################################
(Train) Epoch: 2, Loss: 0.998
<< Validation >>
(Valid) Epoch: 2, Loss: 9.895, Acc: 0.192, AUC: nan, F1: 0.2
[[  9   1 169]
 [  0   2 455]
 [  0   1 308]]
##########################################################
(Train) Epoch: 3, Loss: 0.892
<< Validation >>
(Valid) Epoch: 3, Loss: 10.19, Acc: 0.192, AUC: nan, F1: 0.2
[[  9   1 169]
 [  1   2 454]
 [  0   1 308]]
##########################################################
(Train) Epoch: 4, Loss: 0.831
<< Validation >>
(Valid) Epoch: 4, Loss: 8.898, Acc: 0.194, AUC: nan, F1: 0.212
[[ 12   0 167]
 [  0   3 454]
 [  0   1 308]]
##########################################################
(Train) Epoch: 5, Loss: 0.791
<< Validation >>
(Valid) Epoch: 5, Loss: 7.392, Acc: 0.194, AUC: nan, F1: 0.212
[[ 12   1 166]
 [  0   3 454]
 [  



(Train) Epoch: 6, Loss: 0.753
<< Validation >>
(Valid) Epoch: 6, Loss: 6.143, Acc: 0.196, AUC: nan, F1: 0.219
[[ 14   0 165]
 [  0   3 454]
 [  0   1 308]]
##########################################################
(Train) Epoch: 7, Loss: 0.725
<< Validation >>
(Valid) Epoch: 7, Loss: 5.207, Acc: 0.197, AUC: nan, F1: 0.219
[[ 12   2 165]
 [  0   7 450]
 [  0   1 308]]
##########################################################
(Train) Epoch: 8, Loss: 0.714
<< Validation >>
(Valid) Epoch: 8, Loss: 4.486, Acc: 0.201, AUC: nan, F1: 0.235
[[ 15   1 163]
 [  0  11 446]
 [  0   1 308]]
##########################################################
(Train) Epoch: 9, Loss: 0.697
<< Validation >>
(Valid) Epoch: 9, Loss: 3.862, Acc: 0.211, AUC: nan, F1: 0.262
[[ 16   4 159]
 [  0  27 430]
 [  0   1 308]]
##########################################################
(Train) Epoch: 10, Loss: 0.678
<< Validation >>
(Valid) Epoch: 10, Loss: 3.354, Acc: 0.231, AUC: nan, F1: 0.307
[[ 17  12 150]
 [  0  58 399



(Train) Epoch: 14, Loss: 0.644
<< Validation >>
(Valid) Epoch: 14, Loss: 1.843, Acc: 0.404, AUC: nan, F1: 0.579
[[ 22 114  43]
 [  1 434  22]
 [  0  94 215]]
##########################################################
(Train) Epoch: 15, Loss: 0.638
<< Validation >>
(Valid) Epoch: 15, Loss: 1.775, Acc: 0.38, AUC: nan, F1: 0.537
[[ 22 125  32]
 [  0 450   7]
 [  0 149 160]]
##########################################################
(Train) Epoch: 16, Loss: 0.625
<< Validation >>
(Valid) Epoch: 16, Loss: 1.783, Acc: 0.361, AUC: nan, F1: 0.504
[[ 24 135  20]
 [  0 450   7]
 [  0 184 125]]
##########################################################
(Train) Epoch: 17, Loss: 0.621
<< Validation >>
(Valid) Epoch: 17, Loss: 1.821, Acc: 0.341, AUC: nan, F1: 0.465
[[ 27 141  11]
 [  0 452   5]
 [  0 222  87]]
##########################################################
(Train) Epoch: 18, Loss: 0.608
<< Validation >>
(Valid) Epoch: 18, Loss: 1.876, Acc: 0.323, AUC: nan, F1: 0.421
[[ 29 148   2]
 [  0 



(Valid) Epoch: 21, Loss: 1.81, Acc: 0.323, AUC: nan, F1: 0.425
[[ 34 144   1]
 [  0 455   2]
 [  1 260  48]]
##########################################################
(Train) Epoch: 22, Loss: 0.584
<< Validation >>
(Valid) Epoch: 22, Loss: 1.737, Acc: 0.332, AUC: nan, F1: 0.453
[[ 36 142   1]
 [  0 454   3]
 [  1 247  61]]
##########################################################
(Train) Epoch: 23, Loss: 0.576
<< Validation >>
(Valid) Epoch: 23, Loss: 1.762, Acc: 0.329, AUC: nan, F1: 0.444
[[ 37 141   1]
 [  0 455   2]
 [  1 254  54]]
##########################################################
(Train) Epoch: 24, Loss: 0.569
<< Validation >>
(Valid) Epoch: 24, Loss: 1.826, Acc: 0.317, AUC: nan, F1: 0.406
[[ 39 139   1]
 [  0 457   0]
 [  1 278  30]]
##########################################################
(Train) Epoch: 25, Loss: 0.574
<< Validation >>
(Valid) Epoch: 25, Loss: 1.951, Acc: 0.308, AUC: nan, F1: 0.378
[[ 39 140   0]
 [  0 457   0]
 [  1 292  16]]
#######################



(Train) Epoch: 29, Loss: 0.539
<< Validation >>
(Valid) Epoch: 29, Loss: 2.346, Acc: 0.304, AUC: nan, F1: 0.363
[[ 48 131   0]
 [  1 456   0]
 [ 10 298   1]]
##########################################################
(Train) Epoch: 30, Loss: 0.531
<< Validation >>
(Valid) Epoch: 30, Loss: 2.431, Acc: 0.307, AUC: nan, F1: 0.373
[[ 52 127   0]
 [  1 456   0]
 [ 14 293   2]]
##########################################################
(Train) Epoch: 31, Loss: 0.527
<< Validation >>
(Valid) Epoch: 31, Loss: 2.621, Acc: 0.3, AUC: nan, F1: 0.348
[[ 41 138   0]
 [  0 457   0]
 [  7 301   1]]
##########################################################
(Train) Epoch: 32, Loss: 0.522
<< Validation >>
(Valid) Epoch: 32, Loss: 2.757, Acc: 0.299, AUC: nan, F1: 0.341
[[ 38 141   0]
 [  0 457   0]
 [  4 304   1]]
##########################################################
(Train) Epoch: 33, Loss: 0.516
<< Validation >>
(Valid) Epoch: 33, Loss: 2.734, Acc: 0.303, AUC: nan, F1: 0.36
[[ 46 133   0]
 [  0 45



(Train) Epoch: 36, Loss: 0.496
<< Validation >>
(Valid) Epoch: 36, Loss: 2.47, Acc: 0.308, AUC: nan, F1: 0.375
[[ 53 126   0]
 [  0 457   0]
 [ 12 296   1]]
##########################################################
(Train) Epoch: 37, Loss: 0.479
<< Validation >>
(Valid) Epoch: 37, Loss: 2.555, Acc: 0.298, AUC: nan, F1: 0.338
[[ 35 144   0]
 [  0 457   0]
 [  2 304   3]]
##########################################################
(Train) Epoch: 38, Loss: 0.477
<< Validation >>
(Valid) Epoch: 38, Loss: 2.553, Acc: 0.3, AUC: nan, F1: 0.348
[[ 36 143   0]
 [  0 457   0]
 [  1 302   6]]
##########################################################
(Train) Epoch: 39, Loss: 0.47
<< Validation >>
(Valid) Epoch: 39, Loss: 2.561, Acc: 0.308, AUC: nan, F1: 0.375
[[ 50 129   0]
 [  0 457   0]
 [ 12 293   4]]
##########################################################
(Train) Epoch: 40, Loss: 0.465
<< Validation >>
(Valid) Epoch: 40, Loss: 2.701, Acc: 0.308, AUC: nan, F1: 0.417
[[107  72   0]
 [ 55 402

