-
Notifications
You must be signed in to change notification settings - Fork 15
/
get_plot_forecast_data.R
224 lines (205 loc) · 8.66 KB
/
get_plot_forecast_data.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
#' Combine [load_truth()] and [pivot_forecasts_wider()]
#'
#' @param forecast_data required data.frame with forecasts in the format returned
#' by [load_forecasts()].
#' It has columns `model`, `forecast_date`, `location`, `target`, `type`, `quantile`,
#' `value`, `horizon` and `target_end_date`.
#' @param truth_data optional data.frame from one truth source in the format returned
#' by [load_truth()]. It needs to have columns `model`, `target_variable`,
#' `target_end_date`, `location` and `value`.
#' `Model` column can be "Observed Data (a truth source)".
#' @param models_to_plot characters of model abbreviations
#' @param forecast_dates_to_plot date string vectors for forecast dates to plot.
#' Default to all forecast dates available in `forecast_data`.
#' @param horizons_to_plot forecasts are plotted for the horizon time steps after
#' the forecast date.
#' @param quantiles_to_plot vector of quantiles to include in the plot
#' @param locations_to_plot a vector of strings of fips code or CBSA codes or location names,
#' such as "Hampshire County, MA", "Alabama", "United Kingdom".
#' A US county location names must include state abbreviation.
#' Default to `NULL` which would include all locations available in `forecast_data`.
#' @param plot_truth logical to indicate whether truth data should be plotted.
#' Default to `TRUE`.
#' @param truth_source character specifying where the truth data will
#' be loaded from if truth_data is not provided. Currently support `"JHU"`,
#' `NYTimes"` and `"HealthData"`.
#' Optional if `truth_data` is provided.
#' @param target_variable_to_plot string specifying target type. It should be one of
#' `"cum death"`, `"inc case"`, `"inc death"`, `"inc hosp"` and `"inc flu hosp"`.
#' @param truth_as_of the plot includes the truth data that would have been
#' in real time as of the `truth_as_of` date.
#' @param hub character, which hub to use. Default is `"US"`.
#' Other options are `"ECDC"` and `"FluSight"`.
#'
#' @return data.frame with columns `model`,
#' `forecast_date`, `location`, `target_variable`, `type`, `quantile`, `value`,
#' `horizon` and `target_end_date`.
#'
#' @export
get_plot_forecast_data <- function(forecast_data,
truth_data = NULL,
models_to_plot,
forecast_dates_to_plot,
horizons_to_plot,
quantiles_to_plot,
locations_to_plot = NULL,
plot_truth = TRUE,
truth_source,
target_variable_to_plot,
truth_as_of = NULL,
hub = c("US", "ECDC", "FluSight")) {
hub <- match.arg(hub,
choices = c("US", "ECDC", "FluSight"),
several.ok = TRUE
)
# get lists of valid parameter choices based on `hub`
if (hub[1] == "US") {
valid_location_codes <- covidHubUtils::hub_locations$fips
valid_target_variables <- c(
"cum death", "inc case",
"inc death", "inc hosp"
)
valid_truth_sources <- c("JHU", "NYTimes", "HealthData", "USAFacts")
} else if (hub[1] == "ECDC") {
valid_location_codes <- covidHubUtils::hub_locations_ecdc$location
valid_target_variables <- c("inc case", "inc death")
valid_truth_sources <- c("JHU", "jhu", "ECDC", "ecdc")
} else if (hub[1] == "FluSight") {
valid_location_codes <- covidHubUtils::hub_locations_flusight$fips
valid_target_variables <- c("inc flu hosp")
valid_truth_sources <- c("HealthData")
}
# validate locations_to_plot
if (missing(locations_to_plot)) {
locations_to_plot <- unique(forecast_data$location)
} else {
# Convert location names to fips codes or country abbreviations
locations_to_plot <- name_to_fips(locations_to_plot, hub)
}
locations_to_plot <- intersect(
as.character(locations_to_plot),
as.character(valid_location_codes)
)
# validate forecast_dates_to_plot
if (missing(forecast_dates_to_plot)) {
forecast_dates_to_plot <- unique(forecast_data$forecast_date)
} else {
forecast_dates_to_plot <- as.Date(forecast_dates_to_plot)
if (!all(forecast_dates_to_plot %in% forecast_data$forecast_date)) {
stop("Error in get_plot_forecast_data: Not all forecast_dates are available in forecast data.")
}
}
# validate truth data if provided
if (!is.null(truth_data)) {
# check if truth_data has all needed columns
columns_check <- all(c(
"model", "target_variable",
"target_end_date", "location", "value"
)
%in% colnames(truth_data))
if (columns_check == FALSE) {
stop("Error in get_plot_forecast_data: Please provide columns model,
target_variable, target_end_date, location and value in truth_data.")
} else {
# check if all fips codes in location column are valid
if (!all(truth_data$location %in% valid_location_codes)) {
stop("Error in get_plot_forecast_data: Please make sure all fips codes in location column are valid.")
}
# check if truth_data has data from specified location
if (!all(locations_to_plot %in% truth_data$location)) {
stop("Error in get_plot_forecast_data: Please provide a valid locations_to_plot.")
}
# check if truth_data has specified target variable
if (!(target_variable_to_plot %in% truth_data$target_variable)) {
stop("Error in get_plot_forecast_data: Please provide a valid target variable.")
}
}
} else {
# validate truth_source
if (tolower(truth_source) == "usafacts") {
stop("USAFacts can no longer be downloaded. Please use another truth source.")
}
truth_source <- match.arg(truth_source,
choices = valid_truth_sources,
several.ok = FALSE
)
}
# create temporal resolution for loading truth
if (target_variable_to_plot == "inc hosp") {
temporal_resolution <- "daily"
} else {
temporal_resolution <- "weekly"
}
# warning for truth_as_of
if (!is.null(truth_as_of)) {
warning("Warning in get_plot_forecast_data: Currently versioned truth data is not supported.")
}
# filter to include selected models, forecast dates, locations and target variable
forecast_data <- forecast_data %>%
dplyr::filter(
model %in% models_to_plot,
forecast_date %in% forecast_dates_to_plot,
location %in% locations_to_plot,
target_variable == target_variable_to_plot
)
if (!missing(horizons_to_plot)) {
forecast_data <- forecast_data %>%
dplyr::filter(horizon <= horizons_to_plot)
}
forecasts <- pivot_forecasts_wider(forecast_data, quantiles_to_plot) %>%
dplyr::mutate(truth_forecast = "forecast")
if (hub[1] == "US") {
forecasts <- forecasts %>%
dplyr::rename(abbr = location, location = full_location_name)
} else if (hub[1] == "ECDC") {
forecasts <- forecasts %>%
dplyr::rename(abbr = location, location = location_name)
} else if (hub[1] == "FluSight") {
forecasts <- forecasts %>%
dplyr::rename(abbr = location, location = location_name)
}
if (plot_truth) {
if (is.null(truth_data)) {
# call load_truth if the user did not provide truth_data
truth <- load_truth(
truth_source = truth_source,
target_variable = target_variable_to_plot,
locations = locations_to_plot,
temporal_resolution = temporal_resolution,
hub = hub
) %>%
dplyr::rename(point = value) %>%
dplyr::mutate(truth_forecast = "truth")
} else {
# process truth_data for plotting
truth <- truth_data %>%
dplyr::filter(
location %in% locations_to_plot,
target_variable == target_variable_to_plot
)
# add location info if user-provided truth does not have them
if ((!"location_name" %in% colnames(truth)) | (!"full_location_name" %in% colnames(truth))) {
truth <- truth %>%
dplyr::select(model, target_variable, target_end_date, location, value) %>%
join_with_hub_locations(hub = hub)
}
truth <- truth %>%
dplyr::rename(point = value) %>%
dplyr::mutate(truth_forecast = "truth", point = as.numeric(point))
}
if (hub[1] == "US") {
truth <- truth %>%
dplyr::rename(abbr = location, location = full_location_name)
} else if (hub[1] == "ECDC") {
truth <- truth %>%
dplyr::rename(abbr = location, location = location_name)
} else if (hub[1] == "FluSight") {
truth <- truth %>%
dplyr::rename(abbr = location, location = location_name)
}
plot_data <- dplyr::bind_rows(forecasts, truth)
return(plot_data)
} else {
return(forecasts)
}
}