# 01c Plotting Utils
> Handy functions for visualizing our datasets, complete with the vector fields we assign to them. Mostly builds off of matplotlib, with nicer syntax for using `.quiver`.

In [None]:
# default_exp datasets
# hide
from nbdev.showdoc import *
import numpy as np
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Pointset Plotting

Functions to visualize pointsets and manifolds with flow.

In [None]:
# export
import matplotlib.pyplot as plt


def plot_directed_2d(X, flows, labels=None, mask_prob=0.5, cmap="viridis", ax=None):
    num_nodes = X.shape[0]
    alpha_points, alpha_arrows = (0.1, 1) if labels is None else (1, 0.1)
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot()
    ax.scatter(X[:, 0], X[:, 1], marker=".", c=labels, cmap=cmap, alpha=alpha_points)
    mask = np.random.rand(num_nodes) > mask_prob
    ax.quiver(X[mask, 0], X[mask, 1], flows[mask, 0], flows[mask, 1], alpha=alpha_arrows)
    ax.set_aspect("equal")
    if ax is None:
        plt.show()


In [None]:
# export
def plot_origin_3d(ax, xlim, ylim, zlim):
    ax.plot(xlim, [0, 0], [0, 0], color="k", alpha=0.5)
    ax.plot([0, 0], ylim, [0, 0], color="k", alpha=0.5)
    ax.plot([0, 0], [0, 0], zlim, color="k", alpha=0.5)


def plot_directed_3d(X, flow, labels=None, mask_prob=0.5, cmap="viridis", origin=False, ax=None):
    num_nodes = X.shape[0]
    alpha_points, alpha_arrows = (0.1, 1) if labels is None else (1, 0.1)
    mask = np.random.rand(num_nodes) > mask_prob
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(projection="3d")
    if origin:
        plot_origin_3d(
            ax,
            xlim=[X[:, 0].min(), X[:, 0].max()],
            ylim=[X[:, 1].min(), X[:, 1].max()],
            zlim=[X[:, 2].min(), X[:, 2].max()],
        )
    ax.scatter(X[:, 0], X[:, 1], X[:, 2], marker=".", c=labels, cmap=cmap, alpha=alpha_points)
    ax.quiver(
        X[mask, 0],
        X[mask, 1],
        X[mask, 2],
        flow[mask, 0],
        flow[mask, 1],
        flow[mask, 2],
        alpha=alpha_arrows,
        length=0.5,
    )
    if ax is None:
        plt.show()


For general 3d manifolds (like diffusion maps), this `plot_3d` function is handy.

In [None]:
# export
# For plotting 2D and 3D graphs
import plotly.express as px
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


def plot_3d(
    X,
    distribution=None,
    title="",
    lim=None,
    use_plotly=False,
    colorbar=False,
    cmap="viridis",
):
    if distribution is None:
        distribution = np.zeros(len(X))
    if lim is None:
        lim = np.max(np.linalg.norm(X, axis=1))
    if use_plotly:
        d = {"x": X[:, 0], "y": X[:, 1], "z": X[:, 2], "colors": distribution}
        df = pd.DataFrame(data=d)
        fig = px.scatter_3d(
            df,
            x="x",
            y="y",
            z="z",
            color="colors",
            title=title,
            range_x=[-lim, lim],
            range_y=[-lim, lim],
            range_z=[-lim, lim],
        )
        fig.show()
    else:
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection="3d")
        ax.axes.set_xlim3d(left=-lim, right=lim)
        ax.axes.set_ylim3d(bottom=-lim, top=lim)
        ax.axes.set_zlim3d(bottom=-lim, top=lim)
        im = ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=distribution, cmap=cmap)
        ax.set_title(title)
        if colorbar:
            fig.colorbar(im, ax=ax)
        plt.show()


# Graph Plotting

Functions to visualize (small) graphs, directed and undirected.

In [None]:
# export graph_datasets
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx


def visualize_graph(data, is_networkx=False, to_undirected=False, ax=None):
    G = data if is_networkx else to_networkx(data, to_undirected=to_undirected)
    pos = nx.spring_layout(G, seed=42)
    if ax is None:
        nx.draw_networkx(
            G, pos=pos, arrowsize=20, node_color="#adade0"
        )
        plt.show()
    else:
        nx.draw_networkx(
            G, pos=pos, arrowsize=20, node_color="#adade0", ax=ax
        )


In [None]:
# export graph_datasets
import torch
import matplotlib.pyplot as plt
from torch_geometric.utils import to_dense_adj


def visualize_heatmap(edge_index, order_ind=None, cmap = "copper", ax=None):
    dense_adj = to_dense_adj(edge_index)[0]
    if order_ind is not None:
        dense_adj = dense_adj[order_ind, :][:, order_ind]
    if ax is not None:
        ax.imshow(dense_adj, cmap=cmap)
    else:
        plt.imshow(dense_adj, cmap=cmap)
        plt.show()


We also make a version of `visualize_heatmap` that doesn't rely on torch geometric.

In [None]:
# export
def visualize_edge_index(edge_index, order_ind=None, cmap = "copper", ax=None):
    num_nodes = edge_index.max() + 1
    row, col = edge_index
    dense_adj = np.zeros((num_nodes, num_nodes))
    for r, c in zip(row, col):
        dense_adj[r,c] = 1
    if order_ind is not None:
        dense_adj = dense_adj[order_ind, :][:, order_ind]
    if ax is not None:
        ax.imshow(dense_adj, cmap=cmap)
    else:
        plt.imshow(dense_adj, cmap=cmap)
        plt.show()

## Galary display

The following function takes a list of data generating functions and plot all of them to display a galary of datasets.

In [None]:
# export
import matplotlib.pyplot as plt

def display_galary(vizset, ncol=4):
    nviz = len(vizset)
    nrow = int(np.ceil(nviz/ncol))
    fig = plt.figure(figsize=(4*ncol, 3*nrow))
    for i, viz in enumerate(vizset):
        name, data, vizcall, is3d = viz
        ax = fig.add_subplot(nrow, ncol, i+1, projection="3d" if is3d else None)
        vizcall(data, ax)
        ax.set_title(name, y=1.0)

In [None]:
# export
def display_flow_galary(dataset, ncol=4):
    vizset = []
    for name, data in dataset:
        vizset.append((name, data, lambda data, ax: plot_directed_3d(
            data[0], data[1], data[2], mask_prob=0.5, ax=ax
        ), True))
    display_galary(vizset, ncol)

In [None]:
# export graph_datasets
from FRED.datasets import display_galary
import torch
def display_heatmap_galary(dataset, ncol=4):
    vizset = []
    for name, data in dataset:
        vizset.append((name, data, lambda data, ax: visualize_heatmap(
            data.edge_index, 
            order_ind=None if data.y is None else torch.argsort(data.y[:,-1]), 
            ax=ax
        ), False))
    display_galary(vizset, ncol)

In [None]:
# export graph_datasets
def display_graph_galary(dataset, ncol=4):
    vizset = []
    for name, data in dataset:
        vizset.append((name, lambda ax: visualize_graph(data, ax=ax), False))
    display_galary(vizset, ncol)