# 01c Plotting Utils
> Handy functions for visualizing our datasets, complete with the vector fields we assign to them. Mostly builds off of matplotlib, with nicer syntax for using `.quiver`.

In [None]:
# default_exp datasets
# hide
from nbdev.showdoc import *
import numpy as np
import matplotlib.pyplot as plt
import torch
import FlowNet
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Pointset Plotting

Functions to visualize pointsets and manifolds with flow.

In [None]:
# export
import matplotlib.pyplot as plt


def plot_directed_2d(X, flows, labels, mask_prob=0.5, cmap="viridis"):
    num_nodes = X.shape[0]
    fig = plt.figure()
    ax = fig.add_subplot()
    ax.scatter(X[:, 0], X[:, 1], marker=".", c=labels, cmap=cmap)
    mask = np.random.rand(num_nodes) > mask_prob
    ax.quiver(X[mask, 0], X[mask, 1], flows[mask, 0], flows[mask, 1], alpha=0.1)
    ax.set_aspect("equal")
    plt.show()


In [None]:
# export
def plot_origin_3d(ax, xlim, ylim, zlim):
    ax.plot(xlim, [0, 0], [0, 0], color="k", alpha=0.5)
    ax.plot([0, 0], ylim, [0, 0], color="k", alpha=0.5)
    ax.plot([0, 0], [0, 0], zlim, color="k", alpha=0.5)


def plot_directed_3d(X, flow, labels, mask_prob=0.5, cmap="viridis", origin=False):
    num_nodes = X.shape[0]
    mask = np.random.rand(num_nodes) > mask_prob
    fig = plt.figure()
    ax = fig.add_subplot(projection="3d")
    if origin:
        plot_origin_3d(
            ax,
            xlim=[X[:, 0].min(), X[:, 0].max()],
            ylim=[X[:, 1].min(), X[:, 1].max()],
            zlim=[X[:, 2].min(), X[:, 2].max()],
        )
    ax.scatter(X[:, 0], X[:, 1], X[:, 2], marker=".", c=labels, cmap=cmap)
    ax.quiver(
        X[mask, 0],
        X[mask, 1],
        X[mask, 2],
        flow[mask, 0],
        flow[mask, 1],
        flow[mask, 2],
        alpha=0.1,
        length=0.5,
    )
    plt.show()


For general 3d manifolds (like diffusion maps), this `plot_3d` function is handy.

In [None]:
# export
# For plotting 2D and 3D graphs
import plotly.express as px
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


def plot_3d(
    X,
    distribution=None,
    title="",
    lim=None,
    use_plotly=False,
    colorbar=False,
    cmap="viridis",
):
    if distribution is None:
        distribution = np.zeros(len(X))
    if lim is None:
        lim = np.max(np.linalg.norm(X, axis=1))
    if use_plotly:
        d = {"x": X[:, 0], "y": X[:, 1], "z": X[:, 2], "colors": distribution}
        df = pd.DataFrame(data=d)
        fig = px.scatter_3d(
            df,
            x="x",
            y="y",
            z="z",
            color="colors",
            title=title,
            range_x=[-lim, lim],
            range_y=[-lim, lim],
            range_z=[-lim, lim],
        )
        fig.show()
    else:
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection="3d")
        ax.axes.set_xlim3d(left=-lim, right=lim)
        ax.axes.set_ylim3d(bottom=-lim, top=lim)
        ax.axes.set_zlim3d(bottom=-lim, top=lim)
        im = ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=distribution, cmap=cmap)
        ax.set_title(title)
        if colorbar:
            fig.colorbar(im, ax=ax)
        plt.show()


# Graph Plotting

Functions to visualize (small) graphs, directed and undirected.

In [None]:
# export
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx


def visualize_graph(data):
    G = to_networkx(data, to_undirected=False)
    nx.draw_networkx(
        G, pos=nx.spring_layout(G, seed=42), arrowsize=20, node_color="#adade0"
    )
    plt.show()


In [None]:
# export
import torch
import matplotlib.pyplot as plt
from torch_geometric.utils import to_dense_adj


def visualize_heatmap(edge_index, order_ind=None):
    dense_adj = to_dense_adj(edge_index)[0]
    if order_ind is not None:
        dense_adj = dense_adj[order_ind, :][:, order_ind]
    plt.imshow(dense_adj, cmap="copper")
    plt.show()
