In [None]:
# default_exp trainers
from nbdev.showdoc import *
import numpy as np
import matplotlib.pyplot as plt
import torch
import FRED
if torch.__version__[:4] == '1.13': # If using pytorch with MPS, use Apple silicon GPU acceleration
    device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.has_mps else "cpu")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device", device)
%load_ext autoreload
%autoreload 2

# 04b Visualizations for Training
> Simple functions to visualize progress during network training.

This notebook will house barebones code to visualize the losses and embedded points of FRED concurrently with training, as well as tools for saving and creating GIFs from the embedding trainings.

A general philosophy here is to avoid printing ad naseum as much as possible, by compressing the form of information served to the most appropriate condensed form for efficient summarization.

We'll start with the meat and bones of the flow embedder: visualizing the embedded points, with a grid to visualize flow arrows.

In [None]:
# export
import torch
from FRED.embed import compute_grid

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def visualize_points(
    embedded_points,
    flow_artist,
    labels=None,
    device=device,
    title="FRED's Embedding",
    save=False,
    **kwargs,
):
    # computes grid around points
    # TODO: This might create CUDA errors
    grid = compute_grid(embedded_points.to(device)).to(device)
    # controls the x and y axes of the plot
    # linspace(min on axis, max on axis, spacing on plot -- large number = more field arrows)
    uv = flow_artist(grid).detach().cpu()
    u = uv[:, 0].cpu()
    v = uv[:, 1].cpu()
    x = grid.detach().cpu()[:, 0]
    y = grid.detach().cpu()[:, 1]
    # quiver
    # 	plots a 2D field of arrows
    # 	quiver([X, Y], U, V, [C], **kw);
    # 	X, Y define the arrow locations, U, V define the arrow directions, and C optionally sets the color.
    if labels is not None:
        sc = plt.scatter(
            embedded_points[:, 0].detach().cpu(),
            embedded_points[:, 1].detach().cpu(),
            c=labels,
        )
    # 			plt.legend()
    else:
        sc = plt.scatter(
            embedded_points[:, 0].detach().cpu(), embedded_points[:, 1].detach().cpu()
        )
    plt.suptitle("Flow Embedding")
    plt.quiver(x, y, u, v)
    # Display all open figures.
    if save:
        plt.savefig(f"visualizations/{title}.jpg")
    else:
        plt.show()
    plt.close()


In [None]:
# export
def save_embedding_visualization(
    embedded_points,
    flow_artist,
    labels=None,
    device=device,
    title="FRED's Embedding",
    **kwargs
):
    visualize_points(
        embedded_points=embedded_points,
        flow_artist=flow_artist,
        labels=labels,
        device=device,
        title=title,
        save=True,
    )

In [None]:
# export
def collate_loss(
    provided_losses,
    weights,
    prior_losses=None,
    loss_type="total",
):
    # diffusion_loss,reconstruction_loss, smoothness_loss
    k = ""
    if prior_losses is None:
        # if there are no prior losses, initialize a new dictionary to store these
        prior_losses = {}
        for key in provided_losses.keys():
            prior_losses[key] = []
            # k = key
        prior_losses["total"] = []
    for key in provided_losses.keys():
        try:
            prior_losses[key].append(provided_losses[key].detach().cpu().numpy() * weights[key])
        except:
            prior_losses[key].append(0)
    return prior_losses
