In [None]:
# default_exp flow_embedding_training_utils
from nbdev.showdoc import *
import numpy as np
import matplotlib.pyplot as plt
import torch
import directed_graphs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Visualization Utils for the Flow Embedder

In the quest to have maximum code separation, we're separating the of 05c's flow embedder into separate embedding and visualization schemes.

This notebook will house barebones code to visualize the losses and embedded points of the MFE concurrently with training, as well as tools for saving and creating GIFs from the embedding trainings.

A general philosophy here is to avoid printing ad naseum as much as possible, by compressing the form of information served to the most appropriate condensed form for efficient summarization.

We'll start with the meat and bones of the flow embedder: visualizing the embedded points, with a grid to visualize flow arrows.

In [None]:
# export
from directed_graphs.multiscale_flow_embedder import compute_grid
device = torch.device("cuda" if torch.has_cuda else "cpu")
def visualize_points(embedded_points, flow_artist, labels = None, device = device):
		# computes grid around points
		# TODO: This might create CUDA errors
		grid = compute_grid(embedded_points.to(device))
		# controls the x and y axes of the plot
		# linspace(min on axis, max on axis, spacing on plot -- large number = more field arrows)
		uv = flow_artist(grid).detach().cpu()
		u = uv[:,0].cpu()
		v = uv[:,1].cpu()
		x = grid.detach().cpu()[:,0]
		y = grid.detach().cpu()[:,1]
		# quiver 
		# 	plots a 2D field of arrows
		# 	quiver([X, Y], U, V, [C], **kw); 
		# 	X, Y define the arrow locations, U, V define the arrow directions, and C optionally sets the color.
		if labels is not None:
			sc = plt.scatter(embedded_points[:,0].detach().cpu(),embedded_points[:,1].detach().cpu(), c=labels)
			plt.legend()
		else:
			sc = plt.scatter(embedded_points[:,0].detach().cpu(),embedded_points[:,1].detach().cpu())
		plt.suptitle("Flow Embedding")
		plt.quiver(x,y,u,v)
		# Display all open figures.
		plt.show()

In [None]:
def visualize_loss(self, loss_type = "total"):
		# diffusion_loss,reconstruction_loss, smoothness_loss
		x = []
		k = ""
		losses = {}
		for key in self.losses.keys():
			losses[key] = []
			k = key
		losses["total"] = []
		for i in range(len(self.losses["diffusion"])):
			x.append(i)
			for key in self.losses.keys():
				try:
					losses[key].append(self.losses[key][i].detach().cpu().numpy())
				except:
					losses[key].append(0)
		if loss_type == "all":
			for key in self.losses.keys():
				plt.plot(x, losses[key])
			plt.legend(self.losses.keys(), loc='upper right')
			plt.title("loss")
		else:
			plt.plot(x, losses[loss_type])
			plt.title(loss_type)

# Training Scheme

This class will instantiate the Flow Embedder with preset parameters, and will train for a few hundred epochs while calling in these visualization functions, like swappable modules.

The base class will take a list of visualization functions as input, and will iterate over them whenever it is called to visualize.

Subsequent notebooks (05c0n) will inherit from this class.

In [None]:
# export
import torch.nn as nn
import torch
from directed_graphs.multiscale_flow_embedder import MultiscaleDiffusionFlowEmbedder

class FETrainer(nn.Module):
  def __init__(self, X, flows):
    self.vizfiz = [
      visualize_points,
    ]
    self.FE = MultiscaleDiffusionFlowEmbedder(
      X = X,
      flows = flows,
      ts = (1, 2, 4, 8),
      sigma_graph = 0.5,
      flow_strength_graph = 5
    )
    self.epochs_between_visualization = 100
    self.total_epochs = 10000
  
  def fit(self):
    for epoch_num in self.total_epochs // self.epochs_between_visualization:
      emb_X, emb_flows, losses = self.FE.fit(n_steps = self.epochs_between_visualization)
      self.visualze(emb_X, emb_flows, losses)

  def visualize(self, X, flows, losses):
    for viz_f in self.vizfiz:
      viz_f(X = X, flows = flows, losses = losses)
      

  