# UMAP Visualization: PKS vs non-PKS

Loads precomputed UMAP embeddings and plots them, colored by label.
Embeddings expected in `pks_nonpks_umap_embeddings.parquet` generated by the UMAP script.

In [None]:
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Prefer the provided absolute path; fallback to repo-relative path
abs_path = Path('/Users/yashchainani/Desktop/PythonProjects/ContrastiveGNNs/data/processed/pks_nonpks_umap_embeddings.parquet')
rel_path = Path('..') / 'data' / 'processed' / 'pks_nonpks_umap_embeddings.parquet'
in_path = abs_path if abs_path.exists() else rel_path
print('Reading embeddings from:', in_path)
df = pd.read_parquet(in_path)
df.head()


In [None]:
# Verify columns and derive label
required = {'umap_1', 'umap_2'}
missing = required - set(df.columns)
if missing:
    raise ValueError(f'Missing required columns: {missing}')

if 'is_pks' in df.columns:
    df['label'] = df['is_pks'].astype(str)
else:
    # Build label from source
    if 'source' not in df.columns:
        raise ValueError('Expected either is_pks or source column')
    df['label'] = df['source'].astype(str).where(df['source'] == 'PKS', 'non-PKS')

df['label'].value_counts()


In [None]:
# Optional: subsample for plotting performance (set to None for full plot)
MAX_POINTS = 300000  # adjust or set to None for all points
if MAX_POINTS is not None and len(df) > MAX_POINTS:
    df_plot = df.sample(MAX_POINTS, random_state=42)
    print(f'Subsampled to {len(df_plot):,} points for plotting')
else:
    df_plot = df
len(df_plot)


In [None]:
# Plot with a high-contrast palette
plt.figure(figsize=(9, 7))
palette = {'PKS': '#1f77b4', 'non-PKS': '#ff7f0e'}
sns.scatterplot(
    data=df_plot, x='umap_1', y='umap_2', hue='label',
    palette=palette, s=6, linewidth=0, alpha=0.5
)
plt.title('UMAP of ECFP4 fingerprints: PKS vs non-PKS')
plt.legend(title='Label', markerscale=2)
plt.tight_layout()
out_png = Path('..') / 'data' / 'processed' / 'pks_nonpks_umap_viz.png'
out_png.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(out_png, dpi=200)
print('Saved figure to', out_png)
plt.show()
