In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import sys
import os
import collections
import itertools
import pickle
from scipy.integrate import quad
from matplotlib import pyplot as plt
from scipy.special import logit, expit
from scipy.stats import norm

# Random Walk Autocorrelation (RWA)
 - this notebook contains code to compute the random walk autocorrelation on NAS-Bench-201 datasets, and on arbitrary probability density functions

### Compute the RWA from PDFs

In [None]:
# first, define a few PDFs

def sample(v, std=.35, dist='normal'):
    # sample a random point from the nbhd of v
    if dist == 'uniform':
        return np.random.rand()
    elif dist == 'lipschitz':
        return np.random.uniform(max(0, v-std), min(1, v+std))
    elif dist == 'normal':
        # rejection sampling
        u = np.random.rand()
        y = np.random.rand() * pdf(v, v, dist='normal', std=std)
        if y < pdf(u, v, dist='normal', std=std):
            return u
        else:
            return sample(v, std=std, dist='lipschitz')
        
def pdf(u, v, dist='normal', std=.35):
    # return the value of the pdf of nbhd(v) at u
    if dist == 'uniform':
        # uniform distribution on [0,1]
        return 1
    elif dist == 'lipschitz':
        # uniform on [v-std, v+std]
        if v - std <= u and u <= v + std:
            return 1/(min(1, v+std)-max(0, v-std))
        else: 
            return 0
    elif dist == 'normal':
        # normal dist with mean=v, std=std, scaled to be in [0,1]
        return norm.pdf(u, v, std) * (norm.cdf(1, v, std) - norm.cdf(0, v, std)) ** -1

In [None]:
def sample_constrained(cell, std, low=0, high=1, dist='normal'):
    for _ in range(200):
        sampled = sample(cell, std=std, dist=dist)
        if sampled > low and sampled < high:
            return sampled
    return cell

def rwa_from_pdf(trials=100000,
                size=36,
                std=.35,
                low=0,
                high=1):
    # compute RWA for a synthetic dataset based on a PDF
    cell = .25
    window = collections.deque([cell])
    for _ in range(size - 1):
        cell = sample_constrained(cell, std=std, low=low, high=high, dist='normal')
        window.append(cell)
    
    autocorrs = np.zeros((size, trials, 2))
    for t in range(trials):
        if t % (trials/10) == 0:
            print('trial', t)
            #pass
        cell = sample_constrained(cell, std=std, low=low, high=high, dist='normal')
        window.append(cell)
        window.popleft()
        autocorrs[:, t, 0] = np.array([window[-1]] * size)
        autocorrs[:, t, 1] = np.array(window)
    
    corr = []
    for i in range(size):
        corr.append(np.corrcoef(autocorrs[i, :, 0], autocorrs[i, :, 1])[1,0])
    xs = [np.power(size - i - 1, 1/2) for i in range(size)]
    return xs, corr


### compute RWA on the NASBench-201 datasets

In [None]:
sys.path.append(os.path.expanduser('~/naszilla/bananas'))
sys.path.append(os.path.expanduser('~/AutoDL-Projects/lib/'))

from nas_bench_201.cell import Cell
from nas_201_api import NASBench201API as API

In [None]:
def pert(cell, nasbench, low=0, high=100):
    for i in range(200):
        perturbed = Cell(**cell).perturb(nasbench)
        if Cell(**perturbed).get_val_loss(nasbench, dataset=dataset) > low and \
        Cell(**perturbed).get_val_loss(nasbench, dataset=dataset) < high:
            return perturbed
    print('failed')
    return Cell(**cell).perturb(nasbench)

def random_walk(nasbench,
                trials=10000,
                size=36,
                dataset='cifar10',
                save=False,
                low=0,
                high=100):
    
    # if low, high are proportions, compute the losses
    if high < 1:
        losses, _ = pickle.load(open('{}_losses.pkl'.format(dataset), 'rb'))
        losses.sort()
        limits = [losses[0], losses[-1]]
        low, high = [losses[int(low*15625)], losses[int(high*15625)]]
        print('limits', limits)
        print('scaled limits', low, high)
        
    # compute rwa for a dataset in nasbench-201
    cell = Cell.random_cell(nasbench)
    while Cell(**cell).get_val_loss(nasbench, dataset=dataset) < low or \
    Cell(**cell).get_val_loss(nasbench, dataset=dataset) > high:
        cell = Cell.random_cell(nasbench)

    window = collections.deque([cell])
    for _ in range(size - 1):
        cell = pert(cell, nasbench, low=low, high=high)
        window.append(Cell(**cell).get_val_loss(nasbench, dataset=dataset))
    
    autocorrs = np.zeros((size, trials, 2))
    for t in range(trials):
        if t % (trials/10) == 0:
            print('trial', t)

        cell = pert(cell, nasbench, low=low, high=high)
        window.append(Cell(**cell).get_val_loss(nasbench, dataset=dataset))
        window.popleft()
        autocorrs[:, t, 0] = np.array([window[-1]] * size)
        autocorrs[:, t, 1] = np.array(window)
    
    corr = []
    for i in range(size):
        corr.append(np.corrcoef(autocorrs[i, :, 0], autocorrs[i, :, 1])[1,0])
    xs = [np.power(size - i - 1, 1/2) for i in range(size)]
    return xs, corr
    

In [None]:
# generate synthetic data
rwa_normals = {}
for std in [.3, .35, .4]:
    print('starting', std)
    xs, corr = rwa_from_pdf(std=std, trials=10000)
    rwa_normals[std] = corr 
    plt.plot(data['xs'], corr, label='normal pdf, std={}'.format(std))


In [None]:
# download the nas-bench-201 dataset, and then load it with this command
nasbench = API(os.path.expanduser('~/path/to/NAS-Bench-201-v1_0-e61699.pth'))

In [None]:
# compute RWA on the nas-bench-201 datasets
datasets = ['ImageNet16-120', 'cifar100', 'cifar10']
corrs = {}
for dataset in datasets:
    _, corr = random_walk(nasbench, dataset=dataset, save=False, trials=10000, low=.1, high=.9)
    corrs[dataset] = corr
    print('finished', dataset)
