Skip to content

Commit

Permalink
improve graphviz visulization for join
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 1, 2022
1 parent 062242f commit e24fdd5
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 23 deletions.
24 changes: 12 additions & 12 deletions polars/polars-core/src/frame/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,12 +504,12 @@ where
let splitted_a = split_ca(a, n_threads).unwrap();
let splitted_b = split_ca(b, n_threads).unwrap();
match (
left.has_validity(),
right.has_validity(),
left.null_count() == 0,
right.null_count() == 0,
left.chunks.len(),
right.chunks.len(),
) {
(false, false, 1, 1) => {
(true, true, 1, 1) => {
let keys_a = splitted_a
.iter()
.map(|ca| ca.cont_slice().unwrap())
Expand All @@ -520,7 +520,7 @@ where
.collect::<Vec<_>>();
hash_join_tuples_inner(keys_a, keys_b, swap)
}
(false, false, _, _) => {
(true, true, _, _) => {
let keys_a = splitted_a
.iter()
.map(|ca| ca.into_no_null_iter().collect::<Vec<_>>())
Expand Down Expand Up @@ -578,12 +578,12 @@ where
let splitted_a = split_ca(left, n_threads).unwrap();
let splitted_b = split_ca(right, n_threads).unwrap();
match (
left.has_validity(),
right.has_validity(),
left.null_count(),
right.null_count(),
left.chunks.len(),
right.chunks.len(),
) {
(false, false, 1, 1) => {
(0, 0, 1, 1) => {
let keys_a = splitted_a
.iter()
.map(|ca| ca.cont_slice().unwrap())
Expand All @@ -594,7 +594,7 @@ where
.collect::<Vec<_>>();
hash_join_tuples_left(keys_a, keys_b)
}
(false, false, _, _) => {
(0, 0, _, _) => {
let keys_a = splitted_a
.iter()
.map(|ca| ca.into_no_null_iter().collect_trusted::<Vec<_>>())
Expand Down Expand Up @@ -751,8 +751,8 @@ where
let splitted_a = split_ca(a, n_partitions).unwrap();
let splitted_b = split_ca(b, n_partitions).unwrap();

match (a.has_validity(), b.has_validity()) {
(false, false) => {
match (a.null_count(), b.null_count()) {
(0, 0) => {
let iters_a = splitted_a
.iter()
.map(|ca| ca.into_no_null_iter())
Expand Down Expand Up @@ -802,8 +802,8 @@ impl HashJoin<BooleanType> for BooleanChunked {
let splitted_a = split_ca(a, n_partitions).unwrap();
let splitted_b = split_ca(b, n_partitions).unwrap();

match (a.has_validity(), b.has_validity()) {
(false, false) => {
match (a.null_count(), b.null_count()) {
(0, 0) => {
let iters_a = splitted_a
.iter()
.map(|ca| ca.into_no_null_iter())
Expand Down
12 changes: 8 additions & 4 deletions polars/polars-lazy/src/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ impl LogicalPlan {
let (mut branch, id) = id;
match self {
Union { inputs, .. } => {
let current_node = format!("UNION [{:?}]", (branch, id));
self.write_dot(acc_str, prev_node, &current_node, id)?;
for input in inputs {
let current_node = format!("UNION [{:?}]", (branch, id));
self.write_dot(acc_str, prev_node, &current_node, id)?;
input.dot(acc_str, (branch, id + 1), &current_node)?;
branch += 1;
}
Expand Down Expand Up @@ -299,8 +299,12 @@ impl LogicalPlan {
right_on,
..
} => {
let current_node =
format!("JOIN left {:?}; right: {:?} [{}]", left_on, right_on, id);
let current_node = format!(
r#"JOIN
left {:?};
right: {:?} [{}]"#,
left_on, right_on, id
);
self.write_dot(acc_str, prev_node, &current_node, id)?;
input_left.dot(acc_str, (branch + 10, id + 1), &current_node)?;
input_right.dot(acc_str, (branch + 20, id + 1), &current_node)
Expand Down
18 changes: 11 additions & 7 deletions py-polars/polars/internals/lazy_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,18 @@ def show_graph(
show = False

if show and _in_notebook():
from IPython.display import SVG, display

dot = self._ldf.to_dot(optimized)
svg = subprocess.check_output(
["dot", "-Nshape=box", "-Tsvg"], input=f"{dot}".encode()
)
return display(SVG(svg))
try:
from IPython.display import SVG, display

dot = self._ldf.to_dot(optimized)
svg = subprocess.check_output(
["dot", "-Nshape=box", "-Tsvg"], input=f"{dot}".encode()
)
return display(SVG(svg))
except Exception:
raise ImportError(
"Graphviz dot binary should be on your PATH and matplotlib should be installed to show graph."
)
try:
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
Expand Down

0 comments on commit e24fdd5

Please sign in to comment.