In [None]:
import os

import matplotlib.pyplot as plt
import torch

In [None]:
# Read (N + 1, embed_dim) tensors from .pt files
# in embed_dir and return a dictionary mapping
# class labels to (N + 1, embed_dim) tensors. Remove
# the first row of each tensor to get a (N, embed_dim)
# tensor of embeddings for each class.
def read_embed_db(embed_dir):
    embed_db = {}
    for label in os.listdir(embed_dir):
        embed_db[label.split('.')[0]] = torch.load(os.path.join(embed_dir, label))[1:]
    return embed_db

def plot_embeddings(input_map, labels=None, title=None):
    embed_db = {}
    for label in labels:
        embed_db[label] = input_map[label]

    # Compute 2d projection of embeddings using PCA
    # and plot the projection.
    all_embeds = torch.cat([embed_db[label] for label in embed_db])
    U, S, V = torch.pca_lowrank(all_embeds, niter=10)
    proj = torch.matmul(all_embeds, V[:, :2])

    # Create a dictionary mapping class labels to
    # a list of 2d coordinates for each embedding
    # in the class.
    proj_db = {}
    curr = 0
    for label in embed_db:
        proj_db[label] = torch.unbind(proj[curr:curr + len(embed_db[label])])
        curr += len(embed_db[label])

    # Plot the 2d projection
    plt.figure(figsize=(10, 10))
    for label in proj_db:
        x, y = zip(*proj_db[label])
        plt.scatter(x, y, label=label)
    plt.legend()

embed_dir = './tmp'
embed_db = read_embed_db(embed_dir)
plot_embeddings(embed_db, labels=['Bark', 'Cello', 'Flute', 'Cowbell', 'Gunshot_or_gunfire'])