In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense, Dropout
from sklearn.model_selection import train_test_split
import numpy as np
import pickle as pkl
from src.utils.data_transform import *
import pandas as pd

In [None]:
def load_data(file_path):
    data = pd.read_pickle(file_path)
    signal_data = np.array([item[0] for item in data])
    label_data = np.array([item[1] for item in data])
    return signal_data, label_data

In [None]:
def split_data(signal_data, label_data, test_size=0.2):
    # Convert to numpy arrays for compatibility with train_test_split
    signal_data_np = np.array(signal_data)
    label_data_np = np.array(label_data)
    # Split data into training and validation sets
    return train_test_split(signal_data_np, label_data_np, test_size=test_size, random_state=42)

In [None]:
def build_model(input_shape):
    model = Sequential()
    model.add(Conv1D(filters=64, kernel_size=10, activation='relu', input_shape=input_shape, padding='same'))
    model.add(MaxPooling1D(pool_size=2))
    model.add(Conv1D(filters=128, kernel_size=10, activation='relu', padding='same'))
    model.add(MaxPooling1D(pool_size=2))
    model.add(Flatten())
    model.add(Dropout(0.5))
    model.add(Dense(5, activation='softmax'))  # Assuming 5 classes for the output layer

    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

In [None]:
# Load the data
file_path = '../data/processed.nosync/all_final/all_final.pkl'  # Replace with your file path
signals, labels = load_data(file_path)

In [None]:
X_train, X_val, y_train, y_val = split_data(signals, labels)


In [None]:
# y_train_cat = to_categorical(y_train)
# y_val_cat = to_categorical(y_val)

In [None]:
print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
print(f"X_val shape: {X_val.shape}, y_val shape: {y_val.shape}")


In [None]:
model = build_model(input_shape=(20, 6))

In [None]:
history = model.fit(X_train, y_t4rain, validation_data=(X_val, y_val), epochs=32, batch_size=32)