In [None]:
import nni
import torch
from torchvision import transforms
from torchvision.datasets import CIFAR10
import nni.common.blob_utils
nni.common.blob_utils.NNI_BLOB = 'https://repo.dailylime.kr/mirror/nni'
from nni.nas.evaluator.pytorch import DataLoader
from nni.nas.hub.pytorch import DARTS as DartsSpace



# Define model search space

In [None]:
from nni.nas.hub.pytorch import ProxylessNAS
model_space = DartsSpace(
    width=16,           # the initial filters (channel number) for the model
    num_cells=8,        # the number of stacked cells in total
    dataset='cifar'     # to give a hint about input resolution, here is 32x32
)


In [None]:
fast_dev_run = False


In [None]:
import numpy as np
from nni.nas.evaluator.pytorch import Classification
from torch.utils.data import SubsetRandomSampler

transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])

CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

transform_valid = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
valid_data = nni.trace(CIFAR10)(root='./data', train=False, download=True, transform=transform_valid)
valid_loader = DataLoader(valid_data, batch_size=256, num_workers=6)

train_data = nni.trace(CIFAR10)(root='./data', train=True, download=True, transform=transform)

# num_samples = len(train_data)
# indices = np.random.permutation(num_samples)
# split = num_samples // 2

search_train_loader = DataLoader(
    train_data, batch_size=256, num_workers=6,
    sampler=SubsetRandomSampler(indices[:split]),
)

search_valid_loader = DataLoader(
    valid_data, batch_size=352, num_workers=6,
    sampler=SubsetRandomSampler(indices[split:]),
)

evaluator = Classification(
    learning_rate=1e-3,
    weight_decay=1e-4,
    train_dataloaders=search_train_loader,
    val_dataloaders=search_valid_loader,
    max_epochs=10,
    fast_dev_run=fast_dev_run,
    num_classes=10
)

# define search strategy

In [None]:
from nni.nas.strategy import DARTS as DartsStrategy

strategy = DartsStrategy()

In [None]:
import os
os.environ["NNI_CONFIG_DIR"] = "/scratch/pt2295/Assign_IDLS/PROJECT"

# run experiment

In [None]:
from nni.nas.experiment import NasExperiment

experiment = NasExperiment(model_space, evaluator, strategy)
experiment.run()

In [None]:
exported_arch = experiment.export_top_models(formatter='dict')[0]

exported_arch

In [None]:
import json
#save best json
with open("exported_arch/best_darts.json", "w") as outfile: 
    json.dump(exported_arch, outfile)

In [None]:
from nni.nas.space import model_context

with model_context(exported_arch):
    final_model = DartsSpace(width=16, num_cells=8, dataset='cifar')

In [None]:
final_model# this is just the architecture

In [None]:
train_loader = DataLoader(train_data, batch_size=256, num_workers=6)  # Use the original training data

# retrain model with best architecture

In [None]:
max_epochs = 50

evaluator = Classification(
    learning_rate=1e-3,
    weight_decay=1e-4,
    train_dataloaders=train_loader,
    val_dataloaders=valid_loader,
    max_epochs=max_epochs,
    num_classes=10,
    export_onnx=False,          # Disable ONNX export for this experiment
    fast_dev_run=fast_dev_run   # Should be false for fully training
)

evaluator.fit(final_model)

In [None]:
torch.save(final_model.state_dict(), 'exported_arch/best_darts_model.pt')

In [None]:
final_model.state_dict()

In [None]:
final_model.parameters