# Main

This file contains everything you need to run the model. This requires already be in csv format in the "data/clean" directory.

You must run everything in "scripts" before running this file.

In [None]:
import os
import numpy as np
import pandas as pd

## Resample the data and derive velocity

The fpv-uzh and mid-air datasets have different sample rates and formats. This code is meant to resample everything to 10hz (sample time = 0.1sec). This function also removes the rotation data from the fpv-uzh dataset.

The sample rate is can actually be thought of as a hyperparameter for our model. For now, we are following the VECTOR GRU paper.

The model can either be trained on positional or velocity data.

In [None]:
def resample(df: pd.DataFrame, sampling_time: float):
    df = df.sort_values("timestamp")
    new_timestamps = np.arange(df['timestamp'].min(), df['timestamp'].max(), sampling_time)
    
    # Interpolate tx, ty, tz values at the new timestamps
    new_tx = np.interp(new_timestamps, df['timestamp'], df['tx'])
    new_ty = np.interp(new_timestamps, df['timestamp'], df['ty'])
    new_tz = np.interp(new_timestamps, df['timestamp'], df['tz'])
    
    # Create a new DataFrame with the resampled trajectory
    df_resampled = pd.DataFrame({
        'timestamp': new_timestamps,
        'tx': new_tx,
        'ty': new_ty,
        'tz': new_tz
    })
    
    return df_resampled

def velocity(df: pd.DataFrame):
    out = pd.DataFrame(columns=["timestamp", "vx", "vy", "vz"])
    dt = df["timestamp"].diff()

    out["timestamp"] = df["timestamp"]
    out["vx"] = df["tx"].diff() / dt
    out["vy"] = df["ty"].diff() / dt
    out["vz"] = df["tz"].diff() / dt

    return out.iloc[1:]

# resample everything in "data/clean" and save in "resampled"
out_path_pos = "../data/resampled"
out_path_vel = "../data/velocity"
os.makedirs(out_path_pos, exist_ok=True)
os.makedirs(out_path_vel, exist_ok=True)

def walk_midair():
    midair_path = "../data/clean/mid-air"
    for dirpath, dirnames, filenames in os.walk(midair_path):
        for filename in filenames:
            if filename.endswith("csv") or filename.endswith("txt"):
                df = pd.read_csv(os.path.join(dirpath, filename))
                res = resample(df, 0.1)
                vel = velocity(res)
                res.to_csv(os.path.join(out_path_pos, filename))
                vel.to_csv(os.path.join(out_path_vel, filename))
# fpv data has a weird format
def walk_fpv():
    fpv_path = "../data/clean/fpv-uzh"
    for dirpath, dirnames, filenames in os.walk(fpv_path):
        for filename in filenames:
            if filename.endswith("csv") or filename.endswith("txt"):
                df = pd.read_csv(os.path.join(dirpath, filename), 
                         sep=' ',
                         comment="#", 
                         header=None,
                         names=["timestamp", "tx", "ty", "tz", "qx", "qy", "qz", "qw"])
                res = resample(df, 0.1)
                vel = velocity(res)
                res.to_csv(os.path.join(out_path_pos, filename))
                vel.to_csv(os.path.join(out_path_vel, filename))
    
def walk_riotu():
    riotu_path = "../data/clean/riotu-labs"
    for dirpath, dirnames, filenames in os.walk(riotu_path):
        for filename in filenames:
            if filename.endswith("csv") or filename.endswith("txt"):
                df = pd.read_csv(os.path.join(dirpath, filename))
                res = resample(df, 0.1)
                vel = velocity(res)
                res.to_csv(os.path.join(out_path_pos, filename))
                vel.to_csv(os.path.join(out_path_vel, filename))
    
def walk_random_trajectories():
    random_trajectories_path = "../data/clean/random-trajectories"
    for dirpath, dirnames, filenames in os.walk(random_trajectories_path):
        for filename in filenames:
            if filename.endswith("csv") or filename.endswith("txt"):
                df = pd.read_csv(os.path.join(dirpath, filename))
                res = resample(df, 0.1)
                vel = velocity(res)
                res.to_csv(os.path.join(out_path_pos, filename))
                vel.to_csv(os.path.join(out_path_vel, filename))
    
walk_midair()
walk_fpv()
walk_riotu()
walk_random_trajectories()

## Split data

5-fold stratified cross validation. Must be stratified because training data comes from different distributions.