# Training _The Cannon_ Using SFH Library

In [1]:
import numpy as np
from astropy.io import fits
import AnniesLasso as tc
import matplotlib.pyplot as plt
from AnniesLasso.thecannon.vectorizer.polynomial import PolynomialVectorizer
from AnniesLasso.thecannon.model import CannonModel

In [2]:
# Load the tables containing the training set labels, and the spectra.
training_set = fits.getdata("OUTPUTS/sfh_2000_10_20250826_144749_weights.fits")

flux = np.load("OUTPUTS/sfh_2000_10_20250826_144749_spectra.npy")
ivar = np.load("OUTPUTS/sfh_2000_10_20250826_144749_invvar.npy")
wav = np.load("OUTPUTS/sfh_2000_10_20250826_144749_wavelength.npy")

FileNotFoundError: [Errno 2] No such file or directory: 'OUTPUTS/sfh_2000_10_20250826_144749_weights.fits'

In [None]:
# Define softmax function for use throughout the notebook
def softmax(x, axis=1):
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / np.sum(e_x, axis=axis, keepdims=True)

# Check for NaNs or Infs in data arrays
def check_data_for_nans_infs(*arrays):
    for i, arr in enumerate(arrays):
        if np.isnan(arr).any():
            print(f"Warning: NaNs found in array {i}")
        if np.isinf(arr).any():
            print(f"Warning: Infs found in array {i}")

check_data_for_nans_infs(training_set, flux, ivar)

In [None]:
# Define training and test sets from library of 2000 spectra
q = np.random.randint(0, 10, len(training_set)) % 10
test_set = (q == 1)
train_set = (q != 1)

In [None]:
# Create a vectorizer that defines our model form.
vectorizer = PolynomialVectorizer(terms=("1", "2", "3", "4", "5", "6", "7", "8", "9", "10"))

In [None]:
vectorizer.label_names

('1', '2', '3', '4', '5', '6', '7', '8', '9', '10')

In [None]:
train_set = np.asarray(train_set, dtype=bool)
model = CannonModel(training_set[train_set], flux[train_set], ivar[train_set],
                    vectorizer=vectorizer, dispersion=wav)

# Train the model!
model.train()

2025-08-26 16:54:55,132 [INFO] Training 10-label CannonModel with 1797 stars and 4334 pixels/star
stty: 'standard input': Inappropriate ioctl for device
2025-08-26 16:54:55,172 [DEBUG] Couldn't get screen size. Progressbar may look odd.
stty: 'standard input': Inappropriate ioctl for device
2025-08-26 16:54:55,172 [DEBUG] Couldn't get screen size. Progressbar may look odd.


[                                                                                                    ]   0% (18/4334)                          



[                                                                                                    ]   1% (40/4334)                          



[=                                                                                                   ]   2% (79/4334)                          



[====                                                                                                ]   4% (189/4334)                          



[====                                                                                                ]   5% (198/4334)                          



[=====                                                                                               ]   5% (217/4334; ~3m until finished)      









































































































































































(array([[ 1.17664077e+00,  2.45471187e-02,  2.59078707e-02, ...,
         -2.30849544e-03,  7.70976049e-03,  7.72862328e-03],
        [ 1.08634328e+00,  1.90683543e-02,  2.02575023e-02, ...,
         -5.55836305e-03,  3.19759952e-03,  3.54186419e-03],
        [ 1.00335251e+00,  1.37133372e-02,  1.47424128e-02, ...,
         -8.82793648e-03, -1.00631981e-03, -2.31035114e-04],
        ...,
        [ 8.18757786e-01, -3.00859788e-02, -3.37131380e-02, ...,
          8.14426982e-04, -6.22295371e-03, -6.60476752e-03],
        [ 8.70296565e-01, -2.66186789e-02, -2.99237490e-02, ...,
          2.53358364e-03, -3.33227847e-03, -3.55133078e-03],
        [ 8.46324424e-01, -5.76100796e-03, -5.85292319e-03, ...,
         -4.10607860e-03, -3.83361737e-03, -3.78063062e-03]]),
 array([0., 0., 0., ..., 0., 0., 0.]),
 [{'grad': array([-8.69868931e-08, -1.70622141e-06, -9.47950107e-07, -7.45723986e-06,
           5.41293739e-06, -2.14922438e-06, -5.02850542e-06,  3.85084722e-06,
           5.92809318e-06,

In [None]:
validation_set_labels = model.test(
    flux[test_set], ivar[test_set])

stty: 'standard input': Inappropriate ioctl for device
2025-08-26 16:55:46,716 [DEBUG] Couldn't get screen size. Progressbar may look odd.
2025-08-26 16:55:46,717 [INFO] Running test step on 203 spectra
2025-08-26 16:55:46,717 [INFO] Running test step on 203 spectra


[                                                                                                    ]   0% (1/203)                          



In [None]:
validation_set_labels[2]

({'fvec': array([0.01181772, 0.00861301, 0.00673115, ..., 0.00302824, 0.00210074,
         0.01295808]),
  'nfev': 15,
  'njev': 10,
  'fjac': array([[-1.46398350e+02, -1.33001336e+02, -1.33313012e+02, ...,
          -3.79753322e-04, -5.12376384e-04, -3.20433013e-03],
         [-1.35084545e+02,  7.84138654e+00, -1.61697983e+01, ...,
          -8.45166070e-03, -3.23893770e-02, -8.06177595e-02],
         [-1.35401103e+02, -1.42953330e-02, -6.01079851e+00, ...,
          -8.79135033e-02, -7.28504860e-02,  5.94917399e-02],
         ...,
         [-1.26257087e+02,  7.02640704e-01,  1.48807261e-01, ...,
          -7.52021888e-02, -9.43118232e-02, -2.34982068e-01],
         [-1.22285740e+02,  1.57046066e+00, -1.42615467e+00, ...,
           1.14063212e-01,  1.43700420e-01,  1.11873391e-01],
         [-1.31925011e+02,  6.57315177e-02, -5.51878162e+00, ...,
          -1.57248626e-01, -1.05304369e-01,  4.37723020e-02]]),
  'ipvt': array([ 5,  3,  2,  8,  4,  7,  9,  6, 10,  1], dtype=int32),
  '

In [None]:
# Apply softmax transformation to ensure all labels are positive and sum to 1

# Ensure validation_set_labels is a numpy array
if isinstance(validation_set_labels, tuple):
    validation_set_labels = validation_set_labels[0]

if hasattr(validation_set_labels, 'dtype') and validation_set_labels.dtype.fields is not None:
    # Structured array: convert to regular ndarray
    validation_set_labels = validation_set_labels.view(np.float64).reshape(len(validation_set_labels), -1)

# If you trained on 9 parameters, append a column of zeros for the 10th
if validation_set_labels.shape[1] == 9:
    logits = np.hstack([validation_set_labels, np.zeros((validation_set_labels.shape[0], 1))])
else:
    logits = validation_set_labels

labels_simplex = softmax(logits, axis=1)

# labels_simplex now contains 10 positive values per object that sum to 1

In [None]:
# Plotting: compare predicted and true labels in simplex space

# For the true labels, append a column of zeros if using only 9 parameters
if training_set[test_set].shape[1] == 9:
    true_logits = np.hstack([training_set[test_set], np.zeros((training_set[test_set].shape[0], 1))])
else:
    true_logits = training_set[test_set]

true_labels_simplex = softmax(true_logits, axis=1)

# Define label names for plotting
label_names = [f"label_{i+1}" for i in range(labels_simplex.shape[1])]

for i, label_name in enumerate(label_names):
    # fig, ax = plt.subplots()
    x = true_labels_simplex[:, i]
    y = labels_simplex[:, i]
    abs_diff = np.abs(y - x)
    # ax.scatter(x, y, facecolor="k")

    # limits = np.array([ax.get_xlim(), ax.get_ylim()])
    # ax.set_xlim(limits.min(), limits.max())
    # ax.set_ylim(limits.min(), limits.max())

    # ax.set_title(f"{label_name}: {np.mean(abs_diff):.2f}")
    print(f"{label_name}: {np.mean(abs_diff):.5f}")

label_1: 0.03641
label_2: 0.03711
label_3: 0.00323
label_4: 0.00310
label_5: 0.00408
label_6: 0.00515
label_7: 0.00374
label_8: 0.00606
label_9: 0.00396
label_10: 0.00862


In [None]:
labels_simplex

array([[0.15913446, 0.06617076, 0.11022674, ..., 0.09381309, 0.09254088,
        0.09737515],
       [0.13148326, 0.08484146, 0.09612859, ..., 0.09834795, 0.09355952,
        0.09842919],
       [0.11303218, 0.08169583, 0.09829652, ..., 0.09761165, 0.10205544,
        0.09960294],
       ...,
       [0.14153675, 0.06906771, 0.09777181, ..., 0.09591999, 0.09947179,
        0.10469732],
       [0.13834766, 0.06880946, 0.10157462, ..., 0.09741664, 0.09308973,
        0.10168769],
       [0.10153415, 0.09296831, 0.10839247, ..., 0.09827995, 0.09762404,
        0.09278577]])