## Copy of : 2024-09-30-wt

In [4]:
using Pkg
using IJulia
using DifferentialEquations
using Plots, CSV, XLSX, DataFrames
using Peaks, Statistics,Roots, NaNStatistics, Random
using StaticArrays
using IterTools
using Printf
using XLSX
using DataFrames

In [13]:
function freeR(rt,at,ks)
    return (rt .- at .- ks .+ sqrt.((at .- rt .- ks).^2 .+ 4*at*ks))/2
end
function freeA(rt,at,ks)
    return (at .- rt .- ks .+ sqrt.((at .- rt .- ks).^2 .+ 4*at*ks))/2
end
function compRA(rt,at,ks)
    return (at .+ rt .+ ks .- sqrt.((at .- rt .- ks).^2 .+ 4*at*ks))/2
end
function coeffI(rt,at,ka,ks,kb,kd)
    return (ks .+ (ks*kd/kb) .+ freeR(rt,at,ks)) ./ (ks .+ (ks*kd/kb) .+ freeR(rt,at,ks)*(ks*kd/(ka*kb)))
end
function coeffJ(rt,at,ka,ks,kb,kd)
    return (ks .+ ka .+ freeR(rt,at,ks)) ./ (ks .+ (ks*kd/kb) .+ freeR(rt,at,ks)*(ks*kd/(ka*kb)))
end
function tranBSD(rt,at,ka,ks,kb,kd)
    return (coeffI(rt,at,ka,ks,kb,kd) .* freeA(rt,at,ks) ./ ka) ./ (1 .+ (coeffI(rt,at,ka,ks,kb,kd) .* freeA(rt,at,ks) ./ ka) .+ (coeffJ(rt,at,ka,ks,kb,kd) .* freeA(rt,at,ks) ./ ka) .* (freeR(rt,at,ks) ./ kb))
end
function rhythm(t,amp,p,ph,lev)
    rate = (amp/2)*cos(2*pi*(t-ph)/p)+lev
    return rate
end
function modelDro!(dp,p,pa,t)
    (a3,b1,b2,b3,AT,Ka,Ks,Kb,Kd) = pa
    dp[1] = tranBSD(p[3],AT,Ka,Ks,Kb,Kd)-b1*p[1]
    dp[2] = p[1]-b2*p[2]-a3*p[2]
    dp[3] = a3*p[2]-b3*p[3]
end
function modelWT!(dp,p,pa,t)
    (a3,b1,b2,b3,AT,Ka,Ks,Kb,Kd,ampD,phD,ampT,phT) = pa
    dp[1] = tranBSD(p[3],AT,Ka,Ks,Kb,Kd)-rhythm(t,ampD,24,phD,b1)p[1]                       # Decay
    dp[2] = rhythm(t,ampT,24,phT,1)*p[1]-b2*p[2]-a3*p[2]                                   # Translation
    dp[3] = a3*p[2]-b3*p[3]
end

# Function to format arrays to 4 decimal places
function format_array(arr)
    return [round(x, digits=4) for x in arr]
end

# Function to compute Coefficient of Variation (CV)
function cv(x)
    return std(x) / mean(x)
end

# Define the plotting function
function create_combined_plot(detWT, timept, wtMper, wtPER, bound, wtMperScale, wtPerScale)
    # Create the first subplot
    p1 = plot([0:0.01:bound], detWT[1, 90001:(90001+bound*100)]/wtMperScale, 
        c=:green, label="Sim", linewidth=3)
    plot!(p1, timept, wtMper, seriestype=:scatter, c=:red, label="Exp (dots)", legend=:bottomleft)
    xlims!(p1, 0, bound)
    ylims!(p1, 0, 1.2)
    title!(p1, "WT_mPer")

    # Create the second subplot
    p2 = plot([0:0.01:bound], detWT[3, 90001:(90001+bound*100)]/wtPerScale, 
        c=:green, label="Sim", linewidth=3)
    plot!(p2, timept, wtPER, seriestype=:scatter, c=:red, label="Exp (dots)", legend=:bottomleft)
    xlims!(p2, 0, bound)
    ylims!(p2, 0, 1.2)
    title!(p2, "WT_PER")

    # Combine the plots
    return plot(p1, p2, layout = (2, 1), size = (1200, 1200))
end

function find_significant_extrema(data, window_size=100)
    # Find all peaks and troughs
    peaks, _ = findmaxima(data)
    troughs, _ = findminima(data)
    # Filter significant peaks and troughs based on the surrounding window
    significant_peaks = filter(p -> is_significant_extremum(data, p, window_size, >), peaks)
    significant_troughs = filter(t -> is_significant_extremum(data, t, window_size, <), troughs)
    return significant_peaks, significant_troughs
end
function is_significant_extremum(data, index, window_size, compare_func)
    start_idx = max(1, index - window_size)
    end_idx = min(length(data), index + window_size)
    # Check if the current extremum is the most extreme value in the window
    return all(i -> i == index || compare_func(data[index], data[i]), start_idx:end_idx)
end
function analyze_oscillations(time, data, n_cycles=10)
    # Use find_significant_extrema to get only significant peaks and troughs
    significant_peaks, significant_troughs = find_significant_extrema(data)
    # Combine and sort all significant extrema chronologically
    extrema = sort(vcat(significant_peaks, significant_troughs))
    # Identify whether each extremum is a peak or trough
    is_peak = [i in significant_peaks for i in extrema]
    # Initialize arrays
    peak_values = Float64[]
    trough_values = Float64[]
    peak_times = Float64[]
    trough_times = Float64[]
    amplitudes = Float64[]
    periods = Float64[]
    # Track the last peak for period calculation
    last_peak_time = nothing
    # Analyze cycles
    cycle_count = 0
    for (i, idx) in enumerate(extrema)
        if cycle_count >= n_cycles
            break
        end
        if is_peak[i]
            push!(peak_values, data[idx])
            push!(peak_times, time[idx]) 
            if !isempty(trough_values)
                # Calculate amplitude
                push!(amplitudes, data[idx] - trough_values[end])
            end          
            if !isnothing(last_peak_time)
                # Calculate period
                push!(periods, time[idx] - last_peak_time)
                cycle_count += 1
            end
            last_peak_time = time[idx]
        else
            push!(trough_values, data[idx])
            push!(trough_times, time[idx])
        end
    end
    return peak_values, trough_values, peak_times, trough_times, amplitudes, periods
end

function calculate_pcm_scaling_factor(paInitPcm)
    a3, b1, b2, b3, AT, Ka, Ks, Kb, Kd = paInitPcm

    time_range = 0:0.01:300
    bound = 300

    # Solve the ODE problem
    prob = ODEProblem(modelDro!, zeros(3), (0.0, 1200.0), paInitPcm)
    detDro = solve(prob, saveat=0.01, Rosenbrock23())
        
    # Use views instead of slices
    dp1_solution_pcm = @view detDro[1, 90001:(90001+bound*100)]
    dp3_solution_pcm = @view detDro[3, 90001:(90001+bound*100)]


    # Find peaks and troughs in one pass
    peaks_dp1, troughs_dp1 = find_significant_extrema(dp1_solution_pcm)
    peaks_dp3, troughs_dp3 = find_significant_extrema(dp3_solution_pcm)

    dp1_cycle_num = length(peaks_dp1)
    dp3_cycle_num = length(peaks_dp3)

    # Analyze oscillations (1st)
    peaks_dp1, troughs_dp1, peak_times_dp1, trough_times_dp1, amplitudes_dp1, periods_dp1 = 
        analyze_oscillations(time_range, dp1_solution_pcm, dp1_cycle_num)
    peaks_dp3, troughs_dp3, peak_times_dp3, trough_times_dp3, amplitudes_dp3, periods_dp3 = 
        analyze_oscillations(time_range, dp3_solution_pcm, dp3_cycle_num)

    # Find scaling factor 
    dp1_div_factor = !isempty(peaks_dp1) ? maximum(peaks_dp1) : 1.0
    dp3_div_factor = !isempty(peaks_dp3) ? maximum(peaks_dp3) : 1.0

    return dp1_div_factor, dp3_div_factor
end

global saved_count = 0

function calculate_amp_CV_period_CV(paInit, wtMper, wtPER)
    global saved_count

    a3, b1, b2, b3, AT, Ka, Ks, Kb, Kd, ampD, phD, ampT, phT = paInit

    time_range = 0:0.01:300
    bound = 300
    timept = collect(0:4:bound)
    
    # Solve the ODE problem
    prob = ODEProblem(modelWT!, zeros(3), (0.0, 1200.0), paInit)
    detWT = solve(prob, saveat=0.01, Rosenbrock23())

    # Calculate pcm_scaling_factor
    wtMperScale, wtPerScale = calculate_pcm_scaling_factor(paInit[1:9])
    detWT[1] ./ wtMperScale
    detWT[3] ./ wtPerScale
    
    # Use views instead of slices
    dp1_solution = @view detWT[1, 90001:(90001+bound*100)]
    dp3_solution = @view detWT[3, 90001:(90001+bound*100)]

    # Find number of peaks and troughs in one pass
    peaks_dp1, troughs_dp1 = find_significant_extrema(dp1_solution)
    peaks_dp3, troughs_dp3 = find_significant_extrema(dp3_solution)
    dp1_cycle_num = length(peaks_dp1)
    dp3_cycle_num = length(peaks_dp3)
    
    # Analyze oscillations
    peaks_dp1, troughs_dp1, peak_times_dp1, trough_times_dp1, amplitudes_dp1, periods_dp1 = 
        analyze_oscillations(time_range, dp1_solution, dp1_cycle_num)
    peaks_dp3, troughs_dp3, peak_times_dp3, trough_times_dp3, amplitudes_dp3, periods_dp3 = 
        analyze_oscillations(time_range, dp3_solution, dp3_cycle_num)

    # Check if lengths match before division
    len_dp1 = min(length(amplitudes_dp1), length(peaks_dp1))
    len_dp3 = min(length(amplitudes_dp3), length(peaks_dp3))
    
    # Truncate arrays to the minimum length
    # Compute relamp using vectorized operations
    relamp_1 = amplitudes_dp1[1:len_dp1] ./ peaks_dp1[1:len_dp1]
    relamp_3 = amplitudes_dp3[1:len_dp3] ./ peaks_dp3[1:len_dp3]

    periods_wtMper = format_array(periods_dp1) 
    periods_wtPER = format_array(periods_dp3)

    # Calculate the period cost
    cost_periods_wtMper = sqrt(sum((1 .- periods_wtMper ./ 24).^2))    
    cost_periods_wtPER = sqrt(sum((1 .- periods_wtPER ./ 24).^2))

    cv_relamp_1 = cv(relamp_1)
    cv_relamp_3 = cv(relamp_3)
    cv_periods_wtMper = cv(periods_wtMper)
    cv_periods_wtPER = cv(periods_wtPER)
    
    # Check if the various costs are less than their threshold values
    if cv_relamp_1 < 0.1 && cv_relamp_3 < 0.1 && cv_periods_wtMper < 0.1 && cv_periods_wtPER < 0.1 && cost_periods_wtMper < 0.1 && cost_periods_wtPER < 0.1         
        # Prepare data to be saved
        data = [a3, b1, b2, b3, AT, Ka, Ks, Kb, Kd, ampD, phD, ampT, phT, 
                cv_relamp_1, cv_relamp_3, cv_periods_wtMper, cv_periods_wtPER, 
                cost_periods_wtMper, cost_periods_wtPER]
        
        # Printing the parameters with required formatting
        println("WT param: ",
            "a3 = ", round(a3, digits=4), ", ",
            "b1 = ", round(b1, digits=4), ", ",
            "b2 = ", round(b2, digits=4), ", ",
            "b3 = ", round(b3, digits=4), ", ",
            "AT = ", round(AT, digits=4), ", ",
            "Ka = ", round(Ka, digits=4), ", ",
            "Ks = ", round(Ks, digits=4), ", ",
            "Kb = ", round(Kb, digits=4), ", ",
            "Kd = ", round(Kd, digits=4), ", ",
            "ampD = ", round(ampD, digits=4), ", ",
            "phD = ", round(phD, digits=4), ", ",
            "ampT = ", round(ampT, digits=4), ", ",
            "phT = ", round(phT, digits=4)
        )
        
        # Printing the cost-related parameters with required formatting
        println("Costs: ",
            "cv(relamp_1) = ", round(cv_relamp_1, digits=4), ", ",
            "cv(relamp_3) = ", round(cv_relamp_3, digits=4), ", ",
            "cv(periods_wtMper) = ", round(cv_periods_wtMper, digits=4), ", ",
            "cv(periods_wtPER) = ", round(cv_periods_wtPER, digits=4), ", ",
            "cost_periods_wtMper = ", round(cost_periods_wtMper, digits=4), ", ",
            "cost_periods_wtPER = ", round(cost_periods_wtPER, digits=4), "\n"
        )
                
        # Create a folder to save the images
        output_folder = "graphs_output_20241002_folder_1"
        if !isdir(output_folder)
            mkdir(output_folder)
        end
        
        # filename 제작 및 figure를 저장. 
        filename = @sprintf("ampD=%.4f_phD=%.4f_ampT=%.4f_phT=%.4f.png", 
            ampD, phD, ampT, phT)

        combined_plot = create_combined_plot(detWT, timept, wtMper, wtPER, bound, wtMperScale, wtPerScale)     
        filepath = joinpath(output_folder, filename)
        savefig(combined_plot, filepath)

            
        # WT 파라미터 및 기타 측정한 정보를 Excel 파일에 저장하는 부분. 
        """
        # Define column names
        column_names = ["a3", "b1", "b2", "b3", "AT", "Ka", "Ks", "Kb", "Kd", "ampD", "phD", "ampT", "phT", 
                        "cv(relamp_1)", "cv(relamp_3)", "cv(periods_wtMper)", "cv(periods_wtPER)", 
                        "cost_periods_wtMper", "cost_periods_wtPER"]
        
        df = DataFrame(Dict(zip(column_names, data)))

        # Define file name
        filename = "oscillation_results.xlsx"

        # Check if file exists
        if isfile(filename)
            # If file exists, read existing data
            existing_df = DataFrame(XLSX.readdata(filename, "Sheet1"))
            # Append new data to existing data
            new_df = vcat(existing_df, df)
            # Write the combined data back to the file
            XLSX.writetable(filename, new_df, overwrite=true)
        else
            # If file doesn't exist, create it and write the data
            XLSX.writetable(filename, df)
        end
        """
        
        # Increment saved count and print progress
        saved_count += 1
        if saved_count % 50 == 0
            println("Progress: $saved_count permutations saved.")
        end
    end
end


calculate_amp_CV_period_CV (generic function with 2 methods)

## main()

In [15]:
timept = [0, 4, 8, 12, 16, 20]
wtMper = [34.845, 26.06667, 26.09167, 49.225, 55.625, 64.27083]
pcmMper = [54.59242, 57.97024, 39.78646, 37.68333, 38.33725, 47.00128]
normMper = maximum(pcmMper)
wtMper = wtMper/normMper
pcmMper = pcmMper/normMper

wtPER = [788.8639, 558.5368, 401.4341, 330.7683, 308.5475, 336.6191]
pcmPER = [1180.3421, 1028.4309, 1521.4779, 1305.6787, 895.0438, 821.742]
normPER = maximum(pcmPER)
wtPER = wtPER/normPER
pcmPER = pcmPER/normPER

# Define constants
b1 = 0.022480368

# Define ranges for parameters

ampD = [2 * i * b1 / 10 for i in 1:10]
phD = [0, 3, 6, 9, 12, 15, 18, 21]
ampT = [0.2 * i for i in 1:10]
phT = [0, 3, 6, 9, 12, 15, 18, 21]

""" 
# test
ampD = [0.001] 
phD = [0]
ampT = [0.001]
phT = [0]
"""

# Use Iterators.product for lazy evaluation of permutations
permutations = Iterators.product(ampD, phD, ampT, phT)

# Calculate the total number of permutations without creating a full list
total_permutations = length(ampD) * length(phD) * length(ampT) * length(phT)
println("Total permutations to attempt: $total_permutations")

# Iterate through permutations without storing them in memory
i = 0
for perm in permutations
    ampD, phD, ampT, phT = perm
    pcmPa = [1.2591611472941155, 0.022480368, 0.12111630480870345, 0.9419074673224994, 
             0.094151061, 5.1853665988298946e-5, 3.0062548853823446e-9, 7.670935885911712e-5, 
             20.62876347705542, ampD, phD, ampT, phT]

    # Perform the calculation
    calculate_amp_CV_period_CV(pcmPa, wtMper, wtPER)

    # Track progress and print every 300 iterations
    i += 1
    if i % 300 == 0
        percent_complete = round(i / total_permutations * 100, digits=2)
        println("Progress: $i / $total_permutations total permutations attempted ($(percent_complete)% completed)")
    end
end



Total permutations to attempt: 6400
Progress: 300 / 6400 total permutations attempted (4.69% completed)
Progress: 600 / 6400 total permutations attempted (9.38% completed)
WT param: a3 = 1.2592, b1 = 0.0225, b2 = 0.1211, b3 = 0.9419, AT = 0.0942, Ka = 0.0001, Ks = 0.0, Kb = 0.0001, Kd = 20.6288, ampD = 0.0045, phD = 0.0, ampT = 2.0, phT = 0.0
Costs: cv(relamp_1) = 0.0045, cv(relamp_3) = 0.0, cv(periods_wtMper) = 0.0007, cv(periods_wtPER) = 0.0001, cost_periods_wtMper = 0.0024, cost_periods_wtPER = 0.0004
WT param: a3 = 1.2592, b1 = 0.0225, b2 = 0.1211, b3 = 0.9419, AT = 0.0942, Ka = 0.0001, Ks = 0.0, Kb = 0.0001, Kd = 20.6288, ampD = 0.009, phD = 0.0, ampT = 2.0, phT = 0.0
Costs: cv(relamp_1) = 0.0038, cv(relamp_3) = 0.0001, cv(periods_wtMper) = 0.0004, cv(periods_wtPER) = 0.0002, cost_periods_wtMper = 0.0013, cost_periods_wtPER = 0.0006
WT param: a3 = 1.2592, b1 = 0.0225, b2 = 0.1211, b3 = 0.9419, AT = 0.0942, Ka = 0.0001, Ks = 0.0, Kb = 0.0001, Kd = 20.6288, ampD = 0.0135, phD = 0.0,

LoadError: InterruptException: