42 changes: 27 additions & 15 deletions lib/files.c
Expand Up @@ -7,29 +7,37 @@
#include "log.h"
#include "tpm2_util.h"

static bool get_file_size(FILE *fp, unsigned long *file_size, const char *path) {
bool files_get_file_size(FILE *fp, unsigned long *file_size, const char *path) {

long current = ftell(fp);
if (current < 0) {
LOG_ERR("Error getting current file offset for file \"%s\" error: %s", path, strerror(errno));
if (path) {
LOG_ERR("Error getting current file offset for file \"%s\" error: %s", path, strerror(errno));
}
return false;
}

int rc = fseek(fp, 0, SEEK_END);
if (rc < 0) {
LOG_ERR("Error seeking to end of file \"%s\" error: %s", path, strerror(errno));
if (path) {
LOG_ERR("Error seeking to end of file \"%s\" error: %s", path, strerror(errno));
}
return false;
}

long size = ftell(fp);
if (size < 0) {
LOG_ERR("ftell on file \"%s\" failed: %s", path, strerror(errno));
if (path) {
LOG_ERR("ftell on file \"%s\" failed: %s", path, strerror(errno));
}
return false;
}

rc = fseek(fp, current, SEEK_SET);
if (rc < 0) {
LOG_ERR("Could not restore initial stream position for file \"%s\" failed: %s", path, strerror(errno));
if (path) {
LOG_ERR("Could not restore initial stream position for file \"%s\" failed: %s", path, strerror(errno));
}
return false;
}

Expand All @@ -41,23 +49,27 @@ static bool get_file_size(FILE *fp, unsigned long *file_size, const char *path)
static bool read_bytes_from_file(FILE *f, UINT8 *buf, UINT16 *size,
const char *path) {
unsigned long file_size;
bool result = get_file_size(f, &file_size, path);
bool result = files_get_file_size(f, &file_size, path);
if (!result) {
/* get_file_size() logs errors */
return false;
}

/* max is bounded on UINT16 */
if (file_size > *size) {
LOG_ERR(
"File \"%s\" size is larger than buffer, got %lu expected less than %u",
path, file_size, *size);
if (path) {
LOG_ERR(
"File \"%s\" size is larger than buffer, got %lu expected less than %u",
path, file_size, *size);
}
return false;
}

result = files_read_bytes(f, buf, file_size);
if (!result) {
LOG_ERR("Could not read data from file \"%s\"", path);
if (path) {
LOG_ERR("Could not read data from file \"%s\"", path);
}
return false;
}

Expand All @@ -66,7 +78,7 @@ static bool read_bytes_from_file(FILE *f, UINT8 *buf, UINT16 *size,
return true;
}

bool files_load_bytes_from_file(const char *path, UINT8 *buf, UINT16 *size) {
bool files_load_bytes_from_path(const char *path, UINT8 *buf, UINT16 *size) {
if (!buf || !size || !path) {
return false;
}
Expand All @@ -83,12 +95,12 @@ bool files_load_bytes_from_file(const char *path, UINT8 *buf, UINT16 *size) {
return result;
}

bool files_load_bytes_from_stdin(UINT8 *buf, UINT16 *size) {
bool files_load_bytes_from_file(FILE *file, UINT8 *buf, UINT16 *size, const char *path) {
if (!buf || !size) {
return false;
}

return read_bytes_from_file(stdin, buf, size, "stdin");
return read_bytes_from_file(file, buf, size, path);
}

bool files_save_bytes_to_file(const char *path, UINT8 *buf, UINT16 size) {
Expand Down Expand Up @@ -308,7 +320,7 @@ bool files_does_file_exist(const char *path) {
return false;
}

bool files_get_file_size(const char *path, unsigned long *file_size) {
bool files_get_file_size_path(const char *path, unsigned long *file_size) {

bool result = false;

Expand All @@ -328,7 +340,7 @@ bool files_get_file_size(const char *path, unsigned long *file_size) {
return false;
}

result = get_file_size(fp, file_size, path);
result = files_get_file_size(fp, file_size, path);

fclose(fp);
return result;
Expand Down
36 changes: 33 additions & 3 deletions lib/files.h
Expand Up @@ -21,7 +21,22 @@
* @return
* True on success, false otherwise.
*/
bool files_load_bytes_from_file(const char *path, UINT8 *buf, UINT16 *size);
bool files_load_bytes_from_path(const char *path, UINT8 *buf, UINT16 *size);

/**
* Reads a series of bytes from a stdio FILE object.
* @param file
* The file to read from.
* @param buf
* The buffer to read into.
* @param size
* The size of the buffer to read into.
* @param path
* An optional path for error reporting. A NULL path disables error logging.
* @return
* True on success, False otherwise.
*/
bool files_load_bytes_from_file(FILE *file, UINT8 *buf, UINT16 *size, const char *path);

/**
* Reads a series of bytes from the standard input as a byte array. This is similar to
Expand All @@ -36,7 +51,9 @@ bool files_load_bytes_from_file(const char *path, UINT8 *buf, UINT16 *size);
* @return
* True on success, false otherwise.
*/
bool files_load_bytes_from_stdin(UINT8 *buf, UINT16 *size);
static inline bool files_load_bytes_from_stdin(UINT8 *buf, UINT16 *size) {
return files_load_bytes_from_file(stdin, buf, size, "<stdin>");
}

/**
* Similar to files_write_bytes(), in that it writes an array of bytes to disk,
Expand Down Expand Up @@ -101,7 +118,20 @@ bool files_does_file_exist(const char *path);
* @return
* True for success or False for error.
*/
bool files_get_file_size(const char *path, unsigned long *file_size);
bool files_get_file_size_path(const char *path, unsigned long *file_size);

/**
* Similar to files_get_file_size_path(), but uses an already opened FILE object.
* @param fp
* The file pointer to query the size of.
* @param file_size
* Output of the file size.
* @param path
* An optional path used for error reporting, a NULL path disables error logging.
* @return
* True on success, False otherwise.
*/
bool files_get_file_size(FILE *fp, unsigned long *file_size, const char *path);

/**
* Writes a TPM2.0 header to a file.
Expand Down
6 changes: 3 additions & 3 deletions lib/tpm2_policy.c
Expand Up @@ -54,7 +54,7 @@ static bool evaluate_populate_pcr_digests(TPML_PCR_SELECTION pcr_selections,
//Check if the input pcrs file size is the same size as the pcr selection setlist
if (raw_pcrs_file) {
unsigned long filesize = 0;
bool result = files_get_file_size(raw_pcrs_file, &filesize);
bool result = files_get_file_size_path(raw_pcrs_file, &filesize);
if (!result) {
LOG_ERR("Could not retrieve raw_pcrs_file size");
return false;
Expand Down Expand Up @@ -123,8 +123,8 @@ TPM_RC tpm2_policy_pcr_build(TSS2_SYS_CONTEXT *sapi_context,
// Calculate hashes
TPM2B_DIGEST pcr_digest = TPM2B_TYPE_INIT(TPM2B_DIGEST, buffer);
rval = tpm_hash_sequence(sapi_context,
policy_session->authHash, pcr_values.count,
&pcr_values.digests[0], &pcr_digest);
policy_session->authHash, TPM_RH_NULL, pcr_values.count,
pcr_values.digests, &pcr_digest, NULL);
if (rval != TPM_RC_SUCCESS) {
return rval;
}
Expand Down