Skip to content

Commit

Permalink
Merge pull request #120 from ljwolf/adbscan-fix
Browse files Browse the repository at this point in the history
Adbscan fix
  • Loading branch information
sjsrey committed Jul 5, 2020
2 parents 4e650ad + e57976d commit fffb6c1
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions esda/adbscan.py
Expand Up @@ -7,17 +7,17 @@
import warnings
import pandas
import numpy as np
from geopandas import GeoSeries
from libpysal.cg.alpha_shapes import alpha_shape_auto
from scipy.spatial import cKDTree
from collections import Counter
from sklearn.cluster import DBSCAN
from sklearn.neighbors import KNeighborsClassifier
from sklearn.base import BaseEstimator as _BaseEstimator, ClusterMixin as _ClusterMixin

__all__ = ["ADBSCAN", "remap_lbls", "ensemble", "get_cluster_boundary"]


class ADBSCAN:
class ADBSCAN(_ClusterMixin, _BaseEstimator):
"""
A-DBSCAN, as introduced in :cite:`ab_gl_vm2020joue`.
Expand Down Expand Up @@ -184,7 +184,7 @@ def fit(self, X, y=None, sample_weight=None, xy=["X", "Y"]):
columns=["rep-%s" % str(i).zfill(zfiller) for i in range(self.reps)],
)
# Multi-core implementation of parallel draws
if (self.n_jobs is -1) or (self.n_jobs > 1):
if (self.n_jobs == -1) or (self.n_jobs > 1):
pool = _setup_pool(self.n_jobs)
# Set different parallel seeds!!!
warn_msg = (
Expand Down Expand Up @@ -329,7 +329,7 @@ def remap_lbls(solus, xys, xy=["X", "Y"], n_jobs=1):
index=solus.index,
columns=solus.columns,
)
if (n_jobs is -1) or (n_jobs > 1):
if (n_jobs == -1) or (n_jobs > 1):
pool = _setup_pool(n_jobs)
s_ids = solus.drop(ref, axis=1).columns.tolist()
to_loop_over = [(solus[s], ref_centroids, ref_kdt, xys, xy) for s in s_ids]
Expand Down Expand Up @@ -420,7 +420,9 @@ def ensemble(solus_relabelled):
counts = np.array(list(map(f, solus_relabelled.values)))
winner = counts[:, 0]
votes = counts[:, 1].astype(int) / solus_relabelled.shape[1]
pred = pandas.DataFrame({"lbls": winner, "pct": votes}, index=solus_relabelled.index)
pred = pandas.DataFrame(
{"lbls": winner, "pct": votes}, index=solus_relabelled.index
)
return pred


Expand Down Expand Up @@ -500,6 +502,12 @@ def get_cluster_boundary(labels, xys, xy=["X", "Y"], n_jobs=1, crs=None, step=1)
>>> polys[0].wkt
'POLYGON ((0.7217553174317995 0.8192869956700687, 0.7605307121989587 0.9086488808086682, 0.9177741225129434 0.8568503024577332, 0.8126209616521135 0.6262871483113925, 0.6125260668293881 0.5475861559192435, 0.5425443680112613 0.7546476915298572, 0.7217553174317995 0.8192869956700687))'
"""
try:
from geopandas import GeoSeries
except ModuleNotFoundError:

def GeoSeries(data, index=None, crs=None):
return list(data)

lbl_type = type(labels.iloc[0])
noise = lbl_type(-1)
Expand Down

0 comments on commit fffb6c1

Please sign in to comment.