In [1]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

from helpers import load_distilbert
import json
from sacred import Experiment
from sacred.observers import MongoObserver, FileStorageObserver

import numpy as np
from baselines import ZeroBaselineFactory
from evaluators import ProportionalityEvaluator
from helpers import load_albert_v2, load_imdb_albert_lig_data, extract_token_ids_and_attributions, load_distilbert
from attribution_methods import RandomAttributionValues
from tqdm import tqdm

# hill climbing search of a one-dimensional objective function
from numpy import asarray
from numpy import arange
from numpy.random import randn
from numpy.random import rand
from numpy.random import seed
from matplotlib import pyplot
from evaluators import ProportionalityEvaluator

In [2]:
data = load_imdb_albert_lig_data()
model = load_distilbert(from_notebook=1)

Reusing dataset thermostat (/home/tim/.cache/huggingface/datasets/thermostat/imdb-albert-lig/1.0.1/0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b)


Loading Thermostat configuration: imdb-albert-lig


In [3]:
observations, lig_attributions = zip(*[extract_token_ids_and_attributions(d) for d in data[:1000]])
del data

evaluator = ProportionalityEvaluator(model=model, baseline_factory=ZeroBaselineFactory)

array([0.63955045, 0.36044952], dtype=float32)

In [10]:
def hillclimbing(objective, bounds, n_iterations, step_size):
	# generate an initial point
	solution = bounds[:, 0] + rand(len(bounds)) * (bounds[:, 1] - bounds[:, 0])
	# evaluate the initial point
	solution_eval = objective(solution)
	# run the hill climb
	solutions = list()
	solutions.append(solution)
	for i in range(n_iterations):
		# take a step
		candidate = solution + randn(len(bounds)) * step_size
		# evaluate candidate point
		candidte_eval = objective(candidate)
		# check if we should keep the new point
		if candidte_eval <= solution_eval:
			# store the new point
			solution, solution_eval = candidate, candidte_eval
			# keep track of solutions
			solutions.append(solution)
			# report progress
			print('>%d = %.5f' % (i, solution_eval))
	return [solution, solution_eval, solutions]

In [17]:
# seed the pseudorandom number generator
seed(5)
# define data point
observation = observations[0]
lig_tpn = evaluator.compute_tpn(observation, lig_attributions[0])
# define range for input
bounds = asarray([(0,1) for _ in range(len(observation))])
# define the total iterations
n_iterations = 20
# define the maximum step size
step_size = 0.1
#define objective function
objective = lambda x: evaluator.compute_tpn(observation=observation, attribution_values=x)
# perform the hill climbing search
best, score, solutions = hillclimbing(objective, bounds, n_iterations, step_size)
print('Done!')
print(f'LIG score was {lig_tpn}. Hillclimbing reached {score} after {n_iterations}.')


>0 = 0.03923
>5 = 0.03671
>6 = 0.03581
>7 = 0.03241
>11 = 0.02932
>16 = 0.02322
>18 = 0.02219
>19 = 0.01975
Done!
LIG score was 0.032406174991904214. Hillclimbing reached 0.019751162526147063 after 8.


In [14]:
solutions



[array([0.22199317, 0.87073231, 0.20671916, 0.91861091, 0.48841119,
        0.61174386, 0.76590786, 0.51841799, 0.2968005 , 0.18772123,
        0.08074127, 0.7384403 , 0.44130922, 0.15830987, 0.87993703,
        0.27408646, 0.41423502, 0.29607993, 0.62878791, 0.57983781,
        0.5999292 , 0.26581912, 0.28468588, 0.25358821, 0.32756395,
        0.1441643 , 0.16561286, 0.96393053, 0.96022672, 0.18841466,
        0.02430656, 0.20455555, 0.69984361, 0.77951459, 0.02293309,
        0.57766286, 0.00164217, 0.51547261, 0.63979518, 0.9856244 ,
        0.2590976 , 0.80249689, 0.87048309, 0.92274961, 0.00221421,
        0.46948837, 0.98146874, 0.3989448 , 0.81373248, 0.5464565 ,
        0.77085409, 0.48493107, 0.02911156, 0.08652569, 0.11145381,
        0.25124511, 0.96491529, 0.63176605, 0.8166602 , 0.566082  ,
        0.63535621, 0.81190239, 0.92668262, 0.91262676, 0.82481072,
        0.09420273, 0.36104842, 0.03550903, 0.54635835, 0.79614272,
        0.0511428 , 0.18866774, 0.36547777, 0.24