Below is just the code for the pytorch dataloader, which will read the different data points for us in a convenient way. The lines of interest in this code snippet are just lines 62-65, where the scaler object is loaded from the pickle format, and lines 135-140 at the end, which basically contain the definition of `scale_data`.

You do not need to encapsulate the use of the scaling of the data points with such a class, this just stands to illustrates how one can go about loading the scaler and then subsequenty using it the right way, even with data with some columns dropped out, which the scaler exepcts to be there.

In [28]:
import torch
import os
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from sklearn.preprocessing import StandardScaler
import joblib #this is used to "unpickle" the scaler

import spacetimeformer as stf

import matplotlib.pyplot as plt

'''
This is the data loader class for the bats dataset. It is written so that it 
can load data that is generated from the Sonobats software, and is then processed
in a particular way, by the data/prepare_data.py script. A description of the 
arguments is as follows:
root_path: The path to the directory where the data is stored.
prefix: The prefix of the files that are stored in the root_path directory.
ignore_cols: A list of columns to ignore when loading the data.
target_cols: A list of columns that are the target columns (the columns that we want to predict).
time_col_name: The name that we want to give to the time column.
'''
class BatsCSVDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 root_path = '/home/vdesai/bats_data/new_dataset/splits',
                 prefix = 'split',
                 ignore_cols = [],
                 target_cols = [],
                 time_col_name = "TimeIndex",
                 val_split = 0.1, 
                 test_split = 0.1, 
                 context_points = 57,
                 target_points = 1,
                 split = "train"
    ):
        assert root_path is not None
        assert prefix is not None

        self.root_path = root_path
        self.prefix = prefix
        
        self.mapping_df = pd.read_csv(f"{root_path}/{prefix}_mapping.csv")
        self.config_df = pd.read_csv(f"{root_path}/{prefix}_config.csv")

        self.max_length = self.config_df[self.config_df.parameter == "max_length"]["value"].values[0]
        self.min_length = self.config_df[self.config_df.parameter == "min_length"]["value"].values[0]

        assert context_points is None or target_points is None

        if context_points is None and target_points is None:
            context_points = self.max_length - 1
            target_points = 1

        elif context_points is None:
            context_points = self.max_length - target_points

        else:
            target_points = self.max_length - context_points
                    
        self.seq_length = context_points + target_points
        self.time_col_name = time_col_name
        self.ignore_cols = ignore_cols
        self.context_points = context_points
        self.target_points = target_points
        self.split = split
        self.metadata_cols = ["Filename", "Cntxt_sz"]

        self.val_split = val_split
        self.test_split = test_split
        self.train_split = 1 - val_split - test_split

        assert self.train_split > 0

        self.run_sanity_check()        
        
        self.mapping_df["cumulative_count"] = self.mapping_df["n_samples"].cumsum() - self.mapping_df["n_samples"]

        self.total_chirps = self.mapping_df["n_samples"].sum()        
        self.train_chirps = int(self.total_chirps * self.train_split)
        self.val_chirps = int(self.total_chirps * self.val_split)
        self.test_chirps = int(self.total_chirps * self.test_split)                

        ## Loading the scaler
        self.scaler = joblib.load(f"{root_path}/{prefix}_scaler.pkl")
        self.scaler.set_output(transform = "pandas")
        self.scaler_cols = list(self.scaler.get_feature_names_out())
        
        if not target_cols:
            target_cols = pd.read_feather(
                            os.path.join(self.root_path, self.mapping_df.iloc[0]["Filename"].split("/")[-1])
                        ).columns.tolist()
            if time_col_name in target_cols:
                target_cols.remove(time_col_name)
            
            for col in ignore_cols:
                if col in target_cols:
                    target_cols.remove(col)
        
        self.target_cols = target_cols
        self.split = split
        self.file_id_to_samples = {}

    
    def run_sanity_check(self):
        #reading a single df to make sure the time column is in there.
        df = pd.read_feather(os.path.join(self.root_path, self.mapping_df.iloc[0]["Filename"].split("/")[-1]))
        #assert self.time_col_name in df.columns

        #check that every file in the mapping df actually exists
        for filename in self.mapping_df["Filename"]:
            assert os.path.exists(os.path.join(self.root_path, filename.split("/")[-1]))
        
        #check that the count in the mapping df actually is equal to the number of rows in the file
        for idx, row in self.mapping_df.iterrows():
            df = pd.read_feather(os.path.join(self.root_path, row["Filename"].split("/")[-1]))
            n_samples = row["n_samples"]
            file_id_to_samples = df.groupby("file_id")["chirp_idx"].max().reset_index()
            file_id_to_samples["n_samples"] = file_id_to_samples["chirp_idx"] - self.min_length + 2
            total_samples = file_id_to_samples["n_samples"].sum()
            assert n_samples == total_samples
        
    
    def __len__(self):
        return {
            "train": self.train_chirps,
            "val": self.val_chirps,
            "test": self.test_chirps
        }[self.split]
        
    
    #add one more dimension to the tensor if only 1 dimensional
    def _torch(self, *dfs):
        return tuple(
                torch.from_numpy(x.values).float().unsqueeze(1) if len(x.shape) == 1 
                else torch.from_numpy(x.values).float() for x in dfs
            )
    
    def make_len(self, df, seq_len):
        #pad with rows containing zeros to make df of length seq_len
        if len(df) < seq_len:
            df = pd.concat([pd.DataFrame(np.zeros((seq_len - len(df), len(df.columns))), columns = df.columns), df], axis = 0)

        df[self.time_col_name] = StandardScaler().fit_transform(np.arange(seq_len).reshape(-1,1))
        return df
    
    def get_file_id_to_samples(self, df, filename):
        if filename in self.file_id_to_samples:
            return self.file_id_to_samples[filename]
        else:
            file_id_to_samples = df.groupby("file_id")["chirp_idx"].max().reset_index()
            file_id_to_samples["n_samples"] = file_id_to_samples["chirp_idx"] - self.min_length + 2
            file_id_to_samples["cum_samples"] = file_id_to_samples["n_samples"].cumsum() - file_id_to_samples["n_samples"]
            self.file_id_to_samples[filename] = file_id_to_samples
            return file_id_to_samples
        
    def __getitem__(self, idx):
        if self.split == "val":
            idx += self.train_chirps
        elif self.split == "test":
            idx += self.train_chirps + self.val_chirps

        split_to_use = self.mapping_df[self.mapping_df["cumulative_count"] <= idx].iloc[-1]
        filename = split_to_use["Filename"]
        sample_idx = idx - split_to_use["cumulative_count"]
        df = pd.read_feather(os.path.join(self.root_path, filename.split("/")[-1]))
        file_id_to_samples = self.get_file_id_to_samples(df, filename)
        file_id_to_use_ = file_id_to_samples[file_id_to_samples["cum_samples"] <= sample_idx].iloc[-1]
        file_id_to_use = file_id_to_use_["file_id"]
        chirps_to_use = sample_idx - file_id_to_use_["cum_samples"]
        df_slice = df[df.file_id == file_id_to_use].copy()

        if self.ignore_cols:
            df_slice.drop(columns=self.ignore_cols, inplace=True, errors = 'ignore')

        series_slice = self.make_len(df_slice.iloc[:-chirps_to_use] if chirps_to_use > 0 else df_slice, self.seq_length)

        return series_slice
        
        ctxt_slice, trgt_slice = (
            series_slice.iloc[: self.context_points],
            series_slice.iloc[self.context_points :]
        )

        ctxt_x = ctxt_slice[self.time_col_name]
        trgt_x = trgt_slice[self.time_col_name]

        ctxt_y = ctxt_slice[self.target_cols]
        trgt_y = trgt_slice[self.target_cols]

        return self._torch(ctxt_x, ctxt_y, trgt_x, trgt_y)

    #Function which scales the data back into the original space
    def scale_data(self, data):
        return pd.DataFrame(
                self.scaler.inverse_transform(pd.DataFrame(data, columns = self.scaler_cols), copy = True),
                columns = self.scaler_cols,
        )[[c for c in data.columns if c in self.scaler_cols]]

Here, just as an example, let us set `ignore_cols` to the string `'PrecedingIntrvl'`. This would allow us to the check the functionality of when we are trying to rescale a data point back into the original space which has some of the columns missing inside it. So here, below, which just instantiate a basic BatsCSVDataset. Here, there is actually no need for it to inherit from `torch.utils.data.Dataset`, it can just be its own standalone class.

In [29]:
bats_dataset = BatsCSVDataset('/home/vdesai/data/training_data/daytime/splits/', target_points = 1, context_points = None, ignore_cols = ['PrecedingIntrvl'])

In [30]:
bats_dataset.scale_data(bats_dataset[0])

Unnamed: 0,TimeInFile,CallsPerSec,CallDuration,Fc,HiFreq,LowFreq,Bndwdth,FreqMaxPwr,PrcntMaxAmpDur,TimeFromMaxToFc,...,PreFc500Residue,PreFc1000Residue,PreFc3000,PreFc3000Residue,KneeToFcResidue,Kn-FcCurviness,meanKn-FcCurviness,Kn-FcCurvinessTrndSlp,MinAccpQuality,Max#CallsConsidered
0,2456.85532,10.835649,5.073571,26.423852,43.774069,24.695117,19.078952,29.228546,62.112623,1.513489,...,0.015013,0.045651,-3.901869,0.140392,0.141357,0.071402,0.001445,-0.004776,0.8,32.0
1,2456.85532,10.835649,5.073571,26.423852,43.774069,24.695117,19.078952,29.228546,62.112623,1.513489,...,0.015013,0.045651,-3.901869,0.140392,0.141357,0.071402,0.001445,-0.004776,0.8,32.0
2,2456.85532,10.835649,5.073571,26.423852,43.774069,24.695117,19.078952,29.228546,62.112623,1.513489,...,0.015013,0.045651,-3.901869,0.140392,0.141357,0.071402,0.001445,-0.004776,0.8,32.0
3,2456.85532,10.835649,5.073571,26.423852,43.774069,24.695117,19.078952,29.228546,62.112623,1.513489,...,0.015013,0.045651,-3.901869,0.140392,0.141357,0.071402,0.001445,-0.004776,0.8,32.0
4,2456.85532,10.835649,5.073571,26.423852,43.774069,24.695117,19.078952,29.228546,62.112623,1.513489,...,0.015013,0.045651,-3.901869,0.140392,0.141357,0.071402,0.001445,-0.004776,0.8,32.0
5,2456.85532,10.835649,5.073571,26.423852,43.774069,24.695117,19.078952,29.228546,62.112623,1.513489,...,0.015013,0.045651,-3.901869,0.140392,0.141357,0.071402,0.001445,-0.004776,0.8,32.0
6,2456.85532,10.835649,5.073571,26.423852,43.774069,24.695117,19.078952,29.228546,62.112623,1.513489,...,0.015013,0.045651,-3.901869,0.140392,0.141357,0.071402,0.001445,-0.004776,0.8,32.0
7,2456.85532,10.835649,5.073571,26.423852,43.774069,24.695117,19.078952,29.228546,62.112623,1.513489,...,0.015013,0.045651,-3.901869,0.140392,0.141357,0.071402,0.001445,-0.004776,0.8,32.0
8,2456.85532,10.835649,5.073571,26.423852,43.774069,24.695117,19.078952,29.228546,62.112623,1.513489,...,0.015013,0.045651,-3.901869,0.140392,0.141357,0.071402,0.001445,-0.004776,0.8,32.0
9,2456.85532,10.835649,5.073571,26.423852,43.774069,24.695117,19.078952,29.228546,62.112623,1.513489,...,0.015013,0.045651,-3.901869,0.140392,0.141357,0.071402,0.001445,-0.004776,0.8,32.0


The first few columns of this dataframe are constant values, which is because this dataframe was padded with zeros before it was scaled. To just get the values for the chrips which were originally there in the files, one could try not padding tht