In [1]:
import numpy as np
import pandas as pd
from scipy.signal import fftconvolve
import sys

In [2]:
FSTRAIN = 'data_raw/fly/strains.csv'

TARG_BHV = 'MTN'
TWDWS = [.03, 1, 60]
TARGS = [f'{TARG_BHV}_MN_{twdw}' for twdw in TWDWS]

STRAINS = ['NM91', 'ZH23']
STRAIN_KEY = '_'.join(STRAINS).lower()

MSTRAINS = [(pd.read_csv(FSTRAIN)['STRAIN'] == strain) for strain in STRAINS]
MSTRAIN = np.any(MSTRAINS, axis=0)
ISTRAIN = MSTRAIN.nonzero()[0]

NTRIAL = MSTRAIN.sum()

TAU_S = 10
TAU_P = 20
TH = np.arange(5*TAU_P, dtype=float)
H_S = np.exp(-TH/TAU_S)/20
H_P = np.exp(-TH/TAU_P)/20

FDECIM = .005  # how much of the original data to actually keep

In [3]:
def exp_filt(song):
    s = np.array([char for char in song]) == 'S'
    p = np.array([char for char in song]) == 'P'
    temp = fftconvolve(s, H_S, mode='full')[:len(s)] + fftconvolve(p, H_P, mode='full')[:len(p)]
    return np.concatenate([[0], temp[:-1]])

In [4]:
PFXS = ['clf', 'clf_scrambled', 'rgr', 'rgr_scrambled']
LOOK_BACKS = [100]
# LOOK_BACKS = [100, 1000, 2000, 4000]

In [5]:
columns = ['fmtn', 'session', 'frame', 'song']
df_full = pd.read_csv('data_raw/fly/c_song_f_behav_true.csv')

df_trs = [df_full[df_full.ID == i] for i in ISTRAIN]
del df_full

paths_all = []

for pfx in PFXS:
    for look_back in LOOK_BACKS:
        sys.stdout.write(f'lookback={look_back}')
        data_dicts = []

        for df_tr in df_trs:
            sys.stdout.write('.')

            frames = np.array(df_tr['FRAME']).astype(int)
            song = np.repeat('Q', len(df_tr))

            song[np.array(df_tr['S']) == 1] = 'S'
            song[np.array(df_tr['P']) == 1] = 'P'

            song = ''.join(song)

            if pfx.startswith('clf'):
                fmtn = (exp_filt(song) >= 1/20).astype(int)
            elif pfx.startswith('rgr'):
                fmtn = exp_filt(song)
                fmtn += .1*np.random.randn(len(fmtn))*np.std(fmtn)  # add some wee noise
            if pfx.endswith('scrambled'):
                fmtn = fmtn[np.random.permutation(len(fmtn))]

            for cframe, frame in enumerate(frames):
                song_till_now = song[:cframe]
                if len(song_till_now) < look_back:
                    prefix = ''.join(np.repeat('Q', look_back-len(song_till_now)))
                    song_till_now = prefix+song_till_now
                song_seg = song_till_now[-look_back:]
                data_dict = {
                    'fmtn': fmtn[cframe],
                    'session': np.array(df_tr['ID']).astype(int)[cframe],
                    'frame': frame,
                    'song': song_seg}

                data_dicts.append(data_dict)
                
        # decimate the data dict
        data_dicts_dec = [data_dict for data_dict in data_dicts if np.random.rand() < FDECIM]

        print('')
        df = pd.DataFrame(columns=columns, data=data_dicts_dec)
        path = f'data_s5/fly_mini/{pfx}_lookback_{look_back}_fdecim_{FDECIM}.tsv'
        df.to_csv(path, sep='\t', index=False, header=False)
        
        paths_all.append(path)

lookback=100.......................................................................................
lookback=100.......................................................................................
lookback=100.......................................................................................
lookback=100.......................................................................................


In [6]:
for path in paths_all:
    sys.stdout.write(f'Loading {path}...\n')
    df = pd.read_csv(path, sep='\t', header=None)

    # split into training, val, and test (here val and test are same)
    nrow_train = int(len(df)*.8)
    df_train = df.iloc[:nrow_train, :]
    df_val = df.iloc[nrow_train:, :]
    df_test = df.iloc[nrow_train:, :]

    df_train.to_csv(path[:-4] + '.train.tsv', sep='\t', header=False, index=False)
    df_val.to_csv(path[:-4] + '.eval.tsv', sep='\t', header=False, index=False)
    df_test.to_csv(path[:-4] + '.test.tsv', sep='\t', header=False, index=False)

Loading data_s5/fly_mini/clf_lookback_100_fdecim_0.005.tsv...
Loading data_s5/fly_mini/clf_scrambled_lookback_100_fdecim_0.005.tsv...
Loading data_s5/fly_mini/rgr_lookback_100_fdecim_0.005.tsv...
Loading data_s5/fly_mini/rgr_scrambled_lookback_100_fdecim_0.005.tsv...
