# PyTorch model

In [1]:
import os
import torch
import sys
from torch import nn
# import lightning as L
import pandas as pd
import xarray as xr
import csv
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(parent_dir)

In [2]:
import numpy as np
import matplotlib.pyplot as plt
# from brokenaxes import brokenaxes
import warnings 
import tqdm
warnings.filterwarnings("ignore")

In [3]:
from simspice.data.SproutDataset_NeMg import SproutDataset
from torch.utils.data import DataLoader
from simspice.utils.Augmentation import Augmentation
from simspice.data.Sprout_ML_NeMg import Sprout_ML, interpolate_arrays

In [4]:
import wandb

In [5]:
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.device(0))
torch.cuda.get_device_name(0)

True
4
<torch.cuda.device object at 0x7fb1c0401a90>


'NVIDIA A100 80GB PCIe'

In [6]:
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"

In [7]:
len([766.6 , 766.79,  766.98,  767.17,  767.36,  767.55,  767.74,  767.93,
        768.12,  768.31,  768.5 ,  768.69,  768.88,  769.07,  769.26,
        769.45,  769.64,  769.83,  770.02,  770.21,  770.4 ,  770.59,
        770.78,  770.97,  771.16,  771.35,  771.54,  771.73,  771.92,
        772.11,  772.3 ,  772.49,  772.68,  772.87,  773.06,  773.25,
        773.44,  773.63,  773.82,  774.01])

40

### Write to netcdf format 

In [8]:
def write_files_to_netcdf(directory, netcdf_file, nbr_files=None):
    all_flux = []
    all_mask = []
    all_k = []
    all_l = []
    index = []
    filenames = []
    
    i = 0
    for filename in tqdm.tqdm(os.listdir(directory)[:nbr_files]):
        if os.path.isfile(os.path.join(directory, filename)):
            file = Sprout_ML(directory, filename)
            wvl_common = file.common_wvl
            
            #################################
            padded_spectra, masks = file.pad_flux_array(method='zeros')  # get padded flux arrays and masks
            wvl_arrays_padded = file.pad_wvl_arrays()
            full_spectra = np.vstack(padded_spectra)  # combine spectra into a single array
            full_mask = np.vstack(masks)  # combine masks
            full_wvl = np.hstack(wvl_arrays_padded)

            for k in range(full_spectra.shape[1]):
                for l in range(full_spectra.shape[2]):
                    spec_common_wvl, mask_common_wvl = interpolate_arrays(wvl_common, full_wvl, full_spectra[:, k, l], full_mask[:, k, l])
                    all_flux.append(spec_common_wvl)
                    all_mask.append(mask_common_wvl)
                    all_k.append(k)
                    all_l.append(l)
                    index.append(i)
                    filenames.append(filename)
                    i += 1

    # xarray Dataset
    ds = xr.Dataset(
        {
            "flux": (("index", "wvl"), np.array(all_flux)),
            "mask": (("index", "wvl"), np.array(all_mask)),
            "x-index" : (("index"), all_k),
            "y-index" : (("index"), all_l),
            "filename" : (("index"), filenames),
        },
        coords={"index": index,
                "wvl" : wvl_common, # interp1d array set of wvl every single spectrum is going to be interpolated into.
                }
            )

    # Save to NetCDF
    ds.to_netcdf(netcdf_file)

directory_path = '/d0/tvaresano/SimSPICE/data_L2/'
netcdf_file_path = '/d0/tvaresano/SimSPICE/spectra_train_NeMg.nc'

write_files_to_netcdf(directory_path, netcdf_file_path, 17)

100%|██████████| 17/17 [02:23<00:00,  8.44s/it]


In [11]:
netcdf_file_path = '/d0/tvaresano/SimSPICE/spectra_Feb2023_NeMg.nc'
directory_path = '/d0/tvaresano/SimSPICE/data_L2/Feb2023/'
write_files_to_netcdf(directory_path, netcdf_file_path, 1)

100%|██████████| 1/1 [00:08<00:00,  8.53s/it]
