Skip to content

Commit

Permalink
rekt visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Feb 12, 2016
1 parent 7a57474 commit d59d246
Show file tree
Hide file tree
Showing 6 changed files with 472 additions and 20 deletions.
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
JSON
389 changes: 389 additions & 0 deletions notebooks/Test_Visualization.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/MCTS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ include("vanilla.jl")
include("dpw_solver.jl")
include("dpw.jl")

#include("visualization.jl")
include("visualization.jl")

end # module
4 changes: 2 additions & 2 deletions src/dpw.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function simulate(dpw::DPWPolicy,s::State,d::Int)
# TODO: reimplement this as a loop instead of a recursion?

# This function returns the reward for one iteration of MCTSdpw
if d == 0
if d == 0 || isterminal(dpw.mdp, s)
return 0.0 # XXX is this right or should it be a rollout?
end
if !haskey(dpw.T,s) # if state is not yet explored, add it to the set of states, perform a rollout
Expand Down Expand Up @@ -113,7 +113,7 @@ function simulate(dpw::DPWPolicy,s::State,d::Int)
end
sanode.V[sp].N += 1
else # sample from transition states proportional to their occurence in the past
warn("sampling states: |V|=$(length(sanode.V)), N=$(sanode.N)")
# warn("sampling states: |V|=$(length(sanode.V)), N=$(sanode.N)")
rn = rand(dpw.solver.rng, 1:sanode.N) # this is where Jon's bug was (I think)
cnt = 0
local sp
Expand Down
77 changes: 65 additions & 12 deletions src/tree_vis.js
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@

// ************** Generate the tree diagram *****************
// var margin = {top: 20, right: 120, bottom: 20, left: 120},
var margin = {top: 20, right: 120, bottom: 20, left: 120},
var margin = {top: 20, right: 120, bottom: 80, left: 120},
width = $("#treevis").width() - margin.right - margin.left,
height = 600 - margin.top - margin.bottom;
height = 600 - margin.top - margin.bottom;
// TODO make height a parameter of TreeVisualizer

var i = 0,
duration = 750,
root;

var tree = d3.layout.tree()
.size([height, width]);
.size([width, height]);

var diagonal = d3.svg.diagonal()
.projection(function(d) { return [d.y, d.x]; });
var diagonal = d3.svg.diagonal();
//.projection(function(d) { return [d.y, d.x]; });
// uncomment above to make the tree go horizontally

var svg = d3.select("#treevis").append("svg")
.attr("width", width + margin.right + margin.left)
Expand All @@ -24,7 +26,7 @@ var svg = d3.select("#treevis").append("svg")
console.log("tree data:");
console.log(treeData[rootID]);
root = createDisplayNode(treeData[rootID]);
root.x0 = height / 2;
root.x0 = width / 2;
root.y0 = 0;
update(root);
console.log("tree should appear");
Expand All @@ -33,7 +35,12 @@ function createDisplayNode(nd) {
var dnode = {"dataID":nd.id,
"children":null,
"_children":null,
"info":nd.info};
"tag":nd.tag,
"type":nd.type,
"N":nd.N};
if (nd.type=="action") {
dnode.Q = nd.Q;
}
return dnode;
}

Expand All @@ -53,6 +60,15 @@ function initializeChildren(d) {
}
}

function tooltip(d) {
var tt = d.tag + "\n" +
"id: " + d.dataID + "\n" +
"N: " + d.N;
if (d.type=="action") {
tt += "\nQ: " + d.Q;
}
return tt;
}
/*
function collapse(d) {
if ("children" in d && d.children) {
Expand All @@ -70,7 +86,13 @@ function update(source) {
links = tree.links(nodes);

// Normalize for fixed-depth.
nodes.forEach(function(d) { d.y = d.depth * 180; });
// nodes.forEach(function(d) { d.y = d.depth * 180; });

/*
var newHeight = height;
nodes.forEach(function(d) { if (d.y > newHeight) {newHeight = d.y;} });
svg.attr("height", height + margin.top + margin.bottom);
*/

// Update the nodes…
var node = svg.selectAll("g.node")
Expand All @@ -79,24 +101,55 @@ function update(source) {
// Enter any new nodes at the parent's previous position.
var nodeEnter = node.enter().append("g")
.attr("class", "node")
.attr("transform", function(d) { return "translate(" + source.y0 + "," + source.x0 + ")"; })
.attr("transform", function(d) { return "translate(" + source.x0 + "," + source.y0 + ")"; })
.on("click", click);

nodeEnter.append("circle")
.attr("r", 1e-6)
.style("fill", function(d) { return d._children ? "lightsteelblue" : "#fff"; });

/*
nodeEnter.append("text")
.attr("x", function(d) { return d.children || d._children ? -13 : 13; })
.attr("dy", ".35em")
.attr("text-anchor", function(d) { return d.children || d._children ? "end" : "start"; })
.text(function(d) { return d.info; })
.text(function(d) { return d.tag; })
.style("fill-opacity", 1e-6);
*/

/*
nodeEnter.append("text")
.attr("y", 25)
.attr("text-anchor", "middle")
.text(function(d) { return d.tag + " N: " + d.N + (d.type=="action"? " Q: " + d.Q.toPrecision(4):""); })
.style("fill-opacity", 1e-6);
*/
var tbox = nodeEnter.append("text")
.attr("y", 25)
.attr("text-anchor", "middle")
.style("fill-opacity", 1e-6);

tbox.append("tspan")
.text( function(d) { return d.tag; } );

tbox.append("tspan")
.attr("dy","1.2em")
.attr("x",0)
.text( function(d) {return "N: " + d.N;} );

tbox.append("tspan")
.attr("dy","1.2em")
.attr("x",0)
.text( function(d) { if (d.type=="action") {return " Q: " + d.Q.toPrecision(4);}});

// tooltip
nodeEnter.append("title").text(tooltip)


// Transition nodes to their new position.
var nodeUpdate = node.transition()
.duration(duration)
.attr("transform", function(d) { return "translate(" + d.y + "," + d.x + ")"; });
.attr("transform", function(d) { return "translate(" + d.x + "," + d.y + ")"; });

nodeUpdate.select("circle")
.attr("r", 10)
Expand All @@ -108,7 +161,7 @@ function update(source) {
// Transition exiting nodes to the parent's new position.
var nodeExit = node.exit().transition()
.duration(duration)
.attr("transform", function(d) { return "translate(" + source.y + "," + source.x + ")"; })
.attr("transform", function(d) { return "translate(" + source.x + "," + source.y + ")"; })
.remove();

nodeExit.select("circle")
Expand Down
19 changes: 14 additions & 5 deletions src/visualization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ type TreeVisualizer{PolicyType}
init_state
end

node_tag(s) = string(s)

function create_json(visualizer::TreeVisualizer{DPWPolicy})
local root_id
next_id = 1
Expand All @@ -17,7 +19,9 @@ function create_json(visualizer::TreeVisualizer{DPWPolicy})
node_dict[next_id] = sd = Dict("id"=>next_id,
"type"=>:state,
"children_ids"=>Array(Int,0),
"info"=>"$next_id $s N:$(sn.N)")
"tag"=>node_tag(s),
"N"=>sn.N
)
if s == visualizer.init_state
root_id = next_id
end
Expand All @@ -29,7 +33,10 @@ function create_json(visualizer::TreeVisualizer{DPWPolicy})
node_dict[next_id] = Dict("id"=>next_id,
"type"=>:action,
"children_ids"=>Array(Int,0),
"info"=>"$next_id $a N:$(san.N), Q:$(@sprintf("%.2g", san.Q))")
"tag"=>node_tag(a),
"N"=>san.N,
"Q"=>san.Q
)
push!(sd["children_ids"], next_id)
sa_dict[(s,a)] = next_id
next_id += 1
Expand All @@ -45,9 +52,11 @@ function create_json(visualizer::TreeVisualizer{DPWPolicy})
push!(sad["children_ids"], s_dict[sp])
else
node_dict[next_id] = Dict("id"=>next_id,
"type"=>:state,
"children_ids"=>Array(Int,0),
"info"=>"$next_id $s N:0")
"type"=>:state,
"children_ids"=>Array(Int,0),
"tag"=>node_tag(sp),
"N"=>0
)
push!(sad["children_ids"], next_id)
next_id += 1
end
Expand Down

0 comments on commit d59d246

Please sign in to comment.