In [None]:
import os
import sys
import time

sys.path.append("../src/Metamodel/")

from utils import EarlyStopping

import torch
import torch.nn as nn
import numpy as np

from data.dataloader import create_train_val_loader

from models.depthwiseNet import DepthNet

from sklearn.metrics import confusion_matrix, auc, roc_auc_score, f1_score, classification_report
from torchmetrics.classification import MulticlassAUROC

import matplotlib.pyplot as plt
import seaborn as sns

import torch.backends.cudnn as cudnn
import random

# from pytorchtools import EarlyStopping

torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(0)
np.random.seed(0)
torch.cuda.manual_seed_all(0)

In [None]:
DATAPATH = "../data/new"
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model = DepthNet(lengths=30, patch_size=30, in_chans=2, embed_dim=256, norm_layer=None, output_dim=3).to(device)

In [None]:
clients_results_meta = []

# MODELPATH = f"../singlelog/total_3class.pt"
MODELPATH = "../src/Metamodel/log/best_model_t1t2.pt"
#MODELPATH = "../src/Metamodel/log/best_model_cg.pt"
pretrained = torch.load(MODELPATH, map_location=device)
model.load_state_dict(pretrained)

In [None]:
weightedType = "macro"
batch_size = 256
auroc = MulticlassAUROC(num_classes=3, average=weightedType).to(device)

test_result_dic = {"loss": [], "acc": [], "f1": [], "auc": [], "confusion_matrix": []}
criterion = nn.CrossEntropyLoss().to(device)

In [None]:
valid_loss = 0

for i in range(10):
    client_num = i+1
    print(f"Client {client_num} Test")
    PATH = os.path.join("../data/new", f"c{i+1}_data.csv")
    
    generator = torch.Generator()
    generator.manual_seed(0)
    
    train_set, _ = create_train_val_loader(DATAPATH, batch_size, length=30,
                                                    meta_train_client_idx_lst=[client_num], FLtrain=True)
    train_data, test_data = torch.utils.data.random_split(train_set,
                                                            [int(len(train_set)*0.8),
                                                            len(train_set)-int(len(train_set)*0.8)],
                                                            generator=generator)
    
    train_loader = torch.utils.data.DataLoader(
            train_data, batch_size=32, shuffle=True, drop_last=False
        )

    test_loader = torch.utils.data.DataLoader(
            test_data, batch_size=32, shuffle=True, drop_last=False
        )
    
    test_pred = []
    test_real = []
    test_proba = []

    correct = 0

    test_loss = 0.0

    with torch.no_grad():
        model.eval()

        for data in test_loader:
            x_data, stage = data[0].to(device), data[1].to(device)
            
            # pred_value, _ = model(x_data)
            pred_value, _ = model(x_data)
            pred = torch.argmax(pred_value, dim=1)
            pred_proba = torch.sigmoid(pred_value)

            correct += torch.sum(pred==stage).item()

            test_pred.extend(pred.detach().cpu().numpy())
            test_real.extend(stage.detach().cpu().numpy())
            test_proba.extend(pred_proba.detach().cpu().numpy())

            loss = criterion(pred_value, stage)
            test_loss += loss.item()
        
        auc = auroc(torch.tensor(test_proba), torch.tensor(test_real))
        acc = correct / len(test_loader.dataset)
        valid_loss /= len(test_loader)
        f1score = f1_score(test_real, test_pred, average=weightedType)            

        print(f"(Test) Loss: {round(valid_loss, 3)}, AUC: {round(float(auc),3)}, ACC: {round(acc,3)}, F1score: {round(f1score, 3)}")
        print("################################################################################################")
        