In [None]:
!pip install numpy==1.26.4 scikit-learn==1.5.3
!pip install -U tsai
!pip install fastparquet

In [None]:
import os
import re
import pandas as pd
from tsai.all import *
from fastai.callback.wandb import WandbCallback
from fastai.callback.tracker import EarlyStoppingCallback
from fastai.callback.fp16 import MixedPrecision
from sklearn.metrics import confusion_matrix
import wandb

In [None]:
EPOCHS = 300
PATIENCE = 5
LR = 1e-3

wandb.login(key=API_WANDB)
wandb.init(
    project="ts_classification",
    name="TSPerceiver_full_pipeline",
    entity="nasa-public_static_void_frogs",
    config={
        "epochs": EPOCHS,
        "lr": LR,
        "model": "TSPerceiver"
    }
)

In [None]:
def sort_num(el):
    num = re.search(r"version_02_(\w+)_(.+)\..+", el)
    part = num.group(2)
    if part == 'final':
        return float('inf')
    else:
        return int(part)
FOLDERS = ['/kaggle/input/nasa-cooked/init_df/', '/kaggle/input/nasa-cooked/init_df_not_in_koi/']
START_INDEX = 0
END_INDEX = 10

In [None]:
parquets = []
csvs = []
for f in FOLDERS:
    l = os.listdir(f)
    print(l)
    parquets.append([el for el in l if el.endswith('.parquet')])
    csvs.append([el for el in l if el.endswith('.csv')])
    
    parquets[-1].sort(key=sort_num)
    csvs[-1].sort(key=sort_num)

In [None]:
dfs_to_concat = []
for f, csv_list, parquet_list in zip(FOLDERS, csvs, parquets):
    print(f'using folder {f}')
    for pair in zip(csv_list[START_INDEX:END_INDEX], parquet_list[START_INDEX:END_INDEX]):
        print(f'concating pair {pair}')
        df_values_loaded = pd.read_parquet(f + pair[1], engine="fastparquet")
        df_values_loaded.columns = df_values_loaded.columns.astype(int)
        df_ids_loaded = pd.read_csv(f + pair[0])
        print(f'{df_values_loaded.shape=}, {df_ids_loaded.shape=}')
        if df_values_loaded.shape[0] > df_values_loaded.shape[1]: #rows more than columns, saved correctly, needs transpose
            print('concating with transpose')
            df_part = pd.concat([df_ids_loaded, df_values_loaded.T.reset_index(drop=True)], axis=1)
        else: #columns more than rows, saved incorrectly, doesn't need transpose
            print('concating without transpose')
            df_part = pd.concat([df_ids_loaded, df_values_loaded], axis=1)
        print(f'{df_part.shape=}')
        df_part.set_index(['KEPID', 'PLANET_NUM'], inplace=True)
        dfs_to_concat.append(df_part)

In [None]:
full_df = pd.concat(dfs_to_concat, axis=0)

In [None]:
nan_counts = full_df.isna().sum()
nan_counts = nan_counts[nan_counts > 0].sort_values(ascending=False)

plt.figure(figsize=(10, 5))
plt.bar(nan_counts.index, nan_counts.values)
plt.xticks(rotation=45, ha='right')
plt.show()

In [None]:
threshold = 0.6  

nan_ratio = full_df.isna().mean()
cols_to_keep = nan_ratio[nan_ratio < threshold].index
full_df = full_df[cols_to_keep]

full_df = full_df.apply(pd.to_numeric, errors='coerce')
full_df = full_df.reset_index()
full_df = full_df.interpolate(method='polynomial', order=3, limit_direction='both')
full_df = full_df.fillna(method='bfill').fillna(method='ffill')


In [None]:
x = full_df.drop(columns="LABEL")
y = full_df["LABEL"]
splits = RandomSplitter(valid_pct=0.2)(range_of(y))

In [None]:
dls = get_ts_dls(
    x, y, splits=splits,
    tfms=[None, TSClassification()],
    batch_tfms=[TSStandardize(by_sample=True)],
    path='.',
    bs=2
)

In [None]:
seq_len = x.shape[1]
c_in = x.shape[2] if x.ndim==3 else 1
c_out = len(set(y))

learn = ts_learner(
    dls,
    TSPerceiver,
    metrics=accuracy,
    cbs=[
        ShowGraph(),
        WandbCallback(log_model=True),
        EarlyStoppingCallback(monitor='accuracy', patience=5),
        MixedPrecision()
    ],
    arch_config={'seq_len': seq_len}  
)


In [None]:
learn.fit_one_cycle(EPOCHS, LR)

In [None]:
preds, targs = learn.get_preds()
acc = accuracy(preds, targs)
wandb.log({"final_accuracy": acc})

cm = confusion_matrix(targs, preds.argmax(1))

wandb.log({
    "confusion_matrix": wandb.plot.confusion_matrix(
        probs=None,
        y_true=targs,
        preds=preds.argmax(1),
        class_names=list(set(targs.numpy()))
    )
})
wandb.finish()