In [73]:
import torch
from sklearn.manifold import TSNE
import plotly.express as px
from sklearn.decomposition import PCA
import plotly.graph_objects as go
import pandas as pd
import numpy as np

In [None]:
MODEL = "spa"
FILTER = {
    # "object_class": "Car_Red",
    "orientation": "Orbit",
    "environment": "Forest"
}

FEATURES = "MEAN"
repr_dic = "/shared/results/common/kargin/unreal_engine/features/initial_test"

features = torch.load(f"{repr_dic}/repr_{MODEL}.pt", weights_only=True)

# Get indices that satisfy all filter conditions
valid_indices = []
for i in range(len(features["object_class"])):
    valid = True
    for key, value in FILTER.items():
        if features[key][i] != value:
            valid = False
            break
    if valid:
        valid_indices.append(i)

features = {
    "features": features["features"][valid_indices],
    "orientation": [features["orientation"][i] for i in valid_indices],
    "environment": [features["environment"][i] for i in valid_indices],
    "object_class": [features["object_class"][i] for i in valid_indices],
    "number": [features["number"][i] for i in valid_indices],
    "path": [features["path"][i] for i in valid_indices],
}

print(list(features.keys()))
print(features['features'].shape)
print(len(features['environment']))

# Create unique combinations for visualization
unique_combinations = set(zip(features['object_class'], features['orientation'], features['environment']))
print("Unique combinations:", unique_combinations)

# Sort features for each combination separately
all_features = features['features']
sorted_features = []
sorted_numbers = []
sorted_labels = []

for cls, orien, env in unique_combinations:
    # Get indices for this combination
    combo_indices = [i for i in range(len(features['object_class'])) 
                    if features['object_class'][i] == cls 
                    and features['orientation'][i] == orien 
                    and features['environment'][i] == env]
    
    # Sort by number within this combination
    sorted_indices = np.argsort([features['number'][i] for i in combo_indices])
    combo_indices = [combo_indices[i] for i in sorted_indices]
    
    # Add to our lists
    sorted_features.append(all_features[combo_indices])
    sorted_numbers.extend([features['number'][i] for i in combo_indices])
    sorted_labels.extend([f"{cls}_{orien}_{env}"] * len(combo_indices))

# Concatenate all sorted features
all_features = torch.cat(sorted_features, dim=0)
labels = sorted_labels

if FEATURES == "CLS":
    features = all_features[:, 0, :]
elif FEATURES == "MEAN":
    features = all_features[:, -196:, :].mean(1)
elif FEATURES == "CENTER":
    features = all_features[:, 105, :]
else:
    raise Exception("bruh")

MODEL += "(token=" + FEATURES + ")"

In [186]:
MODE="TSNE"
# MODE="PCA"

PERPLEXITY = 100

if MODE == "TSNE":
    tsne = TSNE(n_components=2, perplexity=PERPLEXITY)
    features_2d = tsne.fit_transform(features.numpy())
    MODE += '(p=' + str(PERPLEXITY) + ")"
elif MODE == "PCA":
    pca = PCA(n_components=2)
    features_2d = pca.fit_transform(features.numpy())
else:
    raise Exception("bruh")

df = pd.DataFrame({
    'x': features_2d[:, 0],
    'y': features_2d[:, 1],
    'label': labels,
})

df["frame_no"] = df.groupby('label').cumcount() * 10

df["env"] = df["label"].apply(lambda x: "forest" if "Forest" in x else "city")
df["centric"] = df["label"].apply(lambda x: "obj" if "Orbit" in x else "ego")
df["subject"] = df["label"].apply(lambda x: "human" if "human" in x
    else "car" if "Car"  in x
    else "bike" if "Bike"  in x
    else "?")

In [None]:
TITLE = f"{MODEL}, {MODE}, filter=`{FILTER}`"

fig = px.scatter(
    df,
    x='x',
    y='y',
    color='label',
    hover_name='label',
    hover_data=['label', 'frame_no'],
    title=TITLE,
    labels={'label': 'Class'},
    width=1800,
    height=1200,
)

for trace in fig.data:
    trace.update(
        mode='lines+markers',     # Connect points
        line=dict(
            width=1,              # Fine/thin line
            shape='spline',        # Smooth curves instead of straight lines
            smoothing=1.1,         # You can play with 0-1.5 to make it even smoother
        ),
        marker=dict(size=6)        # Optional: smaller marker points
    )


for trace in fig.data:
    n_points = len(trace.x)
    trace.update(
        marker=dict(
            symbol=['diamond'] + ['circle']*(n_points-2) + ["square"],
            size=[9] + [6] * (n_points - 2) + [8],
        )
    )

    for trace in fig.data:
        fig.add_annotation(
            x=trace.x[0],
            y=trace.y[0],
            text=trace.name,
            showarrow=True,
            arrowcolor='gray',          # <-- arrow is gray
            arrowhead=0,                # simple straight arrow
            arrowsize=1,
            arrowwidth=1,
            xanchor='left',
            ax=15,                      # <-- push text 20px to the right
            ay=5,                       # no vertical shift
            font=dict(
                color=trace.marker.color,  # text color same as marker color
                size=13,
                family='Arial'
            ),
            opacity=0.7,
            bgcolor='rgba(255,255,255,0.1)',
        )
    
fig.update_traces(
    marker=dict(
    line=dict(color='black', width=0.5)  # Add small black outline
))

fig.update_layout(
    showlegend=False,
    legend_title_text='',
    margin=dict(l=20, r=20, t=50, b=20),
        legend=dict(
        font=dict(size=16),
    ),
    coloraxis_showscale=False
)

fig.update_layout(
    xaxis=dict(
        showticklabels=False,         
    ),
    yaxis=dict(
        showticklabels=False,  
    ),
    xaxis_title=None,
    yaxis_title=None
)



fig.write_html(f"figs_html/{fig.layout.title.text}.html")
# fig.write_image(f"figs_png/{fig.layout.title.text}.png")
fig.show()