Skip to content

Commit

Permalink
refactor: replace binary by linear search in the C implementation of …
Browse files Browse the repository at this point in the history
…closest; also fixes #65
  • Loading branch information
sgibb committed Sep 24, 2020
1 parent 2013f41 commit 47f74b6
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 160 deletions.
288 changes: 128 additions & 160 deletions src/closest.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,56 +7,6 @@
#include <Rinternals.h>
#include <math.h>

/**
* Find leftmost interval.
*
* This function is a simplified version of R's findInterval.
*
* \param x key value to look for.
* \param ptable table/haystack where to look for.
* \param low lowest/leftmost index to start the search.
* \param n highest/rightmost index.
* \return index of the leftmost element
*
* \note For easier implementation we ignore some corner cases here. Index `0`
* and `n - 1` have to be handled outside this function. Also if it is called
* multiple times (e.g. if x is an array in the calling function) x has to
* sorted increasingly (we never lower `low`).
*
* \sa r-source/src/appl/interv.c
*/
int leftmost(double x, double* ptable, int low, int n) {
int high = low + 1, mid = high;

/* same interval as the last one */
if (x < ptable[high] && ptable[low] <= x)
return low;

/* exponentional search */
for (int step = 1;; step *= 2) {
/* x is still > ptable[high] so we can update low as well to keep search
* space small
*/
low = high;
high = low + step;
if (high >= n || x < ptable[high])
break;
}
high = (high >= n) ? n - 1 : high;

/* binary search */
while(1) {
mid = (low + high) / 2;
if (mid == low)
break;
if (x > ptable[mid])
low = mid;
else
high = mid;
}
return low;
}

/**
* Find closest value to table, keep duplicates.
*
Expand All @@ -70,91 +20,40 @@ int leftmost(double x, double* ptable, int low, int n) {
* \note x and table have to be sorted increasingly and not containing any NA.
*/
SEXP C_closest_dup_keep(SEXP x, SEXP table, SEXP tolerance, SEXP nomatch) {
int nx = LENGTH(x);
double* px = REAL(x);
double *px = REAL(x);
const unsigned int nx = LENGTH(x);

int ntable = LENGTH(table);
double* ptable = REAL(table);
double *ptable = REAL(table);
const unsigned int ntable1 = LENGTH(table) - 1;

double* ptolerance = REAL(tolerance);
double *ptolerance = REAL(tolerance);

SEXP out = PROTECT(allocVector(INTSXP, nx));
int* pout = INTEGER(out);

int low = 0, cur = 0;

for (int i = 0; i < nx; ++i) {
if (px[i] < ptable[0])
cur = 0;
else if (px[i] >= ptable[ntable - 1])
cur = ntable - 1;
else {
low = leftmost(px[i], ptable, low, ntable);
cur = (px[i] - ptable[low] <= ptable[low + 1] - px[i]) ?
low : low + 1;
}

/* cur + 1 is needed here to translate between R's and C's indices */
pout[i] = fabs(px[i] - ptable[cur]) <= ptolerance[i] ?
cur + 1 : asInteger(nomatch);
}
unsigned int j = 1;

UNPROTECT(1);
return(out);
}
double prevdiff = R_PosInf, nextdiff = R_PosInf;

/**
* Find closest value to table, keep just closest duplicate.
* This function should be reused in the outer join C implementation.
*
* \param pout pointer to output array.
* \param px pointer to arrays with key values to look for.
* \param nx number of elements in `px`.
* \param ptable table/haystack where to look for.
* \param ntable number of elements in `ptable`.
* \param ptolerance pointer to tolerance array that stores the tolerance
* to be accepted as match, has to be of length == length(x).
* \param nomatch value that should be returned if a key couldn't be matched.
* \return index the closest element
*
* \note x and table have to be sorted increasingly and not containing any NA.
*/
void closest_dup_closest(int *pout, double *px, int nx,
double *ptable, int ntable,
double *ptolerance, int nomatch) {
for (unsigned int i = 0; i < nx; ++i) {
while(j < ntable1 && ptable[j] < px[i])
++j;

int low = 0, cur = 0;
int last = R_NegInf;
double absdiff, lastdiff = R_PosInf;

for (int i = 0; i < nx; ++i) {
if (px[i] < ptable[0])
cur = 0;
else if (px[i] >= ptable[ntable - 1])
cur = ntable - 1;
else {
low = leftmost(px[i], ptable, low, ntable);
cur = (px[i] - ptable[low] <= ptable[low + 1] - px[i]) ?
low : low + 1;
}
/* fabs should be just needed for the first element */
prevdiff = fabs(px[i] - ptable[j - 1]);
nextdiff = fabs(ptable[j] - px[i]);

absdiff = fabs(px[i] - ptable[cur]);
if (absdiff <= ptolerance[i]) {
if (last == cur) {
if (absdiff < lastdiff && last == cur) {
pout[i] = cur + 1;
pout[i - 1] = nomatch;
lastdiff = absdiff;
} else
pout[i] = nomatch;
} else {
pout[i] = cur + 1;
last = cur;
lastdiff = absdiff;
}
if (prevdiff <= ptolerance[i] || nextdiff <= ptolerance[i]) {
if (prevdiff <= nextdiff)
pout[i] = j;
else
pout[i] = ++j;
} else
pout[i] = nomatch;
pout[i] = asInteger(nomatch);
}

UNPROTECT(1);
return out;
}

/**
Expand All @@ -170,18 +69,76 @@ void closest_dup_closest(int *pout, double *px, int nx,
* \note x and table have to be sorted increasingly and not containing any NA.
*/
SEXP C_closest_dup_closest(SEXP x, SEXP table, SEXP tolerance, SEXP nomatch) {
int nx = LENGTH(x);
double *px = REAL(x);
const unsigned int nx = LENGTH(x);

double *ptable = REAL(table);
const unsigned int ntable1 = LENGTH(table) - 1;

double *ptolerance = REAL(tolerance);

SEXP out = PROTECT(allocVector(INTSXP, nx));
int* pout = INTEGER(out);

unsigned int j = 1, lastj = 0;
double prevdiff = R_PosInf, nextdiff = R_PosInf, lastdiff = R_PosInf;

closest_dup_closest(
INTEGER(out),
REAL(x), nx,
REAL(table), LENGTH(table),
REAL(tolerance), asInteger(nomatch)
);
for (int i = 0; i < nx; ++i) {
while(j < ntable1 && ptable[j] < px[i])
++j;

/* fabs should be just needed for the first element */
prevdiff = fabs(px[i] - ptable[j - 1]);
nextdiff = fabs(ptable[j] - px[i]);

if (prevdiff <= ptolerance[i] || nextdiff <= ptolerance[i]) {
if (prevdiff < nextdiff) {
/* match on the left */
if (lastj == j) {
if (prevdiff < lastdiff) {
/* same match as before but with smaller difference */
pout[i] = j;
pout[i - 1] = asInteger(nomatch);
lastdiff = prevdiff;
} else
pout[i] = asInteger(nomatch);
} else {
pout[i] = j;
lastdiff = prevdiff;
}
} else if (prevdiff > nextdiff) {
/* match on the right */
if (lastj == j + 1) {
if (nextdiff < lastdiff) {
/* same match as before but with smaller difference */
pout[i] = ++j;
pout[i - 1] = asInteger(nomatch);
lastdiff = nextdiff;
} else
pout[i] = asInteger(nomatch);
} else {
pout[i] = ++j;
lastdiff = nextdiff;
}
} else {
/* match on the left or right */
if (lastj == j && prevdiff > lastdiff) {
pout[i] = ++j;
lastdiff = nextdiff;
} else {
pout[i] = j;
lastdiff = prevdiff;
}
}
lastj = j;
} else {
pout[i] = asInteger(nomatch);
lastdiff = R_PosInf;
}
}

UNPROTECT(1);
return(out);
return out;
}

/**
Expand All @@ -197,40 +154,51 @@ SEXP C_closest_dup_closest(SEXP x, SEXP table, SEXP tolerance, SEXP nomatch) {
* \note x and table have to be sorted increasingly and not containing any NA.
*/
SEXP C_closest_dup_remove(SEXP x, SEXP table, SEXP tolerance, SEXP nomatch) {
int nx = LENGTH(x);
double* px = REAL(x);
double *px = REAL(x);
const unsigned int nx = LENGTH(x);

int ntable = LENGTH(table);
double* ptable = REAL(table);
double *ptable = REAL(table);
const unsigned int ntable1 = LENGTH(table) - 1;

double* ptolerance = REAL(tolerance);
double *ptolerance = REAL(tolerance);

SEXP out = PROTECT(allocVector(INTSXP, nx));
int* pout = INTEGER(out);

int low = 0, cur = 0;
int last = R_NegInf;

for (int i = 0; i < nx; ++i) {
if (px[i] < ptable[0])
cur = 0;
else if (px[i] >= ptable[ntable - 1])
cur = ntable - 1;
else {
low = leftmost(px[i], ptable, low, ntable);
cur = (px[i] - ptable[low] <= ptable[low + 1] - px[i]) ?
low : low + 1;
}

if (fabs(px[i] - ptable[cur] <= ptolerance[i])) {
if (last != cur)
pout[i] = cur + 1;
else
pout[i - 1] = pout[i] = asInteger(nomatch);
last = cur;
}
}
unsigned int j = 1, lastj = 0;
double prevdiff = R_PosInf, nextdiff = R_PosInf;

for (unsigned int i = 0; i < nx; ++i) {
while(j < ntable1 && ptable[j] < px[i])
++j;

/* fabs should be just needed for the first element */
prevdiff = fabs(px[i] - ptable[j - 1]);
nextdiff = fabs(ptable[j] - px[i]);

if (prevdiff <= ptolerance[i] || nextdiff <= ptolerance[i]) {
if (prevdiff <= nextdiff) {
/* match on the left */
if (lastj == j) {
pout[i] = asInteger(nomatch);
pout[i - 1] = asInteger(nomatch);
} else {
pout[i] = j;
}
} else {
/* match on the right */
if (lastj == j + 1) {
pout[i] = asInteger(nomatch);
pout[i - 1] = asInteger(nomatch);
} else {
pout[i] = ++j;
}
}
} else
pout[i] = asInteger(nomatch);
lastj = j;
}

UNPROTECT(1);
return(out);
UNPROTECT(1);
return out;
}
7 changes: 7 additions & 0 deletions tests/testthat/test_matching.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ test_that("closest, duplicates", {
expect_equal(closest(1.6, 1:2, tolerance = 0.5, duplicates = "keep"), 2)
expect_equal(closest(1.6, 1:2, tolerance = 0.5, duplicates = "closest"), 2)
expect_equal(closest(1.5, 1:2, tolerance = 0.5, duplicates = "remove"), 1)

# equal distances, see
# https://github.com/rformassspectrometry/MsCoreUtils/issues/65
x <- as.numeric(c(1, 3, 5, 6, 8))
y <- as.numeric(c(3, 4, 5, 7))
expect_equal(closest(x, y, tolerance = 1, duplicates = "closest"),
c(NA, 1, 3, 4, NA))
})

test_that("common", {
Expand Down

0 comments on commit 47f74b6

Please sign in to comment.