In [None]:
# default_exp trainers
from nbdev.showdoc import *
import numpy as np
import matplotlib.pyplot as plt
import torch
import FRED
from FRED.trainers import *
if torch.__version__[:4] == '1.13': # If using pytorch with MPS, use Apple silicon GPU acceleration
    device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.has_mps else "cpu")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device", device)
%load_ext autoreload
%autoreload 2

Using device mps
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# 04 Flow Embedder Trainers
> Wrappers to train different variants of FRED, while producing bespoke visualizations. Useful for comparing multiple models, and also training en masse, e.g. on clusters

In [None]:
# export
import torch.nn as nn
import torch
import time
import datetime
import FRED
from tqdm import trange
import glob
from PIL import Image
import os
import ipywidgets as widgets
import base64
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class FETrainer(object):
    def __init__(self, X, flows, labels, device=device):
        # super(FETrainer, self).__init__()
        self.vizfiz = [
            save_embedding_visualization,
        ]
        self.FE = MultiscaleDiffusionFlowEmbedder(
            X=X,
            flows=flows,
            ts=(1, 2, 4, 8),
            sigma_graph=0.5,
            flow_strength_graph=5,
            device=device,
            use_embedding_grid=False,
        ).to(device)
        self.losses = None
        self.labels = labels
        self.title = "Vanilla MFE"
        self.epochs_between_visualization = 100
        self.total_epochs = 10000
        self.timestamp = datetime.datetime.now().isoformat()
        os.mkdir(f"visualizations/{self.timestamp}")

    def fit(self):
        num_training_runs = self.total_epochs // self.epochs_between_visualization
        for epoch_num in trange(num_training_runs):
            start = time.time()
            emb_X, flow_artist, losses = self.FE.fit(
                n_steps=self.epochs_between_visualization
            )
            stop = time.time()
            title = f"{self.timestamp}/{self.title} Epoch {epoch_num:03d}"
            self.visualize(emb_X, flow_artist, losses, title)
            self.losses = collate_loss(provided_losses=losses, prior_losses=self.losses)
        self.embedded_points = emb_X
        self.flow_artist = flow_artist

    def visualize(self, embedded_points, flow_artist, losses, title):
        for viz_f in self.vizfiz:
            viz_f(
                embedded_points=embedded_points,
                flow_artist=flow_artist,
                losses=losses,
                title=title,
                labels=self.labels,
                FE = self.FE
            )

    def training_gif(self, duration=10):
        file_names = glob.glob(f"visualizations/{self.timestamp}/*.jpg")
        file_names.sort()
        frames = [
            Image.open(image)
            for image in file_names
        ]
        frame_one = frames[0]
        frame_one.save(
            f"visualizations/{self.timestamp}/{self.title}.gif",
            format="GIF",
            append_images=frames,
            save_all=True,
            duration=duration,
            loop=0,
        )
        # display in jupyter notebook
        b64 = base64.b64encode(
            open(f"visualizations/{self.timestamp}/{self.title}.gif", "rb").read()
        ).decode("ascii")
        display(widgets.HTML(f'<img src="data:image/gif;base64,{b64}" />'))

    def visualize_embedding(self):
        visualize_points(
            embedded_points=self.embedded_points,
            flow_artist=self.flow_artist,
            labels=self.labels,
            title=self.title,
        )

    def visualize_loss(self, loss_type="all"):
        if loss_type == "all":
            for key in self.losses.keys():
                plt.plot(self.losses[key])
            plt.legend(self.losses.keys(), loc="upper right")
            plt.title("loss")
        else:
            plt.plot(self.losses[loss_type])
            plt.title(loss_type)


ModuleNotFoundError: No module named 'directed_graphs'