In [3]:
import os, sys
import pandas as pd
import numpy as np
from statsmodels.tsa.seasonal import seasonal_decompose
import jax.numpy as jnp

# current_dir = os.path.dirname(os.path.abspath(__file__))
# parent_dir = os.path.dirname(current_dir)
# sys.path.append(parent_dir)
from database import db_operations


#TODO: THE RETURNED DATATYPE IS BEING MIXED AROUND, DETERMINE WHAT DATATYPE I WANT IT TO END AS AND ENFORCE IT WITH ->

class DataProcessed:
    def __init__(self):
        self.data = None
        self.shift_dict = {}
        self.y = None
        self.X = None
        self.ydates = None
        self.Xdates = None
        self.db = db_operations.Database(
            project_id=os.getenv("PROJECT_ID"),
            region=os.getenv("REGION"),
            instance_name=os.getenv("INSTANCE_NAME"),
            db_user=os.getenv("DB_USER"),
            db_pass=os.getenv("DB_PASS"),
            db_name=os.getenv("DB_NAME")
        )

    @staticmethod
    def date_to_unix(date_array):
        timestamps = np.array([int(datetime.fromisoformat(date_string.rstrip('Z')).timestamp()) for date_string in date_array])
        return timestamps
    
    def pull(self):

        y = self.db.pull_table("ao_pdo_enso.mytable")
        self.y = np.array([line[2:] for line in y], dtype=np.float32)
        #self.ydates = np.array([line[1] for line in y])
        self.ydates = self.date_to_unix([line[1] for line in y])
                
        x = self.db.pull_table("ao_pdo_enso.climate_indices")
        self.X = np.array([line[1:] for line in x], dtype=np.float32)
        #self.Xdates = np.array([line[0] for line in x])
        self.Xdates = self.date_to_unix([line[0].strftime('%Y-%m-%d 00:00:00+00:00') for line in x])

        #self.data = np.vstack()
        return None
    
    def auto_corr(self, target, feature, steps = 36):
        df = self.data[[target, feature]].copy()
        corrs = []
        
        for i in range(-steps, steps + 1):
            shifted_df = df.copy()
            shifted_df[feature] = shifted_df[feature].shift(periods = i)
            correlation = shifted_df.dropna().corr().iloc[0,1]
            corrs.append((i, correlation))
            
        return corrs
    
    def remove_seasonality(self, target_only=True, resid=False):
        target_cols = [col for col in self.data.columns if 'SNOTEL' in col]
        
        if target_only == False:
            target_cols = self.data.columns
            
        df = self.data
        df = df.set_index(pd.date_range('1985-01-01', '2024-08-01', freq='MS'))
        for i in target_cols:
            if resid == False:
                x = df[i].dropna()
                #ind = pd.date_range('1985-01-01',periods=len(x) , freq='MS')
                hold = seasonal_decompose(x, period=12)
                df[i] = hold.trend
            else:
                df[i] = df[i].dropna()
                hold = seasonal_decompose(df[i], period = 12)
                df[i] = hold.trend + hold.resid
            
            
        self.data = df
        
        return None
    
    def shift_for_correlation(self, target_column, steps=48):
        features = [col for col in self.data.columns if col != target_column]
        for feat in features:
            corrs = [(i, self.data[[feat, target_column]].shift(i).dropna().corr().iloc[0, 1]) for i in range(-steps, steps)]
            max_shift = max(corrs, key=lambda x: abs(x[1]))[0]
            self.data[feat] = self.data.loc[:, feat].shift(max_shift)
            self.shift_dict[feat] = max_shift
        return None

        
    def single_target(self, target):
        target_cols = [col for col in self.data.columns if 'SNOTEL' not in col]
        x1 = self.data.loc[:, target_cols]
        x2 = self.data.iloc[:, target]
        
        self.data = x1.join(x2)
        
        return None
        
    def drop_na(self):
        self.data = self.data.dropna()
        
        return None
    
    def seperate_y(self):
        target_cols = [col for col in self.data.columns if 'SNOTEL' in col]
        feature_cols = [col for col in self.data.columns if not 'SNOTEL' in col]


        self.y = self.data[target_cols]
        self.X = self.data[feature_cols]

        
        return None

In [4]:
x = DataProcessed()

In [5]:
x.pull()

Created engine: Engine(mysql+pymysql://)


In [10]:
x.X.shape

(512, 10)

In [9]:
x.y.shape

(476, 34)

In [17]:
np.full((36, 34), np.nan)

array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan]])

In [24]:
x1=np.r_[np.full((36,34), np.nan), x.y]
pd.DataFrame(x1)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,24,25,26,27,28,29,30,31,32,33
0,,,,,,,,,,,...,,,,,,,,,,
1,,,,,,,,,,,...,,,,,,,,,,
2,,,,,,,,,,,...,,,,,,,,,,
3,,,,,,,,,,,...,,,,,,,,,,
4,,,,,,,,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
507,,,0.0,0.0,,0.0,,,,,...,,,,,0.0,0.0,,,,
508,,,0.0,0.0,,,,,,,...,,,,,0.0,0.0,,,,
509,,,,,,,,,,,...,,,,,0.0,,,,,
510,,,,,,,,,,,...,,,,,0.0,,,,,


In [23]:
np.c_[x.X, x1]

array([[24.28000069, -0.23999999, 25.84000015, ...,         nan,
                nan,         nan],
       [25.37999916, -0.72000003, 26.26000023, ...,         nan,
                nan,         nan],
       [25.21999931, -1.38      , 26.92000008, ...,         nan,
                nan,         nan],
       ...,
       [22.52000046, -0.64999998, 26.51000023, ...,         nan,
                nan,         nan],
       [21.42000008, -0.41      , 25.79000092, ...,         nan,
                nan,         nan],
       [20.52000046, -0.34      , 24.96999931, ...,         nan,
                nan,         nan]])

In [22]:
print(x1.shape)
print(x.X.shape)

(512, 34)
(512, 10)


In [8]:
np.r_[x.X, x.y]

ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 10 and the array at index 1 has size 34

In [33]:
import jax

# Warm-up JIT compilation
key = jax.random.PRNGKey(0)
_ = jax.random.normal(key, shape=(1000, 1000))

# NumPy timing
%timeit np.random.normal(size=(1000, 1000))

# JAX timing with new key each time and blocking
%timeit jax.random.normal(jax.random.PRNGKey(0), shape=(100, 5)).block_until_ready()

10.1 ms ± 79.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
521 µs ± 68.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
