### Import libraries

In [None]:
import pandas as pd
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch
from models import *

### Random, Torch and Cuda settings

In [None]:
seed = 50
random_seed = 40
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Define test function

In [None]:
def test(model, dataloader, criterion, device, X_test):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    all_sobject_ids  = []
    true_labels = []
    predicted_labels = []
    predicted_probs = []
    conv_outputs = []
    
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            output, conv_out = model(data, return_conv=True)
            loss = criterion(output, target)

            running_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

            all_indices = (batch_idx * dataloader.batch_size) + np.arange(target.size(0))
            all_sobject_ids.extend(X_test.iloc[all_indices].index.tolist())
            true_labels.extend(target.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())
            predicted_probs.extend(F.softmax(output, dim=1).cpu().numpy())
            conv_outputs.append(conv_out.cpu().numpy())

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct / total

    return epoch_loss, epoch_acc, all_sobject_ids, true_labels, predicted_labels, predicted_probs, conv_outputs

### Load criterion and optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

### Load data
The .csv file should contain two columns - sobject_id and label (1 for binary, 0 for single). The .h5 file should contain spectra converted from .fits files as shown in fits_2_h5.ipynb.

In [None]:
labels_df = pd.read_csv(r"path-to-labels.csv")

store = pd.HDFStore(f'path-to-file.h5')
X_test = store['X_test']
y_test = store['y_test'].values.flatten()
store.close()

### Run the model
First set the run_count, which defines how many times you want to run the model (with a different random seed each time). Then choose the model architecture, all the models are listed in models.py. The trained models are saved in the folder models_saved. After the model goes through all the runs, the results are saved to a .csv file.

In [None]:
# Create PyTorch datasets
class SpectraDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X.values, dtype=torch.float32).unsqueeze(1)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

test_dataset = SpectraDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

test_samples_df = pd.DataFrame()

run_count = 5
flux_length = 4096

for run in range(run_count):

    model = CNN_1a(flux_length)
    model = model.to(device)

    model.load_state_dict(torch.load(f'models_saved/{model.__class__.__name__[3:]}_run_{run + 1}.pth', map_location=torch.device('cuda')))
    model.eval()

    test_loss, test_acc, test_sobject_ids, true_labels, predicted_labels, predicted_probs, conv_outputs = test(model, test_loader, criterion, device, X_test)

    if run == 0:
        test_samples_df = pd.DataFrame({
            'sobject_id': test_sobject_ids,
            'true_label': true_labels
        })

    test_samples_df[f'predicted_label_run_{run + 1}'] = predicted_labels
    test_samples_df[f'predicted_prob_run_{run + 1}'] = [max(probs) for probs in predicted_probs]

    print(f"Run {run + 1} completed.")

test_samples_df.to_csv(f"path-to-file.csv", index=False)

print('FINISHED!')
