# Distribution Analysis

In [None]:
from constants.data_constants import JORDAN_DATASET_FILEPATH, MAESTRO_DATASET_FILEPATH
from constants.real_time_constants import SLIDING_WINDOW_LEN, STRIDE
from data.jordan_dataset import JordanDataset
from data.maestro_dataset import MaestroDataset
from data.sliding_window import SlidingWindowDataset

## pure data
id_train_base_dataset = JordanDataset(
    data_dir=JORDAN_DATASET_FILEPATH,
    split="train",
    name="id_train_dataset",
    split_input_and_output_ids=True,
)
id_test_base_dataset = JordanDataset(
    data_dir=JORDAN_DATASET_FILEPATH,
    split="validation",
    name="id_test_dataset",
    split_input_and_output_ids=True,
)
ood_test_base_dataset = MaestroDataset(
    data_dir=MAESTRO_DATASET_FILEPATH,
    split="test",
    name="maestro_test_dataset"
)

## dataset that takes chunks of 120 tokens out of the above datasets
id_train_dataset = SlidingWindowDataset(
    base_dataset=id_train_base_dataset,
    name="id_train_dataset",
    k=SLIDING_WINDOW_LEN,
    stride=STRIDE,
)
id_test_dataset = SlidingWindowDataset(
    base_dataset=id_test_base_dataset,
    name="id_test_dataset",
    k=SLIDING_WINDOW_LEN,
    stride=STRIDE,
)   
ood_test_dataset = SlidingWindowDataset(
    base_dataset=ood_test_base_dataset,
    name="ood_test_dataset",
    k=SLIDING_WINDOW_LEN,
    stride=STRIDE,
)

In [None]:
from constants.model_constants import JORDAN_MODEL_NAME, DEVICE
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM
from utils.data_loading import collate_fn
from extract_layers.pooling_functions import pool_last_k_tokens

model = AutoModelForCausalLM.from_pretrained(JORDAN_MODEL_NAME).to(DEVICE)
num_layers = 24
pooling_function = pool_last_k_tokens(1)

id_train_dataloader = DataLoader(
    id_train_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn
)

In [None]:
from extract_layers.extract_layers_main import extract_representations

data = extract_representations(
    model,
    id_train_dataloader,
    pooling_function,
    layers=[12],
)
layer_data = data[12]
print(layer_data.shape)


In [None]:
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# Convert to numpy if it's a torch tensor
if hasattr(layer_data, 'numpy'):
    X = layer_data.numpy()
else:
    X = np.array(layer_data)

print(f"Data shape: {X.shape}")

# normalize
scaler = StandardScaler()
X = scaler.fit_transform(X)

# Fit PCA with maximum components to see full explained variance
n_components = min(100, X.shape[1])  # Use up to 100 components or all features
pca = PCA(n_components=n_components)
pca.fit(X)

# Explained variance ratio (proportion of total variance explained by each component)
explained_variance_ratio = pca.explained_variance_ratio_  # Shape: (n_components,)

# Cumulative explained variance
cumulative_variance = np.cumsum(explained_variance_ratio)

# Find number of components for 95% explained variance
# np.searchsorted finds the index where 0.95 would be inserted to maintain sorted order
# This gives us the first index where cumulative_variance >= 0.95
n_components_95 = np.searchsorted(cumulative_variance, 0.95) + 1
n_components_99 = np.searchsorted(cumulative_variance, 0.99) + 1
# If threshold is never reached, searchsorted returns len(array), so we cap it
n_components_95 = min(n_components_95, len(cumulative_variance))
n_components_99 = min(n_components_99, len(cumulative_variance))

print(f"\nExplained Variance Statistics:")
print(f"  First component: {explained_variance_ratio[0]:.4f} ({explained_variance_ratio[0]*100:.2f}%)")
print(f"  First 10 components: {cumulative_variance[9]:.4f} ({cumulative_variance[9]*100:.2f}%)")
print(f"  Components for 95% variance: {n_components_95}")
print(f"  Components for 99% variance: {n_components_99}")
print(f"  Total variance explained by {n_components} components: {cumulative_variance[-1]:.4f} ({cumulative_variance[-1]*100:.2f}%)")

# Plot explained variance
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Explained variance ratio per component
axes[0].plot(range(1, min(50, len(explained_variance_ratio)) + 1), 
              explained_variance_ratio[:50], 'b-', marker='o', markersize=3)
axes[0].set_xlabel('Principal Component')
axes[0].set_ylabel('Explained Variance Ratio')
axes[0].set_title('Explained Variance per Component')
axes[0].grid(True, alpha=0.3)

# Plot 2: Cumulative explained variance
axes[1].plot(range(1, len(cumulative_variance) + 1), 
             cumulative_variance, 'r-', linewidth=2)
axes[1].axhline(y=0.95, color='g', linestyle='--', label='95% threshold')
axes[1].axhline(y=0.99, color='orange', linestyle='--', label='99% threshold')
axes[1].axvline(x=n_components_95, color='g', linestyle=':', alpha=0.5)
axes[1].axvline(x=n_components_99, color='orange', linestyle=':', alpha=0.5)
axes[1].set_xlabel('Number of Components')
axes[1].set_ylabel('Cumulative Explained Variance')
axes[1].set_title('Cumulative Explained Variance')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# UMAP visualization in 2D and 3D (with fallback to t-SNE if UMAP unavailable)
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Try to import UMAP, fall back to t-SNE if there's a compatibility issue
try:
    import umap
    USE_UMAP = True
    method_name = "UMAP"
except ImportError as e:
    print(f"UMAP import failed ({e}), falling back to t-SNE...")
    from sklearn.manifold import TSNE
    USE_UMAP = False
    method_name = "t-SNE"

# Subsample data (it's computationally expensive)
# Use a subset for faster computation
n_samples = min(5000, X.shape[0])
if X.shape[0] > n_samples:
    print(f"Subsampling to {n_samples} samples for {method_name} visualization...")
    indices = np.random.choice(X.shape[0], n_samples, replace=False)
    X_subset = X[indices]
else:
    X_subset = X
    indices = np.arange(X.shape[0])

print(f"Running {method_name} on {X_subset.shape[0]} samples...")

# 2D embedding
print(f"Computing 2D {method_name}...")
if USE_UMAP:
    reducer_2d = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, min_dist=0.1)
else:
    reducer_2d = TSNE(n_components=2, random_state=42, perplexity=30, max_iter=1000)
result_2d = reducer_2d.fit_transform(X_subset)

# 3D embedding
print(f"Computing 3D {method_name}...")
if USE_UMAP:
    reducer_3d = umap.UMAP(n_components=3, random_state=42, n_neighbors=15, min_dist=0.1)
else:
    reducer_3d = TSNE(n_components=3, random_state=42, perplexity=30, max_iter=1000)
result_3d = reducer_3d.fit_transform(X_subset)

# Plot embeddings
fig = plt.figure(figsize=(16, 6))

# 2D plot
ax1 = fig.add_subplot(121)
scatter = ax1.scatter(result_2d[:, 0], result_2d[:, 1], 
                     c=range(len(result_2d)), cmap='viridis', 
                     alpha=0.6, s=10)
ax1.set_xlabel(f'{method_name} Component 1')
ax1.set_ylabel(f'{method_name} Component 2')
ax1.set_title(f'{method_name} Visualization (2D)')
ax1.grid(True, alpha=0.3)
plt.colorbar(scatter, ax=ax1, label='Sample Index')

# 3D plot
ax2 = fig.add_subplot(122, projection='3d')
scatter3d = ax2.scatter(result_3d[:, 0], result_3d[:, 1], result_3d[:, 2],
                       c=range(len(result_3d)), cmap='viridis', 
                       alpha=0.6, s=10)
ax2.set_xlabel(f'{method_name} Component 1')
ax2.set_ylabel(f'{method_name} Component 2')
ax2.set_zlabel(f'{method_name} Component 3')
ax2.set_title(f'{method_name} Visualization (3D)')
plt.colorbar(scatter3d, ax=ax2, label='Sample Index')

plt.tight_layout()
plt.show()

print(f"\n{method_name} completed!")
print(f"2D {method_name} shape: {result_2d.shape}")
print(f"3D {method_name} shape: {result_3d.shape}")


