In [1]:
%conda install -yq pytorch-cpu -c pytorch

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.
Channels:
 - pytorch
 - conda-forge
Platform: linux-64
Collecting package metadata (repodata.json): ...working... done
Solving environment: ...working... done

## Package Plan ##

  environment location: /opt/conda/envs/notebook

  added / updated specs:
    - pytorch-cpu


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    ca-certificates-2024.12.14 |       hbcca054_0         153 KB  conda-forge
    filelock-3.16.1            |     pyhd8ed1ab_1          17 KB  conda-forge
    gmpy2-2.1.5                |  py311h0f6cedb_3         198 KB  conda-forge
    libtorch-2.5.1             |cpu_generic_h1b269f6_6        51.0 MB  conda-forge
    mpc-1.3.1                  |       h24ddda3_1         114 KB  conda-forge
    mpmath-1.3.0               |

In [1]:
"""This notebook trains and evaluates a convolutional neural network for wavelet analysis.
"""

import os
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from copy import deepcopy

from model import ConvNet

In [2]:
cols = ["ebs", "exo", "flares", "rot"]
cat = pd.DataFrame([], columns=["TIC", "sector", *cols, "wavelet"])
cat.index.name = "filename"

for i, col in enumerate(cols):
    for s in os.listdir(f"wavelets/tess-{col}"):
        if s not in cat.index:
            wav = np.load(f"wavelets/tess-{col}/{s}")
            cat.loc[s] = [int(s[24:40]), int(s[19:23]), *([0]*len(cols)), wav/wav.max()]
        cat.loc[s, col] = 1

cat = cat.reset_index().set_index(["TIC", "sector"]).sort_index()
cat[cols] = cat[cols].div(cat[cols].sum(axis=1), axis=0)

In [None]:
# construct list of values "{tic}_{s}" for each sector of each TIC ID
# unused for now but this might be useful later.
# ebs = pd.read_csv("catalogs/tess-ebs.csv")
# ebs["sectors"] = ebs["sectors"].apply(lambda x: list(map(int, x.strip("[]").split(", "))))
# ids = [f"{row['ID']}_{s}" for _, row in ebs.iterrows() for s in row['sectors']]
# len(ids)

In [3]:
# Set global parameters based on command line input
RUN_NUMBER = 0 # which channels selection to use
BATCH_SIZE = 100
PATIENCE = 30
MAX_EPOCHS = 500
RUN_NAME = "batchnorm1"

MODEL_NAME = f"{RUN_NAME}_{RUN_NUMBER}"

# Channel configurations
channels = {
    0: [8, 16, 32],
    1: [16, 32, 64],
    2: [32, 64, 128],
    3: [64, 128, 256]
}

# Create output directories
for folder in ['models', 'plots', 'losses', 'output']:
    os.makedirs(os.path.join(MODEL_NAME, folder), exist_ok=True)

selected_channels = channels[RUN_NUMBER]
loss_function = torch.nn.CrossEntropyLoss()

In [4]:
class WaveletDataset(Dataset):
    """Custom dataset class for loading wavelet data from files.

    This class is responsible for loading wavelet data and corresponding labels
    from the specified file paths. It supports splitting the data into training,
    validation, and test sets.

    Attributes:
        data_frame (np.ndarray): Array containing the wavelet data.
        labels (np.ndarray): Array containing the scaled labels.
    """    
    def __init__(self, data, mode, split=[0.8, 0.1, 0.1]):
        """
        Initialize the dataset.
        
        Args:
            data (DataFrame): the input and output data.
            mode (str): Mode to load ('train', 'val', or 'test').
            split (list): train, validation, and test split fractions.
        """
        ftrain, fval, ftest = [s/sum(split) for s in split]
        i_train = int(ftrain * len(data))
        i_val = int((ftrain + fval) * len(data))
        
        if mode == "train":
            df = data.iloc[:i_train]
        elif mode == "val":
            df = data.iloc[i_train:i_val]
        elif mode == "test":
            df = data.iloc[i_val:]
            
        self.features = df["wavelet"].values
        self.labels = df[["ebs", "exo", "flares", "rot"]].values

    def __len__(self):
        """Return the length of the dataset."""
        return len(self.labels)

    def __getitem__(self, idx):
        """Retrieve a single sample and its corresponding label.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            tuple: A tuple containing the sample data (torch.Tensor) and the 
                   corresponding label (torch.Tensor).
        """
        X = self.features[idx].astype('float32')
        X = torch.unsqueeze(torch.tensor(X), 0)
        label = torch.tensor(self.labels[idx].astype('float32'))
        return X, label
        

def train(model, device, train_loader, val_loader, patience, max_epochs, model_name="cnn"):
    """Train the neural network for the specified number of epochs.

    This function orchestrates the training loop, updating model weights based on
    the training data, validating the model on a validation set, and handling
    early stopping based on validation loss.

    Args:
        model (torch.nn.Module): The neural network model to be trained.
        device (torch.device): The device (CPU or GPU) on which to perform training.
        train_loader (DataLoader): DataLoader for the training dataset.
        val_loader (DataLoader): DataLoader for the validation dataset.
        patience (int): Early stopping patience.
        max_epochs (int): Maximum number of training iterations.
        model_name (str): The name of the model, used for saving the trained weights.

    Returns:
        tuple: A tuple containing the best model weights, training losses, and validation losses.
    """
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    scheduler = ReduceLROnPlateau(optimizer, factor=0.7, patience=3)
    
    train_losses, val_losses = [], []
    min_loss = float('inf')
    early_stopping_count = 0
    best_weights = None

    for epoch in range(1, max_epochs + 1):
        train_loss = train_epoch(model, device, train_loader, optimizer, epoch)
        val_loss = test(model, device, val_loader, epoch, MODEL_NAME, mode="Validation")
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        scheduler.step(val_loss)

        if val_loss < min_loss:
            min_loss = val_loss
            early_stopping_count = 0
            best_weights = deepcopy(model.state_dict())
        else:
            early_stopping_count += 1
            print(f"        Early Stopping count: {early_stopping_count}/{patience}")
            if early_stopping_count == patience:
                print(f"\nEarly Stopping. Best Epoch: {epoch - patience} with loss {min_loss:.4f}.")
                break

    torch.save(best_weights, f"{model_name}/models/{model_name}.pt")
    return best_weights, train_losses, val_losses


def train_epoch(model, device, train_loader, optimizer, epoch):
    """Train the model for one epoch.

    This function processes each batch of training data, computes the loss,
    and updates the model weights accordingly.

    Args:
        model (torch.nn.Module): The neural network model to be trained.
        device (torch.device): The device (CPU or GPU) for training.
        train_loader (DataLoader): DataLoader for the training dataset.
        optimizer (torch.optim.Optimizer): The optimizer used for weight updates.
        epoch (int): The current epoch number.

    Returns:
        float: The average loss for the epoch.
    """
    model.train()
    losses = []

    ndata = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_function(output, target)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        ndata += len(data)
        #if (batch_idx * len(data)) % 10000 == 0:
        print(
            f'Train Epoch: {epoch:3d} [{ndata:6d}/{len(train_loader.dataset)}'
            f' ({100*ndata/len(train_loader.dataset):3.0f}%)]\tLoss: {losses[-1]:.6f}',
            end="\r")

    return np.mean(losses)


def test(model, device, test_loader, epoch=None, model_name=None, mode="Validation"):
    """Evaluate the model on the test set.

    This function assesses the model's performance on a specified dataset
    and computes the average loss. It can also generate a plot of predictions
    versus true values.

    Args:
        model (torch.nn.Module): The neural network model to be evaluated.
        device (torch.device): The device (CPU or GPU) for evaluation.
        test_loader (DataLoader): DataLoader for the test dataset.
        epoch (int, optional): The current epoch number (for labeling purposes).
        model_name (str, optional): The name of the model (for labeling purposes).
        mode (str, optional): Indicates whether the evaluation is for training, validation, or testing.

    Returns:
        float: The average loss on the test set.
    """
    model.eval()
    test_loss = 0
    targets, preds = [], []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            targets.extend(target.cpu().numpy())
            preds.extend(output.cpu().numpy())
            test_loss += loss_function(output, target).item()

    test_loss /= len(test_loader)
    print(f'\n                     Average {mode} Loss: {test_loss:.4f}')

    if mode.lower() == "test":
        return np.squeeze(preds), np.squeeze(targets), test_loss
    return test_loss

In [5]:
# Create Datasets
np.random.seed(42)
mycat = cat.sample(frac=1)
train_loader = DataLoader(WaveletDataset(mycat, "train"), batch_size=BATCH_SIZE)
val_loader = DataLoader(WaveletDataset(mycat, "val"), batch_size=BATCH_SIZE)

device = torch.device("cpu")
model = ConvNet(channels[RUN_NUMBER], k=3).to(device)

# Training
weights, train_losses, val_losses = train(model, device, train_loader, val_loader, PATIENCE, MAX_EPOCHS, MODEL_NAME)

# Evaluate best-fit model
model.load_state_dict(weights)
print("\nFinal Performance!")
test(model, device, train_loader, mode="Training")
test(model, device, val_loader, mode="Validation")

# Plot learning curve
plt.figure()
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss', alpha=0.5)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f"{MODEL_NAME}/plots/{MODEL_NAME}_performance.png")
plt.close()

                     Average Validation Loss: 1.3629
                     Average Validation Loss: 1.3080
                     Average Validation Loss: 1.2868
                     Average Validation Loss: 1.2695
                     Average Validation Loss: 1.2561
                     Average Validation Loss: 1.2430
                     Average Validation Loss: 1.2325
                     Average Validation Loss: 1.2246
                     Average Validation Loss: 1.2150
                     Average Validation Loss: 1.2076
                     Average Validation Loss: 1.2011
                     Average Validation Loss: 1.1948
                     Average Validation Loss: 1.1897
                     Average Validation Loss: 1.1839
                     Average Validation Loss: 1.1784
                     Average Validation Loss: 1.1752
                     Average Validation Loss: 1.1711
                     Average Validation Loss: 1.1670
                     Average Validation Loss: 

In [6]:
# Predictions on Test set
print('\nPrediction on Test set')
test_loader = DataLoader(WaveletDataset(mycat, "test"), batch_size=BATCH_SIZE)
preds, labels, loss = test(model, device, test_loader, mode="Test")


Prediction on Test set

                     Average Test Loss: 0.9887


In [7]:
output = mycat.iloc[-len(labels):].drop(columns="wavelet")
output = output.rename(columns={c: c+"_true" for c in cols})
output[[c+"_pred" for c in cols]] = preds
output.to_csv(f"{MODEL_NAME}/output/{MODEL_NAME}_output.csv")
output

Unnamed: 0_level_0,Unnamed: 1_level_0,filename,ebs_true,exo_true,flares_true,rot_true,ebs_pred,exo_pred,flares_pred,rot_pred
TIC,sector,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
366972961,25,tess2020133194932-s0025-0000000366972961-0182-...,1.0,0.0,0.0,0.0,0.992689,0.004412,0.001550,0.001349
349156098,31,tess2020294194027-s0031-0000000349156098-0198-...,1.0,0.0,0.0,0.0,0.890435,0.087259,0.013445,0.008861
139804406,1,tess2018206045859-s0001-0000000139804406-0120-...,0.0,0.0,1.0,0.0,0.007339,0.005922,0.926536,0.060204
237913194,28,tess2020212050318-s0028-0000000237913194-0190-...,0.0,1.0,0.0,0.0,0.004000,0.993885,0.001446,0.000668
238123653,7,tess2019006130736-s0007-0000000238123653-0131-...,0.0,0.0,0.0,1.0,0.008049,0.006662,0.044507,0.940781
...,...,...,...,...,...,...,...,...,...,...
264461976,32,tess2020324010417-s0032-0000000264461976-0200-...,0.0,0.0,0.0,1.0,0.166463,0.006412,0.036199,0.790926
339960875,7,tess2019006130736-s0007-0000000339960875-0131-...,0.0,0.0,0.0,1.0,0.009004,0.001515,0.030555,0.958926
343173162,24,tess2020106103520-s0024-0000000343173162-0180-...,0.0,0.0,0.0,1.0,0.000272,0.000790,0.001234,0.997703
350073391,26,tess2020160202036-s0026-0000000350073391-0188-...,0.0,0.0,0.0,1.0,0.008206,0.006810,0.018824,0.966161
