Navigation Menu

Skip to content

Commit

Permalink
add string list support for ProtobufFile
Browse files Browse the repository at this point in the history
  • Loading branch information
gsomix committed Sep 23, 2013
1 parent 345c989 commit 1f0ae3e
Show file tree
Hide file tree
Showing 5 changed files with 382 additions and 122 deletions.
312 changes: 215 additions & 97 deletions src/shogun/io/ProtobufFile.cpp
Expand Up @@ -39,9 +39,7 @@ void CProtobufFile::init()
version=1;
message_size=1024*1024;

// repeated field contains pairs key-value
// so we need in worst case double-sized buffer for elements
buffer=SG_MALLOC(uint8_t, message_size*2);
buffer=SG_MALLOC(uint8_t, message_size*sizeof(uint32_t));
}

#define GET_VECTOR(sg_type) \
Expand Down Expand Up @@ -157,7 +155,7 @@ void CProtobufFile::set_matrix(const sg_type* matrix, int32_t num_feat, int32_t
int32_t num_messages=compute_num_messages(num_feat*num_vec, sizeof(sg_type)); \
write_global_header(ShogunVersion::MATRIX); \
write_matrix_header(num_feat, num_vec, num_messages); \
write_memory_block(matrix, num_feat*num_vec, num_messages); \
write_memory_block(matrix, num_feat*num_vec, num_messages); \
}

SET_MATRIX(int8_t)
Expand Down Expand Up @@ -196,23 +194,21 @@ SET_SPARSE_MATRIX(SCNi16, int16_t)
SET_SPARSE_MATRIX(SCNu16, uint16_t)
#undef SET_SPARSE_MATRIX

void CProtobufFile::get_string_list(
SGString<char>*& strings, int32_t& num_str,
int32_t& max_string_len)
{
SG_NOTIMPLEMENTED
}

#define GET_STRING_LIST(sg_type) \
void CProtobufFile::get_string_list( \
SGString<sg_type>*& strings, int32_t& num_str, \
int32_t& max_string_len) \
{ \
SG_NOTIMPLEMENTED \
read_and_validate_global_header(ShogunVersion::STRING_LIST); \
StringListHeader data_header=read_string_list_header(); \
num_str=data_header.num_str(); \
max_string_len=data_header.max_string_len(); \
read_string_list(strings, data_header); \
}

GET_STRING_LIST(int8_t)
GET_STRING_LIST(uint8_t)
GET_STRING_LIST(char)
GET_STRING_LIST(int32_t)
GET_STRING_LIST(uint32_t)
GET_STRING_LIST(int64_t)
Expand All @@ -224,21 +220,19 @@ GET_STRING_LIST(int16_t)
GET_STRING_LIST(uint16_t)
#undef GET_STRING_LIST

void CProtobufFile::set_string_list(
const SGString<char>* strings, int32_t num_str)
{
SG_NOTIMPLEMENTED
}

#define SET_STRING_LIST(sg_type) \
void CProtobufFile::set_string_list( \
const SGString<sg_type>* strings, int32_t num_str) \
{ \
SG_NOTIMPLEMENTED \
write_global_header(ShogunVersion::STRING_LIST); \
uint64_t summary_len=write_string_list_header(strings, num_str); \
int32_t num_messages=compute_num_messages(summary_len, sizeof(sg_type)); \
write_string_list(strings, num_str, num_messages); \
}

SET_STRING_LIST(int8_t)
SET_STRING_LIST(uint8_t)
SET_STRING_LIST(char)
SET_STRING_LIST(int32_t)
SET_STRING_LIST(uint32_t)
SET_STRING_LIST(int64_t)
Expand All @@ -250,82 +244,10 @@ SET_STRING_LIST(int16_t)
SET_STRING_LIST(uint16_t)
#undef SET_STRING_LIST

#define READ_MEMORY_BLOCK(chunk_type, sg_type) \
void CProtobufFile::read_memory_block(sg_type*& vector, int32_t len, int32_t num_messages) \
{ \
vector=SG_MALLOC(sg_type, len); \
\
chunk_type chunk; \
int32_t elements_at_message=message_size/sizeof(sg_type); \
for (int32_t i=0; i<num_messages; i++) \
{ \
read_message(chunk); \
\
int32_t num_elements_to_read=0; \
if ((len-(i+1)*elements_at_message)<=0) \
num_elements_to_read=len-i*elements_at_message; \
else \
num_elements_to_read=elements_at_message; \
\
for (int32_t j=0; j<num_elements_to_read; j++) \
vector[j+i*elements_at_message]=chunk.data(j); \
} \
}

READ_MEMORY_BLOCK(Int32Chunk, int8_t)
READ_MEMORY_BLOCK(UInt32Chunk, uint8_t)
READ_MEMORY_BLOCK(Int32Chunk, char)
READ_MEMORY_BLOCK(Int32Chunk, int32_t)
READ_MEMORY_BLOCK(UInt32Chunk, uint32_t)
READ_MEMORY_BLOCK(Float32Chunk, float32_t)
READ_MEMORY_BLOCK(Float64Chunk, float64_t)
READ_MEMORY_BLOCK(Float64Chunk, floatmax_t)
READ_MEMORY_BLOCK(Int32Chunk, int16_t)
READ_MEMORY_BLOCK(UInt32Chunk, uint16_t)
READ_MEMORY_BLOCK(Int64Chunk, int64_t)
READ_MEMORY_BLOCK(UInt64Chunk, uint64_t)
#undef READ_MEMORY_BLOCK

#define WRITE_MEMORY_BLOCK(chunk_type, sg_type) \
void CProtobufFile::write_memory_block(const sg_type* vector, int32_t len, int32_t num_messages) \
{ \
int32_t elements_at_message=message_size/sizeof(sg_type); \
for (int32_t i=0; i<num_messages; i++) \
{ \
chunk_type chunk; \
\
int32_t num_elements_to_write=0; \
if ((len-(i+1)*elements_at_message)<=0) \
num_elements_to_write=len-i*elements_at_message; \
else \
num_elements_to_write=elements_at_message; \
\
for (int32_t j=0; j<num_elements_to_write; j++) \
chunk.add_data(vector[j+i*elements_at_message]); \
\
write_message(chunk); \
} \
}

WRITE_MEMORY_BLOCK(Int32Chunk, int8_t)
WRITE_MEMORY_BLOCK(UInt32Chunk, uint8_t)
WRITE_MEMORY_BLOCK(Int32Chunk, char)
WRITE_MEMORY_BLOCK(Int32Chunk, int32_t)
WRITE_MEMORY_BLOCK(UInt64Chunk, uint32_t)
WRITE_MEMORY_BLOCK(Int64Chunk, int64_t)
WRITE_MEMORY_BLOCK(UInt64Chunk, uint64_t)
WRITE_MEMORY_BLOCK(Float32Chunk, float32_t)
WRITE_MEMORY_BLOCK(Float64Chunk, float64_t)
WRITE_MEMORY_BLOCK(Float64Chunk, floatmax_t)
WRITE_MEMORY_BLOCK(Int32Chunk, int16_t)
WRITE_MEMORY_BLOCK(UInt32Chunk, uint16_t)
#undef WRITE_MEMORY_BLOCK


void CProtobufFile::write_big_endian_uint(uint32_t number, uint8_t* array, uint32_t size)
{
if (size<4)
SG_ERROR("CProtobufFile::write_big_endian_uint:: Array is too small to write\n");
SG_ERROR("array is too small to write\n");

array[0]=(number>>24)&0xffu;
array[1]=(number>>16)&0xffu;
Expand All @@ -336,16 +258,16 @@ void CProtobufFile::write_big_endian_uint(uint32_t number, uint8_t* array, uint3
uint32_t CProtobufFile::read_big_endian_uint(uint8_t* array, uint32_t size)
{
if (size<4)
SG_ERROR("CProtobufFile::write_big_endian_uint:: Array is too small to read\n");
SG_ERROR("array is too small to read\n");

return (array[0]<<24) | (array[1]<<16) | (array[2]<<8) | array[3];
}

int32_t CProtobufFile::compute_num_messages(int32_t len, int32_t sizeof_type) const
int32_t CProtobufFile::compute_num_messages(uint64_t len, int32_t sizeof_type) const
{
int32_t elements_at_message=message_size/sizeof_type;
int32_t num_messages=len/elements_at_message;
if (len % elements_at_message > 0)
uint32_t elements_in_message=message_size/sizeof_type;
uint32_t num_messages=len/elements_in_message;
if (len % elements_in_message > 0)
num_messages++;

return num_messages;
Expand Down Expand Up @@ -383,6 +305,14 @@ MatrixHeader CProtobufFile::read_matrix_header()
return data_header;
}

StringListHeader CProtobufFile::read_string_list_header()
{
StringListHeader data_header;
read_message(data_header);

return data_header;
}

void CProtobufFile::write_vector_header(int32_t len, int32_t num_messages)
{
VectorHeader data_header;
Expand All @@ -400,6 +330,40 @@ void CProtobufFile::write_matrix_header(int32_t num_feat, int32_t num_vec, int32
write_message(data_header);
}

#define WRITE_STRING_LIST_HEADER(sg_type) \
uint64_t CProtobufFile::write_string_list_header(const SGString<sg_type>* strings, int32_t num_str) \
{ \
uint64_t counter=0; \
int32_t max_string_len=0; \
StringListHeader data_header; \
data_header.set_num_str(num_str); \
for (int32_t i=0; i<num_str; i++) \
{ \
data_header.add_str_len(strings[i].slen); \
counter+=strings[i].slen; \
if (strings[i].slen>max_string_len) \
max_string_len=strings[i].slen; \
} \
data_header.set_max_string_len(max_string_len); \
write_message(data_header); \
\
return counter; \
}

WRITE_STRING_LIST_HEADER(int8_t)
WRITE_STRING_LIST_HEADER(uint8_t)
WRITE_STRING_LIST_HEADER(char)
WRITE_STRING_LIST_HEADER(int32_t)
WRITE_STRING_LIST_HEADER(uint32_t)
WRITE_STRING_LIST_HEADER(int64_t)
WRITE_STRING_LIST_HEADER(uint64_t)
WRITE_STRING_LIST_HEADER(float32_t)
WRITE_STRING_LIST_HEADER(float64_t)
WRITE_STRING_LIST_HEADER(floatmax_t)
WRITE_STRING_LIST_HEADER(int16_t)
WRITE_STRING_LIST_HEADER(uint16_t)
#undef WRITE_STRING_LIST_HEADER

void CProtobufFile::read_message(google::protobuf::Message& message)
{
uint32_t bytes_read=0;
Expand Down Expand Up @@ -435,4 +399,158 @@ void CProtobufFile::write_message(const google::protobuf::Message& message)
REQUIRE(bytes_write==msg_size, "IO error\n");
}

#define READ_MEMORY_BLOCK(chunk_type, sg_type) \
void CProtobufFile::read_memory_block(sg_type*& vector, uint64_t len, int32_t num_messages) \
{ \
vector=SG_MALLOC(sg_type, len); \
\
chunk_type chunk; \
int32_t elements_in_message=message_size/sizeof(sg_type); \
for (int32_t i=0; i<num_messages; i++) \
{ \
read_message(chunk); \
\
int32_t num_elements_to_read=0; \
if ((len-(i+1)*elements_in_message)<=0) \
num_elements_to_read=len-i*elements_in_message; \
else \
num_elements_to_read=elements_in_message; \
\
for (int32_t j=0; j<num_elements_to_read; j++) \
vector[j+i*elements_in_message]=chunk.data(j); \
} \
}

READ_MEMORY_BLOCK(Int32Chunk, int8_t)
READ_MEMORY_BLOCK(UInt32Chunk, uint8_t)
READ_MEMORY_BLOCK(UInt32Chunk, char)
READ_MEMORY_BLOCK(Int32Chunk, int32_t)
READ_MEMORY_BLOCK(UInt32Chunk, uint32_t)
READ_MEMORY_BLOCK(Float32Chunk, float32_t)
READ_MEMORY_BLOCK(Float64Chunk, float64_t)
READ_MEMORY_BLOCK(Float64Chunk, floatmax_t)
READ_MEMORY_BLOCK(Int32Chunk, int16_t)
READ_MEMORY_BLOCK(UInt32Chunk, uint16_t)
READ_MEMORY_BLOCK(Int64Chunk, int64_t)
READ_MEMORY_BLOCK(UInt64Chunk, uint64_t)
#undef READ_MEMORY_BLOCK

#define WRITE_MEMORY_BLOCK(chunk_type, sg_type) \
void CProtobufFile::write_memory_block(const sg_type* vector, uint64_t len, int32_t num_messages) \
{ \
chunk_type chunk; \
int32_t elements_in_message=message_size/sizeof(sg_type); \
for (int32_t i=0; i<num_messages; i++) \
{ \
\
int32_t num_elements_to_write=0; \
if ((len-(i+1)*elements_in_message)<=0) \
num_elements_to_write=len-i*elements_in_message; \
else \
num_elements_to_write=elements_in_message; \
\
for (int32_t j=0; j<num_elements_to_write; j++) \
chunk.add_data(vector[j+i*elements_in_message]); \
\
write_message(chunk); \
chunk.Clear(); \
} \
}

WRITE_MEMORY_BLOCK(Int32Chunk, int8_t)
WRITE_MEMORY_BLOCK(UInt32Chunk, uint8_t)
WRITE_MEMORY_BLOCK(UInt32Chunk, char)
WRITE_MEMORY_BLOCK(Int32Chunk, int32_t)
WRITE_MEMORY_BLOCK(UInt64Chunk, uint32_t)
WRITE_MEMORY_BLOCK(Int64Chunk, int64_t)
WRITE_MEMORY_BLOCK(UInt64Chunk, uint64_t)
WRITE_MEMORY_BLOCK(Float32Chunk, float32_t)
WRITE_MEMORY_BLOCK(Float64Chunk, float64_t)
WRITE_MEMORY_BLOCK(Float64Chunk, floatmax_t)
WRITE_MEMORY_BLOCK(Int32Chunk, int16_t)
WRITE_MEMORY_BLOCK(UInt32Chunk, uint16_t)
#undef WRITE_MEMORY_BLOCK

#define READ_STRING_LIST(chunk_type, sg_type) \
void CProtobufFile::read_string_list( \
SGString<sg_type>*& strings, const StringListHeader& data_header) \
{ \
strings=SG_MALLOC(SGString<sg_type>, data_header.num_str()); \
\
chunk_type chunk; \
read_message(chunk); \
int32_t elements_in_message=message_size/sizeof(sg_type); \
int32_t buffer_counter=0; \
for (uint32_t i=0; i<data_header.num_str(); i++) \
{ \
strings[i]=SGString<sg_type>(data_header.str_len(i)); \
for (int32_t j=0; j<strings[i].slen; j++) \
{ \
strings[i].string[j]=chunk.data(buffer_counter); \
buffer_counter++; \
\
if (buffer_counter==elements_in_message) \
{ \
read_message(chunk); \
buffer_counter=0; \
} \
} \
} \
}

READ_STRING_LIST(Int32Chunk, int8_t)
READ_STRING_LIST(UInt32Chunk, uint8_t)
READ_STRING_LIST(UInt32Chunk, char)
READ_STRING_LIST(Int32Chunk, int32_t)
READ_STRING_LIST(UInt32Chunk, uint32_t)
READ_STRING_LIST(Float32Chunk, float32_t)
READ_STRING_LIST(Float64Chunk, float64_t)
READ_STRING_LIST(Float64Chunk, floatmax_t)
READ_STRING_LIST(Int32Chunk, int16_t)
READ_STRING_LIST(UInt32Chunk, uint16_t)
READ_STRING_LIST(Int64Chunk, int64_t)
READ_STRING_LIST(UInt64Chunk, uint64_t)
#undef READ_STRING_LIST

#define WRITE_STRING_LIST(chunk_type, sg_type) \
void CProtobufFile::write_string_list( \
const SGString<sg_type>* strings, int32_t num_str, int32_t num_messages) \
{ \
chunk_type chunk; \
int32_t elements_in_message=message_size/sizeof(sg_type); \
int32_t buffer_counter=0; \
for (int32_t i=0; i<num_str; i++) \
{ \
for (int32_t j=0; j<strings[i].slen; j++) \
{ \
chunk.add_data(strings[i].string[j]); \
buffer_counter++; \
\
if (buffer_counter==elements_in_message) \
{ \
write_message(chunk); \
chunk.Clear(); \
buffer_counter=0; \
} \
} \
} \
\
if (buffer_counter!=0) \
write_message(chunk); \
}

WRITE_STRING_LIST(Int32Chunk, int8_t)
WRITE_STRING_LIST(UInt32Chunk, uint8_t)
WRITE_STRING_LIST(UInt32Chunk, char)
WRITE_STRING_LIST(Int32Chunk, int32_t)
WRITE_STRING_LIST(UInt64Chunk, uint32_t)
WRITE_STRING_LIST(Int64Chunk, int64_t)
WRITE_STRING_LIST(UInt64Chunk, uint64_t)
WRITE_STRING_LIST(Float32Chunk, float32_t)
WRITE_STRING_LIST(Float64Chunk, float64_t)
WRITE_STRING_LIST(Float64Chunk, floatmax_t)
WRITE_STRING_LIST(Int32Chunk, int16_t)
WRITE_STRING_LIST(UInt32Chunk, uint16_t)
#undef WRITE_STRING_LIST

#endif /* HAVE_PROTOBUF */

0 comments on commit 1f0ae3e

Please sign in to comment.