In [0]:
!curl -O https://storage.googleapis.com/nasbench/nasbench_only108.tfrecord

!git clone https://github.com/google-research/nasbench
!pip install ./nasbench

from nasbench import api

nasbench = api.NASBench('nasbench_only108.tfrecord')

In [0]:
import copy
import numpy as np
import matplotlib.pyplot as plt
import random


NUM_VERTICES = 7
MAX_EDGES = 9
EDGE_AGENTS = 21
OP_AGENTS = 5   
ALLOWED_OPS = [CONV3X3, CONV1X1, MAXPOOL3X3]
ALLOWED_EDGES = [0, 1]  

In [0]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

In [0]:
agents_edge=np.ones([7,7,2])
agents_op=np.ones([5,3])

elt_edge=[0,1]
elt_op=[0,1,2]

In [0]:
def sample_edge_matrix():
  matrix=np.zeros([7,7])
  for i in range(7):
      for j in range(i+1, 7):
          matrix[i][j] = np.random.choice(elt_edge, 1, p=softmax(agents_edge[i][j]))
          
  return matrix.astype(int)

In [0]:
def sample_operations():
  choice_op=np.zeros([5])
  for i in range(5):
    choice_op[i]=np.random.choice(elt_op, 1, p=softmax(agents_op[i]))  
  
  for i in range(1,6):
    ops[i]=op_mapping[choice_op[i-1]]
    
  return choice_op, ops

In [0]:
op_mapping={}
op_mapping[0]='conv1x1-bn-relu'
op_mapping[1]='conv3x3-bn-relu'
op_mapping[2]='maxpool3x3'
ops=['op']*7
ops[0]='input'
ops[-1]='output'

In [0]:
def update_agents(matrix, choice_op, reward):
  lambda1=1
  lambda2=1
  for i in range(7):
    for j in range(i+1,7):
      agents_edge[i][j][int(matrix[i][j])]+=reward/lambda1
      
  for i in range(5):
    agents_op[i][int(choice_op[i])]+=reward/lambda2
  

In [0]:
def get_accuracy(matrix, ops):
  cell = api.ModelSpec(matrix=mat, ops=ops)
  data = nasbench.query(cell)
  return data['test_accuracy']

In [0]:
def get_manas_matrix(trained):

  if(trained):
    manas_matrix=np.zeros([7,7])
    edges=[]
    for i in range(7):
      for j in range(i+1,7):
        edges.append(softmax(agents_edge[i][j])[1])

    edges.sort(reverse=True)
    min_val=edges[MAX_EDGES]

    for i in range(7):
      for j in range(i+1,7):
        prob_edge_ij=softmax(agents_edge[i][j])[1]
        if(prob_edge_ij)>min_val and prob_edge_ij > 0.5:
          manas_matrix[i][j]=1
  else: 
    for draw_i in range(max_draws):
      manas_matrix=sample_edge_matrix()  
      _, ops=sample_operations()
      if(nasbench.is_valid(api.ModelSpec(matrix=manas_matrix, ops=ops))):
        break

    
  return manas_matrix

In [0]:
def get_manas_ops():
  op_num=np.zeros(5)
  manas_ops=['op']*7
  manas_ops[0]='input'
  manas_ops[-1]='output'
  
  for i in range(5):
    op_num[i]=np.argmax(agents_op[i])
  
  for i in range(1,6):
    manas_ops[i]=op_mapping[op_num[i-1]]
  
  
  return manas_ops

In [0]:
def reset_agents():
  agents_edge.fill(1)
  agents_op.fill(1)

In [0]:
def manas_search():
  manas_acc=[]
  trained=False

  for i in range(n_steps):

    for draw_i in range(max_draws):
      matrix=sample_edge_matrix()  
      choice_op, ops=sample_operations()
      if(nasbench.is_valid(api.ModelSpec(matrix=matrix, ops=ops))):
        break

      if(draw_i==max_draws-1):
        raise Exception('number of draws for edge matrix exceeded 100!')


    reward=get_accuracy(matrix, ops) - mean_acc
    update_agents(matrix, choice_op, reward)
    
    manas_matrix=get_manas_matrix(trained)
    manas_ops=get_manas_ops()
    manas_acc.append(get_accuracy(manas_matrix, manas_ops))
    trained=True
    
  return manas_acc

In [0]:
def random_search():
  r_acc_mean=[]
  r_acc_max=[]
  reset_agents()
  r_acc=[]
  
  for i in range(n_steps):
    
    for draw_i in range(max_draws):
      matrix=sample_edge_matrix()  
      choice_op, ops=sample_operations()
      if(nasbench.is_valid(api.ModelSpec(matrix=matrix, ops=ops))):
        break

      if(draw_i==max_draws-1):
        raise Exception('number of draws for edge matrix exceeded 100!')

    r_acc.append(get_accuracy(matrix, ops))
  
  for i in range(len(r_acc)):
    r_acc_mean.append(np.average(r_acc[:i+1]))
    r_acc_max.append(np.max(r_acc[:i+1]))
    
    
  return r_acc, r_acc_mean, r_acc_max

In [0]:
n_steps=50
mean_acc=0.9
max_draws=20

reset_agents()
manas_acc=manas_search()
r_acc, r_acc_mean, r_acc_max=random_search()

In [0]:
steps=np.arange(1,n_steps+1)

fig, ax1 = plt.subplots(figsize=(10, 8))
colors = ['tab:blue', 'blue', 'tab:red', 'Green']

ax1.set_xlabel('STEPS')
ax1.set_ylabel('ACC')

ax1.plot(steps, r_acc, color=colors[0], linestyle='--', dashes=(5, 4))
ax1.plot(steps, r_acc_mean, color=colors[1])
ax1.plot(steps, r_acc_max, color=colors[3])
ax1.plot(steps, manas_acc, color=colors[2])

plt.gca().legend(('RS','RS (mean)','RS (max)', 'MANAS'), loc='lower right')
ax1.set_facecolor("#ffffb3")

plt.grid(True, color="#93a1a1", alpha=0.3)