In [None]:
# default_exp embed
from nbdev.showdoc import *
import numpy as np
import matplotlib.pyplot as plt
import torch
import FRED
from FRED.embed import *
if torch.__version__[:4] == '1.13': # If using pytorch with MPS, use Apple silicon GPU acceleration
    device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.has_mps else "cpu")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device", device)
%load_ext autoreload
%autoreload 2

Using device cpu


# 03 Embedder
> At the heart of FRED: the flow embedder. 

FRED's embedder is pretty simple. Given a directed graph along with the coordinates that gave rise to the nodes, FRED embeds it into a lower dimensional space with an autoencoder. FRED also draws a vector field on the embedding space, to endow the embedded points with a sense of flow -- recreating the flows over the directed graph from which they came. FRED is rewarded for drawing arrows such that the flows in the embedding space mimic the flows in the ambient space. And, to give the visualization desirable properties, he is given bonus points for drawing flows that are as smooth as possible - and also placing the points in such a way that they resemble the directed diffusion map. The result is an embedding of the points and velocities that "respects the flow" by incorporating flow information into the placement of points.

This can also be done *without* the coordinates of the nodes -- e.g. when we have an abstract directed graph, unencumbered by physical coordinates. In this case, a GNN serves as the graph embedder that creates embedding coordinates from the input graph. This is implemented as a separate network.

# Manifold with Flow Embedder


In [None]:
# export
import torch
import torch.nn as nn
from FRED.data_processing import affinity_matrix_from_pointset_to_pointset

class ManifoldFlowEmbedder(torch.nn.Module):
    def __init__(
        self,
        embedding_dimension=2,
        embedder_shape=[3, 4, 8, 4, 2],
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        sigma=0.5,
        flow_strength=0.5,
        num_negative_samples = 20,
        smoothness_grid=True,
    ):
        super().__init__()
        self.device = device
        self.embedding_dimension = embedding_dimension
        # embedding parameters
        self.sigma = sigma
        self.flow_strength = flow_strength
        self.smoothness_grid = smoothness_grid
        self.num_negative_samples = num_negative_samples
        # Initialize autoencoder and flow artist
        self.embedder, self.decoder = auto_encoder(embedder_shape, device=self.device)
        self.flowArtist = flow_artist(dim=self.embedding_dimension, device=self.device)
        # training ops
        self.KLD = nn.KLDivLoss(reduction="batchmean", log_target=False)
        self.MSE = nn.MSELoss()
        # self.KLD = homemade_KLD # when running on mac
        self.epsilon = 1e-6  # set zeros to eps

    def loss(self, data, loss_weights):
        # compute autoencoder loss
        losses = {}
        if "reconstruction" in loss_weights and loss_weights["reconstruction"] != 0:
            X_reconstructed = self.decoder(self.embedded_points)
            losses["reconstruction"] = self.MSE(X_reconstructed, data["X"])

        # Compute diffusion map loss
        if "distance regularization" in loss_weights and loss_weights["distance regularization"] != 0:
            diffmap_loss = precomputed_distance_loss(
                data["precomputed distances"], self.embedded_points
            )
            #           diffmap_loss = diffusion_map_loss(self.P_graph_ts[0], self.embedded_points)
            losses["distance regularization"] = diffmap_loss
        if "distance regularization v2" in loss_weights and loss_weights["distance regularization v2"] != 0:
            loss = precomputed_distance_lossV2(
                embedded_points=self.embedded_points,
                near_distances_precomputed=data['distance to neighbors'],
                far_distances_precomputed=data['distance to farbors'],
                center_point_idxs=data['center point idxs'],
                neighbor_idxs=data['neighbor idxs'],
                farbor_idxs=data['farbor idxs'],
            )
            losses["distance regularization v2"] = loss
        # Compute flow neighbor loss
        if "flow neighbor loss" in loss_weights and loss_weights["flow neighbor loss"] != 0:
            neighbor_loss = flow_neighbor_loss(
                data["neighbors"],
                self.embedded_points,
                self.embedded_flows,
            )
            losses["flow neighbor loss"] = neighbor_loss
            
        # Computes negative sampling loss
        if "contrastive flow loss" in loss_weights and loss_weights["contrastive flow loss"] != 0:
            # sample random points from the realm outside of the flow neighbors
            row = torch.zeros(self.num_negative_samples).long()
            negative_sample_idxs = torch.randint(data["num flow neighbors"],len(self.embedded_points),(1,self.num_negative_samples))[0]
            not_neighbors = torch.vstack([row, negative_sample_idxs])
            # pass these into the contrastive flow loss
            loss = contrastive_flow_loss(
                not_neighbors,
                self.embedded_points,
                self.embedded_flows
            )
            losses["contrastive flow loss"] = loss

        # Compute smoothness regularization
        if "smoothness" in loss_weights and loss_weights["smoothness"] != 0:
            smoothness_loss = smoothness_of_vector_field(
                self.embedded_points,
                self.flowArtist,
                device=self.device,
                grid_width=20,
                use_grid=self.smoothness_grid,
            )
            losses["smoothness"] = smoothness_loss

        if "kld" in loss_weights and loss_weights["kld"] != 0:
            A = affinity_matrix_from_pointset_to_pointset(self.embedded_points, self.embedded_points, self.embedded_flows, sigma=0.5, flow_strength=1)
            P = torch.nn.functional.normalize(A, p=1, dim=1)
            losses['kld'] = kl_divergence_loss(P, data["P"])

        if "contrastive loss v2" in loss_weights and loss_weights["contrastive loss v2"] != 0:
            losses['contrastive loss v2'] = contrastive_flow_loss_V2(self.embedded_points, self.embedded_flows, center_point_idxs = data["center point idxs"], neighbor_idxs = data["neighbor idxs"])

        return losses

    def forward(self, data, loss_weights):
        self.embedded_points = self.embedder(data["X"])
        self.embedded_flows = self.flowArtist(self.embedded_points)
        losses = self.loss(data, loss_weights)
        return losses