In [1]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import make_pipeline

import pickle
import math
import time
import re

import numpy as np
import os

In [2]:
from agent import Minus_Agent
from state import State
from game import Minus_Text_Game
import util

In [3]:
c = math.sqrt(2)

## Load dataset

In [4]:
categories = ['alt.atheism', 'soc.religion.christian']
newsgroups_train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'), categories=categories)
newsgroups_test = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'), categories=categories)

In [5]:
# converting text to vectors
vectorizer = TfidfVectorizer()
vectors_train = vectorizer.fit_transform(newsgroups_train.data)
vectors_test = vectorizer.transform(newsgroups_test.data)

## Black Box Model

In [6]:
filename = './models/newsgroup_model.sav'
model = pickle.load(open(filename, 'rb'))

In [7]:
model = make_pipeline(vectorizer, model)

## Evaluation

In [8]:
def preprocess_sample(sample):
    return re.sub(r'[^a-zA-Z ]+', ' ', sample)

In [9]:
def mask(sample, predict, explanation):
    sample = " " + sample + " "
    exp = explanation.copy()
    masked = preprocess_sample(sample)
    arg = np.argmax(predict([sample]))
    n_actions = 0       
    while arg == target:
        if len(exp) <= 0:
            break
        maxword, _ = exp.pop(0)
        masked = re.sub(" " + maxword + " ", ' ', masked)
        arg = np.argmax(predict([masked]))
        n_actions += 1
    if arg == target:
        return -1, masked
    else:
        return n_actions, masked

In [10]:
def evaluate(data):
    start_t = time.time()
    game = Minus_Text_Game(data, model.predict_proba, target)
    agent = Minus_Agent(game, c=math.sqrt(2))
    agent.run(episodes=5000, n_edges=1)

    ranks, _, _, _ = agent.get_best_path_as_list()
    n, masked = mask(data, model.predict_proba, ranks)
    # change in log odds
    c = util.change_in_log_odds(model.predict_proba, data, masked, target, end, lime=True)
    
    end_t = time.time()
    print("n_actions")
    print(n)
    print("change")
    print(c)
    print("time")
    print((end_t - start_t)/60)
    return n, c

In [11]:
target = 0
end = 1

In [12]:
ns = []
cs = []
i = 0
j = 0
while True:
    data = newsgroups_test.data[i]
    t = newsgroups_test.target[i]
    y = model.predict([data])
    i += 1
    if i >= 500:
        break
    if t != target or y != target:
        continue
    if j == 41:
        j += 1
        continue
    print("Sample: " + str(j))
    n, c = evaluate(data)
    ns.append(n)
    ns.append(c)
    j += 1
    
    if j >= 50:
        break

2021-09-28 12:20:09,577 - agent - INFO - XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
2021-09-28 12:20:09,579 - agent - INFO - Round:	0
2021-09-28 12:20:09,579 - agent - INFO - XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX


Sample: 1


KeyboardInterrupt: 

In [None]:
with open('./results/newsgroup/mg_' + str(idx) + '.pkl', 'wb') as f:
    pickle.dump(ns, f)
with open('./results/newsgroup/mg_change_' + str(idx) + '.pkl', 'wb') as f:
    pickle.dump(cs, f)