In [1]:
#imports
import sys #you need this to add path to utils folder
sys.path.append('../utils') #now you can import from utils folder
# sys.path.append("../..")
from configs import config_general
from datasharders import datasharder_imu_joints
from datasets import ImuJointPairDataset
from data_utils import fft_filter_signal, wavelet_filter_signal
from model_utils import plot_predictions, plot_loss, train_model, evaluate_model
import torch
from torch.utils.data import DataLoader, random_split 
import matplotlib.pyplot as plt
from LSTM import DeepLSTMModel
import torch.nn as nn
import seaborn as sns

In [2]:
#config definition
#some config errors

# Initialize config object
config = config_general(
    batch_size=64,
    epochs=1000,
    lr=0.002,
    scheduler=None,
    num_channels_imu=3,
    num_channels_joints=3,
    num_sessions=1,
    num_patients=3,
    seed=42,
    data_folder_name="../../datacollection/data",
    dataset_root="../../datasets",
    dataset_train_name="train_dataset",
    dataset_test_name="test_dataset",
    window_length=100,
    imu_transforms=[fft_filter_signal],
    joint_transforms=[], 
    hidden_size=256,
    num_layers=6,
    input_size=3,
    output_size=3
)


In [3]:
#main function that runs everything #TESTING
def main(config, remake_dataset):
    
    
    #device definition?
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = DeepLSTMModel(input_size=config.input_size, hidden_size=config.hidden_size, num_layers=config.num_layers, output_size=config.output_size).to(device)
    criterion = nn.MSELoss() #changes depending on model
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) #changes depending on model
   
    
    #remake_dataset false if no changes made, true if changes need to be made 
    #remake_dataset = True
    
    if(remake_dataset):
    
        #call datasharder on config 
        datasharder = datasharder_imu_joints(config, sample_rate=16000)
    
        #access data from datasharder
        training_data, testing_data = datasharder.load_data()

    
        #save windowed data to config specified destination
        datasharder.save_windowed_data(training_data, "train")
        datasharder.save_windowed_data(testing_data, "test")
    
   #split dataset into train and test
    dataset_train = ImuJointPairDataset(config, "train")
    dataset_test = ImuJointPairDataset(config, "test")
    #print(len(dataset_train)) 
    #print(len(dataset_test))
    #STOP 
    
    
    imu_data_tensor,joint_data_tensor=dataset_train.__getitem__(1)

    print(imu_data_tensor.shape, joint_data_tensor.shape)
    
    #convert to numpy
    imu_data = imu_data_tensor.numpy()
    joint_data = joint_data_tensor.numpy()
    sns.set(style="ticks")
    plt.plot(imu_data[:,1], label="imu")
    plt.plot(joint_data[:,1], label="joint")
    plt.legend()
    plt.show()
    
    train_size = int(0.75 * len(dataset_train))
    val_size = int(0.25 * len(dataset_train))
    
    train_dataset, val_dataset = random_split(dataset_train, [train_size, val_size])

    test_dataset = dataset_test
    
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
    
    #train model
    train_losses, val_losses = train_model(model, train_loader, val_loader, config.epochs, criterion, optimizer, device)
    
    # Plot training and validation loss history
    plot_loss(train_losses, val_losses)
    
    # Evaluate model on validation and test sets and plot predictions
    val_inputs, val_targets, val_predictions = evaluate_model(model, val_loader, device)
    plot_predictions(val_inputs.squeeze(), val_targets.squeeze(), val_predictions.squeeze(), num_channels=3)

    test_inputs, test_targets, test_predictions = evaluate_model(model, test_loader, device)
    plot_predictions(test_inputs.squeeze(), test_targets.squeeze(), test_predictions.squeeze(), num_channels=3)

In [None]:
#use main function accessing imu and joint data files directly
main(config, False)