In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

import os, sys, csv

import sys
sys.path.append("./sngp/")

%load_ext autoreload
%autoreload 2
# Importing our custom module(s)

from sngp.model import RFFGP_Reg
from sngp.loss import square_loss
from sngp.train import train_model

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA:", torch.cuda.get_device_name(device))
# elif torch.backends.mps.is_available():
#     device = torch.device("mps")
#     print("Using MPS (Apple Metal)")
else:
    device = torch.device("cpu")
    print("Using CPU")

print("Final device:", device)

### Saving sngp figures

In [None]:
#Based on gp_vs_sngp/gaussian_process/true_gp.py lines 375-459
#fig_path = /results/SNGP/dataset_nsamples/ or /results/SNGP/dataset_nsamples_seed/
def save_fig_sngp(sngp, X_test, dataset, fig_path, seed=42):
    np.random.seed(seed)
    
    # ---- Determine extended range ----
    x_test_1d = X_test.squeeze(-1)
    x_min = x_test_1d.min().item()
    x_max = x_test_1d.max().item()
    x_range = x_max - x_min
    
    # Extend by 10% on each side
    x_extended_min = x_min - 0.1*x_range
    x_extended_max = x_max + 0.1*x_range
    
    #Commented out the following:
    '''
    # ---- Generate additional noisy test points in extended regions ----
    # Sample ~50 points in each extended region
    n_extra_left = 50
    n_extra_right = 50
    
    x_extra_left = np.random.uniform(x_extended_min, x_min, n_extra_left)
    x_extra_right = np.random.uniform(x_max, x_extended_max, n_extra_right)
    x_extra = np.concatenate([x_extra_left, x_extra_right])
    
    # Generate noisy y values for these extra points
    if dataset == "Sin":
        y_extra = getSin(x_extra, noise=0.1)
    elif dataset == "CrazySin":
        y_extra = getCrazySin(x_extra, noise=0.1)
    else:
        raise ValueError(f"Unknown dataset: {dataset}")
    
    # Convert to torch and combine with original test set
    X_extra = torch.tensor(x_extra, dtype=torch.float32).unsqueeze(-1)
    X_test_extended = torch.cat([X_test, X_extra], dim=0)
    '''
    
    # ---- Get predictions on extended test set ----
    samples, mean, var = sngp.predict(X_test)      #Old: mean, var, cov = sngp.predict(X_test_extended)
    std = torch.sqrt(var.clamp(min=1e-6))
    
    # Sort everything for plotting
    x_extended_1d = X_test.squeeze(-1)              #Old: x_extended_1d = X_test_extended.squeeze(-1)
    idx_extended = torch.argsort(x_extended_1d)
    x_extended_sorted = x_extended_1d[idx_extended]
    mean_sorted = mean[idx_extended]
    std_sorted = std[idx_extended]
    
    # ---- Generate clean data for true function (dense for smooth line) ----
    x_clean = np.linspace(x_extended_min, x_extended_max, 500)
    
    if dataset == "Sin":
        y_clean = getSin(x_clean, noise=0.0)
    elif dataset == "CrazySin":
        y_clean = getCrazySin(x_clean, noise=0.0)
    else:
        raise ValueError(f"Unknown dataset: {dataset}")
    
    # ---- Plotting ----
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    
    # Confidence band (now extends to cover clean function range)
    ax.fill_between(
        x_extended_sorted.detach().numpy(),
        (mean_sorted - 2*std_sorted).detach().numpy(),
        (mean_sorted + 2*std_sorted).detach().numpy(),
        alpha=0.3, color='blue', label='95% confidence (±2σ)'
    )
    
    # Posterior mean
    ax.plot(x_extended_sorted.detach().numpy(), mean_sorted.detach().numpy(), 
            'b-', linewidth=2, label='Posterior mean μ*')
    
    # True function (dotted line)
    ax.plot(x_clean, y_clean, 
            'k--', linewidth=1, label='True function')
    
    ax.set_xlabel('x', fontsize=12)
    ax.set_ylabel('y', fontsize=12)
    ax.set_title('SNGP Regression: Full Closed Form', fontsize=14)
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)
    ax.set_xlim(x_extended_sorted.min().item(), x_extended_sorted.max().item())
    ax.set_ylim(-3, 3)
    
    plt.tight_layout()
    plt.savefig(fig_path, dpi=600)
    plt.close()

In [None]:
##Meant to save rank vs. all metrics for a certain dataset (Sin or CrazySin) of size = n_samples in new file
#Potentially at end of train.py or test.py in SNGP
import os

shared_dir = "/results/SNGP/dataset_nsamples_seed/"

# Append summary to shared file "likelyRank.csv"
with open(os.path.join(shared_dir, "likelyRank.csv"), "a") as f:
    f.write(f"{model.rank},{log_likelihood}\n")

In [None]:
#Get plots of log_likelihood vs. rank fo dataset_nsample
import csv

rank = []
LLHood = []

with open("/results/SNGP/dataset_nsamples_seed/likelyRank.csv", newline="") as f:
    reader = csv.DictReader(f)
    for row in reader:
        epochs.append(int(row["epoch"]))
        losses.append(float(row["loss"]))

plt.plot(epochs, losses)
plt.xlabel("Model Rank (Train Set Percentage)")
plt.ylabel("Log Likelihood")
plt.title("Log Likelihood vs Rank")
plt.show()