#### DataLoader

In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pwd

/content


In [3]:
%cd /content/drive/MyDrive/Colab Notebooks/[Microdegree] Hypotension/code # change directory
!pwd

/content/drive/MyDrive/Colab Notebooks/[Microdegree] Hypotension/code
/content/drive/MyDrive/Colab Notebooks/[Microdegree] Hypotension/code


In [5]:
# !pip install torchsummary
# !pip install tensorboard
# !pip install wandb

In [4]:
import numpy as np
from collections import Counter
from tqdm import tqdm
from matplotlib import pyplot as plt

import os
import glob
import re
import pickle
import multiprocessing
import wandb
import argparse
from datetime import datetime
from pathlib import Path
import random
import json

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

from model_resnet1d import ResNet1D
from pytorchtools import EarlyStopping

# from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve, accuracy_score, precision_score, recall_score, f1_score # classification
from sklearn.metrics import mean_absolute_percentage_error, r2_score # regression

In [5]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

def increment_path(path, exist_ok=False):
    """ Automatically increment path, i.e. runs/exp --> runs/exp0, runs/exp1 etc.

    Args:
        path (str or pathlib.Path): f"{model_dir}/{args.name}".
        exist_ok (bool): whether increment path (increment if False).
    """
    path = Path(path)
    if (path.exists() and exist_ok) or (not path.exists()):
        return str(path)
    else:
        dirs = glob.glob(f"{path}*")
        matches = [re.search(rf"%s(\d+)" % path.stem, d) for d in dirs]
        i = [int(m.groups()[0]) for m in matches if m]
        n = max(i) + 1 if i else 2
        return f"{path}{n}"

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


_criterion_entrypoints = {
    'binary_cross_entropy': nn.BCEWithLogitsLoss, # nn.BCELoss,
    'cross_entropy': nn.CrossEntropyLoss,
}

def criterion_entrypoint(criterion_name):
    return _criterion_entrypoints[criterion_name]

def is_criterion(criterion_name):
    return criterion_name in _criterion_entrypoints

def create_criterion(criterion_name, **kwargs):
    if is_criterion(criterion_name):
        create_fn = criterion_entrypoint(criterion_name)
        criterion = create_fn(**kwargs)
    else:
        raise RuntimeError('Unknown loss (%s)' % criterion_name)
    return criterion

now = datetime.now()
folder_name = now.strftime('%Y-%m-%d-%H:%M:%S')
parser = argparse.ArgumentParser()

parser.add_argument('--seed', type=int, default=42, help='random seed (default: 42)')
parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 1)')
parser.add_argument('--batch_size', type=int, default=128, help='input batch size for training (default: 64)')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate (default: 1e-3)')
# parser.add_argument('--lr_decay_step', type=int, default=20, help='learning rate scheduler deacy step (default: 20)')
parser.add_argument('--criterion', type=str, default='cross_entropy', help='criterion type (default: cross_entropy)')
parser.add_argument('--log_interval', type=int, default=150, help='how many batches to wait before logging training status')
parser.add_argument('--model', type=str, default='resnet1d', help='model type (default: BaseModel)')
parser.add_argument('--name', default='exp', help='model save at {SM_MODEL_DIR}/{name}')
# parser.add_argument('--name', default='exp_'+folder_name, help='model save at {SM_MODEL_DIR}/{name}')
parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR', './model_clf'))

# args = parser.parse_args()
args, _ = parser.parse_known_args()
print(args)

Namespace(seed=42, epochs=100, batch_size=128, lr=0.001, criterion='cross_entropy', log_interval=150, model='resnet1d', name='exp', model_dir='./model')


In [6]:
model_dir = args.model_dir
# save_dir = increment_path(os.path.join(model_dir, args.name))
# save_dir = increment_path(os.path.join(model_dir, args.name+"_"+args.criterion+"_"+str(args.epochs)))
save_dir = increment_path(os.path.join(model_dir, args.criterion+"_epoch"+str(args.epochs)+"_batch"+str(args.batch_size)+"_"+args.name))

print(model_dir)
print(save_dir)

./model
model/cross_entropy_epoch100_batch128_exp


In [7]:
class PPGDataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label

    def __getitem__(self, index):
        return (torch.as_tensor(self.data[index], dtype=torch.float), torch.as_tensor(self.label[index], dtype=torch.long)) # torch.long

    def __len__(self):
        return len(self.data)

In [8]:
# get dataset
with open('../data_clf/train_clf_scaled_x.pkl', 'rb') as f:
    # train_X = pickle.load(f)
    X = pickle.load(f)
    train_X = np.expand_dims(X, 1)
    train_X = torch.tensor(train_X, dtype = torch.float32)
with open('../data_clf/train_clf_y.pkl', 'rb') as f:
    # train_Y = pickle.load(f)
    Y = pickle.load(f)
    train_Y = np.expand_dims(Y, 1)
    train_Y = torch.tensor(train_Y, dtype = torch.float32)

with open('../data_clf/valid_clf_scaled_x.pkl', 'rb') as f:
    # valid_X = pickle.load(f)
    X = pickle.load(f)
    valid_X = np.expand_dims(X, 1)
    valid_X = torch.tensor(valid_X, dtype = torch.float32)
with open('../data_clf/valid_clf_y.pkl', 'rb') as f:
    # valid_Y = pickle.load(f)
    Y = pickle.load(f)
    valid_Y = np.expand_dims(Y, 1)
    valid_Y = torch.tensor(valid_Y, dtype = torch.float32)

with open('../data_clf/test_clf_scaled_x.pkl', 'rb') as f:
    # test_X = pickle.load(f)
    X = pickle.load(f)
    test_X = np.expand_dims(X, 1)
    test_X = torch.tensor(test_X, dtype = torch.float32)
with open('../data_clf/test_clf_y.pkl', 'rb') as f:
    # test_Y = pickle.load(f)
    Y = pickle.load(f)
    test_Y = np.expand_dims(Y, 1)
    test_Y = torch.tensor(test_Y, dtype = torch.float32)

train_dataset = PPGDataset(train_X, train_Y)
val_dataset = PPGDataset(valid_X, valid_Y)
test_dataset = PPGDataset(test_X, test_Y)

print(type(train_X), type(train_Y))
print(train_X.shape, train_Y.shape)
# print(len(train_X), len(train_Y))
# print(len(train_X[0]), train_X[0])
# print(train_Y[:20])
print()
print(type(valid_X), type(valid_Y))
print(valid_X.shape, valid_Y.shape)
# print(len(valid_X), len(valid_Y))
# print(len(valid_X[0]), valid_X[0])
# print(valid_Y[:20])
print()
print(type(test_X), type(test_Y))
print(test_X.shape, test_Y.shape)
# print(len(test_X), len(test_Y))
# print(len(test_X[0]), test_X[0])
# print(test_Y[:20])
print()

print(train_dataset)
print(val_dataset)
print(test_dataset)

# <class 'torch.Tensor'> <class 'torch.Tensor'>
# torch.Size([189429, 1, 3000]) torch.Size([189429, 1])

# <class 'torch.Tensor'> <class 'torch.Tensor'>
# torch.Size([62746, 1, 3000]) torch.Size([62746, 1])

# <class 'torch.Tensor'> <class 'torch.Tensor'>
# torch.Size([62746, 1, 3000]) torch.Size([62746, 1])

<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([189429, 1, 3000]) torch.Size([189429, 1])

<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([62746, 1, 3000]) torch.Size([62746, 1])

<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([60287, 1, 3000]) torch.Size([60287, 1])

<__main__.PPGDataset object at 0x7a3b56a12ad0>
<__main__.PPGDataset object at 0x7a3b56a12a70>
<__main__.PPGDataset object at 0x7a3b56a12b60>


In [9]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

train_dataloader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    num_workers=multiprocessing.cpu_count() // 2,
    shuffle=True,
    pin_memory=use_cuda,
    drop_last=True,
    )

val_dataloader = DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    num_workers=multiprocessing.cpu_count() // 2,
    shuffle=False,
    pin_memory=use_cuda,
    drop_last=True,
    )

test_dataloader = DataLoader(
    test_dataset,
    batch_size=args.batch_size, # 1024, #
    num_workers=multiprocessing.cpu_count() // 2,
    shuffle=False,
    pin_memory=use_cuda,
    drop_last=False,
)

print(device)
print(train_dataloader)
print(val_dataloader)
print(test_dataloader)

cuda
<torch.utils.data.dataloader.DataLoader object at 0x7a3c6821fa30>
<torch.utils.data.dataloader.DataLoader object at 0x7a3b56a13cd0>
<torch.utils.data.dataloader.DataLoader object at 0x7a3b56a128c0>


In [10]:
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "notebook name here"

# !pip install wandb -qqq
# import wandb
wandb.login()
# !wandb login --relogin

[34m[1mwandb[0m: Currently logged in as: [33mhyenagatha02[0m ([33msixseg[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [11]:
# -- wandb initialize with configuration
config={
    "epochs": args.epochs,
    "batch_size": args.batch_size,
    "learning_rate" : args.lr,
    "architecture" : args.model,
    "criterion" : args.criterion
}
wandb.init(entity='hyenagatha02', project="KAIST GSDS Microdegree - Hypotension", name = str(save_dir.split('/')[-1])+str(args.model)+str(args.epochs), config=config)

[34m[1mwandb[0m: Currently logged in as: [33mhyenagatha02[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
# make model
# device_str = "cuda"
# device = torch.device(device_str if torch.cuda.is_available() else "cpu")

## change the hyper-parameters for your own data
# (n_block, downsample_gap, increasefilter_gap) = (8, 1, 2)
# 34 layer (16*2+2): 16, 2, 4
# 98 layer (48*2+2): 48, 6, 12

model = ResNet1D(
    in_channels=1, # 3000,
    base_filters=128, # 128, # 64 for ResNet1D, 352 for ResNeXt1D
    kernel_size= 16, # kernel_size,
    stride=2, # stride,
    groups=32,
    n_block=16, # 48, # n_block=48,
    n_classes=2, # 4, 3
    downsample_gap=2, # 6, # downsample_gap,
    increasefilter_gap=4, # 12, # increasefilter_gap,
    use_do=True)
model.to(device)
print(train_X.shape, train_Y.shape)
summary(model, (train_X.shape[1], train_X.shape[2])) # device=device
# exit()

model.verbose = False # True
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.lr)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=0.5)
# loss_func = torch.nn.CrossEntropyLoss()
criterion = create_criterion(args.criterion)  # default: cross_entropy

logger = SummaryWriter(log_dir=save_dir)
with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f:
    json.dump(vars(args), f, ensure_ascii=False, indent=4)


torch.Size([189429, 1, 3000]) torch.Size([189429, 1])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1            [-1, 128, 3000]           2,176
   MyConv1dPadSame-2            [-1, 128, 3000]               0
       BatchNorm1d-3            [-1, 128, 3000]             256
              ReLU-4            [-1, 128, 3000]               0
            Conv1d-5            [-1, 128, 3000]           8,320
   MyConv1dPadSame-6            [-1, 128, 3000]               0
       BatchNorm1d-7            [-1, 128, 3000]             256
              ReLU-8            [-1, 128, 3000]               0
           Dropout-9            [-1, 128, 3000]               0
           Conv1d-10            [-1, 128, 3000]           8,320
  MyConv1dPadSame-11            [-1, 128, 3000]               0
       BasicBlock-12            [-1, 128, 3000]               0
      BatchNorm1d-13            [-1, 128, 3000]  

In [None]:
import warnings
warnings.filterwarnings('ignore') # "error", "ignore", "always", "default", "module" or "once"

# early_stopping
early_stopping = EarlyStopping(patience = 7, verbose = True)

# train
best_val_loss = np.inf
best_val_acc = 0
best_val_auc = 0
best_val_recall = 0

# for _ in tqdm(range(args.epochs), desc="epoch", leave=False):
for epoch in range(args.epochs):

    # train loop
    model.train()

    loss_value = 0
    matches = 0
    train_preds_by_batch = []
    train_labels_by_batch = []

    # prog_iter = tqdm(train_dataloader, desc="Training", leave=False)
    print("Training...")
    # for batch_idx, train_batch in enumerate(prog_iter):
    for batch_idx, train_batch in enumerate(train_dataloader):
        input_x, input_y = tuple(t.to(device) for t in train_batch)
        # input_x, input_y = tuple(t for t in train_batch)
        # input_x = input_x.to(device).float()
        # input_y = input_y.to(device).long()

        optimizer.zero_grad()
        # preds = model(input_x)
        outs = model(input_x) # torch.Size([128, 2])
        preds = torch.argmax(outs, dim=-1) # torch.Size([128])
        input_y = input_y.squeeze_() # torch.Size([128, 1]) -> torch.Size([128])

        # print('outs : ', type(outs), outs.dtype, outs.shape, outs)
        # print('preds : ', type(preds), preds.dtype, preds.shape, preds)
        # print('input_y : ', type(input_y), input_y.dtype, input_y.shape, input_y)

        # loss = criterion(preds, input_y) # binary_cross_entropy
        # loss.requires_grad_(True) # binary_cross_entropy
        loss = criterion(outs, input_y) # cross_entropy

        loss.backward()
        optimizer.step()

        loss_value += loss.item()
        matches += (preds == input_y).sum().item()

        train_preds_by_batch.extend(preds.cpu().numpy())
        train_labels_by_batch.extend(input_y.cpu().numpy())

        if (batch_idx + 1) % args.log_interval == 0:
            train_loss = loss_value / args.log_interval
            train_acc = matches / args.batch_size / args.log_interval

            train_auc = roc_auc_score(train_labels_by_batch, train_preds_by_batch)
            train_precision = precision_score(train_labels_by_batch, train_preds_by_batch)
            train_recall = recall_score(train_labels_by_batch, train_preds_by_batch)
            train_f1 = f1_score(train_labels_by_batch, train_preds_by_batch)

            current_lr = get_lr(optimizer)
            print(
                f"Epoch[{epoch+1}/{args.epochs}]({batch_idx + 1}/{len(train_dataloader)}) || "
                f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || training auc {train_auc:4.4f} || "
                f"training precision {train_precision:4.4f} || training recall {train_recall:4.4f} || training f1_score {train_f1:4.4f} || lr {current_lr}"
            )
            logger.add_scalar("Train/loss", train_loss, epoch * len(train_dataloader) + batch_idx)
            logger.add_scalar("Train/accuracy", train_acc, epoch * len(train_dataloader) + batch_idx)
            logger.add_scalar("Train/auc", train_auc, epoch * len(train_dataloader) + batch_idx)
            logger.add_scalar("Train/precision", train_precision, epoch * len(train_dataloader) + batch_idx)
            logger.add_scalar("Train/recall", train_recall, epoch * len(train_dataloader) + batch_idx)
            logger.add_scalar("Train/f1_score", train_f1, epoch * len(train_dataloader) + batch_idx)

            loss_value = 0
            matches = 0
            train_preds_by_batch = []
            train_labels_by_batch = []

    # logging wandb train phase
    wandb.log({
        'Train loss': train_loss,
        'Train acc': train_acc,
        'Train auc': train_auc,
        'Train precision' : train_precision,
        'Train recall': train_recall,
        'Train f1_score' : train_f1,
    })

    # scheduler.step(_)
    # scheduler.step(train_loss)

    # val loop
    with torch.no_grad():

        model.eval()
        val_loss_items = []
        val_acc_items = []

        # val_auc_items = []
        # val_precision_items = []
        # val_recall_items = []
        # val_f1_items = []
        # val_pass_count = 0 # for try ~ except auc

        all_val_preds = []
        all_val_labels = []

        # prog_iter_test = tqdm(val_dataloader, desc="Testing", leave=False)
        print()
        print("Calculating validation results...")
        # for batch_idx, val_batch in enumerate(prog_iter_test):
        for val_batch in val_dataloader:
            input_x, input_y = tuple(t.to(device) for t in val_batch)
            # input_x, input_y = tuple(t for t in val_batch)
            # input_x = input_x.to(device).float()
            # input_y = input_y.to(device).long()

            # preds = model(input_x)
            outs = model(input_x)
            preds = torch.argmax(outs, dim=-1)
            input_y = input_y.squeeze_()

            # print('outs : ', type(outs), outs.dtype, outs.shape, outs)
            # print('preds : ', type(preds), preds.dtype, preds.shape, preds)
            # print('input_y : ', type(input_y), input_y.dtype, input_y.shape, input_y)

            # loss_item = criterion(preds, input_y).item() # binary_cross_entropy
            loss_item = criterion(outs, input_y).item() # cross_entrypy
            acc_item = (input_y == preds).sum().item()

            all_val_preds.extend(preds.cpu().numpy())
            all_val_labels.extend(input_y.cpu().numpy())

            # auc_item = roc_auc_score(input_y.cpu().numpy(), preds.cpu().numpy()).item()
            # try:
            #     auc_item = roc_auc_score(input_y.cpu().numpy(), preds.cpu().numpy()).item()
            # except ValueError:
            #     auc_item = 0 # pass
            #     val_pass_count += 1
            # precision_item = precision_score(input_y.cpu().numpy(), preds.cpu().numpy()).item()
            # recall_item = recall_score(input_y.cpu().numpy(), preds.cpu().numpy()).item()
            # f1_item = f1_score(input_y.cpu().numpy(), preds.cpu().numpy()).item()

            val_loss_items.append(loss_item)
            val_acc_items.append(acc_item)

            # val_auc_items.append(auc_item)
            # val_precision_items.append(precision_item)
            # val_recall_items.append(recall_item)
            # val_f1_items.append(f1_item)

    val_loss = np.sum(val_loss_items) / len(val_dataloader)
    val_acc = np.sum(val_acc_items) / len(val_dataset)
    best_val_loss = min(best_val_loss, val_loss)

    # val_auc = np.sum(val_auc_items) / len(val_dataloader)
    # val_precision = np.sum(val_precision_items) / len(val_dataloader)
    # val_recall = np.sum(val_recall_items) / len(val_dataloader)
    # val_f1 = np.sum(val_f1_items) / len(val_dataloader)
    # best_val_acc = max(best_val_acc, val_acc)
    # best_val_auc = max(best_val_auc, val_auc)

    val_auc = roc_auc_score(all_val_labels, all_val_preds)
    val_precision = precision_score(all_val_labels, all_val_preds)
    val_recall = recall_score (all_val_labels, all_val_preds)
    val_f1 = f1_score (all_val_labels, all_val_preds)
    best_val_acc = max(best_val_acc, val_acc)
    best_val_auc = max(best_val_auc, val_auc)
    best_val_recall = max(best_val_recall, val_recall)

    # early stopping
    early_stopping(val_loss, model) # 현재 과적합 상황 추적
    if early_stopping.early_stop: # 조건 만족 시 조기 종료
        break
    
    if val_auc > best_val_acc:
        print(f"New best model for val AUC : {val_auc:4.4f}! saving the best model..")
        torch.save(model.state_dict(), f"{save_dir}/best.pth")
        best_val_auc = val_auc
    torch.save(model.state_dict(), f"{save_dir}/last.pth")
    print(
        f"[Val] acc : {val_acc:4.2%}, loss : {val_loss:4.4}, auc : {val_auc:4.4f}, recall : {val_recall:4.4f}, precision : {val_precision:4.4f}, f1_score : {val_f1:4.4f} || "
        f"Best acc : {best_val_acc:4.2%}, Best loss : {best_val_loss:4.4}, best AUC : {best_val_auc:4.4f}, best Recall : {best_val_recall:4.4f}"
    )
    logger.add_scalar("Val/loss", val_loss, epoch)
    logger.add_scalar("Val/accuracy", val_acc, epoch)
    logger.add_scalar("Val/auc", val_auc, epoch)
    logger.add_scalar("Val/precision", val_precision, epoch)
    logger.add_scalar("Val/recall", val_recall, epoch)
    logger.add_scalar("Val/f1_score", val_f1, epoch)
    print()

    # logging wandb valid phase
    wandb.log({
        'Valid loss': val_loss,
        'Valid acc': val_acc,
        'Valid auc': val_auc,
        'Valid precision' : val_precision,
        'Valid recall': val_recall,
        'Valid f1_score' : val_f1,
    })

wandb.finish()

In [None]:
# test loop
with torch.no_grad():
    
    model.eval()

    test_loss_items = []
    test_acc_items = []
    all_test_preds = []
    all_test_labels = []
    
    # prog_iter_test = tqdm(test_dataloader, desc="Testing", leave=False)
    print("Testing results...")
    # for batch_idx, test_batch in enumerate(prog_iter_test):
    for test_batch in test_dataloader:
        input_x, input_y = tuple(t.to(device) for t in test_batch)

        outs = model(input_x)
        preds = torch.argmax(outs, dim=-1)
        input_y = input_y.squeeze_()
        
        # loss_item = criterion(preds, input_y).item() # binary_cross_entropy
        loss_item = criterion(outs, input_y).item() # cross_entrypy
        acc_item = (input_y == preds).sum().item()

        test_loss_items.append(loss_item)
        test_acc_items.append(acc_item)

        all_test_preds.extend(preds.cpu().numpy())
        all_test_labels.extend(input_y.cpu().numpy())

test_loss = np.sum(test_loss_items) / len(test_dataloader)
test_acc = np.sum(test_acc_items) / len(test_dataset)
test_auc = roc_auc_score(all_test_labels, all_test_preds)
test_precision = precision_score(all_test_labels, all_test_preds)
test_recall = recall_score (all_test_labels, all_test_preds)
test_f1 = f1_score (all_test_labels, all_test_preds)

# test_loss_items = []
# test_acc_items = []

# for pred, label in zip(preds, test_Y.cpu().numpy()):
#     loss_item = criterion(pred, label).item()
#     acc_item = (pred == label).sum().item()

#     test_loss_items.append(loss_item)
#     test_acc_items.append(acc_item)

# test_loss = np.sum(test_loss_items) / len(test_dataloader)
# test_acc = np.sum(test_acc_items) / len(test_dataset)

# test_auc = roc_auc_score(all_test_labels, all_test_preds)
# test_precision = precision_score(all_test_labels, all_test_preds)
# test_recall = recall_score (all_test_labels, all_test_preds)
# test_f1 = f1_score (all_test_labels, all_test_preds)

print(
    f"[test] acc : {test_acc:4.2%}, loss: {test_loss:4.4}, auc : {test_auc:4.4f}, recall : {test_recall:4.4f}, precision : {test_precision:4.4f}, f1_score : {test_f1:4.4f} ||"
)

output_dir = os.environ.get('SM_OUTPUT_DATA_DIR', './output')
save_path = os.path.join(output_dir, f'./preds.pkl')
pickle.dump((all_test_preds, all_test_labels), open(save_path, 'wb'))
print(f"Inference Done! Inference result saved at {save_path}")
