In [None]:
import os
import sys

sys.path.append("../")

import pandas as pd
import random
import matplotlib.pyplot as plt


import numpy as np
import torch
import torch.nn as nn

from src.Metamodel.models.depthwiseNet import DepthNet
from src.Metamodel.data.dataloader import create_train_val_loader
from torch.utils.data import random_split, DataLoader

from sklearn.metrics import confusion_matrix, auc, roc_curve, roc_auc_score, f1_score, classification_report

generator = torch.Generator()
generator.manual_seed(0)
random.seed(0)

# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

from natsort import natsorted

In [None]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

model = DepthNet(lengths=30, patch_size=30, in_chans=2, embed_dim=256, norm_layer=None, output_dim=3).to(device)

# savemodel = torch.load("../src/Metamodel/log/Base_3class_normal.pt", map_location=device)
savemodel = torch.load("../src/Metamodel/log/best_model_t1t2_origin.pt", map_location=device)
model.load_state_dict(savemodel)

client=[5]

In [None]:
model.eval()

correct = 0
valid_loss = 0.0

val_pred = []
val_real = []

with torch.no_grad():
    for data in client_test_loader:
        x_data, stage = data[0].to(device), data[1].to(device)
        
        # pred_value, _ = model(x_data)
        pred_value, _ = model(x_data)
        pred_class = torch.argmax(pred_value, dim=1)
        loss = criterion(pred_value, stage)
        
        valid_loss += loss.item()
        correct += (pred_class == stage).sum().item()
        
        val_pred.extend(pred_class.detach().cpu().numpy())
        val_real.extend(stage.detach().cpu().numpy())
    
    acc = correct / len(client_test_loader.dataset)
    f1 = f1_score(val_real, val_pred, average="macro")
    fpr, tpr, thresholds = roc_curve(val_pred, val_real, pos_label=2)
    
    # roc_score = roc_auc_score(val_real, val_pred, multi_class="ovr", average="macro")
    print(f"(Valid) Loss: {round(valid_loss, 3)} // ACC {round(acc, 3)}")    
    print("---------------------------------------------------")
    print("AUC:", round(auc(fpr, tpr), 3))
    print(confusion_matrix(val_real, val_pred))
    print(classification_report(val_real, val_pred))
    print("---------------------------------------------------")

In [None]:
epochs=50
lr = 0.00001
decay = 0.000001

In [None]:

# optimizer_outer = utils.create_optimizer(optimizer_outer, model.parameters(), {"lr": lr})
optimizer_outer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)
criterion = nn.CrossEntropyLoss().to(device)

best_acc = 0
best_f1 = 0
best_auc = 0

learning_curve = []
valid_curve = []
acc_curve = []

for epoch in range(epochs):
    epoch += 1
    model.train()
    running_loss = 0.0
    
    for data in client_train_loader:            
        x_data, stage = data[0].to(device), data[1].to(device)
        
        optimizer_outer.zero_grad()
        # pred_value, _ = model(x_data)
        pred_value, _ = model(x_data)
        loss = criterion(pred_value, stage)
        loss.backward()
        optimizer_outer.step()
        
        running_loss += loss.item()
        
    running_loss /= len(client_train_loader)
    learning_curve.append(running_loss)
    
    print(f"(Train) Epoch: {epoch}, Loss: {round(running_loss, 3)}")
    
    if epoch % 1 == 0:
        print("<< Validation >>")
        correct = 0
        valid_loss = 0.0
        
        model.eval()
        
        val_pred = []
        val_real = []
        val_score = []
        
        with torch.no_grad():
            for data in client_test_loader:
                x_data, stage = data[0].to(device), data[1].to(device)
                
                # pred_value, _ = model(x_data)
                pred_value, _ = model(x_data)
                pred_class = torch.argmax(pred_value, dim=1)
                loss = criterion(pred_value, stage)
                
                valid_loss += loss.item()
                correct += (pred_class == stage).sum().item()
                
                val_pred.extend(pred_class.detach().cpu().numpy())
                val_score.extend(pred_value.detach().cpu().numpy())
                val_real.extend(stage.detach().cpu().numpy())
            
            valid_curve.append(valid_loss)
            
            acc = correct / len(client_test_loader.dataset)
            acc_curve.append(acc)
            f1 = f1_score(val_real, val_pred, average="micro")
            # fpr, tpr, thresholds = roc_curve(val_real, val_score)
            
            if best_acc < acc:
                best_acc = acc
                # torch.save(model.state_dict(), "./model/depthwiseNet.pth")
                
            if best_f1 < f1:
                best_f1 = f1
                # torch.save(model.state_dict(), f"../src/Metamodel/log/Base_3class_transfer_client{client}.pth")
                
            # if best_auc < auc(fpr, tpr):
            #     best_auc = auc(fpr, tpr)
                # torch.save(model.state_dict(), "./model/depthwiseNet.pth")
            
            # roc_score = roc_auc_score(val_real, val_pred, multi_class="ovr", average="macro")
            print("---------------------------------------------------")
            # print("AUC:", round(auc(fpr, tpr), 3))
            print(confusion_matrix(val_real, val_pred))
            print(classification_report(val_real, val_pred))
            print("---------------------------------------------------")
print("Best ACC: ", best_acc)
print("Best f1: ", best_f1)
print("Best AUC: ", best_auc)

In [None]:
plt.figure(figsize=(40,20))
plt.plot(acc_curve)
plt.xticks(fontsize=30)
plt.show()

In [None]:
plt.figure(figsize=(40,20))
plt.plot(learning_curve)
plt.plot(valid_curve)
plt.xticks(fontsize=30)
plt.show()