Skip to content

Commit

Permalink
Expose dictionary of CommWordStringKernel to the modular interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
vigsterkr committed Jun 14, 2017
1 parent 25c71a5 commit c65c2c5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 20 deletions.
18 changes: 6 additions & 12 deletions src/shogun/kernel/string/CommWordStringKernel.cpp
Expand Up @@ -46,20 +46,17 @@ CCommWordStringKernel::CCommWordStringKernel(

bool CCommWordStringKernel::init_dictionary(int32_t size)
{
dictionary_size= size;
SG_FREE(dictionary_weights);
dictionary_weights=SG_MALLOC(float64_t, size);
dictionary_weights=SGVector<float64_t>(size);
SG_DEBUG("using dictionary of %d words\n", size)
clear_normal();

return dictionary_weights!=NULL;
return dictionary_weights.vector!=NULL;
}

CCommWordStringKernel::~CCommWordStringKernel()
{
cleanup();

SG_FREE(dictionary_weights);
SG_FREE(dict_diagonal_optimization);
}

Expand Down Expand Up @@ -98,7 +95,7 @@ float64_t CCommWordStringKernel::compute_diag(int32_t idx_a)
ASSERT((1<<(sizeof(uint16_t)*8)) > alen)

int32_t num_symbols=(int32_t) l->get_num_symbols();
ASSERT(num_symbols<=dictionary_size)
ASSERT(num_symbols<=dictionary_weights.vlen)

int32_t* dic = dict_diagonal_optimization;
memset(dic, 0, num_symbols*sizeof(int32_t));
Expand Down Expand Up @@ -286,7 +283,7 @@ void CCommWordStringKernel::add_to_normal(int32_t vec_idx, float64_t weight)

void CCommWordStringKernel::clear_normal()
{
memset(dictionary_weights, 0, dictionary_size*sizeof(float64_t));
dictionary_weights.zero();
set_is_initialized(false);
}

Expand Down Expand Up @@ -600,9 +597,6 @@ char* CCommWordStringKernel::compute_consensus(

void CCommWordStringKernel::init()
{
dictionary_size=0;
dictionary_weights=NULL;

use_sign=false;
use_dict_diagonal_optimization=false;
dict_diagonal_optimization=NULL;
Expand All @@ -611,8 +605,8 @@ void CCommWordStringKernel::init()
init_dictionary(1<<(sizeof(uint16_t)*8));
set_normalizer(new CSqrtDiagKernelNormalizer(use_dict_diagonal_optimization));

m_parameters->add_vector(&dictionary_weights, &dictionary_size, "dictionary_weights",
"Dictionary for applying kernel.");
SG_ADD(&dictionary_weights, "dictionary_weights",
"Dictionary for applying kernel.", MS_NOT_AVAILABLE);
SG_ADD(&use_sign, "use_sign",
"If signum(counts) is used instead of counts.", MS_AVAILABLE);
SG_ADD(&use_dict_diagonal_optimization,
Expand Down
12 changes: 4 additions & 8 deletions src/shogun/kernel/string/CommWordStringKernel.h
Expand Up @@ -149,13 +149,11 @@ class CCommWordStringKernel : public CStringKernel<uint16_t>

/** get dictionary
*
* @param dsize dictionary size will be stored in here
* @param dweights dictionary weights will be stored in here
* @return dictionary weights
*/
void get_dictionary(int32_t& dsize, float64_t*& dweights)
SGVector<float64_t> get_dictionary() const
{
dsize=dictionary_size;
dweights = dictionary_weights;
return dictionary_weights;
}

/** compute scoring
Expand Down Expand Up @@ -240,11 +238,9 @@ class CCommWordStringKernel : public CStringKernel<uint16_t>
void init();

protected:
/** size of dictionary (number of possible strings) */
int32_t dictionary_size;
/** dictionary weights - array to hold counters for all possible
* strings */
float64_t* dictionary_weights;
SGVector<float64_t> dictionary_weights;

/** if sign shall be used */
bool use_sign;
Expand Down

0 comments on commit c65c2c5

Please sign in to comment.