# t-SNE Visualizer for SustainVision Checkpoints

Use this notebook to project model embeddings to 2D with t-SNE. You can compare checkpoints trained with different objectives (e.g., cross-entropy vs SimCLR).

**Workflow**
1. Set the checkpoint path in the first cell below.
2. Run all cells.
3. Inspect the scatter plot to see how well classes separate.



In [None]:
# Cell 1: Configuration
CHECKPOINT_PATH = "resnet18_simclr_checkpoints/resnet18_simclr_model.pt"  # path to saved checkpoint
DATASET = "cifar10"
NUM_SAMPLES = 2000
DEVICE = "cuda:1"

In [None]:
# Cell 2: Imports
import sys
from pathlib import Path

sys.path.insert(0, str(Path.cwd()))

import torch
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

from sustainvision.config import TrainingConfig
from sustainvision.data import build_classification_dataloaders
from sustainvision.training import _build_model

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")



In [None]:
# Cell 3: Device summary
def describe_devices():
    print("=== Device Summary ===")
    print(f"Requested device: {DEVICE}")
    if torch.cuda.is_available():
        count = torch.cuda.device_count()
        print(f"CUDA devices: {count}")
        for idx in range(count):
            name = torch.cuda.get_device_name(idx)
            props = torch.cuda.get_device_properties(idx)
            total_gb = props.total_memory / (1024 ** 3)
            with torch.cuda.device(idx):
                free, total = torch.cuda.mem_get_info()
                used_gb = (total - free) / (1024 ** 3)
                print(f"  GPU {idx}: {name} | used {used_gb:.2f} GB / {total_gb:.2f} GB")
    else:
        print("CUDA not available; CPU only.")

describe_devices()

device = torch.device(DEVICE if torch.cuda.is_available() and DEVICE.startswith("cuda") else "cpu")
print(f"\nUsing device: {device}")


In [None]:
# Cell 4: Load checkpoint and dataset
checkpoint_path = Path(CHECKPOINT_PATH)
if not checkpoint_path.is_absolute():
    checkpoint_path = Path.cwd() / checkpoint_path

if not checkpoint_path.exists():
    raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

print(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)

if "config" in checkpoint:
    config = TrainingConfig(**checkpoint["config"])
else:
    print("[warn] No config stored in checkpoint. Using defaults.")
    config = TrainingConfig()

train_loader, _, num_classes = build_classification_dataloaders(
    DATASET,
    batch_size=128,
    num_workers=2,
    val_split=0.0,
    seed=config.seed,
    project_root=Path.cwd(),
    image_size=config.hyperparameters.get("image_size", 224),
    contrastive=config.loss_function in {"simclr", "supcon"},
)

model = _build_model(
    config.model,
    num_classes=num_classes,
    image_size=config.hyperparameters.get("image_size", 224),
    projection_dim=config.hyperparameters.get("projection_dim", 128),
).to(device)

state_key = None
for key_candidate in ["model_state", "model"]:
    if key_candidate in checkpoint:
        state_key = key_candidate
        break

if state_key:
    model.load_state_dict(checkpoint[state_key], strict=False)
else:
    model.load_state_dict(checkpoint, strict=False)

model.eval()
print(f"Model: {config.model} ({config.loss_function})")
print(f"Classes: {num_classes}")
print(f"Model device: {next(model.parameters()).device}")



In [None]:
# Cell 5: Collect embeddings (limited by NUM_SAMPLES)
target_samples = NUM_SAMPLES
embeddings = []
labels = []

with torch.no_grad():
    for inputs, batch_labels in tqdm(train_loader, desc="Embedding batches"):
        inputs = inputs.to(device)
        if hasattr(model, "backbone"):
            feats = model.backbone(inputs)
            if isinstance(feats, tuple):
                feats = feats[0]
            if feats.ndim > 2:
                feats = torch.flatten(feats, 1)
        else:
            feats, _ = model(inputs)
        embeddings.append(feats.cpu())
        labels.append(batch_labels)

        if sum(len(x) for x in labels) >= target_samples:
            break

embeddings = torch.cat(embeddings, dim=0)[:target_samples]
labels = torch.cat(labels, dim=0)[:target_samples]

print(f"Collected embeddings: {embeddings.shape}")

In [None]:
# Cell 6: Run t-SNE
emb_np = embeddings.numpy()
print("Running t-SNE (this may take a while for large NUM_SAMPLES)...")
tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42, verbose=1)
emb_2d = tsne.fit_transform(emb_np)
print("t-SNE finished.")

In [None]:
# Cell 7: Plot
plt.figure(figsize=(12, 8))
palette = sns.color_palette("tab10", n_colors=num_classes)
sns.scatterplot(
    x=emb_2d[:, 0],
    y=emb_2d[:, 1],
    hue=labels.numpy(),
    palette=palette,
    s=30,
    alpha=0.7,
    linewidth=0
)
plt.title(f"t-SNE: {config.model} ({config.loss_function})")
plt.xlabel("t-SNE dimension 1")
plt.ylabel("t-SNE dimension 2")
plt.legend(title="Class", bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
plt.show()

print(f"Samples visualized: {len(emb_2d)}")
print(f"Device: {device}")

## Notes
- Try different checkpoints (cross-entropy vs SimCLR) to compare clusters.
- Adjust `NUM_SAMPLES` to trade off speed vs detail.
- Experiment with `perplexity` in the t-SNE cell for different scales.
- Change `DEVICE` to another GPU (e.g., `cuda:0`) or to `cpu` if needed.