In [None]:
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 
# Copyright 2025 Anonymized Authors

# Licensed under the Apache License, Version 2.0 (the "License"); 
# you may not use this file except in compliance with the License. 
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
"""
This experiment tests the generalization ability of GENE testing it against 
regularized evolution and REINFORCE as implemented in Dong et al (2021). 

Requirements: 

- This notebook requires that torchvision, torch, tensorflow and numpy be
 installed within the Python environment you are running this script in. 

- This notebook requires the submodule autodl. See setup in README

- make sure to have outputs saved at correct location ../outputs/search-tss/
 
"""
import matplotlib.pyplot as plt

import sys; sys.path.append('..')
from utils.plotting import visualize_curve
from utils.nas_utils import run_algorithm
from nats_bench import create

api = create(None, "tss", fast_mode=True, verbose=False)

In [None]:
algorithms = [
    "GENE",
    "regularized evolution",
    "reinforce",
]
datasets = {
    "cifar10" : 200000,
    "cifar100" : 400000,
    "ImageNet16-120": 120000,
}

n = 1000

for algorithm in algorithms:
    for dataset, budget in datasets.items():
        run_algorithm(algorithm, dataset, budget, n)

In [None]:
# exp4_4: compare on all datasets
ylims = {
        "cifar10": (93.8,94.5),
        "cifar100": (70,73),
        "ImageNet16-120": (42,46)}


exp4_4 = {
    "data": 
    {
        "RE" : ["R-EA-SS10","Dark Gray"],
        "REINFORCE" : ["REINFORCE-0.01","Red Orange"],
        "GENE" : ["GENE","Dark Blue"],

    },
    "config": 
    {
        "limits" : ylims,
        "n" : 1000,
        "confidence_intervall" : True,
        "pvalue" : 0.05,
   }
}

# multiple plots
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
visualize_curve(api, exp4_4, "cifar10-T200000", axes[0])
visualize_curve(api, exp4_4, "cifar100-T400000", axes[1])
visualize_curve(api, exp4_4, "ImageNet16-120-T120000", axes[2])

plt.savefig('nats_all.png', dpi=500, bbox_inches='tight')
plt.show()