Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/scitokens.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,68 @@ int scitoken_get_claim_string(const SciToken token, const char *key, char **valu
}


int scitoken_set_claim_string_list(const SciToken token, const char *key,
const char **value, char **err_msg)
{
auto real_token = reinterpret_cast<scitokens::SciToken*>(token);
if (real_token == nullptr) {
if (err_msg) *err_msg = strdup("NULL scitoken passed to scitoken_get_claim_string_list");
return -1;
}
std::vector<std::string> claim_list;
int idx = 0;
while (value[idx++]) {}
claim_list.reserve(idx);

idx = 0;
while (value[idx++]) {
claim_list.emplace_back(value[idx-1]);
}
real_token->set_claim_list(key, claim_list);

return 0;
}


int scitoken_get_claim_string_list(const SciToken token, const char *key, char ***value, char **err_msg) {
auto real_token = reinterpret_cast<scitokens::SciToken*>(token);
if (real_token == nullptr) {
if (err_msg) *err_msg = strdup("NULL scitoken passed to scitoken_get_claim_string_list");
return -1;
}
std::vector<std::string> claim_list;
try {
claim_list = real_token->get_claim_list(key);
} catch (std::exception &exc) {
if (err_msg) {*err_msg = strdup(exc.what());}
return -1;
}
auto claim_list_c = static_cast<char **>(malloc(sizeof(char **) * (claim_list.size() + 1)));
claim_list_c[claim_list.size()] = nullptr;
int idx = 0;
for (const auto &entry : claim_list) {
claim_list_c[idx] = strdup(entry.c_str());
if (!claim_list_c[idx]) {
scitoken_free_string_list(claim_list_c);
if (err_msg) {*err_msg = strdup("Failed to create a copy of string entry in list");}
return -1;
}
idx++;
}
*value = claim_list_c;
return 0;
}


void scitoken_free_string_list(char **value) {
int idx = 0;
do {
free(value[idx++]);
} while (value[idx]);
free(value);
}


int scitoken_get_expiration(const SciToken token, long long *expiry, char **err_msg) {
scitokens::SciToken *real_token = reinterpret_cast<scitokens::SciToken*>(token);
if (!real_token->has_claim("exp")) {
Expand Down
21 changes: 21 additions & 0 deletions src/scitokens.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,27 @@ int scitoken_set_claim_string(SciToken token, const char *key, const char *value

int scitoken_get_claim_string(const SciToken token, const char *key, char **value, char **err_msg);

/**
* Given a SciToken object, parse a specific claim's value as a list of strings. If the JSON value
* is not actually a list of strings - or the claim is not set - returns an error and sets the
* err_msg appropriately.
*
* The returned value is a list of strings that ends with a nullptr.
*/
int scitoken_get_claim_string_list(const SciToken token, const char *key, char ***value, char **err_msg);

/**
* Given a list of strings that was returned by scitoken_get_claim_string_list, free all the associated
* memory.
*/
void scitoken_free_string_list(char **value);

/**
* Set the value of a claim to a list of strings.
*/
int scitoken_set_claim_string_list(const SciToken token, const char *key,
const char **values, char **err_msg);

int scitoken_get_expiration(const SciToken token, long long *value, char **err_msg);

void scitoken_set_lifetime(SciToken token, int lifetime);
Expand Down
25 changes: 25 additions & 0 deletions src/scitokens_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ friend class scitokens::Validator;
return m_claims.find(key) != m_claims.end();
}

void
set_claim_list(const std::string &claim, std::vector<std::string> &claim_list) {
picojson::array array;
array.reserve(claim_list.size());
for (const auto &entry : claim_list) {
array.emplace_back(entry);
}
m_claims[claim] = jwt::claim(picojson::value(array));
}

// Return a claim as a string
// If the claim is not a string, it can throw
// a std::bad_cast() exception.
Expand All @@ -157,6 +167,21 @@ friend class scitokens::Validator;
return m_claims[key].as_string();
}

const std::vector<std::string>
get_claim_list(const std::string &key) {
picojson::array array;
try {
array = m_claims[key].as_array();
} catch (std::bad_cast &) {
throw JsonException("Claim's value is not a JSON list");
}
std::vector<std::string> result;
for (const auto &value : array) {
result.emplace_back(value.get<std::string>());
}
return result;
}

void
set_lifetime(int lifetime) {
m_lifetime = lifetime;
Expand Down
26 changes: 26 additions & 0 deletions test/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ class SerializeTest : public ::testing::Test {
"1", ec_public, &err_msg);
ASSERT_TRUE(rv == 0);

const char *groups[3] = {nullptr, nullptr, nullptr};
const char group0[] = "group0";
const char group1[] = "group1";
groups[0] = group0;
groups[1] = group1;
rv = scitoken_set_claim_string_list(m_token.get(), "groups", groups,
&err_msg);
ASSERT_TRUE(rv == 0);

m_read_token.reset(scitoken_create(nullptr));
ASSERT_TRUE(m_read_token.get() != nullptr);
}
Expand Down Expand Up @@ -115,6 +124,23 @@ TEST_F(SerializeTest, VerifyTest) {
EXPECT_FALSE(rv == 0);
}

TEST_F(SerializeTest, TestStringList) {
char *err_msg = nullptr;

char **value;
auto rv = scitoken_get_claim_string_list(m_token.get(), "groups", &value, &err_msg);
ASSERT_TRUE(rv == 0);
ASSERT_TRUE(value != nullptr);

ASSERT_TRUE(value[0] != nullptr);
EXPECT_STREQ(value[0], "group0");

ASSERT_TRUE(value[1] != nullptr);
EXPECT_STREQ(value[1], "group1");

EXPECT_TRUE(value[2] == nullptr);
}


TEST_F(SerializeTest, VerifyWLCGTest) {

Expand Down