In [1]:
import numpy as np
import optuna
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import savgol_filter
%load_ext autoreload
%autoreload 2


In [2]:
def plot_with_sliding_window_variance(ax, x, y, xlabel, ylabel, title, window_size=5):
    # Ensure x, y are numpy arrays for easier manipulation
    x, y = np.array(x), np.array(y)
    
    # Scatter plot
    ax.scatter(x, y)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.grid(True)
    
    # Sort x and y by x to ensure correct plotting
    sort_idx = np.argsort(x)
    x, y = x[sort_idx], y[sort_idx]
    if len(x) < 5:
        ax.legend(['models'])
        return
    # Linear fit
    slope, intercept = np.polyfit(x, y, 1)
    fit_line = np.polyval([slope, intercept], x)
    ax.plot(x, fit_line, color='red')
    if len(x) > 5:
        # Calculate variance with a sliding window
        variance = np.array([np.var(y[max(0, i-window_size//2):min(len(y), i+window_size//2)]) for i in range(len(y))])
        sigma = np.sqrt(variance)
        
        # Smooth the variance to avoid overly noisy bands
        sigma_smoothed = savgol_filter(sigma, window_length=min(100, len(x) - 1 if len(x) % 2 == 0 else len(x)), polyorder=3)
        
        # Plot sliding window variance (confidence interval) around the line
        ax.fill_between(x, fit_line - sigma_smoothed, fit_line + sigma_smoothed, color='red', alpha=0.2)

        ax.legend(['models', 'trend', 'variance'])
    else:
        ax.legend(['models', 'trend'])

def bar_encoder(ax, encoders, encoder_positions, *args, skip_labels=False):
    n_bars = len(args)
    n_groups = len(encoders)
    width = 0.8 / n_bars  # Calculate the width for each bar

    # Iterate through each metric provided
    for i, (metric, label) in enumerate(args):
        averages = []  # Store average metric values
        error = []  # Store min and max relative to the mean for error bars
        max_vals = []

        # Calculate statistics for each encoder
        for encoder in encoders:
            mask = encoder_positions == encoder
            filtered_metric = metric[mask]
            mean_val = np.mean(filtered_metric)
            averages.append(mean_val)
            # Calculate errors as distances from the mean to the min and max values
            min_val = np.min(filtered_metric)
            max_val = np.max(filtered_metric)
            error.append([mean_val - min_val, max_val - mean_val])
            max_vals.append(max_val)
        
        # Calculate positions for the current set of bars
        positions = np.arange(len(encoders)) + i * width
        
        # Plot the bars with modified error bars representing the min to max range
        bars = ax.bar(positions, averages, width, label=label, alpha=0.7)
        # Add error bars for min and max values
        error = np.array(error).T  # Transpose to match the expected shape for error bars
        ax.errorbar(positions, averages, yerr=error, fmt='none', ecolor='black', capsize=5, alpha=0.7, label='min to max')
        # Annotate each bar with min and max values
        for bar, avg, maximum in zip(bars, averages, max_vals):
            if avg == maximum:
                ax.annotate(f'{maximum:.2f}',
                            xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
                            xytext=(0, 3),  # 3 points vertical offset
                            textcoords="offset points",
                            ha='center', va='bottom', fontsize=8)
            else:
                ax.annotate(f'Max: {maximum:.2f}\nAvg: {avg:.2f}',
                            xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
                            xytext=(0, 3),  # 3 points vertical offset
                            textcoords="offset points",
                            ha='center', va='bottom', fontsize=8)

    ax.set_xticks(np.arange(n_groups) + width * (n_bars - 1) / 2)
    # ax.set_xticklabels([f'{encoder}\nsamples: {np.sum([encoder_positions == encoder])}' for encoder in encoders], rotation=15)
    if not skip_labels:
        ax.set_xticklabels([e.replace("_", "_\n").replace(".", ".\n") for e in encoders])
    ax.legend()

In [16]:
study_name = "sub"
db = "sqlite:///optuna_v3.db"
storage = optuna.storages.RDBStorage(url=db)
study_names = list(map(lambda x: x.study_name, storage.get_all_studies()))
study_names

['subVPSDE']

In [41]:
trials = []
for sn  in study_names:
    if sn.startswith(study_name):
        study = optuna.load_study(study_name=sn, storage=storage)
        trials += study.trials

print(f"loaded {len(trials)} trials")


# values = list([t.user_attrs | t.params |{"auc": (t.values[1] + t.values[2])/2} for t in trials if t.values is not None])
values = list([t.user_attrs for t in trials])# if t.values is not None])

print(len(values), len(trials))

loaded 50 trials
50 50


In [42]:
from collections import defaultdict
scores = defaultdict(list)
for encoder in ['swin', 'deit', 'repvgg', 'resnet50d', 'bit']:

    dataset = 'imagenet'

    intrests = [
                f'{encoder}_{dataset}_nearood_imagenet-o_AUC',
                f'{encoder}_{dataset}_nearood_imagenet-o_FPR_95',

                f'{encoder}_{dataset}_farood_inaturalist_AUC',
                f'{encoder}_{dataset}_farood_inaturalist_FPR_95',

                f'{encoder}_{dataset}_farood_openimageo_AUC',
                f'{encoder}_{dataset}_farood_openimageo_FPR_95',

                f'{encoder}_{dataset}_farood_textures_AUC',
                f'{encoder}_{dataset}_farood_textures_FPR_95',
                ]
    
    full_values_auc = []
    full_values_fpr = []
    for i, v in enumerate(values):
        missing = False
        for k in intrests:
            if k not in v:
                # print(f"missing {k}")
                missing = True
        if not missing:
            full_values_auc.append((i, [v[k] for k in intrests if k.endswith("AUC")]))
            full_values_fpr.append((i, [v[k] for k in intrests if k.endswith("FPR_95")]))

    print(f"encoder {encoder} has {len(full_values_auc)} values")
    if len(full_values_auc) == 0:
        continue

    for (i, fpr), (j, auc) in zip(full_values_fpr, full_values_auc):
        if i != j:
            print("ERROR", i, j)
            continue
        fpr = np.array(fpr).mean()
        auc = np.array(auc).mean()
        scores[encoder].append((i, fpr, auc))

for k, v in scores.items():
    sorted_v = sorted(v, key=lambda x: x[2], reverse=True)
    print(f"{k:10}>", end=" ")
    for i, fpr, auc in sorted_v[:5]:
        fpr *= 100
        auc *= 100
        print(f" {i:2}: ({auc=:.2f} {fpr=:.2f}) ", end=" ")
    print()
        




encoder swin has 35 values
encoder deit has 39 values
encoder repvgg has 37 values
encoder resnet50d has 36 values
encoder bit has 37 values
swin      >  37: (auc=90.85 fpr=45.12)   33: (auc=90.79 fpr=44.83)    3: (auc=90.71 fpr=46.31)   24: (auc=90.55 fpr=46.19)    4: (auc=90.43 fpr=46.65)  
deit      >  36: (auc=82.88 fpr=80.52)   24: (auc=82.47 fpr=79.93)    3: (auc=82.35 fpr=80.30)   26: (auc=82.16 fpr=80.60)    4: (auc=82.13 fpr=78.50)  
repvgg    >  36: (auc=83.40 fpr=64.92)    3: (auc=82.78 fpr=66.05)   10: (auc=82.73 fpr=66.51)    8: (auc=82.58 fpr=65.74)   30: (auc=82.50 fpr=65.79)  
resnet50d >  24: (auc=85.87 fpr=61.54)   10: (auc=85.87 fpr=61.34)    3: (auc=85.59 fpr=62.12)   26: (auc=85.31 fpr=61.63)    4: (auc=84.93 fpr=63.46)  
bit       >  24: (auc=95.30 fpr=22.05)   36: (auc=95.23 fpr=22.11)   10: (auc=95.03 fpr=22.74)   30: (auc=94.97 fpr=22.82)    3: (auc=94.96 fpr=22.88)  


In [60]:
for encoder, (auc, conf) in sorted(encoder_best_auc.items(), key=lambda x: x[1], reverse=True):
    print(f"{encoder}: {auc:.1%}")
    for k, v in conf.items():
        print(f"    {k}: {np.mean(v):.3f}, {v}")

vit: 91.2%
    n_epochs: 366.667, [500, 300, 300]
    bottleneck_channels: 853.333, [256, 1792, 512]
    num_res_blocks: 6.000, [9, 6, 3]
    time_embed_dim: 341.333, [512, 256, 256]
    dropout: 0.222, [0.20507666213092263, 0.24317537435908987, 0.21856901360140613]
    lr: 0.004, [1.2184911425255087e-06, 0.003413265727580557, 0.009921410750025034]
    likelihood_weighting: 1.000, [True, True, True]
    reduce_mean: 0.333, [False, True, False]
    beta_min: 0.453, [0.7707111985158176, 0.04509669122298443, 0.5434724083688577]
    beta_max: 15.000, [15, 20, 10]
dinov2: 86.6%
    n_epochs: 200.000, [100, 300, 200]
    bottleneck_channels: 1109.333, [512, 1792, 1024]
    num_res_blocks: 7.667, [3, 11, 9]
    time_embed_dim: 512.000, [256, 1024, 256]
    dropout: 0.289, [0.19047268055467803, 0.323222360975637, 0.35343530173117726]
    lr: 0.003, [0.00827849803876892, 6.139565305858912e-05, 0.00036565887340390897]
    likelihood_weighting: 0.667, [True, False, True]
    reduce_mean: 0.333, [