Skip to content

Commit

Permalink
Merge branch 'hash-field-expiry-integ' into hfe-integ-aof-and-replica
Browse files Browse the repository at this point in the history
  • Loading branch information
moticless committed May 22, 2024
2 parents 9b5bd8f + 95cbe87 commit ef1b9e9
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 64 deletions.
2 changes: 1 addition & 1 deletion src/db.c
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,7 @@ void scanGenericCommand(client *c, robj *o, unsigned long long cursor) {
val = p; /* Keep pointer to value */

p = lpNext(lp, p);
serverAssert(!lpGetValue(p, NULL, &expire_at));
serverAssert(p && lpGetIntegerValue(p, &expire_at));

if (hashTypeIsExpired(o, expire_at) ||
(use_pattern && !stringmatchlen(pat, sdslen(pat), (char *)str, len, 0)))
Expand Down
12 changes: 12 additions & 0 deletions src/listpack.c
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,18 @@ unsigned char *lpGetValue(unsigned char *p, unsigned int *slen, long long *lval)
return vstr;
}

/* This is just a wrapper to lpGet() that is able to get an integer from an entry directly.
* Returns 1 and stores the integer in 'lval' if the entry is an integer.
* Returns 0 if the entry is a string. */
int lpGetIntegerValue(unsigned char *p, long long *lval) {
int64_t ele_len;
if (!lpGet(p, &ele_len, NULL)) {
*lval = ele_len;
return 1;
}
return 0;
}

/* Find pointer to the entry equal to the specified entry. Skip 'skip' entries
* between every comparison. Returns NULL when the field could not be found. */
unsigned char *lpFind(unsigned char *lp, unsigned char *p, unsigned char *s,
Expand Down
1 change: 1 addition & 0 deletions src/listpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ unsigned char *lpDup(unsigned char *lp);
unsigned long lpLength(unsigned char *lp);
unsigned char *lpGet(unsigned char *p, int64_t *count, unsigned char *intbuf);
unsigned char *lpGetValue(unsigned char *p, unsigned int *slen, long long *lval);
int lpGetIntegerValue(unsigned char *p, long long *lval);
unsigned char *lpFind(unsigned char *lp, unsigned char *p, unsigned char *s, uint32_t slen, unsigned int skip);
unsigned char *lpFirst(unsigned char *lp);
unsigned char *lpLast(unsigned char *lp);
Expand Down
9 changes: 5 additions & 4 deletions src/object.c
Original file line number Diff line number Diff line change
Expand Up @@ -490,10 +490,11 @@ void dismissHashObject(robj *o, size_t size_hint) {
/* Dismiss hash table memory. */
dismissMemory(d->ht_table[0], DICTHT_SIZE(d->ht_size_exp[0])*sizeof(dictEntry*));
dismissMemory(d->ht_table[1], DICTHT_SIZE(d->ht_size_exp[1])*sizeof(dictEntry*));
} else if (o->encoding == OBJ_ENCODING_LISTPACK ||
o->encoding == OBJ_ENCODING_LISTPACK_EX) {
unsigned char *lp = hashTypeListpackGetLp(o);
dismissMemory(lp, lpBytes(lp));
} else if (o->encoding == OBJ_ENCODING_LISTPACK) {
dismissMemory(o->ptr, lpBytes((unsigned char*)o->ptr));
} else if (o->encoding == OBJ_ENCODING_LISTPACK_EX) {
listpackEx *lpt = o->ptr;
dismissMemory(lpt->lp, lpBytes((unsigned char*)lpt->lp));
} else {
serverPanic("Unknown hash encoding type");
}
Expand Down
22 changes: 21 additions & 1 deletion src/rdb.c
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,7 @@ static int _lpEntryValidation(unsigned char *p, unsigned int head_count, void *u
int tuple_len;
long count;
dict *fields;
long long last_expireat;
} *data = userdata;

if (data->fields == NULL) {
Expand All @@ -1840,6 +1841,19 @@ static int _lpEntryValidation(unsigned char *p, unsigned int head_count, void *u
}
}

/* Validate TTL field, only for listpackex. */
if (data->count % data->tuple_len == 2) {
long long expire_at;
/* Must be an integer. */
if (!lpGetIntegerValue(p, &expire_at)) return 0;
/* Must be less than EB_EXPIRE_TIME_MAX. */
if (expire_at < 0 || (unsigned long long)expire_at > EB_EXPIRE_TIME_MAX) return 0;
/* TTL fields are ordered. If the current field has TTL, the previous field must
* also have one, and the current TTL must be greater than the previous one. */
if (expire_at != 0 && (data->last_expireat == 0 || expire_at < data->last_expireat)) return 0;
data->last_expireat = expire_at;
}

(data->count)++;
return 1;
}
Expand All @@ -1859,7 +1873,8 @@ int lpValidateIntegrityAndDups(unsigned char *lp, size_t size, int deep, int tup
int tuple_len;
long count;
dict *fields; /* Initialisation at the first callback. */
} data = {tuple_len, 0, NULL};
long long last_expireat; /* Last field's expiry time to ensure order in TTL fields. */
} data = {tuple_len, 0, NULL, -1};

int ret = lpValidateIntegrity(lp, size, 1, _lpEntryValidation, &data);

Expand Down Expand Up @@ -2257,6 +2272,11 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, redisDb* db, int rdbflags,
if (dupSearchDict != NULL) dictRelease(dupSearchDict);
return NULL;
}
if (expire > EB_EXPIRE_TIME_MAX) {
rdbReportCorruptRDB("invalid expire time: %llu", (unsigned long long)expire);
decrRefCount(o);
return NULL;
}

/* read the field name */
if ((field = rdbGenericLoadStringObject(rdb, RDB_LOAD_SDS, &fieldLen)) == NULL) {
Expand Down
78 changes: 23 additions & 55 deletions src/t_hash.c
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ static uint64_t listpackExExpireDryRun(const robj *o) {
serverAssert(o->encoding == OBJ_ENCODING_LISTPACK_EX);

uint64_t expired = 0;
unsigned char *fptr, *s;
unsigned char *fptr;
listpackEx *lpt = o->ptr;

fptr = lpFirst(lpt->lp);
Expand All @@ -356,10 +356,7 @@ static uint64_t listpackExExpireDryRun(const robj *o) {
fptr = lpNext(lpt->lp, fptr);
serverAssert(fptr);
fptr = lpNext(lpt->lp, fptr);
serverAssert(fptr);

s = lpGetValue(fptr, NULL, &val);
serverAssert(!s);
serverAssert(fptr && lpGetIntegerValue(fptr, &val));

if (!hashTypeIsExpired(o, val))
break;
Expand All @@ -376,15 +373,14 @@ static uint64_t listpackExGetMinExpire(robj *o) {
serverAssert(o->encoding == OBJ_ENCODING_LISTPACK_EX);

long long expireAt;
unsigned char *fptr, *s;
unsigned char *fptr;
listpackEx *lpt = o->ptr;

/* As fields are ordered by expire time, first field will have the smallest
* expiry time. Third element is the expiry time of the first field */
fptr = lpSeek(lpt->lp, 2);
if (fptr != NULL) {
s = lpGetValue(fptr, NULL, &expireAt);
serverAssert(!s);
serverAssert(lpGetIntegerValue(fptr, &expireAt));

/* Check if this is a non-volatile field. */
if (expireAt != HASH_LP_NO_TTL)
Expand All @@ -399,7 +395,7 @@ void listpackExExpire(redisDb *db, robj *o, ExpireInfo *info) {
char buf[LONG_STR_SIZE];
serverAssert(o->encoding == OBJ_ENCODING_LISTPACK_EX);
uint64_t min = EB_EXPIRE_TIME_INVALID;
unsigned char *ptr, *pTuple, *s;
unsigned char *ptr, *pTuple;
listpackEx *lpt = o->ptr;

ptr = lpFirst(lpt->lp);
Expand All @@ -414,10 +410,7 @@ void listpackExExpire(redisDb *db, robj *o, ExpireInfo *info) {
ptr = lpNext(lpt->lp, ptr);
serverAssert(ptr);
ptr = lpNext(lpt->lp, ptr);
serverAssert(ptr);

s = lpGetValue(ptr, NULL, &val);
serverAssert(!s);
serverAssert(ptr && lpGetIntegerValue(ptr, &val));

/* Fields are ordered by expiry time. If we reached to a non-expired
* or a non-volatile field, we know rest is not yet expired. */
Expand Down Expand Up @@ -493,7 +486,7 @@ static void listpackExUpdateExpiry(robj *o, sds field,
unsigned int slen;
long long val;
unsigned char tmp[512] = {0};
unsigned char *valstr, *s, *elem;
unsigned char *valstr, *elem;
listpackEx *lpt = o->ptr;
sds tmpval = NULL;

Expand Down Expand Up @@ -521,10 +514,7 @@ static void listpackExUpdateExpiry(robj *o, sds field,
fptr = lpNext(lpt->lp, fptr);
serverAssert(fptr);
fptr = lpNext(lpt->lp, fptr);
serverAssert(fptr);

s = lpGetValue(fptr, NULL, &currExpiry);
serverAssert(!s);
serverAssert(fptr && lpGetIntegerValue(fptr, &currExpiry));

if (currExpiry == HASH_LP_NO_TTL || (uint64_t) currExpiry >= expireAt) {
/* Found a field with no expiry time or with a higher expiry time.
Expand Down Expand Up @@ -565,7 +555,7 @@ static void listpackExUpdateExpiry(robj *o, sds field,

/* Add new field ordered by expire time. */
void listpackExAddNew(robj *o, sds field, sds value, uint64_t expireAt) {
unsigned char *fptr, *s, *elem;
unsigned char *fptr, *elem;
listpackEx *lpt = o->ptr;

/* Shortcut, just append at the end if this is a non-volatile field. */
Expand All @@ -581,10 +571,7 @@ void listpackExAddNew(robj *o, sds field, sds value, uint64_t expireAt) {
fptr = lpNext(lpt->lp, fptr);
serverAssert(fptr);
fptr = lpNext(lpt->lp, fptr);
serverAssert(fptr);

s = lpGetValue(fptr, NULL, &currExpiry);
serverAssert(!s);
serverAssert(fptr && lpGetIntegerValue(fptr, &currExpiry));

if (currExpiry == HASH_LP_NO_TTL || (uint64_t) currExpiry >= expireAt) {
/* Found a field with no expiry time or with a higher expiry time.
Expand Down Expand Up @@ -618,10 +605,8 @@ SetExRes hashTypeSetExpiryListpack(HashTypeSetEx *ex, sds field,
{
long long expireTime;
uint64_t prevExpire = EB_EXPIRE_TIME_INVALID;
unsigned char *s;

s = lpGetValue(tptr, NULL, &expireTime);
serverAssert(!s);
serverAssert(lpGetIntegerValue(tptr, &expireTime));

if (expireTime != HASH_LP_NO_TTL) {
prevExpire = (uint64_t) expireTime;
Expand Down Expand Up @@ -753,10 +738,7 @@ GetFieldRes hashTypeGetFromListpack(robj *o, sds field,
serverAssert(vptr != NULL);

h = lpNext(lpt->lp, vptr);
serverAssert(h != NULL);

h = lpGetValue(h, NULL, &expire);
serverAssert(h == NULL);
serverAssert(h && lpGetIntegerValue(h, &expire));

if (hashTypeIsExpired(o, expire))
return GET_FIELD_EXPIRED;
Expand Down Expand Up @@ -1276,7 +1258,6 @@ static SetExRes hashTypeSetExListpack(redisDb *db, robj *o, sds field, HashTypeS
if (fptr != NULL) {
fptr = lpFind(lpt->lp, fptr, (unsigned char*)field, sdslen(field), 2);
if (fptr != NULL) {
unsigned char *p;
/* Grab pointer to the value (fptr points to the field) */
vptr = lpNext(lpt->lp, fptr);
serverAssert(vptr != NULL);
Expand All @@ -1292,9 +1273,7 @@ static SetExRes hashTypeSetExListpack(redisDb *db, robj *o, sds field, HashTypeS
res = HSET_UPDATE;
}
tptr = lpNext(lpt->lp, vptr);
serverAssert(tptr != NULL);
p = lpGetValue(tptr, NULL, &expireTime);
serverAssert(!p);
serverAssert(tptr && lpGetIntegerValue(tptr, &expireTime));

if (ex) {
res = hashTypeSetExpiryListpack(ex, field, fptr, vptr, tptr,
Expand Down Expand Up @@ -1487,9 +1466,7 @@ int hashTypeNext(hashTypeIterator *hi, int skipExpiredFields) {
serverAssert(vptr != NULL);

tptr = lpNext(zl, vptr);
serverAssert(tptr != NULL);

lpGetValue(tptr, NULL, &expire_time);
serverAssert(tptr && lpGetIntegerValue(tptr, &expire_time));

if (!skipExpiredFields || !hashTypeIsExpired(hi->subject, expire_time))
break;
Expand Down Expand Up @@ -1737,7 +1714,7 @@ void hashTypeConvertListpackEx(robj *o, int enc, ebuckets *hexpires) {
hfieldFree(key); sdsfree(value); /* Needed for gcc ASAN */
hashTypeReleaseIterator(hi); /* Needed for gcc ASAN */
serverLogHexDump(LL_WARNING,"listpack with dup elements dump",
o->ptr,lpBytes(o->ptr));
lpt->lp,lpBytes(lpt->lp));
serverPanic("Listpack corruption detected");
}

Expand Down Expand Up @@ -2920,9 +2897,7 @@ static void httlGenericCommand(client *c, const char *cmd, long long basetime, i
fptr = lpNext(lpt->lp, fptr);
serverAssert(fptr);
fptr = lpNext(lpt->lp, fptr);
serverAssert(fptr);

lpGetValue(fptr, NULL, &expire);
serverAssert(fptr && lpGetIntegerValue(fptr, &expire));

if (expire == HASH_LP_NO_TTL) {
addReplyLongLong(c, HFE_GET_NO_TTL);
Expand Down Expand Up @@ -3166,7 +3141,7 @@ void hpersistCommand(client *c) {
return;
} else if (hashObj->encoding == OBJ_ENCODING_LISTPACK_EX) {
long long prevExpire;
unsigned char *fptr, *vptr, *tptr, *s;
unsigned char *fptr, *vptr, *tptr;
listpackEx *lpt = hashObj->ptr;

addReplyArrayLen(c, numFields);
Expand All @@ -3185,10 +3160,7 @@ void hpersistCommand(client *c) {
vptr = lpNext(lpt->lp, fptr);
serverAssert(vptr);
tptr = lpNext(lpt->lp, vptr);
serverAssert(tptr);

s = lpGetValue(tptr, NULL, &prevExpire);
serverAssert(!s);
serverAssert(tptr && lpGetIntegerValue(tptr, &prevExpire));

if (prevExpire == HASH_LP_NO_TTL) {
addReplyLongLong(c, HFE_PERSIST_NO_TTL);
Expand Down Expand Up @@ -3525,7 +3497,7 @@ static int hgetfParseAndRewriteArgs(client *c, int *flags, uint64_t *expireAt,
static int hgetfReplyValueAndSetExpiry(client *c, robj *o, sds field, int flag,
uint64_t expireAt, uint64_t *minPrevExp)
{
unsigned char *fptr = NULL, *vptr = NULL, *tptr, *h;
unsigned char *fptr = NULL, *vptr = NULL, *tptr;
hfield hf = NULL;
dict *d = NULL;
dictEntry *de = NULL;
Expand Down Expand Up @@ -3565,10 +3537,7 @@ static int hgetfReplyValueAndSetExpiry(client *c, robj *o, sds field, int flag,
serverAssert(vptr != NULL);

tptr = lpNext(lpt->lp, vptr);
serverAssert(tptr != NULL);

h = lpGetValue(tptr, NULL, &expire);
serverAssert(h == NULL);
serverAssert(tptr && lpGetIntegerValue(tptr, &expire));

if (expire != HASH_LP_NO_TTL)
prevExpire = expire;
Expand Down Expand Up @@ -3757,7 +3726,7 @@ static int hsetfSetFieldAndReply(client *c, robj *o, sds field, sds value,

if (o->encoding == OBJ_ENCODING_LISTPACK_EX) {
long long expire;
unsigned char *fptr, *vptr = NULL, *tptr, *h;
unsigned char *fptr, *vptr = NULL, *tptr;
listpackEx *lpt = o->ptr;

fptr = lpFirst(lpt->lp);
Expand All @@ -3766,8 +3735,7 @@ static int hsetfSetFieldAndReply(client *c, robj *o, sds field, sds value,
if (fptr != NULL) {
vptr = lpNext(lpt->lp, fptr);
tptr = lpNext(lpt->lp, vptr);
h = lpGetValue(tptr, NULL, &expire);
serverAssert(!h);
serverAssert(tptr && lpGetIntegerValue(tptr, &expire));

if (expire != HASH_LP_NO_TTL)
prevExpire = expire;
Expand Down Expand Up @@ -3926,7 +3894,7 @@ static int hsetfParseAndRewriteArgs(client *c, int *flags, uint64_t *expireAt,
*firstFieldPos = -1;
*fieldCount = -1;

for (int i = 2 ; i < c->argc ; ++i) {
for (int i = 2 ; i < c->argc ; i++) {
int flag = 0;

if (!strcasecmp(c->argv[i]->ptr, "fvs")) {
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/corrupt-dump-fuzzer.tcl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ proc generate_collections {suffix elements} {
# add both string values and integers
if {$j % 2 == 0} {set val $j} else {set val "_$j"}
$rd hset hash$suffix $j $val
$rd hset hashmd$suffix $j $val
$rd hexpire hashmd$suffix [expr {int(rand() * 10000)}] FIELDS 1 $j
$rd lpush list$suffix $val
$rd zadd zset$suffix $j $val
$rd sadd set$suffix $val
$rd xadd stream$suffix * item 1 value $val
}
for {set j 0} {$j < $elements * 5} {incr j} {
for {set j 0} {$j < $elements * 7} {incr j} {
$rd read ; # Discard replies
}
$rd close
Expand Down

0 comments on commit ef1b9e9

Please sign in to comment.