In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..', 'src')))

In [2]:
import json
import random
from agent import RecommendationAgent

In [3]:
agent = RecommendationAgent(n_statistics=25)

agent.add_domain('healthcare')
agent.add_domain('asset performance management')
agent.add_domain('cyber security')
agent.add_domain("e-commerce")

agent.add_visualization('nodelink topology driven')
agent.add_visualization('nodelink attribute driven faceting')
agent.add_visualization('nodelink attribute driven positioning')
agent.add_visualization('adjacency matrix')
agent.add_visualization('quilts')
agent.add_visualization('biofabric')
agent.add_visualization('treemap')
agent.add_visualization('sunburst')

In [4]:
with open('../data/training.json', 'r') as file:
    training = json.load(file)

with open('../data/testing.json', 'r') as file:
    testing = json.load(file)

In [5]:
test_interval = 100
accuracies = []
intervals = []

for i, sample in enumerate(training):
    domain = sample['domain']
    statistics = sample['statistics']
    prediction = sample['visualization']
    state_id = agent.state_id(statistics)
    recommendation_index = agent.visualizations.index(prediction)
    agent.update_q_value(domain, state_id, recommendation_index, reward=5, require_feedback=True)

    correct_recommendations = 0
    if (i + 1) % test_interval == 0:
        test_samples = random.sample(testing, 10)
        for test_sample in test_samples:
            test_domain = test_sample['domain']
            test_statistics = test_sample['statistics']
            label = test_sample['visualization']
            test_state_id = agent.state_id(test_statistics)
            test_recommended_index, _ = agent.recommend_visualization(test_domain, test_state_id)
            if agent.visualizations[test_recommended_index] == label:
                correct_recommendations += 1
        accuracy = correct_recommendations / len(test_samples)
        accuracies.append(accuracy)
        intervals.append(i + 1)
        print('Round: ', i, '', accuracy)


Round:  99  0.1
Round:  199  0.0
Round:  299  0.2
Round:  399  0.3
Round:  499  0.1
Round:  599  0.1
Round:  699  0.1
Round:  799  0.1
Round:  899  0.2
Round:  999  0.2
Round:  1099  0.3
Round:  1199  0.0
Round:  1299  0.1
Round:  1399  0.0
Round:  1499  0.0
Round:  1599  0.1
Round:  1699  0.0
Round:  1799  0.0
Round:  1899  0.1
Round:  1999  0.1
Round:  2099  0.1
Round:  2199  0.2
Round:  2299  0.1
Round:  2399  0.1
Round:  2499  0.0
Round:  2599  0.0
Round:  2699  0.0
Round:  2799  0.2
Round:  2899  0.2
Round:  2999  0.0
Round:  3099  0.1
Round:  3199  0.1
Round:  3299  0.1
Round:  3399  0.0
Round:  3499  0.0
Round:  3599  0.2
Round:  3699  0.0
Round:  3799  0.0
Round:  3899  0.3
Round:  3999  0.1
Round:  4099  0.1
Round:  4199  0.0
Round:  4299  0.2
Round:  4399  0.1
Round:  4499  0.0
Round:  4599  0.0
Round:  4699  0.0
Round:  4799  0.0
Round:  4899  0.1
Round:  4999  0.0
Round:  5099  0.2
Round:  5199  0.1
Round:  5299  0.1
Round:  5399  0.0
Round:  5499  0.1
Round:  5599  0.2
Rou