-
Notifications
You must be signed in to change notification settings - Fork 0
/
slice.R
69 lines (67 loc) · 1.96 KB
/
slice.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#' Slice samples from a Magnitude table
#'
#' @param conn a Magnitude connection.
#' @param n integer.
#' @param offset integer.
#' @param normalized logical;
#' whether or not vector embeddings should be normalized?
#' @returns a tibble.
#' @export
slice_n <- function(conn, n, offset = 0, normalized = TRUE) {
n <- n[1]
offset <- offset[1]
if (offset > dim(conn)[1]) {
rlang::abort("`offset` must be smaller than the Magnitude size.")
}
res <-
RSQLite::dbSendQuery(conn,
"SELECT * FROM magnitude LIMIT ? OFFSET ?",
params = list(as.integer(n), as.integer(offset))
)
tbl <- RSQLite::dbFetch(res) %>%
tibble::as_tibble()
RSQLite::dbClearResult(res)
db_result_to_vec(conn, tbl, normalized)
}
#' Slice samples by index from a Magnitude table
#'
#' @param conn a Magnitude connection.
#' @param index integer vector.
#' @param normalized logical;
#' whether or not vector embeddings should be normalized?
#' @returns a tibble.
#' @export
slice_index <- function(conn, index, normalized = TRUE) {
if (max(index) > dim(conn)[1]) {
rlang::warn("`index` bigger than the Magnitude table size is ignored.")
}
index <- as.integer(index[!index > dim(conn)[1]])
res <-
RSQLite::dbSendQuery(conn,
"SELECT * FROM magnitude WHERE ROWID IN (?)",
params = list(index)
)
tbl <- RSQLite::dbFetch(res) %>%
tibble::as_tibble()
RSQLite::dbClearResult(res)
db_result_to_vec(conn, tbl, normalized)
}
#' Slice samples by fraction from a Magnitude table
#'
#' @param conn a Magnitude connection.
#' @param frac numeric.
#' @param normalized logical;
#' whether or not vector embeddings should be normalized?
#' @returns a tibble.
#' @export
slice_frac <- function(conn, frac = .001, normalized = TRUE) {
if (frac > 1) {
rlang::abort("`frac` must be smaller than 1.")
}
size <- dim(conn)[1]
index <- sample(seq_len(size),
size = trunc(size * as.numeric(frac)),
replace = FALSE
)
slice_index(conn, index, normalized)
}