Skip to content

Commit

Permalink
Merge pull request gdsfactory#1258 from tvt173/fix-flatten-refs-recur…
Browse files Browse the repository at this point in the history
…sive

Fix flatten refs recursive

Former-commit-id: 69a4136 [formerly 085241c]
Former-commit-id: fe16a3a421d04cd2d74f9f769de5c1953bd0eb45
  • Loading branch information
joamatab committed Feb 9, 2023
2 parents 4117747 + 340cc43 commit de9ae78
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 81 deletions.
148 changes: 70 additions & 78 deletions gdsfactory/component.py
Expand Up @@ -2279,78 +2279,66 @@ def recurse_structures(


def flatten_invalid_refs_recursive(
component: Component, grid_size: Optional[float] = None
) -> Component:
"""Returns new Component with flattened references.
component: Component,
grid_size: Optional[float] = None,
updated_components=None,
traversed_components=None,
):
"""Recursively flattens component references which have invalid transformations (i.e. non-90 deg rotations or sub-grid translations) and returns a copy if any subcells have been modified.
WARNING: this function will produce same-name copies of cells. It is strictly meant to be used on write of the GDS file and
should not be mixed with other cells, or you will likely experience issues with duplicate cells
Args:
component: to flatten invalid references.
grid_size: optional grid size in um.
component: the component to fix (in place).
grid_size: the GDS grid size, in um, defaults to active PDK.get_grid_size()
any translations with higher resolution than this are considered invalid.
updated_components: the running dictionary of components which have been modified by this transformation. Should always be None, except for recursive invocations.
traversed_components: the set of component names which have been traversed. Should always be None, except for recursive invocations.
"""
from gdsfactory.decorators import is_invalid_ref
from gdsfactory.functions import transformed
import networkx as nx

def _create_dag(component):
"""DAG where components point to references which then point to components again."""
nodes = {}
edges = {}

def _add_nodes_recursive(g, component):
g.add_node(component.name)
nodes[component.name] = component
for ref in component.references:
edge_name = f"{component.name}:{ref.name}"
g.add_edge(component.name, edge_name)
g.add_edge(edge_name, ref.parent.name)
edges[edge_name] = ref
_add_nodes_recursive(g, ref.parent)

g = nx.DiGraph()
_add_nodes_recursive(g, component)

return g, nodes, edges

def _find_leaves(g):
leaves = [n for n, d in g.out_degree() if d == 0]
return leaves

def _prune_leaves(g):
"""Prune components AND references pointing to them at the bottom of the DAG.
Helper function
"""
comps = _find_leaves(g)
for component in comps:
g.remove_node(component)
refs = _find_leaves(g)
for r in refs:
g.remove_node(r)
return g, comps, refs

finished_comps = {}
g, comps, refs = _create_dag(component)
while True:
g, comp_leaves, ref_leaves = _prune_leaves(g)
if not comp_leaves:
break
new_comps = {}
for ref_name in ref_leaves:
r = refs[ref_name]
comp_name, _ = ref_name.split(":")
if comp_name in finished_comps:
continue
new_comps[comp_name] = comps[comp_name] = new_comps.get(
comp_name
) or Component(name=comp_name)
if is_invalid_ref(r, grid_size):
comp = transformed(r, cache=False, decorator=None) # type: ignore
comps[comp.name] = comp
r = refs[ref_name] = ComponentReference(comp)
comps[comp_name].add(
copy_reference(refs[ref_name], parent=comps[r.parent.name])
)
finished_comps.update(new_comps)
return finished_comps[component.name]

invalid_refs = []
refs = component.references
subcell_modified = False
if updated_components is None:
updated_components = {}
if traversed_components is None:
traversed_components = set()
for ref in refs:
# mark any invalid refs for flattening
# otherwise, check if there are any modified cells beneath (we need not do this if the ref will be flattened anyways)
if is_invalid_ref(ref, grid_size):
invalid_refs.append(ref.name)
else:
# otherwise, recursively flatten refs if the subcell has not already been traversed
if ref.parent.name not in traversed_components:
flatten_invalid_refs_recursive(
ref.parent,
grid_size=grid_size,
updated_components=updated_components,
traversed_components=traversed_components,
)
# now, if the ref's cell been modified, mark it as such
if ref.parent.name in updated_components:
subcell_modified = True
if invalid_refs or subcell_modified:
new_component = component.copy()
new_component.name = component.name
# make sure all modified cells have their references updated
new_refs = new_component.references.copy()
for ref in new_refs:
if ref.name in invalid_refs:
new_component.flatten_reference(ref)
elif (
ref.parent.name in updated_components
and ref.parent is not updated_components[ref.parent.name]
):
ref.parent = updated_components[ref.parent.name]
component = new_component
updated_components[component.name] = new_component
traversed_components.add(component.name)
return component


def test_same_uid() -> None:
Expand Down Expand Up @@ -2483,33 +2471,37 @@ def test_import_gds_settings():

def test_flatten_invalid_refs_recursive():
import gdsfactory as gf
from gdsfactory.difftest import run_xor
from gdsfactory.routing.all_angle import get_bundle_all_angle

@gf.cell
def flat():
c = gf.Component()
mmi1 = (c << gf.components.mmi1x2()).move((0, 1e-4))
mmi2 = (c << gf.components.mmi1x2()).rotate(90)
mmi1 = (c << gf.components.mmi1x2()).move((0, -1.0005))
mmi2 = (c << gf.components.mmi1x2()).rotate(80)
mmi2.move((40, 20))
route = gf.routing.get_route(mmi1.ports["o2"], mmi2.ports["o1"], radius=5)
c.add(route.references)
bundle = get_bundle_all_angle([mmi1.ports["o2"]], [mmi2.ports["o1"]])
for route in bundle:
c.add(route.references)
return c

@gf.cell
def hierarchy():
c = gf.Component()
(c << flat()).rotate(33)
(c << flat()).rotate(33).move((0, 100))
(c << flat()).move((100, 0))
return c

c_orig = hierarchy()
c_new = flatten_invalid_refs_recursive(c_orig)
assert c_new is not c_orig
assert c_new != c_orig
assert c_orig.references[0].parent.name != c_new.references[0].parent.name
assert (
c_orig.references[1].parent.references[0].parent.name
!= c_new.references[1].parent.references[0].parent.name
)
invalid_refs_filename = "invalid_refs.gds"
invalid_refs_fixed_filename = "invalid_refs_fixed.gds"
# gds files should still be same to 1nm tolerance
c_orig.write_gds(invalid_refs_filename)
c_new.write_gds(invalid_refs_fixed_filename)
run_xor(invalid_refs_filename, invalid_refs_fixed_filename)


if __name__ == "__main__":
Expand Down
18 changes: 15 additions & 3 deletions gdsfactory/component_reference.py
Expand Up @@ -174,7 +174,7 @@ def __init__(
columns=columns, rows=rows, v1=v1, v2=v2
)

self.ref_cell = component
self._ref_cell = component
self._owner = None
self._name = name

Expand Down Expand Up @@ -207,9 +207,13 @@ def columns(self) -> int:
def spacing(self) -> Optional[Tuple[float, float]]:
return self._reference.repetition.spacing

@property
def ref_cell(self):
return self._ref_cell

@property
def parent(self):
return self.ref_cell
return self._ref_cell

@property
def origin(self):
Expand Down Expand Up @@ -243,9 +247,17 @@ def x_reflection(self) -> bool:
def x_reflection(self, value):
self._reference.x_reflection = value

def _set_ref_cell(self, value):
self._ref_cell = value
self._reference.cell = value._cell

@ref_cell.setter
def ref_cell(self, value):
self._set_ref_cell(value)

@parent.setter
def parent(self, value):
self.ref_cell = value
self._set_ref_cell(value)

def get_polygons(
self,
Expand Down

0 comments on commit de9ae78

Please sign in to comment.