Skip to content

Commit

Permalink
add sorted neighbourhood
Browse files Browse the repository at this point in the history
  • Loading branch information
GreatYYX committed Oct 2, 2019
1 parent f89b009 commit 0cc26cb
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 0 deletions.
7 changes: 7 additions & 0 deletions rltk/blocking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,11 @@
from rltk.blocking.hash_block_generator import HashBlockGenerator
from rltk.blocking.token_block_generator import TokenBlockGenerator
from rltk.blocking.canopy_block_generator import CanopyBlockGenerator
from rltk.blocking.sorted_neighbourhood_block_generator import SortedNeighbourhoodBlockGenerator
from rltk.blocking.blocking_helper import BlockingHelper

Blocker = BlockGenerator
HashBlocker = HashBlockGenerator
TokenBlocker = TokenBlockGenerator
CanopyBlocker = CanopyBlockGenerator
SortedNeighbourhoodBlocker = SortedNeighbourhoodBlockGenerator
94 changes: 94 additions & 0 deletions rltk/blocking/sorted_neighbourhood_block_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import Callable
from functools import cmp_to_key

from rltk.blocking.block_generator import BlockGenerator
from rltk.blocking.block import Block
from rltk.blocking.block_black_list import BlockBlackList


class SortedNeighbourhoodBlockGenerator(BlockGenerator):
"""
Sorted Neighbourhood Blocker.
Args:
window_size (int): Window size.
comparator (Callable): Define how to compare two tokens t1 and t2.
The signature is `comparator(t1: str, t2: str) -> int`.
If return is 0, t1 equals t2; if return is -1, t1 is less than t2;
if return is 1, t1 is greater than t2.
block_id_prefix (str): The block id prefix of each block.
"""
def __init__(self, window_size: int = 3, comparator: Callable = None, block_id_prefix='sorted_neighbourhood_'):
if comparator is None:
comparator = self._default_comparator
self.window_size = window_size
self.comparator = comparator
self.block_id_prefix = block_id_prefix

def block(self, dataset, function_: Callable = None, property_: str = None,
block: Block = None, block_black_list: BlockBlackList = None, base_on: Block = None):
"""
The return of `property_` or `function_` should be a vector (list).
"""
block = super()._block_args_check(function_, property_, block)

if base_on:
for block_id, dataset_id, record_id in base_on:
if dataset.id == dataset_id:
r = dataset.get_record(record_id)
value = function_(r) if function_ else getattr(r, property_)
if not isinstance(value, (list, set)):
value = value(set)
for v in value:
if not isinstance(v, str):
raise ValueError('Elements in return list should be string')
if block_black_list and block_black_list.has(v):
continue
v = block_id + '-' + v
block.add(v, dataset.id, r.id)
if block_black_list:
block_black_list.add(v, block)

else:
for r in dataset:
value = function_(r) if function_ else getattr(r, property_)
if not isinstance(value, (list, set)):
value = set(value)
for v in value:
if not isinstance(v, str):
raise ValueError('Elements in return list should be string')
if block_black_list and block_black_list.has(v):
continue
block.add(v, dataset.id, r.id)
if block_black_list:
block_black_list.add(v, block)

return block

def generate(self, block1: Block, block2: Block, output_block: Block = None):
output_block = BlockGenerator._generate_args_check(output_block)

# TODO: in-memory operations here, need to update
# concatenation
all_records = []
for block_id, ds_id, record_id in block1:
all_records.append((block_id, ds_id, record_id))
for block_id, ds_id, record_id in block2:
all_records.append((block_id, ds_id, record_id))
sorted_all_records = sorted(all_records, key=cmp_to_key(self._comparator_wrapper))

# apply slide window
for i in range(len(sorted_all_records) - self.window_size + 1):
block_id = self.block_id_prefix + str(i)
for j in range(self.window_size):
record = sorted_all_records[i + j]
output_block.add(block_id, record[1], record[2])

return output_block

def _comparator_wrapper(self, t1, t2):
return self.comparator(t1[0], t2[0])

@staticmethod
def _default_comparator(t1, t2):
return 0 if t1 == t2 else (1 if t1 > t2 else -1)
49 changes: 49 additions & 0 deletions rltk/tests/test_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from rltk.blocking.hash_block_generator import HashBlockGenerator
from rltk.blocking.token_block_generator import TokenBlockGenerator
from rltk.blocking.canopy_block_generator import CanopyBlockGenerator
from rltk.blocking.sorted_neighbourhood_block_generator import SortedNeighbourhoodBlockGenerator


class ConcreteRecord(Record):
Expand Down Expand Up @@ -81,3 +82,51 @@ def test_canopy_block_generator():
output_block = bg.generate(block, block)
for k, _ in output_block.key_set_adapter:
assert k in ('[1]', '[2]', '[0]', '[15]')


def test_sorted_neighbourhood_block_generator():
class SNConcreteRecord1(Record):
@property
def id(self):
return self.raw_object['id']

@property
def char(self):
return self.raw_object['char']

class SNConcreteRecord2(SNConcreteRecord1):
pass

sn_raw_data_1 = [
{'id': '11', 'char': 'a'},
{'id': '12', 'char': 'd'},
{'id': '13', 'char': 'c'},
{'id': '14', 'char': 'e'},
]

sn_raw_data_2 = [
{'id': '21', 'char': 'b'},
{'id': '22', 'char': 'a'},
{'id': '23', 'char': 'e'},
{'id': '24', 'char': 'f'},
]

ds1 = Dataset(reader=ArrayReader(sn_raw_data_1), record_class=SNConcreteRecord1)
ds2 = Dataset(reader=ArrayReader(sn_raw_data_2), record_class=SNConcreteRecord2)

bg = SortedNeighbourhoodBlockGenerator(window_size=3)
block = bg.generate(
bg.block(ds1, property_='char'),
bg.block(ds2, property_='char')
)

for block_id, set_ in block.key_set_adapter:
block_data = []
for did, rid in set_:
if did == ds1.id:
block_data.append(ds1.get_record(rid).char)
else:
block_data.append(ds2.get_record(rid).char)
block_data.sort()
for i in range(len(block_data) - 1):
assert block_data[i] <= block_data[i+1] # should be less than or equal to previous char

0 comments on commit 0cc26cb

Please sign in to comment.