In [22]:
import os
import math
import numpy as np
import pandas as pd
import itertools
from IPython.display import clear_output, display

!pip install matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

!pip install plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

!pip install sklearn
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE



### Select Dataset and VLM

In [34]:
from similarity_metrics import Similarity
from dataset import DatasetHandler
    
if True:
    dataset = DatasetHandler("moma_act", "test", 10)

if False:
    dataset = DatasetHandler("moma_sact", "test", 10)
    
if False:
    dataset = DatasetHandler("kinetics_100", "test", 10)

dataset.moma.momaapi.lookup._read_anns() took 1.106520175933838 sec
dataset.moma.momaapi.statistics._read_statistics() took 0.0006163120269775391 sec


In [35]:
ENV = os.environ["CONDA_DEFAULT_ENV"]
if ENV == "VLM_MILES":
    from MILES.wrapper import MILES_SimilarityVLM
    vlm = MILES_SimilarityVLM()
else:
    raise ValueError()

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


######USING ATTENTION STYLE:  frozen-in-time


In [36]:
from classifier import SubVideoAverageFewShotClassifier
classifier = SubVideoAverageFewShotClassifier(vlm, subvideo_segment_duration=1)

### Load Embeddings

In [37]:
# Set fixed order for category names, paths and associated embeddings
category_names = list(dataset.data_dict.keys())
category_paths = [dataset.data_dict[name] for name in category_names]

text_embeds_per_category = [vlm.get_text_embeds(name) for name in category_names]
subvid_embeds_per_category = [
    [
        classifier.get_subvideo_embeds(path)
        for path in paths
    ]
    for paths in category_paths
]
vid_embeds_per_category = [
    [
        vlm.get_video_embeds(path)
        for path in paths
    ]
    for paths in category_paths
]

### Embedding Statistics

In [38]:
intra_subvid_similarity = [
    [
        np.mean(vlm.default_similarity_metric()(np.array(subvid_embeds), np.array(subvid_embeds)))
        for subvid_embeds in category_subvid_embeds
    ]
    for category_subvid_embeds in subvid_embeds_per_category
]

mean_category_intra_subvid_similarity = [
    np.mean(category_intra_subvid_similarity)
    for category_intra_subvid_similarity in intra_subvid_similarity
]

temp = pd.DataFrame(columns=category_names, index=range(max(*[len(paths) for paths in category_paths])))
for i, category_intra_subvid_similarity in enumerate(intra_subvid_similarity):
    for j, video_intra_subvid_similarity in enumerate(category_intra_subvid_similarity):
        temp.iloc[j, i] = video_intra_subvid_similarity
display(temp)

pd.DataFrame([mean_category_intra_subvid_similarity], columns=category_names)

Unnamed: 0,beauty salon service,drive-thru ordering,physical therapy,reception service,table tennis game
0,0.583817,0.70201,0.905029,0.811988,0.731826
1,0.812259,0.681857,0.784008,0.803887,0.576599
2,0.866064,0.656108,0.825429,0.779539,0.812658
3,0.781305,0.705206,0.673036,0.88393,0.865076
4,0.48188,0.928694,0.808298,0.715765,0.697708
...,...,...,...,...,...
100,,,,,0.701937
101,,,,,0.743071
102,,,,,0.831335
103,,,,,0.935701


Unnamed: 0,beauty salon service,drive-thru ordering,physical therapy,reception service,table tennis game
0,0.740081,0.743451,0.722313,0.769837,0.841255


### Compute T-SNE Embeddings

In [6]:
# Stack embeddings to perform T-SNE over all together

stacked_embeddings = []
text_stacked_indices = []
vid_stacked_indices = []
subvid_stacked_indices = []

next_index = 0
for name, text_embed, paths, vid_embeds_per_path, subvid_embeds_per_path in zip(category_names, text_embeds_per_category, category_paths, vid_embeds_per_category, subvid_embeds_per_category):
    stacked_embeddings.append(text_embed)
    text_stacked_indices.append(next_index)
    next_index += 1
    
    stacked_embeddings += vid_embeds_per_path
    vid_stacked_indices.append([next_index + i for i in range(len(vid_embeds_per_path))])
    next_index += len(vid_embeds_per_path)
    
    subvid_stacked_indices.append([])
    for path, subvid_embeds in zip(paths, subvid_embeds_per_path):
        stacked_embeddings += subvid_embeds
        subvid_stacked_indices[-1].append([next_index + i for i in range(len(subvid_embeds))])
        next_index += len(subvid_embeds)
        
stacked_embeddings = np.array(stacked_embeddings)



if vlm.default_similarity_metric() is Similarity.COSINE:
    sklearn_metric = "cosine"
elif vlm.default_similarity_metric() is Similarity.DOT:
    # NOTE: This is imperfect. No distance metric can match dot-product ordering without violating triangle inequality
    # (For any 2 vectors which aren't directly opposite each other, a third vector exists with arbitrarily-high similarity to both)
    sklearn_metric = lambda a, b: math.exp(-Similarity.DOT(a[None, :], b[None, :]))
else:
    raise ValueError("Unknown equivalent sklearn metric name")

sne_embeddings = TSNE(n_components=2, metric=sklearn_metric).fit_transform(stacked_embeddings)



# Unstack SNE embeddings into original fixed order of embeddings
text_sne_embeds_per_category = [
    sne_embeddings[text_stacked_index]
    for text_stacked_index in text_stacked_indices
]
vid_sne_embeds_per_category = [
    [
        sne_embeddings[vid_stacked_index]
        for vid_stacked_index in category_vid_stacked_indices
    ]
    for category_vid_stacked_indices in vid_stacked_indices
]
subvid_sne_embeds_per_category = [
    [
        [
            sne_embeddings[subvid_stacked_index]
            for subvid_stacked_index in path_subvid_stacked_indices
        ]
        for path_subvid_stacked_indices in category_subvid_stacked_indices
    ]
    for category_subvid_stacked_indices in subvid_stacked_indices
]



### Interactive SNE Embedding Plots

In [7]:
# Class-specific plots
# Subplot for each video's subvideo path, plotted against all class text embeddings
PLOT_DIR = f"subvideo_visualizations/subvideo_plots.{vlm.__class__.__name__}.{classifier.subvideo_segment_duration}s_segments.{dataset.id()}"
os.makedirs(PLOT_DIR, exist_ok=True)

for cat_name, text_sne_embed, path_per_vid, vid_sne_embeds_per_vid, subvid_sne_embeds_per_vid in zip(category_names, text_sne_embeds_per_category, category_paths, vid_sne_embeds_per_category, subvid_sne_embeds_per_category):
    SUBPLOT_COUNT = len(path_per_vid)
    
    N_COLS = 4
    N_ROWS = math.ceil(SUBPLOT_COUNT / N_COLS)
    FIG_WIDTH = 1600

    fig = make_subplots(rows=N_ROWS, cols=N_COLS,
                        vertical_spacing=0.2 / N_ROWS, horizontal_spacing=0.2 / N_COLS,
                        subplot_titles=[path.split("/")[-1] for path in path_per_vid[:SUBPLOT_COUNT]])
    for i in range(SUBPLOT_COUNT):
        path = path_per_vid[i]
        subvid_sne_embeds = np.array(subvid_sne_embeds_per_vid[i])
        vid_sne_embed = vid_sne_embeds_per_vid[i]
        
        # Reset color palette
        palette = itertools.cycle(px.colors.qualitative.D3)
        
        # Incorrect class text embeds
        fig.add_trace(
            go.Scatter(
                mode="markers",
                x=[other_text_sne_embed[0] for other_cat_name, other_text_sne_embed in zip(category_names, text_sne_embeds_per_category) if other_cat_name != cat_name],
                y=[other_text_sne_embed[1] for other_cat_name, other_text_sne_embed in zip(category_names, text_sne_embeds_per_category) if other_cat_name != cat_name],
                text=[other_cat_name + " (incorrect)" for other_cat_name in category_names if other_cat_name != cat_name],
                hovertemplate="%{text}<extra></extra>",
                marker_size=20,
                marker_opacity=0.8,
                marker_color=next(palette)
            ),
            row = i // N_COLS + 1,
            col = i % N_COLS + 1
        )
        
        # Subvideo embed path
        fig.add_trace(
            go.Scatter(
                mode="markers+lines",
                x=subvid_sne_embeds[:, 0],
                y=subvid_sne_embeds[:, 1],
                text=[f"Subvideo Embed {i}" for i in range(len(subvid_sne_embeds))],
                hovertemplate="%{text}<extra></extra>",
                marker_size=10,
                marker_opacity=0.8,
                marker_color=next(palette)
            ),
            row = i // N_COLS + 1,
            col = i % N_COLS + 1
        )
        
        # Video Embed
        fig.add_trace(
            go.Scatter(
                mode="markers",
                x=[vid_sne_embed[0]],
                y=[vid_sne_embed[1]],
                text=[f"Full Video Embed"],
                hovertemplate="%{text}<extra></extra>",
                marker_size=10,
                marker_opacity=0.8,
                marker_color=next(palette)
            ),
            row = i // N_COLS + 1,
            col = i % N_COLS + 1
        )
        
        # Plot correct class text embed
        fig.add_trace(
            go.Scatter(
                mode="markers",
                x=text_sne_embed[:1], y=text_sne_embed[1:],
                text=[cat_name + " (correct)"],
                hovertemplate="%{text}<extra></extra>",
                marker_size=20,
                marker_opacity=0.8,
                marker_color=next(palette)
            ),
            row = i // N_COLS + 1,
            col = i % N_COLS + 1
        )
        
        
        
    fig.update_layout(width=FIG_WIDTH, height=FIG_WIDTH / N_COLS * N_ROWS, showlegend=False, title=cat_name)

    fig.write_html(f"{PLOT_DIR}/{cat_name.replace(' ', '_')}.html")
    fig.show()
    clear_output()