-
Notifications
You must be signed in to change notification settings - Fork 7
/
plot_designmatrix.jl
118 lines (100 loc) · 3.7 KB
/
plot_designmatrix.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
plot_designmatrix!(f::Union{GridPosition, GridLayout, Figure}, data::Unfold.DesignMatrix; kwargs...)
plot_designmatrix(data::Unfold.DesignMatrix; kwargs...)
Plot a designmatrix.
## Arguments
- `f::Union{GridPosition, GridLayout, Figure}`\\
`Figure`, `GridLayout`, or `GridPosition` to draw the plot.
- `data::Unfold.DesignMatrix`\\
Data for the plot visualization.
## Keyword argumets (kwargs)
- `standardize_data::Bool = true`\\
Indicates whether the data is standardized by pointwise division of the data with its sampled standard deviation.
- `sort_data::Bool = true`\\
Indicates whether the data is sorted. It uses `sortslices()` of Base Julia.
- `xticks::Num = nothing`\\
Returns the number of labels on the x axis.
- `xticks` = 0: no labels are placed.
- `xticks` = 1: first possible label is placed.
- `xticks` = 2: first and last possible labels are placed.
- 2 < `xticks` < `number of labels`: equally distribute the labels.
- `xticks` ≥ `number of labels`: all labels are placed.
$(_docstring(:designmat))
**Return Value:** `Figure` displaying the Design matrix.
"""
plot_designmatrix(
data::Union{<:Vector{<:AbstractDesignMatrix},<:AbstractDesignMatrix};
kwargs...,
) = plot_designmatrix!(Figure(), data; kwargs...)
function plot_designmatrix!(f, data::Vector{<:AbstractDesignMatrix}; kwargs...)
if length(data) > 1
@warn "multiple $(length(data)) designmatrices found, plotting the first one"
end
plot_designmatrix!(f, data[1]; kwargs...)
end
function plot_designmatrix!(
f::Union{GridPosition,GridLayout,Figure},
data::AbstractDesignMatrix;
xticks = nothing,
sort_data = false,
standardize_data = false,
kwargs...,
)
config = PlotConfig(:designmat)
config_kwargs!(config; kwargs...)
designmat = UnfoldMakie.modelmatrices(data)
if standardize_data
designmat = designmat ./ std(designmat, dims = 1)
designmat[isinf.(designmat)] .= 1.0
end
if isa(designmat, SparseMatrixCSC)
if sort_data
@warn "Sorting does not make sense for time-expanded designmatrices. sort_data has been set to `false`"
sort_data = false
end
designmat = Matrix(designmat[end÷2-2000:end÷2+2000, :])
end
if sort_data
designmat = Base.sortslices(designmat, dims = 1)
end
labels = Unfold.get_coefnames(data)
lLength = length(labels)
# only change xticks if we want less then all
if (xticks !== nothing && xticks < lLength)
@assert(xticks >= 0, "xticks shouldn't be negative")
# sections between xticks
sectionSize = (lLength - 2) / (xticks - 1)
newLabels = []
# first tick. Empty if 0 ticks
if xticks >= 1
push!(newLabels, labels[1])
else
push!(newLabels, "")
end
# fill in ticks in the middle
for i = 1:(lLength-2)
# checks if we're at the end of a section, but NO tick on the very last section
if i % sectionSize < 1 && i < ((xticks - 1) * sectionSize)
push!(newLabels, labels[i+1])
else
push!(newLabels, "")
end
end
# last tick at the end
if xticks >= 2
push!(newLabels, labels[lLength-1])
else
push!(newLabels, "")
end
labels = newLabels
end
# plot Designmatrix
config.axis = merge(config.axis, (; xticks = (1:length(labels), labels)))
ax = Axis(f[1, 1]; config.axis...)
hm = heatmap!(ax, designmat'; config.visual...)
if isa(designmat, SparseMatrixCSC)
ax.yreversed = true
end
apply_layout_settings!(config; fig = f, hm = hm)
return f
end