### Import libraries

In [None]:
import pandas as pd
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch
from models import *

### Random, Torch and Cuda settings

In [None]:
seed = 50
random_seed = 40
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Define train and test functions

In [None]:
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def test(model, dataloader, criterion, device, X_test):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    all_sobject_ids  = [] 
    true_labels = []
    predicted_labels = []
    predicted_probs = []

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)

            running_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

            all_indices = (batch_idx * dataloader.batch_size) + np.arange(target.size(0))
            all_sobject_ids.extend(X_test.iloc[all_indices].index.tolist())
            true_labels.extend(target.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())
            predicted_probs.extend(F.softmax(output, dim=1).cpu().numpy())

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct / total

    return epoch_loss, epoch_acc, all_sobject_ids, true_labels, predicted_labels, predicted_probs

### Load data
The .h5 file should contain spectra converted from .fits files as shown in fits_2_h5.ipynb.

In [None]:
store = pd.HDFStore('path-to-dataset.h5')

X_train = store['X_train']
X_val = store['X_val']
X_test = store['X_test']
y_train = store['y_train'].values.flatten()
y_val = store['y_val'].values.flatten()
y_test = store['y_test'].values.flatten()
store.close()

# Create PyTorch datasets
class SpectraDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X.values, dtype=torch.float32).unsqueeze(1)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_dataset = SpectraDataset(X_train, y_train)
val_dataset = SpectraDataset(X_val, y_val)
test_dataset = SpectraDataset(X_test, y_test)

# Create PyTorch dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

### Train the model

First we define n_epochs as the uppermost limit before the program termination, patience sets early stopping and run_count defines how many times you want to run the model (with a different random seed each time). 
The train/val/test_results_{modelname}.csv include probabilites as given by the softmax function. In {modelname}_train-val.csv the loss and accuracy for train and validation sets are saved. In {modelname}_testAcc.csv the accuracy for the test set are saved.

In [None]:
n_epochs = 1000
patience = 5
run_count = 5
flux_length = 4096

run_data = {
    'run': [],
    'num_epochs': [],
    'test_acc': []
}

test_samples_df = None

all_data = pd.DataFrame()

for run in range(run_count):
    model = CNN_1a(flux_length)
    model = model.to(device)
    model_name = model.__class__.__name__[3:] 

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    train_losses, train_accuracies = [], []
    val_losses, val_accuracies = [], []

    best_val_loss = float('inf')
    epochs_no_improve = 0

    for epoch in range(n_epochs):
        print('Run: ', run, '   Epoch: ', epoch, end='\r')
        train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, _, _, _, _  = test(model, val_loader, criterion, device, X_val)

        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                break

    train_loss, train_acc, train_sobject_ids, train_true_labels, train_predicted_labels, train_predicted_probs = test(model, train_loader, criterion, device, X_train)
    val_loss, val_acc, val_sobject_ids, val_true_labels, val_predicted_labels, val_predicted_probs = test(model, val_loader, criterion, device, X_val)
    test_loss, test_acc, test_sobject_ids, true_labels, predicted_labels, predicted_probs = test(model, test_loader, criterion, device, X_test)

    if run == 0: train_samples_df = pd.DataFrame({'sobject_id': train_sobject_ids,'true_label': train_true_labels})
    if run == 0: val_samples_df = pd.DataFrame({'sobject_id': val_sobject_ids,'true_label': val_true_labels})
    if run == 0: test_samples_df = pd.DataFrame({'sobject_id': test_sobject_ids,'true_label': true_labels})

    train_samples_df[f'predicted_label_run_{run + 1}'] = train_predicted_labels
    train_samples_df[f'predicted_prob_run_{run + 1}'] = [max(probs) for probs in train_predicted_probs]

    val_samples_df[f'predicted_label_run_{run + 1}'] = val_predicted_labels
    val_samples_df[f'predicted_prob_run_{run + 1}'] = [max(probs) for probs in val_predicted_probs]

    test_samples_df[f'predicted_label_run_{run + 1}'] = predicted_labels
    test_samples_df[f'predicted_prob_run_{run + 1}'] = [max(probs) for probs in predicted_probs]

    run_data['run'].append(run + 1)
    run_data['num_epochs'].append(epoch + 1)  # epoch is zero-indexed, so adding 1
    run_data['test_acc'].append(test_acc)

    data = {
        'run': [run + 1] * len(train_losses),
        'epoch': list(range(1, len(train_losses) + 1)),
        'train_losses': train_losses,
        'train_accuracies': train_accuracies,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies
    }
    df = pd.DataFrame(data)
    all_data = pd.concat([all_data, df], ignore_index=True)
    torch.save(model.state_dict(), f'models_saved/{modelname}_run_{run + 1}.pth')

df_run = pd.DataFrame(run_data)


train_samples_df.to_csv(f"train_results_{modelname}.csv", index=False) 
val_samples_df.to_csv(f"val_results_{modelname}.csv", index=False)
test_samples_df.to_csv(f"test_results_{modelname}.csv", index=False)
all_data.to_csv(f'{modelname}_train-val.csv', index=False)
df_run.to_csv(f'{modelname}_testAcc.csv', index=False)

print('FINISHED!')