<a href="https://colab.research.google.com/github/yangxuan8/EVSIvsLSTM/blob/main/run_baseline_26_02_2024.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')
import sys
sys.path.append('/content/drive/MyDrive')

Mounted at /content/drive


In [2]:
!pip install pytorch_lightning
import torch
import pickle
import argparse
import numpy as np
import torch.nn as nn
from torchmetrics import AUROC, Accuracy
from torch.utils.data import DataLoader, random_split, TensorDataset
import pytorch_lightning as pl
import pandas as pd
sys.path.append('/content/drive/MyDrive/DIME_main/experiments/hev')
import feature_groups
sys.path.append('/content/drive/MyDrive/DIME_main')
import dime
from dime.data_utils import HEVDataset, get_group_matrix, get_xy
from dime import MaskingPretrainer
from dime.utils import StaticMaskLayer1d, ConcreteMask, get_confidence, MaskLayerGrouped, get_mlp_network
import torch.optim as optim
from tqdm import tqdm
from baseline_models.base_model import BaseModel
sys.path.append('/content/drive/MyDrive/DIME_main/experiments')
from baselines import eddi, pvae, iterative, dfs, cae

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.2.0.post0-py3-none-any.whl (800 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m800.9/800.9 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.3.1-py3-none-any.whl (840 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.4/840.4 kB[0m [31m39.0 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.10.1-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics, pytorch_lightning
Successfully installed lightning-utilities-0.10.1 pytorch_lightning-2.2.0.post0 torchmetrics-1.3.1


In [3]:
# Set up command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--method', type=str, default='dime', choices=['cae', 'eddi', 'dfs', 'dime', 'fully_supervised'])
parser.add_argument('--use_feature_costs', default=False, action="store_true")
parser.add_argument('--num_trials', type=int, default=3)

_StoreAction(option_strings=['--num_trials'], dest='num_trials', nargs=None, const=None, default=3, type=<class 'int'>, choices=None, required=False, help=None, metavar=None)

In [5]:
if __name__ == '__main__':
    hev_feature_names = feature_groups.hev_feature_names
    hev_feature_groups = feature_groups.hev_feature_groups
    auc_metric = AUROC(task='multiclass', num_classes=3)
    acc_metric = Accuracy(task='multiclass', num_classes=3)

    # Parse args
    args = parser.parse_known_args()[0]
    device = torch.device('cuda', args.gpu)
    num_trials = args.num_trials
    cols_to_drop = []
    if cols_to_drop is not None:
        hev_feature_names = [item for item in hev_feature_names if str(hev_feature_names.index(item)) not in cols_to_drop]
    # Load dataset
    dataset = HEVDataset(data_dir=1, cols_to_drop=cols_to_drop)
    d_in = dataset.X.shape[1]  # 32
    d_out = len(np.unique(dataset.Y))  # 3

    # Get features and groups
    feature_groups_dict, feature_groups_mask = get_group_matrix(hev_feature_names, hev_feature_groups)
    feature_group_indices = {i : key for i, key in enumerate(feature_groups_dict.keys())}
    feat_to_ind = {key: i for i, key in enumerate(hev_feature_names)}

    num_groups = len(feature_groups_mask)  # 32
    print("Num groups=", num_groups)
    print("Num features=", d_in)

    # Split dataset
    train_dataset, val_dataset, test_dataset = random_split(dataset, [0.8, 0.1, 0.1], generator=torch.Generator().manual_seed(0))
    daataset_dict = dict(train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset)
    f = open('/content/drive/MyDrive/dataset/data/dataset.pkl', "wb", pickle.HIGHEST_PROTOCOL)
    pickle.dump(daataset_dict, f)

    print(f'Train samples = {len(train_dataset)}, val samples = {len(val_dataset)}, test samples = {len(test_dataset)}')

    # Find mean/variance for normalizing
    x, y = get_xy(train_dataset)
    mean = np.mean(x, axis=0)
    std = np.clip(np.std(x, axis=0), 1e-3, None)

    # Normalize via the original dataset
    if args.method == 'eddi':
        dataset.X = (dataset.X - mean)/std
    else:
        dataset.X = dataset.X - mean

    # Set up data loaders.
    train_dataloader = DataLoader(
        train_dataset, batch_size=128, shuffle=True, pin_memory=True,
        drop_last=True, num_workers=4)

    val_dataloader = DataLoader(
        val_dataset, batch_size=128, shuffle=False, pin_memory=True, drop_last=True, num_workers=4)

    test_dataloader = DataLoader(
        test_dataset, batch_size=128, shuffle=False, pin_memory=True, drop_last=True, num_workers=4)


    mask_layer = MaskLayerGrouped(append=True, group_matrix=torch.tensor(feature_groups_mask))
    num_features = [1, 3, 5, 10, 15, 20, 25, 30, 32]
    use_feature_costs = False
    feature_costs = None
    if args.use_feature_costs:
        feature_cost_df = pd.read_csv("data/feature_list_hev-nw.csv")
        feature_costs = [feature_cost_df[feature_cost_df['Feature Name'] == feature]['Cost (Hours)'].item() for feature in list(feature_groups_dict.keys())]
        use_feature_costs = True

    for trial in range(num_trials):

        results_dict = {
            'acc': {},
            'features': {}
        }


        if args.method == 'eddi':
            # Train PVAE.
            bottleneck = 16
            hidden = 128
            dropout = 0.3
            encoder = get_mlp_network(d_in + num_groups, bottleneck * 2)
            decoder = get_mlp_network(bottleneck, d_in)

            pv = pvae.PVAE(encoder, decoder, mask_layer, 20, 'gaussian').to(device)
            pv.fit(
                train_dataloader,
                val_dataloader,
                lr=1e-3,
                nepochs=10,
                verbose=False)

            # Train masked predictor.
            model = get_mlp_network(d_in + num_groups, d_out)
            sampler = None
            # if trial == 0:
            sampler = iterative.UniformSampler(get_xy(train_dataset)[0])  # TODO don't actually need sampler
            iterative_selector = iterative.IterativeSelector(model, mask_layer, sampler).to(device)
            iterative_selector.fit(
                train_dataloader,
                val_dataloader,
                lr=1e-3,
                nepochs=10,
                loss_fn=nn.CrossEntropyLoss(),
                patience=5,
                verbose=True)

            # Set up EDDI feature selection object.
            eddi_selector = eddi.EDDI(pv, model, mask_layer, feature_costs=feature_costs).to(device)

            # Evaluate.
            metrics_dict, cost_dict = eddi_selector.evaluate_multiple(test_dataloader, num_features, auc_metric, verbose=False)
            for num in num_features:
                acc = metrics_dict[num]
                results_dict['acc'][num] = acc
                print(f'Num = {num}, Acc = {100*acc:.2f}')

            print(results_dict)
            print(cost_dict)
            with open(f'/content/drive/MyDrive/dataset/results/hev_{args.method}_trial_{trial}_feature_costs_{use_feature_costs}.pkl', 'wb') as f:
                pickle.dump(results_dict, f)

            with open(f'/content/drive/MyDrive/dataset/results/hev_costs_{args.method}_trial_{trial}_feature_costs_{use_feature_costs}.pkl', 'wb') as f:
                pickle.dump(cost_dict, f)

        if args.method == 'dfs':
            max_features = 32

            # Prepare networks.
            predictor = get_mlp_network(d_in + num_groups, d_out)
            selector = get_mlp_network(d_in + num_groups, num_groups)

            # Pretrain predictor

            pretrain = MaskingPretrainer(
                predictor,
                mask_layer,
                lr=1e-3,
                loss_fn=nn.CrossEntropyLoss(),
                val_loss_fn=auc_metric)

            trainer = pl.Trainer(max_epochs=10)
            trainer.fit(pretrain, train_dataloader, val_dataloader)

            """
            pretrain = MaskingPretrainer(predictor, mask_layer,).to(device)

            pretrain.fit(
                train_dataset,
                val_dataset,
                lr=1e-3,
                val_loss_fn=auc_metric,
                val_loss_mode='max',
                nepochs=10,
                loss_fn=nn.CrossEntropyLoss(),
                patience=5,
                verbose=True)
            """


            # Train selector and predictor jointly.
            gdfs = dfs.GreedyDynamicSelection(selector, predictor, mask_layer).to(device)
            gdfs.fit(
                train_dataloader,
                val_dataloader,
                lr=1e-3,
                nepochs=10,
                max_features=max_features,
                loss_fn=nn.CrossEntropyLoss(),
                patience=3,
                verbose=True)

            # Evaluate.
            for num in num_features:
                auroc_list = []
                acc_list = []

                auroc, acc = gdfs.evaluate(test_dataloader, num, (auc_metric, acc_metric))
                #results_dict['acc'][num] = acc
                #print(f'Num = {num}, Acc = {100*acc:.2f}')
                auroc_list.append(auroc)
                acc_list.append(acc)
                print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')

            with open(f'/content/drive/MyDrive/dataset/results/hev_{args.method}_trial_{trial}.pkl', 'wb') as f:
                pickle.dump(results_dict, f)

            # Save model
            gdfs.cpu()
            torch.save(gdfs, f'/content/drive/MyDrive/dataset/results/hev_{args.method}_trial_{trial}.pt')

        if args.method == 'dime':
            max_features = 32

            # Prepare networks.
            hidden = 128
            dropout = 0.3
            predictor = nn.Sequential(
                      nn.Linear(d_in + num_groups, hidden),
                      nn.ReLU(),
                      nn.Dropout(dropout),
                      nn.Linear(hidden, hidden),
                      nn.ReLU(),
                      nn.Dropout(dropout),
                      nn.Linear(hidden, d_out))

            selector = nn.Sequential(
                      nn.Linear(d_in + num_groups, hidden),
                      nn.ReLU(),
                      nn.Dropout(dropout),
                      nn.Linear(hidden, hidden),
                      nn.ReLU(),
                      nn.Dropout(dropout),
                      nn.Linear(hidden, num_groups))

            # Pretrain predictor

            pretrain = MaskingPretrainer(
                predictor,
                mask_layer,
                lr=1e-3,
                loss_fn=nn.CrossEntropyLoss(),
                val_loss_fn=auc_metric)

            trainer = pl.Trainer(max_epochs=30)
            trainer.fit(pretrain, train_dataloader, val_dataloader)

            # Train selector and predictor jointly.
            print("Training CMI estimator")
            print("-"*8)
            gdfs = dfs.GreedyDynamicSelection(selector, predictor, mask_layer).to(device)
            gdfs.fit(
                train_dataloader,
                val_dataloader,
                lr=1e-3,
                nepochs=30,
                max_features=max_features,
                #eps=0.1,
                loss_fn=nn.CrossEntropyLoss(),
                patience=5,
                #feature_costs=feature_costs,
                verbose=True)

            # Evaluate.
            for num in num_features:
                auroc_list = []
                acc_list = []

                #acc = gdfs.evaluate(test_dataloader, num, auc_metric)
                #results_dict['acc'][num] = acc
                #print(f'Num = {num}, Acc = {100*acc:.2f}')
                auroc, acc = gdfs.evaluate(test_dataloader, num, (auc_metric, acc_metric))
                auroc_list.append(auroc)
                acc_list.append(acc)
                print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')


            with open(f'/content/drive/MyDrive/dataset/results/hev_{args.method}_trial_{trial}.pkl', 'wb') as f:
                pickle.dump(results_dict, f)

            # Save model
            gdfs.cpu()
            torch.save(gdfs, f'/content/drive/MyDrive/dataset/results/hev_{args.method}_trial_{trial}.pt')

        # Train with full input
        if args.method == 'fully_supervised':
            model = get_mlp_network(d_in, d_out).to(device)
            opt = optim.Adam(model.parameters(), lr=1e-3)
            criterion = torch.nn.CrossEntropyLoss()
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                    opt, mode='min', factor=0.2, patience=5,
                    min_lr=1e-5, verbose=True)

            num_bad_epochs = 0
            early_stopping_epochs = 6

            for epoch in range(100):
                model.train()
                train_batch_loss = 0
                val_batch_loss = 0
                val_pred_list = []
                val_y_list = []

                for i, (x, y) in enumerate(tqdm(train_dataloader)):
                    x = x.to(device)
                    y = y.to(device)

                    pred = model(x)
                    train_loss = criterion(pred, y)
                    train_batch_loss += train_loss.item()
                    train_loss.backward()
                    opt.step()
                    model.zero_grad()

                model.eval()

                with torch.no_grad():
                    for i, (x, y) in enumerate(tqdm(val_dataloader)):
                        x = x.to(device)
                        y = y.to(device)

                        pred = model(x)
                        val_loss = criterion(pred, y)
                        val_batch_loss += val_loss.item()
                        val_pred_list.append(pred.cpu())
                        val_y_list.append(y.cpu())

                    scheduler.step(val_batch_loss/len(val_dataloader))
                    val_loss = val_batch_loss/len(val_dataloader)
                    # Check if best model.
                    if val_loss == scheduler.best:
                        # best_model = deepcopy(model)
                        num_bad_epochs = 0
                    else:
                        num_bad_epochs += 1

                    # Early stopping.
                    if num_bad_epochs > early_stopping_epochs:
                        print(f'Stopping early at epoch {epoch+1}')
                        break

                print(f"Epoch: {epoch}, Train Loss: {train_batch_loss/len(train_dataloader)},"
                      + f"Val Loss: {val_batch_loss/len(val_dataloader)},"
                      + f"Val Performance: {auc_metric(torch.cat(val_pred_list), torch.cat(val_y_list))}")

            print("Evaluating on test set")

            model.eval()
            confidence_list = []
            test_pred_list = []
            test_y_list = []
            for i, (x, y) in enumerate(tqdm(test_dataloader)):
                x = x.to(device)
                y = y.to(device)

                pred = model(x)
                test_pred_list.append(pred.cpu())
                test_y_list.append(y.cpu())

                confidence_list.append(get_confidence(pred.cpu()))

            print(f"Test Performance:{auc_metric(torch.cat(test_pred_list), torch.cat(test_y_list))}")
            with open('/content/drive/MyDrive/dataset/confidence.npy', 'wb') as f:
                np.save(f, np.array(torch.cat(confidence_list).detach().numpy()))


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name        | Type             | Params
-------------------------------------------------
0 | model       | Sequential       | 25.2 K
1 | mask_layer  | MaskLayerGrouped | 0     
2 | loss_fn     | CrossEntropyLoss | 0     
3 | val_loss_fn | MulticlassAUROC  | 0     
-------------------------------------------------
25.2 K    Trainable params
0         Non-trainable params
25.2 K    Total params
0.101     Total estimated model params size (MB)


Index(['faultNumber', 'VelocityRef:1', '<xdot>', '<BattSoc>', '<BattPwr>',
       '<Cltch1State>', '<Cltch2State>', '<BattV>', '<TransGear>', '<EngSpd>',
       '<IntkVlvLift>', '<EngTrq>', '<ThrPosPct>', '<WgAreaPct>',
       '<EgrVlvAreaPct>', '<VarCompRatioPos>', '<Acc>', '<Dec>', '<IgSw>',
       '<Chrg>', 'TransGear', 'BrkCmd', 'Cltch1Cmd', '<MotTrq>', '<StartTrq>',
       'StartCmd', 'MotTrqCmd', 'BattCrnt:1', 'MotPwrElec:1', 'MotPwrMech:1',
       'IntkVlvLiftCmd', 'FuelMainSoi', 'FuelFlw'],
      dtype='object')
Num groups= 32
Num features= 32
Train samples = 9749, val samples = 1219, test samples = 1218


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 00004: reducing learning rate of group 0 to 2.0000e-04.


Validation: |          | 0/? [00:00<?, ?it/s]

Training CMI estimator
--------
Starting training with temp = 1.0000



100%|██████████| 76/76 [00:11<00:00,  6.34it/s]
100%|██████████| 9/9 [00:01<00:00,  8.58it/s]


--------Epoch 1 (1 total)--------
Val loss = 2.2617, Zero-temp loss = 2.2604



100%|██████████| 76/76 [00:11<00:00,  6.46it/s]
100%|██████████| 9/9 [00:01<00:00,  8.74it/s]


--------Epoch 2 (2 total)--------
Val loss = 3.1551, Zero-temp loss = 3.1578



100%|██████████| 76/76 [00:12<00:00,  5.96it/s]
100%|██████████| 9/9 [00:01<00:00,  8.15it/s]


--------Epoch 3 (3 total)--------
Val loss = 6.9663, Zero-temp loss = 6.9700



100%|██████████| 76/76 [00:12<00:00,  6.32it/s]
100%|██████████| 9/9 [00:01<00:00,  8.29it/s]


--------Epoch 4 (4 total)--------
Val loss = 2.2339, Zero-temp loss = 2.2368



100%|██████████| 76/76 [00:11<00:00,  6.43it/s]
100%|██████████| 9/9 [00:01<00:00,  8.08it/s]


--------Epoch 5 (5 total)--------
Val loss = 2.2344, Zero-temp loss = 2.2369



100%|██████████| 76/76 [00:11<00:00,  6.36it/s]
100%|██████████| 9/9 [00:01<00:00,  8.39it/s]


--------Epoch 6 (6 total)--------
Val loss = 2.4833, Zero-temp loss = 2.4853



100%|██████████| 76/76 [00:11<00:00,  6.34it/s]
100%|██████████| 9/9 [00:01<00:00,  8.61it/s]


--------Epoch 7 (7 total)--------
Val loss = 2.4386, Zero-temp loss = 2.4391



100%|██████████| 76/76 [00:12<00:00,  6.24it/s]
100%|██████████| 9/9 [00:01<00:00,  8.41it/s]


--------Epoch 8 (8 total)--------
Val loss = 2.7690, Zero-temp loss = 2.7702



100%|██████████| 76/76 [00:11<00:00,  6.42it/s]
100%|██████████| 9/9 [00:01<00:00,  8.29it/s]


--------Epoch 9 (9 total)--------
Val loss = 2.6787, Zero-temp loss = 2.6799



100%|██████████| 76/76 [00:11<00:00,  6.33it/s]
100%|██████████| 9/9 [00:01<00:00,  7.80it/s]


--------Epoch 10 (10 total)--------
Val loss = 3.4468, Zero-temp loss = 3.4481

Epoch 00010: reducing learning rate of group 0 to 2.0000e-04.


100%|██████████| 76/76 [00:12<00:00,  6.33it/s]
100%|██████████| 9/9 [00:01<00:00,  8.29it/s]


--------Epoch 11 (11 total)--------
Val loss = 3.3119, Zero-temp loss = 3.3135

Stopping temp = 1.0000 at epoch 11

Starting training with temp = 0.5623



100%|██████████| 76/76 [00:12<00:00,  6.29it/s]
100%|██████████| 9/9 [00:01<00:00,  7.79it/s]


--------Epoch 1 (12 total)--------
Val loss = 2.9923, Zero-temp loss = 2.9928



100%|██████████| 76/76 [00:11<00:00,  6.39it/s]
100%|██████████| 9/9 [00:01<00:00,  7.83it/s]


--------Epoch 2 (13 total)--------
Val loss = 5.4067, Zero-temp loss = 5.4066



100%|██████████| 76/76 [00:11<00:00,  6.40it/s]
100%|██████████| 9/9 [00:01<00:00,  7.74it/s]


--------Epoch 3 (14 total)--------
Val loss = 4.5835, Zero-temp loss = 4.5837



100%|██████████| 76/76 [00:11<00:00,  6.35it/s]
100%|██████████| 9/9 [00:01<00:00,  7.79it/s]


--------Epoch 4 (15 total)--------
Val loss = 5.0447, Zero-temp loss = 5.0443



100%|██████████| 76/76 [00:11<00:00,  6.38it/s]
100%|██████████| 9/9 [00:01<00:00,  7.93it/s]


--------Epoch 5 (16 total)--------
Val loss = 4.2704, Zero-temp loss = 4.2699



100%|██████████| 76/76 [00:11<00:00,  6.35it/s]
100%|██████████| 9/9 [00:01<00:00,  8.28it/s]


--------Epoch 6 (17 total)--------
Val loss = 6.6022, Zero-temp loss = 6.6016



100%|██████████| 76/76 [00:11<00:00,  6.36it/s]
100%|██████████| 9/9 [00:01<00:00,  8.49it/s]


--------Epoch 7 (18 total)--------
Val loss = 5.9023, Zero-temp loss = 5.9016

Epoch 00007: reducing learning rate of group 0 to 2.0000e-04.


100%|██████████| 76/76 [00:12<00:00,  5.89it/s]
100%|██████████| 9/9 [00:00<00:00, 10.81it/s]


--------Epoch 8 (19 total)--------
Val loss = 5.7588, Zero-temp loss = 5.7580

Stopping temp = 0.5623 at epoch 8

Starting training with temp = 0.3162



100%|██████████| 76/76 [00:12<00:00,  6.11it/s]
100%|██████████| 9/9 [00:00<00:00, 11.45it/s]


--------Epoch 1 (20 total)--------
Val loss = 2.3858, Zero-temp loss = 2.3854



100%|██████████| 76/76 [00:12<00:00,  6.23it/s]
100%|██████████| 9/9 [00:00<00:00, 11.44it/s]


--------Epoch 2 (21 total)--------
Val loss = 5.3220, Zero-temp loss = 5.3220



100%|██████████| 76/76 [00:12<00:00,  6.21it/s]
100%|██████████| 9/9 [00:00<00:00, 11.74it/s]


--------Epoch 3 (22 total)--------
Val loss = 5.3582, Zero-temp loss = 5.3580



100%|██████████| 76/76 [00:12<00:00,  6.27it/s]
100%|██████████| 9/9 [00:00<00:00, 11.63it/s]


--------Epoch 4 (23 total)--------
Val loss = 5.0712, Zero-temp loss = 5.0708



100%|██████████| 76/76 [00:12<00:00,  6.20it/s]
100%|██████████| 9/9 [00:00<00:00, 11.27it/s]


--------Epoch 5 (24 total)--------
Val loss = 6.6405, Zero-temp loss = 6.6405



100%|██████████| 76/76 [00:12<00:00,  6.32it/s]
100%|██████████| 9/9 [00:01<00:00,  8.26it/s]


--------Epoch 6 (25 total)--------
Val loss = 8.6094, Zero-temp loss = 8.6097



100%|██████████| 76/76 [00:12<00:00,  6.03it/s]
100%|██████████| 9/9 [00:00<00:00, 11.66it/s]


--------Epoch 7 (26 total)--------
Val loss = 3.9662, Zero-temp loss = 3.9663

Epoch 00007: reducing learning rate of group 0 to 2.0000e-04.


100%|██████████| 76/76 [00:12<00:00,  6.20it/s]
100%|██████████| 9/9 [00:00<00:00, 11.53it/s]


--------Epoch 8 (27 total)--------
Val loss = 3.5099, Zero-temp loss = 3.5101

Stopping temp = 0.3162 at epoch 8

Starting training with temp = 0.1778



100%|██████████| 76/76 [00:12<00:00,  6.26it/s]
100%|██████████| 9/9 [00:00<00:00, 11.36it/s]


--------Epoch 1 (28 total)--------
Val loss = 1.6132, Zero-temp loss = 1.6132



100%|██████████| 76/76 [00:12<00:00,  6.20it/s]
100%|██████████| 9/9 [00:00<00:00, 11.56it/s]


--------Epoch 2 (29 total)--------
Val loss = 2.7769, Zero-temp loss = 2.7769



100%|██████████| 76/76 [00:12<00:00,  6.28it/s]
100%|██████████| 9/9 [00:00<00:00, 11.47it/s]


--------Epoch 3 (30 total)--------
Val loss = 6.5258, Zero-temp loss = 6.5259



100%|██████████| 76/76 [00:12<00:00,  6.21it/s]
100%|██████████| 9/9 [00:00<00:00, 11.39it/s]


--------Epoch 4 (31 total)--------
Val loss = 5.3294, Zero-temp loss = 5.3292



100%|██████████| 76/76 [00:12<00:00,  6.24it/s]
100%|██████████| 9/9 [00:00<00:00, 11.43it/s]


--------Epoch 5 (32 total)--------
Val loss = 2.7799, Zero-temp loss = 2.7798



100%|██████████| 76/76 [00:12<00:00,  6.25it/s]
100%|██████████| 9/9 [00:00<00:00, 11.54it/s]


--------Epoch 6 (33 total)--------
Val loss = 3.9066, Zero-temp loss = 3.9065



100%|██████████| 76/76 [00:12<00:00,  6.26it/s]
100%|██████████| 9/9 [00:00<00:00, 11.55it/s]


--------Epoch 7 (34 total)--------
Val loss = 2.4868, Zero-temp loss = 2.4870

Epoch 00007: reducing learning rate of group 0 to 2.0000e-04.


100%|██████████| 76/76 [00:12<00:00,  6.28it/s]
100%|██████████| 9/9 [00:00<00:00, 11.34it/s]


--------Epoch 8 (35 total)--------
Val loss = 2.1843, Zero-temp loss = 2.1843

Stopping temp = 0.1778 at epoch 8

Starting training with temp = 0.1000



100%|██████████| 76/76 [00:12<00:00,  6.15it/s]
100%|██████████| 9/9 [00:00<00:00, 11.49it/s]


--------Epoch 1 (36 total)--------
Val loss = 2.4703, Zero-temp loss = 2.4704



100%|██████████| 76/76 [00:12<00:00,  6.28it/s]
100%|██████████| 9/9 [00:00<00:00, 11.36it/s]


--------Epoch 2 (37 total)--------
Val loss = 1.9184, Zero-temp loss = 1.9184



100%|██████████| 76/76 [00:12<00:00,  6.14it/s]
100%|██████████| 9/9 [00:00<00:00, 10.76it/s]


--------Epoch 3 (38 total)--------
Val loss = 2.2237, Zero-temp loss = 2.2237



100%|██████████| 76/76 [00:12<00:00,  6.24it/s]
100%|██████████| 9/9 [00:00<00:00, 11.34it/s]


--------Epoch 4 (39 total)--------
Val loss = 1.2175, Zero-temp loss = 1.2175



100%|██████████| 76/76 [00:12<00:00,  6.17it/s]
100%|██████████| 9/9 [00:00<00:00, 10.41it/s]


--------Epoch 5 (40 total)--------
Val loss = 1.2085, Zero-temp loss = 1.2085



100%|██████████| 76/76 [00:12<00:00,  6.23it/s]
100%|██████████| 9/9 [00:00<00:00, 11.53it/s]


--------Epoch 6 (41 total)--------
Val loss = 0.7713, Zero-temp loss = 0.7714



100%|██████████| 76/76 [00:12<00:00,  6.27it/s]
100%|██████████| 9/9 [00:00<00:00, 11.40it/s]


--------Epoch 7 (42 total)--------
Val loss = 0.7771, Zero-temp loss = 0.7771



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 11.66it/s]


--------Epoch 8 (43 total)--------
Val loss = 0.8065, Zero-temp loss = 0.8066



100%|██████████| 76/76 [00:12<00:00,  6.25it/s]
100%|██████████| 9/9 [00:00<00:00, 11.49it/s]


--------Epoch 9 (44 total)--------
Val loss = 0.6294, Zero-temp loss = 0.6294



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:00<00:00, 11.46it/s]


--------Epoch 10 (45 total)--------
Val loss = 0.6102, Zero-temp loss = 0.6103



100%|██████████| 76/76 [00:12<00:00,  6.25it/s]
100%|██████████| 9/9 [00:00<00:00, 11.28it/s]


--------Epoch 11 (46 total)--------
Val loss = 0.5989, Zero-temp loss = 0.5990



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:00<00:00, 11.35it/s]


--------Epoch 12 (47 total)--------
Val loss = 0.5878, Zero-temp loss = 0.5879



100%|██████████| 76/76 [00:12<00:00,  6.28it/s]
100%|██████████| 9/9 [00:00<00:00, 11.05it/s]


--------Epoch 13 (48 total)--------
Val loss = 0.5894, Zero-temp loss = 0.5891



100%|██████████| 76/76 [00:12<00:00,  6.28it/s]
100%|██████████| 9/9 [00:00<00:00, 11.02it/s]


--------Epoch 14 (49 total)--------
Val loss = 0.5911, Zero-temp loss = 0.5908



100%|██████████| 76/76 [00:12<00:00,  6.20it/s]
100%|██████████| 9/9 [00:00<00:00, 11.55it/s]


--------Epoch 15 (50 total)--------
Val loss = 0.5902, Zero-temp loss = 0.5903



100%|██████████| 76/76 [00:12<00:00,  5.92it/s]
100%|██████████| 9/9 [00:00<00:00, 11.28it/s]


--------Epoch 16 (51 total)--------
Val loss = 0.5902, Zero-temp loss = 0.5902



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:00<00:00, 11.34it/s]


--------Epoch 17 (52 total)--------
Val loss = 0.5879, Zero-temp loss = 0.5879



100%|██████████| 76/76 [00:12<00:00,  6.26it/s]
100%|██████████| 9/9 [00:00<00:00, 11.22it/s]


--------Epoch 18 (53 total)--------
Val loss = 0.5850, Zero-temp loss = 0.5851



100%|██████████| 76/76 [00:12<00:00,  6.22it/s]
100%|██████████| 9/9 [00:00<00:00, 11.15it/s]


--------Epoch 19 (54 total)--------
Val loss = 0.5891, Zero-temp loss = 0.5890



100%|██████████| 76/76 [00:12<00:00,  6.22it/s]
100%|██████████| 9/9 [00:00<00:00, 11.21it/s]


--------Epoch 20 (55 total)--------
Val loss = 0.5879, Zero-temp loss = 0.5879



100%|██████████| 76/76 [00:12<00:00,  6.22it/s]
100%|██████████| 9/9 [00:00<00:00, 11.19it/s]


--------Epoch 21 (56 total)--------
Val loss = 0.5831, Zero-temp loss = 0.5832



100%|██████████| 76/76 [00:12<00:00,  6.12it/s]
100%|██████████| 9/9 [00:00<00:00, 11.39it/s]


--------Epoch 22 (57 total)--------
Val loss = 0.5820, Zero-temp loss = 0.5819



100%|██████████| 76/76 [00:12<00:00,  6.23it/s]
100%|██████████| 9/9 [00:00<00:00, 11.16it/s]


--------Epoch 23 (58 total)--------
Val loss = 0.5834, Zero-temp loss = 0.5832



100%|██████████| 76/76 [00:12<00:00,  6.07it/s]
100%|██████████| 9/9 [00:00<00:00, 11.32it/s]


--------Epoch 24 (59 total)--------
Val loss = 0.5806, Zero-temp loss = 0.5805



100%|██████████| 76/76 [00:12<00:00,  6.26it/s]
100%|██████████| 9/9 [00:00<00:00, 11.45it/s]


--------Epoch 25 (60 total)--------
Val loss = 0.5844, Zero-temp loss = 0.5844



100%|██████████| 76/76 [00:12<00:00,  6.20it/s]
100%|██████████| 9/9 [00:00<00:00, 11.31it/s]


--------Epoch 26 (61 total)--------
Val loss = 0.5782, Zero-temp loss = 0.5781



100%|██████████| 76/76 [00:12<00:00,  6.24it/s]
100%|██████████| 9/9 [00:00<00:00, 11.51it/s]


--------Epoch 27 (62 total)--------
Val loss = 0.5842, Zero-temp loss = 0.5839



100%|██████████| 76/76 [00:12<00:00,  6.25it/s]
100%|██████████| 9/9 [00:00<00:00, 11.17it/s]


--------Epoch 28 (63 total)--------
Val loss = 0.5795, Zero-temp loss = 0.5794



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:00<00:00, 11.24it/s]


--------Epoch 29 (64 total)--------
Val loss = 0.5736, Zero-temp loss = 0.5735



100%|██████████| 76/76 [00:12<00:00,  6.23it/s]
100%|██████████| 9/9 [00:00<00:00, 11.28it/s]

--------Epoch 30 (65 total)--------
Val loss = 0.5744, Zero-temp loss = 0.5742

Stopping temp = 0.1000 at epoch 30






Num = 1, AUROC = 79.67, Acc = 60.76
Num = 3, AUROC = 82.52, Acc = 64.32
Num = 5, AUROC = 83.85, Acc = 64.58
Num = 10, AUROC = 86.23, Acc = 67.19
Num = 15, AUROC = 87.58, Acc = 70.14
Num = 20, AUROC = 88.27, Acc = 71.01
Num = 25, AUROC = 88.92, Acc = 71.18
Num = 30, AUROC = 81.43, Acc = 64.32
Num = 32, AUROC = 50.00, Acc = 32.81


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name        | Type             | Params
-------------------------------------------------
0 | model       | Sequential       | 25.2 K
1 | mask_layer  | MaskLayerGrouped | 0     
2 | loss_fn     | CrossEntropyLoss | 0     
3 | val_loss_fn | MulticlassAUROC  | 0     
-------------------------------------------------
25.2 K    Trainable params
0         Non-trainable params
25.2 K    Total params
0.101     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 00004: reducing learning rate of group 0 to 2.0000e-04.


Validation: |          | 0/? [00:00<?, ?it/s]

Training CMI estimator
--------
Starting training with temp = 1.0000



100%|██████████| 76/76 [00:12<00:00,  6.23it/s]
100%|██████████| 9/9 [00:00<00:00, 11.18it/s]


--------Epoch 1 (1 total)--------
Val loss = 3.1063, Zero-temp loss = 3.1037



100%|██████████| 76/76 [00:12<00:00,  6.23it/s]
100%|██████████| 9/9 [00:00<00:00, 10.77it/s]


--------Epoch 2 (2 total)--------
Val loss = 4.2840, Zero-temp loss = 4.2839



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 11.04it/s]


--------Epoch 3 (3 total)--------
Val loss = 5.5847, Zero-temp loss = 5.5856



100%|██████████| 76/76 [00:12<00:00,  6.17it/s]
100%|██████████| 9/9 [00:00<00:00, 11.21it/s]


--------Epoch 4 (4 total)--------
Val loss = 6.7269, Zero-temp loss = 6.7289



100%|██████████| 76/76 [00:12<00:00,  6.01it/s]
100%|██████████| 9/9 [00:00<00:00, 11.13it/s]


--------Epoch 5 (5 total)--------
Val loss = 8.1456, Zero-temp loss = 8.1469



100%|██████████| 76/76 [00:12<00:00,  6.21it/s]
100%|██████████| 9/9 [00:00<00:00, 11.18it/s]


--------Epoch 6 (6 total)--------
Val loss = 9.8035, Zero-temp loss = 9.8063



100%|██████████| 76/76 [00:12<00:00,  6.18it/s]
100%|██████████| 9/9 [00:00<00:00, 11.06it/s]


--------Epoch 7 (7 total)--------
Val loss = 12.5690, Zero-temp loss = 12.5714

Epoch 00007: reducing learning rate of group 0 to 2.0000e-04.


100%|██████████| 76/76 [00:12<00:00,  6.20it/s]
100%|██████████| 9/9 [00:00<00:00, 11.40it/s]


--------Epoch 8 (8 total)--------
Val loss = 13.5973, Zero-temp loss = 13.5995

Stopping temp = 1.0000 at epoch 8

Starting training with temp = 0.5623



100%|██████████| 76/76 [00:12<00:00,  6.21it/s]
100%|██████████| 9/9 [00:00<00:00, 10.94it/s]


--------Epoch 1 (9 total)--------
Val loss = 6.9502, Zero-temp loss = 6.9505



100%|██████████| 76/76 [00:12<00:00,  5.96it/s]
100%|██████████| 9/9 [00:01<00:00,  8.41it/s]


--------Epoch 2 (10 total)--------
Val loss = 8.3020, Zero-temp loss = 8.3023



100%|██████████| 76/76 [00:12<00:00,  6.27it/s]
100%|██████████| 9/9 [00:00<00:00, 11.25it/s]


--------Epoch 3 (11 total)--------
Val loss = 11.4949, Zero-temp loss = 11.4952



100%|██████████| 76/76 [00:12<00:00,  5.93it/s]
100%|██████████| 9/9 [00:00<00:00, 11.03it/s]


--------Epoch 4 (12 total)--------
Val loss = 13.2701, Zero-temp loss = 13.2703



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 11.18it/s]


--------Epoch 5 (13 total)--------
Val loss = 18.5572, Zero-temp loss = 18.5577



100%|██████████| 76/76 [00:12<00:00,  6.11it/s]
100%|██████████| 9/9 [00:00<00:00, 10.54it/s]


--------Epoch 6 (14 total)--------
Val loss = 15.1223, Zero-temp loss = 15.1213



100%|██████████| 76/76 [00:13<00:00,  5.63it/s]
100%|██████████| 9/9 [00:00<00:00, 10.41it/s]


--------Epoch 7 (15 total)--------
Val loss = 18.5346, Zero-temp loss = 18.5341

Epoch 00007: reducing learning rate of group 0 to 2.0000e-04.


100%|██████████| 76/76 [00:13<00:00,  5.56it/s]
100%|██████████| 9/9 [00:00<00:00, 10.25it/s]


--------Epoch 8 (16 total)--------
Val loss = 20.5125, Zero-temp loss = 20.5112

Stopping temp = 0.5623 at epoch 8

Starting training with temp = 0.3162



100%|██████████| 76/76 [00:12<00:00,  6.03it/s]
100%|██████████| 9/9 [00:00<00:00, 11.08it/s]


--------Epoch 1 (17 total)--------
Val loss = 12.9231, Zero-temp loss = 12.9216



100%|██████████| 76/76 [00:12<00:00,  6.15it/s]
100%|██████████| 9/9 [00:00<00:00, 11.26it/s]


--------Epoch 2 (18 total)--------
Val loss = 14.0326, Zero-temp loss = 14.0318



100%|██████████| 76/76 [00:12<00:00,  6.27it/s]
100%|██████████| 9/9 [00:00<00:00, 10.98it/s]


--------Epoch 3 (19 total)--------
Val loss = 5.6538, Zero-temp loss = 5.6531



100%|██████████| 76/76 [00:12<00:00,  6.24it/s]
100%|██████████| 9/9 [00:00<00:00, 11.17it/s]


--------Epoch 4 (20 total)--------
Val loss = 8.5490, Zero-temp loss = 8.5484



100%|██████████| 76/76 [00:12<00:00,  6.15it/s]
100%|██████████| 9/9 [00:00<00:00, 11.07it/s]


--------Epoch 5 (21 total)--------
Val loss = 3.1416, Zero-temp loss = 3.1411



100%|██████████| 76/76 [00:12<00:00,  6.22it/s]
100%|██████████| 9/9 [00:00<00:00, 11.29it/s]


--------Epoch 6 (22 total)--------
Val loss = 1.9394, Zero-temp loss = 1.9392



100%|██████████| 76/76 [00:12<00:00,  6.11it/s]
100%|██████████| 9/9 [00:00<00:00, 11.04it/s]


--------Epoch 7 (23 total)--------
Val loss = 0.8478, Zero-temp loss = 0.8476



100%|██████████| 76/76 [00:12<00:00,  6.20it/s]
100%|██████████| 9/9 [00:00<00:00, 11.09it/s]


--------Epoch 8 (24 total)--------
Val loss = 0.7160, Zero-temp loss = 0.7158



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 11.02it/s]


--------Epoch 9 (25 total)--------
Val loss = 0.6315, Zero-temp loss = 0.6315



100%|██████████| 76/76 [00:12<00:00,  6.17it/s]
100%|██████████| 9/9 [00:00<00:00,  9.94it/s]


--------Epoch 10 (26 total)--------
Val loss = 0.6273, Zero-temp loss = 0.6272



100%|██████████| 76/76 [00:12<00:00,  6.17it/s]
100%|██████████| 9/9 [00:01<00:00,  7.57it/s]


--------Epoch 11 (27 total)--------
Val loss = 0.6327, Zero-temp loss = 0.6328



100%|██████████| 76/76 [00:12<00:00,  6.22it/s]
100%|██████████| 9/9 [00:01<00:00,  7.79it/s]


--------Epoch 12 (28 total)--------
Val loss = 0.6272, Zero-temp loss = 0.6272



100%|██████████| 76/76 [00:12<00:00,  6.25it/s]
100%|██████████| 9/9 [00:01<00:00,  7.87it/s]


--------Epoch 13 (29 total)--------
Val loss = 0.6225, Zero-temp loss = 0.6223



100%|██████████| 76/76 [00:12<00:00,  6.08it/s]
100%|██████████| 9/9 [00:01<00:00,  7.88it/s]


--------Epoch 14 (30 total)--------
Val loss = 0.6184, Zero-temp loss = 0.6184



100%|██████████| 76/76 [00:12<00:00,  6.23it/s]
100%|██████████| 9/9 [00:01<00:00,  8.23it/s]


--------Epoch 15 (31 total)--------
Val loss = 0.6162, Zero-temp loss = 0.6163



100%|██████████| 76/76 [00:12<00:00,  6.00it/s]
100%|██████████| 9/9 [00:00<00:00, 11.17it/s]


--------Epoch 16 (32 total)--------
Val loss = 0.6156, Zero-temp loss = 0.6157



100%|██████████| 76/76 [00:12<00:00,  6.17it/s]
100%|██████████| 9/9 [00:00<00:00, 11.10it/s]


--------Epoch 17 (33 total)--------
Val loss = 0.6115, Zero-temp loss = 0.6112



100%|██████████| 76/76 [00:12<00:00,  6.14it/s]
100%|██████████| 9/9 [00:00<00:00, 11.03it/s]


--------Epoch 18 (34 total)--------
Val loss = 0.6197, Zero-temp loss = 0.6198



100%|██████████| 76/76 [00:12<00:00,  5.85it/s]
100%|██████████| 9/9 [00:00<00:00, 11.16it/s]


--------Epoch 19 (35 total)--------
Val loss = 0.6082, Zero-temp loss = 0.6081



100%|██████████| 76/76 [00:12<00:00,  6.18it/s]
100%|██████████| 9/9 [00:00<00:00, 10.56it/s]


--------Epoch 20 (36 total)--------
Val loss = 0.6091, Zero-temp loss = 0.6091



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 10.96it/s]


--------Epoch 21 (37 total)--------
Val loss = 0.6156, Zero-temp loss = 0.6156



100%|██████████| 76/76 [00:12<00:00,  6.17it/s]
100%|██████████| 9/9 [00:00<00:00, 11.07it/s]


--------Epoch 22 (38 total)--------
Val loss = 0.6098, Zero-temp loss = 0.6097



100%|██████████| 76/76 [00:12<00:00,  6.12it/s]
100%|██████████| 9/9 [00:00<00:00, 11.14it/s]


--------Epoch 23 (39 total)--------
Val loss = 0.6102, Zero-temp loss = 0.6096



100%|██████████| 76/76 [00:12<00:00,  6.07it/s]
100%|██████████| 9/9 [00:00<00:00, 11.26it/s]


--------Epoch 24 (40 total)--------
Val loss = 0.6012, Zero-temp loss = 0.6010



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:00<00:00, 11.09it/s]


--------Epoch 25 (41 total)--------
Val loss = 0.5924, Zero-temp loss = 0.5917



100%|██████████| 76/76 [00:12<00:00,  6.18it/s]
100%|██████████| 9/9 [00:00<00:00, 11.13it/s]


--------Epoch 26 (42 total)--------
Val loss = 0.6023, Zero-temp loss = 0.6017



100%|██████████| 76/76 [00:12<00:00,  6.12it/s]
100%|██████████| 9/9 [00:00<00:00, 11.16it/s]


--------Epoch 27 (43 total)--------
Val loss = 0.6037, Zero-temp loss = 0.6034



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 11.07it/s]


--------Epoch 28 (44 total)--------
Val loss = 0.6048, Zero-temp loss = 0.6041



100%|██████████| 76/76 [00:12<00:00,  6.27it/s]
100%|██████████| 9/9 [00:00<00:00, 11.13it/s]


--------Epoch 29 (45 total)--------
Val loss = 0.6048, Zero-temp loss = 0.6045



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 11.17it/s]


--------Epoch 30 (46 total)--------
Val loss = 0.5991, Zero-temp loss = 0.5983

Stopping temp = 0.3162 at epoch 30

Starting training with temp = 0.1778



100%|██████████| 76/76 [00:12<00:00,  6.21it/s]
100%|██████████| 9/9 [00:00<00:00, 11.02it/s]


--------Epoch 1 (47 total)--------
Val loss = 0.6016, Zero-temp loss = 0.6013



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 11.06it/s]


--------Epoch 2 (48 total)--------
Val loss = 0.6025, Zero-temp loss = 0.6021



100%|██████████| 76/76 [00:12<00:00,  6.21it/s]
100%|██████████| 9/9 [00:00<00:00, 11.02it/s]


--------Epoch 3 (49 total)--------
Val loss = 0.5974, Zero-temp loss = 0.5972



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:00<00:00, 10.39it/s]


--------Epoch 4 (50 total)--------
Val loss = 0.5991, Zero-temp loss = 0.5987



100%|██████████| 76/76 [00:12<00:00,  6.23it/s]
100%|██████████| 9/9 [00:00<00:00, 11.22it/s]


--------Epoch 5 (51 total)--------
Val loss = 0.5946, Zero-temp loss = 0.5941



100%|██████████| 76/76 [00:12<00:00,  6.18it/s]
100%|██████████| 9/9 [00:00<00:00, 11.14it/s]


--------Epoch 6 (52 total)--------
Val loss = 0.5970, Zero-temp loss = 0.5965



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:00<00:00, 10.89it/s]


--------Epoch 7 (53 total)--------
Val loss = 0.5996, Zero-temp loss = 0.5989



100%|██████████| 76/76 [00:12<00:00,  6.20it/s]
100%|██████████| 9/9 [00:00<00:00, 11.16it/s]


--------Epoch 8 (54 total)--------
Val loss = 0.5985, Zero-temp loss = 0.5981



100%|██████████| 76/76 [00:12<00:00,  6.15it/s]
100%|██████████| 9/9 [00:00<00:00, 10.27it/s]


--------Epoch 9 (55 total)--------
Val loss = 0.5983, Zero-temp loss = 0.5980



100%|██████████| 76/76 [00:12<00:00,  6.24it/s]
100%|██████████| 9/9 [00:00<00:00,  9.38it/s]


--------Epoch 10 (56 total)--------
Val loss = 0.5948, Zero-temp loss = 0.5946



100%|██████████| 76/76 [00:12<00:00,  6.27it/s]
100%|██████████| 9/9 [00:01<00:00,  8.25it/s]


--------Epoch 11 (57 total)--------
Val loss = 0.5891, Zero-temp loss = 0.5887



100%|██████████| 76/76 [00:12<00:00,  6.23it/s]
100%|██████████| 9/9 [00:01<00:00,  7.81it/s]


--------Epoch 12 (58 total)--------
Val loss = 0.5918, Zero-temp loss = 0.5914



100%|██████████| 76/76 [00:12<00:00,  5.88it/s]
100%|██████████| 9/9 [00:01<00:00,  8.14it/s]


--------Epoch 13 (59 total)--------
Val loss = 0.5927, Zero-temp loss = 0.5922



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:01<00:00,  7.86it/s]


--------Epoch 14 (60 total)--------
Val loss = 0.5923, Zero-temp loss = 0.5919



100%|██████████| 76/76 [00:11<00:00,  6.34it/s]
100%|██████████| 9/9 [00:01<00:00,  7.90it/s]


--------Epoch 15 (61 total)--------
Val loss = 0.5909, Zero-temp loss = 0.5905



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:00<00:00,  9.22it/s]


--------Epoch 16 (62 total)--------
Val loss = 0.5883, Zero-temp loss = 0.5881



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:00<00:00, 10.79it/s]


--------Epoch 17 (63 total)--------
Val loss = 0.5928, Zero-temp loss = 0.5924



100%|██████████| 76/76 [00:12<00:00,  6.08it/s]
100%|██████████| 9/9 [00:00<00:00, 11.07it/s]


--------Epoch 18 (64 total)--------
Val loss = 0.5871, Zero-temp loss = 0.5865



100%|██████████| 76/76 [00:12<00:00,  6.20it/s]
100%|██████████| 9/9 [00:00<00:00, 11.00it/s]


--------Epoch 19 (65 total)--------
Val loss = 0.5908, Zero-temp loss = 0.5905



100%|██████████| 76/76 [00:12<00:00,  6.15it/s]
100%|██████████| 9/9 [00:00<00:00, 10.33it/s]


--------Epoch 20 (66 total)--------
Val loss = 0.5874, Zero-temp loss = 0.5871



100%|██████████| 76/76 [00:12<00:00,  6.15it/s]
100%|██████████| 9/9 [00:00<00:00, 11.07it/s]


--------Epoch 21 (67 total)--------
Val loss = 0.5885, Zero-temp loss = 0.5881



100%|██████████| 76/76 [00:12<00:00,  6.20it/s]
100%|██████████| 9/9 [00:00<00:00, 10.78it/s]


--------Epoch 22 (68 total)--------
Val loss = 0.5855, Zero-temp loss = 0.5851



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:00<00:00, 11.11it/s]


--------Epoch 23 (69 total)--------
Val loss = 0.5824, Zero-temp loss = 0.5818



100%|██████████| 76/76 [00:12<00:00,  6.23it/s]
100%|██████████| 9/9 [00:00<00:00, 10.97it/s]


--------Epoch 24 (70 total)--------
Val loss = 0.5838, Zero-temp loss = 0.5835



100%|██████████| 76/76 [00:12<00:00,  6.09it/s]
100%|██████████| 9/9 [00:00<00:00, 11.12it/s]


--------Epoch 25 (71 total)--------
Val loss = 0.5857, Zero-temp loss = 0.5853



100%|██████████| 76/76 [00:12<00:00,  6.18it/s]
100%|██████████| 9/9 [00:00<00:00, 11.05it/s]


--------Epoch 26 (72 total)--------
Val loss = 0.5777, Zero-temp loss = 0.5773



100%|██████████| 76/76 [00:12<00:00,  6.14it/s]
100%|██████████| 9/9 [00:00<00:00, 11.14it/s]


--------Epoch 27 (73 total)--------
Val loss = 0.5825, Zero-temp loss = 0.5822



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 10.96it/s]


--------Epoch 28 (74 total)--------
Val loss = 0.5838, Zero-temp loss = 0.5834



100%|██████████| 76/76 [00:12<00:00,  6.15it/s]
100%|██████████| 9/9 [00:00<00:00, 10.86it/s]


--------Epoch 29 (75 total)--------
Val loss = 0.5902, Zero-temp loss = 0.5897



100%|██████████| 76/76 [00:12<00:00,  6.12it/s]
100%|██████████| 9/9 [00:00<00:00, 11.24it/s]


--------Epoch 30 (76 total)--------
Val loss = 0.5851, Zero-temp loss = 0.5841

Stopping temp = 0.1778 at epoch 30

Starting training with temp = 0.1000



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:00<00:00, 11.12it/s]


--------Epoch 1 (77 total)--------
Val loss = 0.5777, Zero-temp loss = 0.5775



100%|██████████| 76/76 [00:12<00:00,  6.10it/s]
100%|██████████| 9/9 [00:00<00:00,  9.98it/s]


--------Epoch 2 (78 total)--------
Val loss = 0.5820, Zero-temp loss = 0.5816



100%|██████████| 76/76 [00:12<00:00,  6.04it/s]
100%|██████████| 9/9 [00:00<00:00, 10.95it/s]


--------Epoch 3 (79 total)--------
Val loss = 0.5789, Zero-temp loss = 0.5785



100%|██████████| 76/76 [00:12<00:00,  6.11it/s]
100%|██████████| 9/9 [00:00<00:00, 10.75it/s]


--------Epoch 4 (80 total)--------
Val loss = 0.5901, Zero-temp loss = 0.5899



100%|██████████| 76/76 [00:12<00:00,  6.17it/s]
100%|██████████| 9/9 [00:00<00:00, 11.21it/s]


--------Epoch 5 (81 total)--------
Val loss = 0.5814, Zero-temp loss = 0.5811



100%|██████████| 76/76 [00:12<00:00,  6.18it/s]
100%|██████████| 9/9 [00:00<00:00, 10.84it/s]


--------Epoch 6 (82 total)--------
Val loss = 0.5830, Zero-temp loss = 0.5827



100%|██████████| 76/76 [00:12<00:00,  6.01it/s]
100%|██████████| 9/9 [00:01<00:00,  7.90it/s]


--------Epoch 7 (83 total)--------
Val loss = 0.5842, Zero-temp loss = 0.5837

Epoch 00007: reducing learning rate of group 0 to 2.0000e-04.


100%|██████████| 76/76 [00:12<00:00,  5.96it/s]
100%|██████████| 9/9 [00:00<00:00, 10.37it/s]


--------Epoch 8 (84 total)--------
Val loss = 0.5737, Zero-temp loss = 0.5736



100%|██████████| 76/76 [00:12<00:00,  6.15it/s]
100%|██████████| 9/9 [00:01<00:00,  8.82it/s]


--------Epoch 9 (85 total)--------
Val loss = 0.5717, Zero-temp loss = 0.5713



100%|██████████| 76/76 [00:12<00:00,  6.22it/s]
100%|██████████| 9/9 [00:01<00:00,  7.85it/s]


--------Epoch 10 (86 total)--------
Val loss = 0.5715, Zero-temp loss = 0.5710



100%|██████████| 76/76 [00:12<00:00,  6.12it/s]
100%|██████████| 9/9 [00:01<00:00,  8.11it/s]


--------Epoch 11 (87 total)--------
Val loss = 0.5679, Zero-temp loss = 0.5676



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:01<00:00,  7.60it/s]


--------Epoch 12 (88 total)--------
Val loss = 0.5690, Zero-temp loss = 0.5683



100%|██████████| 76/76 [00:12<00:00,  6.21it/s]
100%|██████████| 9/9 [00:01<00:00,  7.84it/s]


--------Epoch 13 (89 total)--------
Val loss = 0.5685, Zero-temp loss = 0.5684



100%|██████████| 76/76 [00:12<00:00,  6.25it/s]
100%|██████████| 9/9 [00:00<00:00,  9.07it/s]


--------Epoch 14 (90 total)--------
Val loss = 0.5660, Zero-temp loss = 0.5655



100%|██████████| 76/76 [00:12<00:00,  6.17it/s]
100%|██████████| 9/9 [00:00<00:00, 10.98it/s]


--------Epoch 15 (91 total)--------
Val loss = 0.5650, Zero-temp loss = 0.5645



100%|██████████| 76/76 [00:12<00:00,  6.18it/s]
100%|██████████| 9/9 [00:00<00:00, 11.02it/s]


--------Epoch 16 (92 total)--------
Val loss = 0.5669, Zero-temp loss = 0.5665



100%|██████████| 76/76 [00:12<00:00,  6.00it/s]
100%|██████████| 9/9 [00:00<00:00, 11.19it/s]


--------Epoch 17 (93 total)--------
Val loss = 0.5648, Zero-temp loss = 0.5644



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 10.87it/s]


--------Epoch 18 (94 total)--------
Val loss = 0.5657, Zero-temp loss = 0.5654



100%|██████████| 76/76 [00:12<00:00,  6.14it/s]
100%|██████████| 9/9 [00:00<00:00, 10.86it/s]


--------Epoch 19 (95 total)--------
Val loss = 0.5632, Zero-temp loss = 0.5628



100%|██████████| 76/76 [00:12<00:00,  6.17it/s]
100%|██████████| 9/9 [00:00<00:00, 10.92it/s]


--------Epoch 20 (96 total)--------
Val loss = 0.5649, Zero-temp loss = 0.5645



100%|██████████| 76/76 [00:12<00:00,  6.18it/s]
100%|██████████| 9/9 [00:00<00:00, 10.75it/s]


--------Epoch 21 (97 total)--------
Val loss = 0.5632, Zero-temp loss = 0.5628



100%|██████████| 76/76 [00:12<00:00,  6.18it/s]
100%|██████████| 9/9 [00:00<00:00, 11.00it/s]


--------Epoch 22 (98 total)--------
Val loss = 0.5622, Zero-temp loss = 0.5617



100%|██████████| 76/76 [00:12<00:00,  6.13it/s]
100%|██████████| 9/9 [00:00<00:00, 10.95it/s]


--------Epoch 23 (99 total)--------
Val loss = 0.5624, Zero-temp loss = 0.5620



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 10.82it/s]


--------Epoch 24 (100 total)--------
Val loss = 0.5638, Zero-temp loss = 0.5633



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 10.91it/s]


--------Epoch 25 (101 total)--------
Val loss = 0.5620, Zero-temp loss = 0.5613



100%|██████████| 76/76 [00:12<00:00,  6.20it/s]
100%|██████████| 9/9 [00:00<00:00, 10.90it/s]


--------Epoch 26 (102 total)--------
Val loss = 0.5601, Zero-temp loss = 0.5596



100%|██████████| 76/76 [00:12<00:00,  6.17it/s]
100%|██████████| 9/9 [00:00<00:00, 11.14it/s]


--------Epoch 27 (103 total)--------
Val loss = 0.5600, Zero-temp loss = 0.5596



100%|██████████| 76/76 [00:12<00:00,  6.17it/s]
100%|██████████| 9/9 [00:00<00:00, 10.89it/s]


--------Epoch 28 (104 total)--------
Val loss = 0.5598, Zero-temp loss = 0.5594



100%|██████████| 76/76 [00:12<00:00,  6.15it/s]
100%|██████████| 9/9 [00:00<00:00, 10.52it/s]


--------Epoch 29 (105 total)--------
Val loss = 0.5588, Zero-temp loss = 0.5584



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:00<00:00, 10.97it/s]


--------Epoch 30 (106 total)--------
Val loss = 0.5601, Zero-temp loss = 0.5597

Stopping temp = 0.1000 at epoch 30

Num = 1, AUROC = 79.77, Acc = 58.33
Num = 3, AUROC = 84.36, Acc = 68.23
Num = 5, AUROC = 86.32, Acc = 69.27
Num = 10, AUROC = 86.43, Acc = 68.66
Num = 15, AUROC = 89.63, Acc = 73.44
Num = 20, AUROC = 88.94, Acc = 73.96
Num = 25, AUROC = 88.46, Acc = 72.22
Num = 30, AUROC = 87.40, Acc = 72.22
Num = 32, AUROC = 50.00, Acc = 32.81


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name        | Type             | Params
-------------------------------------------------
0 | model       | Sequential       | 25.2 K
1 | mask_layer  | MaskLayerGrouped | 0     
2 | loss_fn     | CrossEntropyLoss | 0     
3 | val_loss_fn | MulticlassAUROC  | 0     
-------------------------------------------------
25.2 K    Trainable params
0         Non-trainable params
25.2 K    Total params
0.101     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 00004: reducing learning rate of group 0 to 2.0000e-04.


Validation: |          | 0/? [00:00<?, ?it/s]

Training CMI estimator
--------
Starting training with temp = 1.0000



100%|██████████| 76/76 [00:13<00:00,  5.79it/s]
100%|██████████| 9/9 [00:00<00:00, 10.87it/s]


--------Epoch 1 (1 total)--------
Val loss = 3.2723, Zero-temp loss = 3.2714



100%|██████████| 76/76 [00:12<00:00,  6.13it/s]
100%|██████████| 9/9 [00:00<00:00, 10.67it/s]


--------Epoch 2 (2 total)--------
Val loss = 3.8437, Zero-temp loss = 3.8400



100%|██████████| 76/76 [00:12<00:00,  6.09it/s]
100%|██████████| 9/9 [00:00<00:00, 10.80it/s]


--------Epoch 3 (3 total)--------
Val loss = 3.6982, Zero-temp loss = 3.6979



100%|██████████| 76/76 [00:12<00:00,  6.06it/s]
100%|██████████| 9/9 [00:00<00:00, 10.79it/s]


--------Epoch 4 (4 total)--------
Val loss = 4.0172, Zero-temp loss = 4.0165



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:00<00:00, 10.32it/s]


--------Epoch 5 (5 total)--------
Val loss = 4.3792, Zero-temp loss = 4.3790



100%|██████████| 76/76 [00:12<00:00,  6.11it/s]
100%|██████████| 9/9 [00:00<00:00,  9.97it/s]


--------Epoch 6 (6 total)--------
Val loss = 4.9061, Zero-temp loss = 4.9060



100%|██████████| 76/76 [00:12<00:00,  6.17it/s]
100%|██████████| 9/9 [00:00<00:00, 10.89it/s]


--------Epoch 7 (7 total)--------
Val loss = 5.3639, Zero-temp loss = 5.3641

Epoch 00007: reducing learning rate of group 0 to 2.0000e-04.


100%|██████████| 76/76 [00:12<00:00,  6.12it/s]
100%|██████████| 9/9 [00:00<00:00, 10.88it/s]


--------Epoch 8 (8 total)--------
Val loss = 5.3718, Zero-temp loss = 5.3717

Stopping temp = 1.0000 at epoch 8

Starting training with temp = 0.5623



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 10.68it/s]


--------Epoch 1 (9 total)--------
Val loss = 4.4591, Zero-temp loss = 4.4591



100%|██████████| 76/76 [00:12<00:00,  6.07it/s]
100%|██████████| 9/9 [00:00<00:00, 10.20it/s]


--------Epoch 2 (10 total)--------
Val loss = 6.5289, Zero-temp loss = 6.5288



100%|██████████| 76/76 [00:12<00:00,  6.15it/s]
100%|██████████| 9/9 [00:00<00:00, 10.67it/s]


--------Epoch 3 (11 total)--------
Val loss = 6.9138, Zero-temp loss = 6.9139



100%|██████████| 76/76 [00:12<00:00,  6.13it/s]
100%|██████████| 9/9 [00:00<00:00,  9.54it/s]


--------Epoch 4 (12 total)--------
Val loss = 5.4320, Zero-temp loss = 5.4321



100%|██████████| 76/76 [00:12<00:00,  6.11it/s]
100%|██████████| 9/9 [00:01<00:00,  8.01it/s]


--------Epoch 5 (13 total)--------
Val loss = 4.6166, Zero-temp loss = 4.6164



100%|██████████| 76/76 [00:12<00:00,  6.14it/s]
100%|██████████| 9/9 [00:01<00:00,  7.73it/s]


--------Epoch 6 (14 total)--------
Val loss = 4.8935, Zero-temp loss = 4.8933



100%|██████████| 76/76 [00:12<00:00,  6.11it/s]
100%|██████████| 9/9 [00:01<00:00,  7.68it/s]


--------Epoch 7 (15 total)--------
Val loss = 2.2537, Zero-temp loss = 2.2535



100%|██████████| 76/76 [00:12<00:00,  6.22it/s]
100%|██████████| 9/9 [00:01<00:00,  7.65it/s]


--------Epoch 8 (16 total)--------
Val loss = 20.3390, Zero-temp loss = 20.3395



100%|██████████| 76/76 [00:12<00:00,  6.20it/s]
100%|██████████| 9/9 [00:01<00:00,  8.75it/s]


--------Epoch 9 (17 total)--------
Val loss = 3.5047, Zero-temp loss = 3.5048



100%|██████████| 76/76 [00:12<00:00,  6.15it/s]
100%|██████████| 9/9 [00:00<00:00, 10.90it/s]


--------Epoch 10 (18 total)--------
Val loss = 3.4634, Zero-temp loss = 3.4633



100%|██████████| 76/76 [00:12<00:00,  6.14it/s]
100%|██████████| 9/9 [00:00<00:00, 10.20it/s]


--------Epoch 11 (19 total)--------
Val loss = 3.4283, Zero-temp loss = 3.4285



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 10.56it/s]


--------Epoch 12 (20 total)--------
Val loss = 0.7749, Zero-temp loss = 0.7747



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:00<00:00, 10.62it/s]


--------Epoch 13 (21 total)--------
Val loss = 0.7142, Zero-temp loss = 0.7138



100%|██████████| 76/76 [00:12<00:00,  6.06it/s]
100%|██████████| 9/9 [00:00<00:00, 10.62it/s]


--------Epoch 14 (22 total)--------
Val loss = 1.3395, Zero-temp loss = 1.3395



100%|██████████| 76/76 [00:12<00:00,  6.15it/s]
100%|██████████| 9/9 [00:00<00:00, 10.70it/s]


--------Epoch 15 (23 total)--------
Val loss = 0.6024, Zero-temp loss = 0.6026



100%|██████████| 76/76 [00:13<00:00,  5.74it/s]
100%|██████████| 9/9 [00:00<00:00, 10.70it/s]


--------Epoch 16 (24 total)--------
Val loss = 0.6046, Zero-temp loss = 0.6045



100%|██████████| 76/76 [00:12<00:00,  6.10it/s]
100%|██████████| 9/9 [00:00<00:00, 10.85it/s]


--------Epoch 17 (25 total)--------
Val loss = 0.6003, Zero-temp loss = 0.6001



100%|██████████| 76/76 [00:12<00:00,  6.07it/s]
100%|██████████| 9/9 [00:00<00:00, 10.60it/s]


--------Epoch 18 (26 total)--------
Val loss = 0.5980, Zero-temp loss = 0.5979



100%|██████████| 76/76 [00:12<00:00,  6.04it/s]
100%|██████████| 9/9 [00:00<00:00, 10.69it/s]


--------Epoch 19 (27 total)--------
Val loss = 0.5935, Zero-temp loss = 0.5931



100%|██████████| 76/76 [00:12<00:00,  6.13it/s]
100%|██████████| 9/9 [00:00<00:00, 10.25it/s]


--------Epoch 20 (28 total)--------
Val loss = 0.5954, Zero-temp loss = 0.5956



100%|██████████| 76/76 [00:12<00:00,  6.08it/s]
100%|██████████| 9/9 [00:00<00:00, 10.72it/s]


--------Epoch 21 (29 total)--------
Val loss = 0.7843, Zero-temp loss = 0.7841



100%|██████████| 76/76 [00:12<00:00,  6.15it/s]
100%|██████████| 9/9 [00:00<00:00, 10.70it/s]


--------Epoch 22 (30 total)--------
Val loss = 1.1313, Zero-temp loss = 1.1314



100%|██████████| 76/76 [00:12<00:00,  6.08it/s]
100%|██████████| 9/9 [00:00<00:00, 10.66it/s]


--------Epoch 23 (31 total)--------
Val loss = 1.4261, Zero-temp loss = 1.4259



100%|██████████| 76/76 [00:12<00:00,  6.10it/s]
100%|██████████| 9/9 [00:00<00:00, 10.66it/s]


--------Epoch 24 (32 total)--------
Val loss = 0.7035, Zero-temp loss = 0.7035



100%|██████████| 76/76 [00:12<00:00,  6.04it/s]
100%|██████████| 9/9 [00:00<00:00, 10.58it/s]


--------Epoch 25 (33 total)--------
Val loss = 0.5971, Zero-temp loss = 0.5969

Epoch 00025: reducing learning rate of group 0 to 2.0000e-04.


100%|██████████| 76/76 [00:12<00:00,  6.10it/s]
100%|██████████| 9/9 [00:00<00:00, 10.28it/s]


--------Epoch 26 (34 total)--------
Val loss = 0.5824, Zero-temp loss = 0.5822



100%|██████████| 76/76 [00:12<00:00,  6.07it/s]
100%|██████████| 9/9 [00:01<00:00,  8.42it/s]


--------Epoch 27 (35 total)--------
Val loss = 0.5769, Zero-temp loss = 0.5768



100%|██████████| 76/76 [00:12<00:00,  6.13it/s]
100%|██████████| 9/9 [00:01<00:00,  7.23it/s]


--------Epoch 28 (36 total)--------
Val loss = 0.5782, Zero-temp loss = 0.5779



100%|██████████| 76/76 [00:12<00:00,  5.95it/s]
100%|██████████| 9/9 [00:01<00:00,  7.19it/s]


--------Epoch 29 (37 total)--------
Val loss = 0.5745, Zero-temp loss = 0.5743



100%|██████████| 76/76 [00:12<00:00,  6.16it/s]
100%|██████████| 9/9 [00:01<00:00,  7.56it/s]


--------Epoch 30 (38 total)--------
Val loss = 0.5750, Zero-temp loss = 0.5751

Stopping temp = 0.5623 at epoch 30

Starting training with temp = 0.3162



100%|██████████| 76/76 [00:12<00:00,  6.19it/s]
100%|██████████| 9/9 [00:01<00:00,  8.52it/s]


--------Epoch 1 (39 total)--------
Val loss = 0.5848, Zero-temp loss = 0.5849



100%|██████████| 76/76 [00:12<00:00,  6.09it/s]
100%|██████████| 9/9 [00:00<00:00,  9.71it/s]


--------Epoch 2 (40 total)--------
Val loss = 0.5803, Zero-temp loss = 0.5801



100%|██████████| 76/76 [00:12<00:00,  5.95it/s]
100%|██████████| 9/9 [00:00<00:00, 10.59it/s]


--------Epoch 3 (41 total)--------
Val loss = 0.5811, Zero-temp loss = 0.5811



100%|██████████| 76/76 [00:12<00:00,  6.08it/s]
100%|██████████| 9/9 [00:00<00:00, 10.71it/s]


--------Epoch 4 (42 total)--------
Val loss = 0.6136, Zero-temp loss = 0.6130



100%|██████████| 76/76 [00:12<00:00,  6.08it/s]
100%|██████████| 9/9 [00:00<00:00, 10.73it/s]


--------Epoch 5 (43 total)--------
Val loss = 0.5775, Zero-temp loss = 0.5773



100%|██████████| 76/76 [00:12<00:00,  6.06it/s]
100%|██████████| 9/9 [00:00<00:00, 10.57it/s]


--------Epoch 6 (44 total)--------
Val loss = 0.6543, Zero-temp loss = 0.6539



100%|██████████| 76/76 [00:12<00:00,  6.13it/s]
100%|██████████| 9/9 [00:00<00:00, 10.59it/s]


--------Epoch 7 (45 total)--------
Val loss = 0.5756, Zero-temp loss = 0.5753



100%|██████████| 76/76 [00:12<00:00,  6.14it/s]
100%|██████████| 9/9 [00:00<00:00, 10.20it/s]


--------Epoch 8 (46 total)--------
Val loss = 0.5623, Zero-temp loss = 0.5620



100%|██████████| 76/76 [00:12<00:00,  5.86it/s]
100%|██████████| 9/9 [00:01<00:00,  7.66it/s]


--------Epoch 9 (47 total)--------
Val loss = 0.5679, Zero-temp loss = 0.5677



100%|██████████| 76/76 [00:12<00:00,  6.13it/s]
100%|██████████| 9/9 [00:00<00:00, 10.57it/s]


--------Epoch 10 (48 total)--------
Val loss = 0.5724, Zero-temp loss = 0.5723



100%|██████████| 76/76 [00:12<00:00,  6.06it/s]
100%|██████████| 9/9 [00:00<00:00, 10.69it/s]


--------Epoch 11 (49 total)--------
Val loss = 0.5627, Zero-temp loss = 0.5622



100%|██████████| 76/76 [00:12<00:00,  6.05it/s]
100%|██████████| 9/9 [00:00<00:00, 10.62it/s]


--------Epoch 12 (50 total)--------
Val loss = 0.5615, Zero-temp loss = 0.5611



100%|██████████| 76/76 [00:12<00:00,  6.04it/s]
100%|██████████| 9/9 [00:00<00:00, 10.40it/s]


--------Epoch 13 (51 total)--------
Val loss = 0.5626, Zero-temp loss = 0.5625



100%|██████████| 76/76 [00:12<00:00,  6.13it/s]
100%|██████████| 9/9 [00:00<00:00, 10.55it/s]


--------Epoch 14 (52 total)--------
Val loss = 0.5605, Zero-temp loss = 0.5601



100%|██████████| 76/76 [00:12<00:00,  5.92it/s]
100%|██████████| 9/9 [00:00<00:00,  9.51it/s]


--------Epoch 15 (53 total)--------
Val loss = 0.5555, Zero-temp loss = 0.5551



100%|██████████| 76/76 [00:12<00:00,  6.02it/s]
100%|██████████| 9/9 [00:01<00:00,  8.92it/s]


--------Epoch 16 (54 total)--------
Val loss = 0.5659, Zero-temp loss = 0.5657



100%|██████████| 76/76 [00:12<00:00,  6.07it/s]
100%|██████████| 9/9 [00:01<00:00,  7.87it/s]


--------Epoch 17 (55 total)--------
Val loss = 0.5631, Zero-temp loss = 0.5631



100%|██████████| 76/76 [00:12<00:00,  6.14it/s]
100%|██████████| 9/9 [00:01<00:00,  7.38it/s]


--------Epoch 18 (56 total)--------
Val loss = 0.5752, Zero-temp loss = 0.5748



100%|██████████| 76/76 [00:12<00:00,  6.12it/s]
100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


--------Epoch 19 (57 total)--------
Val loss = 0.5592, Zero-temp loss = 0.5586



100%|██████████| 76/76 [00:12<00:00,  6.15it/s]
100%|██████████| 9/9 [00:01<00:00,  8.04it/s]


--------Epoch 20 (58 total)--------
Val loss = 0.5551, Zero-temp loss = 0.5547



100%|██████████| 76/76 [00:12<00:00,  6.08it/s]
100%|██████████| 9/9 [00:00<00:00, 10.62it/s]


--------Epoch 21 (59 total)--------
Val loss = 0.5573, Zero-temp loss = 0.5572



100%|██████████| 76/76 [00:12<00:00,  6.14it/s]
100%|██████████| 9/9 [00:00<00:00, 10.37it/s]


--------Epoch 22 (60 total)--------
Val loss = 0.5590, Zero-temp loss = 0.5589



100%|██████████| 76/76 [00:12<00:00,  6.12it/s]
100%|██████████| 9/9 [00:00<00:00, 10.62it/s]


--------Epoch 23 (61 total)--------
Val loss = 0.8948, Zero-temp loss = 0.8944



100%|██████████| 76/76 [00:12<00:00,  6.13it/s]
100%|██████████| 9/9 [00:00<00:00, 10.33it/s]


--------Epoch 24 (62 total)--------
Val loss = 1.3052, Zero-temp loss = 1.3047



100%|██████████| 76/76 [00:12<00:00,  6.08it/s]
100%|██████████| 9/9 [00:00<00:00, 10.52it/s]


--------Epoch 25 (63 total)--------
Val loss = 0.8505, Zero-temp loss = 0.8500



100%|██████████| 76/76 [00:12<00:00,  6.09it/s]
100%|██████████| 9/9 [00:00<00:00, 10.07it/s]


--------Epoch 26 (64 total)--------
Val loss = 1.0114, Zero-temp loss = 1.0111

Epoch 00026: reducing learning rate of group 0 to 2.0000e-04.


100%|██████████| 76/76 [00:12<00:00,  5.93it/s]
100%|██████████| 9/9 [00:00<00:00, 10.57it/s]


--------Epoch 27 (65 total)--------
Val loss = 0.7080, Zero-temp loss = 0.7077

Stopping temp = 0.3162 at epoch 27

Starting training with temp = 0.1778



100%|██████████| 76/76 [00:12<00:00,  5.89it/s]
100%|██████████| 9/9 [00:01<00:00,  7.36it/s]


--------Epoch 1 (66 total)--------
Val loss = 0.6031, Zero-temp loss = 0.6028



100%|██████████| 76/76 [00:12<00:00,  6.00it/s]
100%|██████████| 9/9 [00:00<00:00, 10.79it/s]


--------Epoch 2 (67 total)--------
Val loss = 0.5632, Zero-temp loss = 0.5628



100%|██████████| 76/76 [00:12<00:00,  6.14it/s]
100%|██████████| 9/9 [00:00<00:00, 10.61it/s]


--------Epoch 3 (68 total)--------
Val loss = 0.6016, Zero-temp loss = 0.6014



100%|██████████| 76/76 [00:12<00:00,  6.07it/s]
100%|██████████| 9/9 [00:00<00:00, 10.53it/s]


--------Epoch 4 (69 total)--------
Val loss = 0.9999, Zero-temp loss = 0.9999



100%|██████████| 76/76 [00:12<00:00,  6.13it/s]
100%|██████████| 9/9 [00:00<00:00, 10.60it/s]


--------Epoch 5 (70 total)--------
Val loss = 0.5610, Zero-temp loss = 0.5603



100%|██████████| 76/76 [00:12<00:00,  6.06it/s]
100%|██████████| 9/9 [00:00<00:00, 10.49it/s]


--------Epoch 6 (71 total)--------
Val loss = 0.5560, Zero-temp loss = 0.5559



100%|██████████| 76/76 [00:12<00:00,  6.12it/s]
100%|██████████| 9/9 [00:00<00:00, 10.39it/s]


--------Epoch 7 (72 total)--------
Val loss = 0.5525, Zero-temp loss = 0.5523



100%|██████████| 76/76 [00:12<00:00,  6.02it/s]
100%|██████████| 9/9 [00:00<00:00,  9.73it/s]


--------Epoch 8 (73 total)--------
Val loss = 0.5605, Zero-temp loss = 0.5601



100%|██████████| 76/76 [00:12<00:00,  6.10it/s]
100%|██████████| 9/9 [00:01<00:00,  8.15it/s]


--------Epoch 9 (74 total)--------
Val loss = 0.5562, Zero-temp loss = 0.5560



100%|██████████| 76/76 [00:12<00:00,  6.12it/s]
100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


--------Epoch 10 (75 total)--------
Val loss = 0.5564, Zero-temp loss = 0.5559



100%|██████████| 76/76 [00:12<00:00,  6.10it/s]
100%|██████████| 9/9 [00:01<00:00,  7.72it/s]


--------Epoch 11 (76 total)--------
Val loss = 0.5590, Zero-temp loss = 0.5584



100%|██████████| 76/76 [00:12<00:00,  6.14it/s]
100%|██████████| 9/9 [00:01<00:00,  7.30it/s]


--------Epoch 12 (77 total)--------
Val loss = 0.5566, Zero-temp loss = 0.5561



100%|██████████| 76/76 [00:12<00:00,  6.08it/s]
100%|██████████| 9/9 [00:00<00:00,  9.87it/s]


--------Epoch 13 (78 total)--------
Val loss = 0.5642, Zero-temp loss = 0.5639

Epoch 00013: reducing learning rate of group 0 to 2.0000e-04.


100%|██████████| 76/76 [00:12<00:00,  6.08it/s]
100%|██████████| 9/9 [00:00<00:00, 10.52it/s]


--------Epoch 14 (79 total)--------
Val loss = 0.5694, Zero-temp loss = 0.5693

Stopping temp = 0.1778 at epoch 14

Starting training with temp = 0.1000



100%|██████████| 76/76 [00:12<00:00,  6.05it/s]
100%|██████████| 9/9 [00:00<00:00, 10.45it/s]


--------Epoch 1 (80 total)--------
Val loss = 0.5539, Zero-temp loss = 0.5537



100%|██████████| 76/76 [00:12<00:00,  6.12it/s]
100%|██████████| 9/9 [00:00<00:00, 10.53it/s]


--------Epoch 2 (81 total)--------
Val loss = 0.5570, Zero-temp loss = 0.5569



100%|██████████| 76/76 [00:12<00:00,  6.08it/s]
100%|██████████| 9/9 [00:00<00:00, 10.48it/s]


--------Epoch 3 (82 total)--------
Val loss = 0.5578, Zero-temp loss = 0.5577



100%|██████████| 76/76 [00:12<00:00,  6.10it/s]
100%|██████████| 9/9 [00:00<00:00, 10.48it/s]


--------Epoch 4 (83 total)--------
Val loss = 0.5599, Zero-temp loss = 0.5598



100%|██████████| 76/76 [00:12<00:00,  6.12it/s]
100%|██████████| 9/9 [00:00<00:00,  9.91it/s]


--------Epoch 5 (84 total)--------
Val loss = 0.5611, Zero-temp loss = 0.5610



100%|██████████| 76/76 [00:12<00:00,  6.10it/s]
100%|██████████| 9/9 [00:00<00:00, 10.22it/s]


--------Epoch 6 (85 total)--------
Val loss = 0.5586, Zero-temp loss = 0.5585



100%|██████████| 76/76 [00:12<00:00,  6.08it/s]
100%|██████████| 9/9 [00:00<00:00, 10.50it/s]


--------Epoch 7 (86 total)--------
Val loss = 0.5610, Zero-temp loss = 0.5610

Epoch 00007: reducing learning rate of group 0 to 2.0000e-04.


100%|██████████| 76/76 [00:12<00:00,  6.05it/s]
100%|██████████| 9/9 [00:00<00:00, 10.41it/s]

--------Epoch 8 (87 total)--------
Val loss = 0.5549, Zero-temp loss = 0.5545

Stopping temp = 0.1000 at epoch 8






Num = 1, AUROC = 80.61, Acc = 61.89
Num = 3, AUROC = 82.04, Acc = 64.32
Num = 5, AUROC = 85.60, Acc = 70.05
Num = 10, AUROC = 87.59, Acc = 72.05
Num = 15, AUROC = 90.94, Acc = 77.17
Num = 20, AUROC = 91.68, Acc = 78.91
Num = 25, AUROC = 87.26, Acc = 73.70
Num = 30, AUROC = 84.44, Acc = 67.19
Num = 32, AUROC = 56.92, Acc = 40.10
