In [None]:
import os
import sys
import time

sys.path.append("../src/Metamodel/")

import torch
import torch.nn as nn

from data.dataloader import create_train_val_loader

from models.depthwiseNet import DepthNet

from sklearn.metrics import confusion_matrix, auc, roc_curve, roc_auc_score, f1_score

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# Initialize seed if specified (might slow down the model)
seed = 0 # Client 9
num = 6
torch.manual_seed(seed)
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

epochs = 200
batch_size = 512
# database = "../../data/padding"
database = "../data/new"

# Create the training, validation and test dataloader
train_set, validation_set = create_train_val_loader(database, batch_size, length=30, meta_train_client_idx_lst=[num], FLtrain=True)

model = DepthNet(lengths=30, patch_size=30, in_chans=2, embed_dim=256, norm_layer=None, output_dim=3).to(device)

In [None]:
train, test = torch.utils.data.random_split(train_set, [int(len(train_set)*0.8), len(train_set)-int(len(train_set)*0.8)], generator=torch.Generator().manual_seed(42))

In [None]:
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)
valid_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=False)

In [None]:
lr = 0.001
# optimizer_outer = utils.create_optimizer(optimizer_outer, model.parameters(), {"lr": lr})
optimizer_outer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer_outer, step_size=40, gamma=0.5, last_epoch=-1, verbose=False)
criterion = nn.CrossEntropyLoss().to(device)

best_acc = 0
best_f1 = 0
# best_auc = 0
for epoch in range(epochs):
    epoch += 1
    model.train()
    running_loss = 0.0
    
    for data in train_loader:            
        x_data, stage = data[0].to(device), data[1].to(device)
        
        optimizer_outer.zero_grad()
        # pred_value, _ = model(x_data)
        pred_value, _ = model(x_data)
        loss = criterion(pred_value, stage)
        loss.backward()
        optimizer_outer.step()
        
        running_loss += loss.item()
        
    running_loss /= len(train_loader)
    scheduler.step()
    
    # print(f"(Train) Epoch: {epoch}, Loss: {round(running_loss, 3)}")
    
    if epoch % 10 == 0:
        print("<< Validation >>")
        correct = 0
        valid_loss = 0.0
        
        model.eval()
        val_pred = []
        val_real = []
        with torch.no_grad():
            for data in test_loader:
                x_data, stage = data[0].to(device), data[1].to(device)
                
                # pred_value, _ = model(x_data)
                pred_value, _ = model(x_data)
                pred_class = torch.argmax(pred_value, dim=1)
                loss = criterion(pred_value, stage)
                
                valid_loss += loss.item()
                correct += (pred_class == stage).sum().item()
                
                val_pred.extend(pred_class.detach().cpu().numpy())
                val_real.extend(stage.detach().cpu().numpy())
            
            acc = correct / len(test_loader.dataset)
            f1 = f1_score(val_real, val_pred, average="macro")
            fpr, tpr, thresholds = roc_curve(val_pred, val_real, pos_label=1)
            
            if best_acc < acc:
                best_acc = acc
                # torch.save(model.state_dict(), "./model/depthwiseNet.pth")
                
            if best_f1 < f1:
                best_f1 = f1
                # torch.save(model.state_dict(), "./log/Base_3class.pth")
            
            # roc_score = roc_auc_score(val_real, val_pred, multi_class="ovr", average="macro")
            print(f"(Valid) Epoch: {epoch}, Loss: {round(valid_loss, 3)}, Acc: {round(acc, 3)}, F1: {round(f1, 3)}")
            # print(f"(Valid) Epoch: {epoch}, Loss: {round(valid_loss, 3)}, Accuracy: {round(acc, 3)},") # AUC: {round(roc_score, 3)}")
            
            sns.heatmap(confusion_matrix(val_real, val_pred), annot=True, fmt="d", cmap="Blues", cbar=False)
            plt.title(f"confusion_matrix (client {num})")
            plt.xlabel("Predicted")
            plt.ylabel("True")
            plt.show()
            
print("Best ACC: ", best_acc)
print("Best f1: ", best_f1)