# 5xFAD Resting State Analysis

Analysis of resting state EEG data from 5xFAD mice model.

## Setup and Dependencies

In [None]:
using Pkg
Pkg.activate("../..")
Pkg.status()
include("../../modules/sessionIO/SessionIO.jl")
include("../../../zzz configs/custom/plotify.jl")

In [None]:
using CairoMakie

## Load Sessions

In [None]:
sessions = begin
	datapath = "5xFAD-resting-state-preprocessed"
	files = readdir(datapath)
	[from_hdf5(joinpath(datapath, file)) for file in files]
end

In [None]:
# Display available files
println("Available session files:")
for (i, file) in enumerate(files)
    println("$i. $file")
end

## Select Session

Change the `selected_file_index` below to choose which session to analyze:

In [None]:
# Select which session to analyze (change this index)
selected_file_index = 1
selection = files[selected_file_index]

println("Selected file: $selection")

In [None]:
# Load the selected session
name = splitext(selection)[1]
matching = filter(s -> s.session == name, sessions)
if isempty(matching)
	error("no session found with ID: $selection")
elseif length(matching) > 1
	@warn "Multiple sessions found with ID: $selection"
end
session = first(matching)
println("Session loaded: $(session.session)")
nothing

## Session Information

In [None]:
println("Session ID: $(session.session)")
println("Sampling Rate: $(session.sampling_rate) Hz")
println("Data Shape: $(size(session.raw))")
println("Good Channels: $(session.good_channels)")
println("Number of Good Channels: $(length(session.good_channels))")

### Raw EEG Visualization

In [None]:
# Set sample range for visualization (adjust as needed)
sample_start = 1
sample_end = 1000
sample_range = sample_start:sample_end

println("Visualizing samples $sample_start to $sample_end")

In [None]:
fig = Figure(size = (900, 450))
ax = plotify(
	fig,
	position = [1,1],
	title = "Resting State EEG (RAW)",
	xlabel = "Time (ms)",
	ylabel = "Amplitude"
)

time_ms = (sample_range .- 1) .* 10
	
for channel in 1:size(session.raw)[1]
	lines!(
		ax,
		time_ms,
		session.raw[channel, sample_range],
		color = channel in session.good_channels ? (:darkcyan, 0.7) : (:red, 0.7)
	)
end
fig

### Bandpower Analysis

Implementation of multitaper spectral estimation and bandpower computation:

In [None]:
using DSP, FFTW, Statistics, LinearAlgebra

function dpss_tapers(N, NW, K)
    """Generate DPSS (Discrete Prolate Spheroidal Sequence) tapers"""
    # Create the tridiagonal matrix for DPSS computation
    w = 2π * NW / N
    n = 0:N-1
    
    # Main diagonal
    main_diag = ((N-1)/2 .- n).^2 .* cos(w)
    
    # Off diagonal
    off_diag = n[2:end] .* (N .- n[2:end]) ./ 4
    
    # Create tridiagonal matrix
    A = Tridiagonal(off_diag, main_diag, off_diag)
    
    # Compute eigenvalues and eigenvectors
    λ, V = eigen(A)
    
    # Sort by eigenvalue (descending) and take first K
    idx = sortperm(λ, rev=true)
    tapers = V[:, idx[1:K]]'
    
    return tapers
end

function bandpower_analysis(session; bands=nothing, nw=2)
    if bands === nothing
        bands = Dict(
            "delta" => (1, 4),
            "theta" => (4, 8),
            "alpha" => (8, 12),
            "beta" => (12, 25),
            "gamma" => (25, 50)
        )
    end
    
    fs = session.sampling_rate
    epochs, channels, samples = size(session.data)
    
    # Create frequency vector
    frequencies = rfftfreq(samples, fs)
    
    # Generate DPSS tapers
    K = 2 * nw - 1
    tapers = dpss_tapers(samples, nw, K)
    
    # Initialize arrays
    psd = zeros(epochs, channels, length(frequencies))
    features = zeros(epochs, channels, length(bands))
    
    # Compute bandpower for each epoch and channel
    for epoch in 1:epochs
        for channel in 1:channels
            signal = session.data[epoch, channel, :]
            spectrum_sum = zeros(length(frequencies))
            
            # Apply each taper and compute spectrum
            for k in 1:K
                tapered = signal .* tapers[k, :]
                spectrum = abs.(rfft(tapered)).^2
                spectrum_sum .+= spectrum
            end
            
            # Average across tapers
            psd[epoch, channel, :] = spectrum_sum ./ K
            
            # Compute bandpower for each frequency band
            for (idx, (band_name, (low, high))) in enumerate(bands)
                mask = (frequencies .>= low) .& (frequencies .<= high)
                features[epoch, channel, idx] = sum(psd[epoch, channel, mask])
            end
        end
    end
    
    return features, collect(keys(bands))
end

function logistic_scaler(features)
    scaled = zeros(size(features))
    epochs, channels, bands = size(features)
    
    for channel in 1:channels
        for band in 1:bands
            x = features[:, channel, band]
            q1 = quantile(x, 0.25)
            q3 = quantile(x, 0.75)
            median_val = median(x)
            IQR = q3 - q1
            
            if IQR == 0
                # Fallback to min-max scaling
                scaled[:, channel, band] = (x .- minimum(x)) ./ (maximum(x) - minimum(x))
            else
                λ = (2 * log(3)) / IQR
                scaled[:, channel, band] = 1 ./ (1 .+ exp.(-λ .* (x .- median_val)))
            end
        end
    end
    
    return scaled
end

println("Bandpower analysis functions defined!")

In [None]:
# Check if session has epoched data
if hasfield(typeof(session), :data) && !isnothing(session.data)
    println("Session has epoched data with shape: $(size(session.data))")
    
    # Compute bandpower features
    println("Computing bandpower features...")
    bandpower_features, band_names = bandpower_analysis(session)
    
    # Apply logistic scaling
    println("Applying logistic scaling...")
    scaled_features = logistic_scaler(bandpower_features)
    
    println("Bandpower computation complete!")
    println("Feature shape: $(size(scaled_features))")
    println("Bands: $(band_names)")
else
    println("Session does not have epoched data. Raw data shape: $(size(session.raw))")
    println("You may need to epoch the data first.")
end

### Bandpower Visualization

Visualize the computed bandpower features:

In [None]:
# Only run this cell if bandpower analysis was successful
if @isdefined scaled_features
    # Select epoch to visualize (change this as needed)
    epoch_idx = 1
    
    fig_bp = Figure(size = (1000, 600))
    
    # Only plot good channels
    good_ch_data = scaled_features[epoch_idx, session.good_channels, :]
    
    ax_bp = plotify(
        fig_bp,
        position = [1,1],
        title = "Bandpower Features (Epoch $epoch_idx)",
        xlabel = "Channels",
        ylabel = "Scaled Bandpower"
    )
    
    # Plot each frequency band
    colors = [:red, :orange, :green, :blue, :purple]
    for (i, band) in enumerate(band_names)
        scatter!(
            ax_bp,
            1:length(session.good_channels),
            good_ch_data[:, i],
            color = colors[i],
            label = band,
            markersize = 8
        )
    end
    
    axislegend(ax_bp, position = :rt)
    fig_bp
else
    println("Bandpower features not computed yet. Run the bandpower analysis cell first.")
end