Skip to content

Commit

Permalink
Merge pull request #12 from takuti/fix-10
Browse files Browse the repository at this point in the history
Disallow to build fature-based recommender from empty vector
  • Loading branch information
takuti committed Jan 20, 2022
2 parents a4d47b9 + 52c1880 commit 58e73cf
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 13 deletions.
12 changes: 7 additions & 5 deletions flurs/data/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class Base(object):

def __init__(self, index, feature=np.array([0.])):
def __init__(self, index, feature=np.array([])):
self.index = index
self.feature = feature

Expand Down Expand Up @@ -36,7 +36,7 @@ def index_one_hot(self, dim):
class User(Base):

def __str__(self):
if len(self.feature) == 1 and self.feature[0] == 0.:
if len(self.feature) == 0:
return 'User(index={})'.format(self.index)
else:
return 'User(index={}, feature={})'.format(self.index, self.feature)
Expand All @@ -45,15 +45,15 @@ def __str__(self):
class Item(Base):

def __str__(self):
if len(self.feature) == 1 and self.feature[0] == 0.:
if len(self.feature) == 0:
return 'Item(index={})'.format(self.index)
else:
return 'Item(index={}, feature={})'.format(self.index, self.feature)


class Event(object):

def __init__(self, user, item, value=1., context=np.array([0.])):
def __init__(self, user, item, value=1., context=np.array([])):
self.user = user
self.item = item
self.value = value
Expand All @@ -73,10 +73,12 @@ def encode(self, n_user=None, n_item=None,
feature=feature, vertical=False)
x = np.concatenate((x, iv))

assert len(x) > 0, 'feature vector has zero dimension'

return x if not vertical else np.array([x]).T

def __str__(self):
if len(self.context) == 1 and self.context[0] == 0.:
if len(self.context) == 0:
return 'Event(user={}, item={}, value={})'.format(self.user, self.item, self.value)
else:
return 'Event(user={}, item={}, value={}, context={})'.format(self.user, self.item, self.value, self.context)
12 changes: 11 additions & 1 deletion flurs/recommender/factorization_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ class FMRecommender(FactorizationMachine, FeatureRecommenderMixin):
"""

def initialize(self, static=False, use_index=False):
"""Initialize a recommender.
Parameters
----------
static : bool, default=False
Disable incremental update if True.
use_index : bool, default=False
Incorporate onehot-encoded user/item index into a feature vector.
"""
super(FMRecommender, self).initialize()
self.static = static
self.use_index = use_index
Expand Down Expand Up @@ -111,7 +121,7 @@ def update(self, e, batch_train=False):

self.update_model(x, e.value)

def score(self, user, candidates, context):
def score(self, user, candidates, context=np.array([])):
# i_mat is (n_item_context, n_item) for all possible items
# extract only target items
i_mat = self.i_mat[:, candidates]
Expand Down
2 changes: 1 addition & 1 deletion flurs/recommender/online_sketch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def update(self, e, batch_train=False):
y = e.encode(index=False, feature=True, context=True)
self.update_model(y)

def score(self, user, candidates, context):
def score(self, user, candidates, context=np.array([])):
# i_mat is (n_item_context, n_item) for all possible items
# extract only target items
i_mat = self.i_mat[:, candidates]
Expand Down
12 changes: 9 additions & 3 deletions flurs/recommender/tests/test_fm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,19 @@ def test_register_item(self):
def test_update(self):
self.recommender.register(User(0))
self.recommender.register(Item(0))
self.recommender.update(Event(User(0), Item(0), 1))
self.recommender.update(
Event(User(0), Item(0), 1, context=np.array([1, 2, 3]))
)
self.assertEqual(self.recommender.n_user, 1)
self.assertEqual(self.recommender.n_item, 1)

def test_score(self):
self.recommender.register(User(0))
self.recommender.register(Item(0))
self.recommender.update(Event(User(0), Item(0), 1))
score = self.recommender.score(User(0), np.array([0]), np.array([0]))
self.recommender.update(
Event(User(0), Item(0), 1, context=np.array([1, 2, 3]))
)
score = self.recommender.score(
User(0), candidates=np.array([0]), context=np.array([1, 2, 3])
)
self.assertTrue(score >= 0.)
12 changes: 9 additions & 3 deletions flurs/recommender/tests/test_sketch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,19 @@ def test_register_item(self):
def test_update(self):
self.recommender.register(User(0))
self.recommender.register(Item(0))
self.recommender.update(Event(User(0), Item(0), 1))
self.recommender.update(
Event(User(0), Item(0), 1, context=np.array([1, 2, 3]))
)
self.assertEqual(self.recommender.n_user, 1)
self.assertEqual(self.recommender.n_item, 1)

def test_score(self):
self.recommender.register(User(0))
self.recommender.register(Item(0))
self.recommender.update(Event(User(0), Item(0), 1))
score = self.recommender.score(User(0), np.array([0]), np.array([0]))
self.recommender.update(
Event(User(0), Item(0), 1, context=np.array([1, 2, 3]))
)
score = self.recommender.score(
User(0), candidates=np.array([0]), context=np.array([1, 2, 3])
)
self.assertTrue(score >= 0. and score <= 1.0)

0 comments on commit 58e73cf

Please sign in to comment.