Skip to content

Commit

Permalink
Update graphx.py
Browse files Browse the repository at this point in the history
Added px2ang as an argument for the graph class. Right now, self.coordinates_ang is used for finding the neighbors but uses self. coordinates to form the nx_graph and to plot so that the plots are made in pixel units.
  • Loading branch information
saimani5 committed Oct 31, 2022
1 parent 922f2a1 commit bc503f3
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions atomai/utils/graphx.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ class Graph:
"""

def __init__(self, coordinates: np.ndarray,
map_dict: Dict) -> None:
map_dict: Dict,
px2ang: float = 1) -> None:
"""
Initializes a graph object
"""
Expand All @@ -76,6 +77,8 @@ def __init__(self, coordinates: np.ndarray,
v = Node(i, coords[:-1].tolist(), map_dict[coords[-1]])
self.vertices.append(v)
self.coordinates = coordinates
self.coordinates_ang = deepcopy(coordinates)
self.coordinates_ang[:, :-1] = self.coordinates[:, :-1] * px2ang
self.map_dict = map_dict
self.size = len(coordinates)
self.rings = []
Expand All @@ -102,14 +105,14 @@ def find_neighbors(self, **kwargs: float):
Rij = get_interatomic_r
e = kwargs.get("expand", 1.2)
max_neighbors = kwargs.get("max_neighbors", -1)
tree = spatial.cKDTree(self.coordinates[:, :3])
uval = np.unique(self.coordinates[:, -1])
tree = spatial.cKDTree(self.coordinates_ang[:, :3])
uval = np.unique(self.coordinates_ang[:, -1])
if len(uval) == 1:
rmax = Rij([self.map_dict[uval[0]], self.map_dict[uval[0]]], e)
if max_neighbors == -1:
neighbors = tree.query_ball_point(self.coordinates[:, :3], r=rmax)
neighbors = tree.query_ball_point(self.coordinates_ang[:, :3], r=rmax)
else:
_, neighbors = tree.query(self.coordinates[:, :3], k=max_neighbors+1, distance_upper_bound = rmax)
_, neighbors = tree.query(self.coordinates_ang[:, :3], k=max_neighbors+1, distance_upper_bound = rmax)
for v, nn in zip(self.vertices, neighbors):
for n in nn:
if not n >= len(self.vertices):
Expand All @@ -123,7 +126,7 @@ def find_neighbors(self, **kwargs: float):
rij = [Rij([a[0], a[1]], e) for a in apairs]
rmax = np.max(rij)
rij = dict(zip(apairs, rij))
for v, coords in zip(self.vertices, self.coordinates):
for v, coords in zip(self.vertices, self.coordinates_ang):
atom1 = self.map_dict[coords[-1]]
if max_neighbors == -1:
nn = tree.query_ball_point(coords[:3], r=rmax)
Expand All @@ -132,7 +135,7 @@ def find_neighbors(self, **kwargs: float):

for n in nn:
if not n >= len(self.vertices):
coords2 = self.coordinates[n]
coords2 = self.coordinates_ang[n]
if self.vertices[n] != v:
atom2 = self.map_dict[coords2[-1]]
eucldist = np.linalg.norm(
Expand Down

0 comments on commit bc503f3

Please sign in to comment.