In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

from nats_bench import create
from nats_bench.api_utils import ArchResults

from simulated_annealing.nats_bench import Operation, CellTopology, Benchmark, NatsBenchTopology

In [3]:
REPO_ROOT = Path().absolute().parent

In [4]:
BENCHMARK_NAME_MAP = {
    "topology": "NATS-tss-v1_0-3ffb9-simple",
}

In [5]:
SELECTED_BENCHMARK = "topology"

In [6]:
benchmark_dir = BENCHMARK_NAME_MAP[SELECTED_BENCHMARK]
benchmark_dir

'NATS-tss-v1_0-3ffb9-simple'

In [7]:
benchmark_path = REPO_ROOT / "models" /  benchmark_dir
benchmark_path

PosixPath('/home/tomaz/git/Politecnico/Extra/AI-Tech-Lab/simulated-annealing/models/NATS-tss-v1_0-3ffb9-simple')

In [8]:
api = create(
    str(benchmark_path),
    SELECTED_BENCHMARK,
    fast_mode=True,
)
api

[2024-06-15 08:29:40] Try to create the NATS-Bench (topology) api from /home/tomaz/git/Politecnico/Extra/AI-Tech-Lab/simulated-annealing/models/NATS-tss-v1_0-3ffb9-simple with fast_mode=True
[2024-06-15 08:29:41] Create NATS-Bench (topology) done with 0/15625 architectures avaliable.


NATStopology(0/15625 architectures, fast_mode=True, file=NATS-tss-v1_0-3ffb9-simple)

In [9]:
architecture_result: ArchResults = api.query_by_index(0)
architecture_result

[2024-06-15 08:29:41] Call query_by_index with arch_index=0, dataname=None, hp=12
Call query_meta_info_by_index with arch_index=0, hp=12
[2024-06-15 08:29:41] Call clear_params with archive_root=/home/tomaz/git/Politecnico/Extra/AI-Tech-Lab/simulated-annealing/models/NATS-tss-v1_0-3ffb9-simple and index=0


ArchResults(arch-index=0, arch=|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|, 5 runs, clear=True)

In [10]:
architecture_result.get_metrics("cifar10", "ori-test")

{'iepoch': 11.0,
 'loss': 0.8653452147483826,
 'accuracy': 69.55,
 'cur_time': 1.0195916947864352,
 'all_time': 12.235100337437222}

In [11]:
architecture_result.query("cifar10")

{111: ResultsCount(cifar10, arch=|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|, FLOP=15.65M, Param=0.129MB, seed=0111, 1 eval-sets: [ori-test])}

In [12]:
api.simulate_train_eval(0, "cifar10")

[2024-06-15 08:29:41] Call query_index_by_arch with arch=0
[2024-06-15 08:29:41] Call the get_more_info function with index=0, dataset=cifar10-valid, iepoch=None, hp=12, and is_random=True.
[2024-06-15 08:29:41] Call query_index_by_arch with arch=0
[2024-06-15 08:29:41] Call _prepare_info with index=0 skip because it is in arch2infos_dict
[2024-06-15 08:29:41] Call the get_latency function with index=0, dataset=cifar10, and hp=12.
[2024-06-15 08:29:41] Call the get_cost_info function with index=0, dataset=cifar10, and hp=12.
[2024-06-15 08:29:41] Call _prepare_info with index=0 skip because it is in arch2infos_dict
Call query_meta_info_by_index with arch_index=0, hp=12
[2024-06-15 08:29:41] Call _prepare_info with index=0 skip because it is in arch2infos_dict


(64.01599998291016, 0.0139359758611311, 89.2020952247438, 89.2020952247438)

In [13]:
api.get_more_info(0, "cifar10")

[2024-06-15 08:29:41] Call the get_more_info function with index=0, dataset=cifar10, iepoch=None, hp=12, and is_random=True.
[2024-06-15 08:29:41] Call query_index_by_arch with arch=0
[2024-06-15 08:29:41] Call _prepare_info with index=0 skip because it is in arch2infos_dict


{'train-loss': 0.8180574755477905,
 'train-accuracy': 71.08,
 'train-per-time': 14.442185997962952,
 'train-all-time': 173.30623197555542,
 'comment': 'In this dict, train-loss/accuracy/time is the metric on the train+valid sets of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train+valid sets by 12 epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.',
 'test-loss': 0.8653452147483826,
 'test-accuracy': 69.55,
 'test-per-time': 1.0195916947864352,
 'test-all-time': 12.235100337437222}

In [14]:
architecture_result.arch_str

'|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|'

In [15]:
topology = CellTopology(
    Operation.avg_pool_3x3,
    Operation.nor_conv_1x1,
    Operation.skip_connect,
    Operation.nor_conv_1x1,
    Operation.skip_connect,
    Operation.skip_connect,
)
topology

CellTopology(edge_0_to_1=<Operation.avg_pool_3x3: 4>, edge_0_to_2=<Operation.nor_conv_1x1: 2>, edge_1_to_2=<Operation.skip_connect: 1>, edge_0_to_3=<Operation.nor_conv_1x1: 2>, edge_1_to_3=<Operation.skip_connect: 1>, edge_2_to_3=<Operation.skip_connect: 1>)

In [16]:
str(topology)

'|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|'

In [17]:
architecture_result.arch_str == str(topology)

True

In [18]:
topology_index = api.query_index_by_arch(str(topology))
topology_index

[2024-06-15 08:29:41] Call query_index_by_arch with arch=|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|


0

In [19]:
api.get_more_info(topology_index, Benchmark.CIFAR10)

[2024-06-15 08:29:41] Call the get_more_info function with index=0, dataset=Benchmark.CIFAR10, iepoch=None, hp=12, and is_random=True.
[2024-06-15 08:29:41] Call query_index_by_arch with arch=0
[2024-06-15 08:29:41] Call _prepare_info with index=0 skip because it is in arch2infos_dict


{'train-loss': 1.0051389471244812,
 'train-accuracy': 63.92000001708984,
 'train-per-time': 7.221092998981476,
 'train-all-time': 86.65311598777771,
 'comment': 'In this dict, train-loss/accuracy/time is the metric on the train set of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train set by 12 epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.',
 'valid-loss': 1.0478905376434327,
 'valid-accuracy': 62.71199997802734,
 'valid-per-time': 2.5489792369660877,
 'valid-all-time': 30.587750843593053,
 'test-loss': 1.0519337089538574,
 'test-accuracy': 62.82,
 'test-per-time': 1.0195916947864352,
 'test-all-time': 12.235100337437222}

In [20]:
bench = NatsBenchTopology(benchmark_path, Benchmark.CIFAR10)
bench

<simulated_annealing.nats_bench.NatsBenchTopology at 0x7f6dd474c140>

In [21]:
result = bench.query(topology)
result

ArchitectureResult(index=0, train=Metrics(loss=0.9745367371559143, accuracy=64.97199997802734, time_per_epoch=7.221092998981476, time=86.65311598777771), val=Metrics(loss=1.020442234992981, accuracy=64.01599998291016, time_per_epoch=2.5489792369660877, time=30.587750843593053), test=Metrics(loss=1.0203236518859864, accuracy=63.66, time_per_epoch=1.0195916947864352, time=12.235100337437222))