Skip to content

Commit

Permalink
Update predict_proba of classifier by diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
tbonald committed May 22, 2023
1 parent d64011d commit 31dd1a5
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 71 deletions.
14 changes: 7 additions & 7 deletions docs/reference/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ Linear algebra

Tools of linear algebra.

Normalization
-------------

.. autofunction:: sknetwork.linalg.normalize

.. autofunction:: sknetwork.linalg.diagonal_pseudo_inverse

Sparse + Low Rank
-----------------

Expand All @@ -17,10 +24,3 @@ Solvers

.. _lanczossvd:
.. autoclass:: sknetwork.linalg.LanczosSVD

Normalization
-------------

.. autofunction:: sknetwork.linalg.normalize

.. autofunction:: sknetwork.linalg.diagonal_pseudo_inverse
43 changes: 0 additions & 43 deletions docs/tutorials/classification/diffusion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -209,49 +209,6 @@
"SVG(image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# threshold\n",
"diffusion = DiffusionClassifier(threshold=0.1)\n",
"labels_pred = diffusion.fit_predict(adjacency, labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# unclassified node\n",
"np.flatnonzero(labels_pred == -1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"image = svg_graph(adjacency, position, labels=labels_pred, seeds=labels)\n",
"SVG(image)"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down
26 changes: 9 additions & 17 deletions sknetwork/classification/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ class DiffusionClassifier(BaseClassifier):
Number of iterations of the diffusion (discrete time).
centering : bool
If ``True``, center the temperature of each label to its mean before classification (default).
threshold : float
Minimum difference of temperatures between the 2 top labels to classify a node (default = 0).
If the difference of temperatures does not exceed this threshold, return -1 for this node (no label).
scale : float
Multiplicative factor applied to tempreatures before softmax (default = 5).
Used only when centering is ``True``.
Attributes
----------
Expand Down Expand Up @@ -60,15 +60,15 @@ class DiffusionClassifier(BaseClassifier):
Zhu, X., Lafferty, J., & Rosenfeld, R. (2005). `Semi-supervised learning with graphs`
(Doctoral dissertation, Carnegie Mellon University, language technologies institute, school of computer science).
"""
def __init__(self, n_iter: int = 10, centering: bool = True, threshold: float = 0):
def __init__(self, n_iter: int = 10, centering: bool = True, scale: float = 5):
super(DiffusionClassifier, self).__init__()

if n_iter <= 0:
raise ValueError('The number of iterations must be positive.')
else:
self.n_iter = n_iter
self.centering = centering
self.threshold = threshold
self.scale = scale

def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray],
labels: Optional[Union[dict, np.ndarray]] = None, labels_row: Optional[Union[dict, np.ndarray]] = None,
Expand Down Expand Up @@ -107,25 +107,17 @@ def fit(self, input_matrix: Union[sparse.csr_matrix, np.ndarray],
temperatures[labels >= 0] = temperatures_seeds
if self.centering:
temperatures -= temperatures.mean(axis=0)
temperatures = np.maximum(temperatures, 0)

labels_ = temperatures.argmax(axis=1)

# softmax
if self.centering:
temperatures = np.exp(self.scale * temperatures)

# set label -1 to nodes not reached by diffusion
distances = get_distances(adjacency, source=np.flatnonzero(labels >= 0))
labels_[distances < 0] = -1
temperatures[distances < 0] = 0

# set label -1 to nodes with low confidence
if self.threshold >= 0:
n_labels = temperatures.shape[1]
if n_labels > 2:
top_temperatures = np.partition(-temperatures, 2, axis=1)[:, :2]
else:
top_temperatures = temperatures
gap = np.abs(top_temperatures[:, 0] - top_temperatures[:, 1])
labels_[gap <= self.threshold] = -1

self.labels_ = labels_
self.probs_ = sparse.csr_matrix(normalize(temperatures))
self._split_vars(input_matrix.shape)
Expand Down
6 changes: 3 additions & 3 deletions sknetwork/classification/tests/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def test_graph(self):
self.assertTrue(len(algo.labels_) == n_nodes)
with self.assertRaises(ValueError):
DiffusionClassifier(n_iter=0)
algo = DiffusionClassifier(centering=False, threshold=1)
algo.fit(adjacency, labels=labels)
self.assertTrue(max(algo.labels_) == -1)
algo = DiffusionClassifier(centering=True, scale=10)
probs = algo.fit_predict_proba(adjacency, labels=labels)[:, 1]
self.assertTrue(max(probs) > 0.99)

def test_bipartite(self):
biadjacency = test_bigraph()
Expand Down
2 changes: 1 addition & 1 deletion sknetwork/linalg/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def normalize(matrix: Union[sparse.csr_matrix, np.ndarray, LinearOperator], p=1)
Returns
-------
normalized matrix :
Normalized matrix.
Normalized matrix (same format as input matrix).
"""
norms = get_norms(matrix, p)
diag = diagonal_pseudo_inverse(norms)
Expand Down

0 comments on commit 31dd1a5

Please sign in to comment.