Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify and speed up duplicate rate checking #692

Merged
merged 2 commits into from
Nov 10, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 11 additions & 32 deletions pynucastro/rates/known_duplicates.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,23 @@
import collections

# there are some exceptions to the no-duplicate rates restriction. We
# list them here by class name and then fname
ALLOWED_DUPLICATES = [
(("ReacLibRate: p_p__d__weak__bet_pos_"),
("ReacLibRate: p_p__d__weak__electron_capture"))
{"ReacLibRate: p_p__d__weak__bet_pos_",
"ReacLibRate: p_p__d__weak__electron_capture"}
]


def find_duplicate_rates(rate_list):
"""given a list of rates, return a list of groups of duplicate
rates"""

duplicates = []
lookup = collections.defaultdict(list)
zingale marked this conversation as resolved.
Show resolved Hide resolved
for rate in rate_list:
same_links = [q for q in rate_list
if q != rate and
sorted(q.reactants) == sorted(rate.reactants) and
sorted(q.products) == sorted(rate.products)]

if same_links:
new_entry = [rate] + same_links
already_found = False
# we may have already found this pair
for dupe in duplicates:
if new_entry[0] in dupe:
already_found = True
break
if not already_found:
duplicates.append(new_entry)
lookup[tuple(sorted(rate.reactants)),
tuple(sorted(rate.products))].append(rate)

duplicates = [entry for entry in lookup.values() if len(entry) > 1]

return duplicates

Expand All @@ -35,17 +26,5 @@ def is_allowed_dupe(rate_list):
"""rate_list is a list of rates that provide the same connection
in a network. Return True if this is an allowed duplicate"""

for allowed_dupe in ALLOWED_DUPLICATES:
found = 0
if len(rate_list) == len(allowed_dupe):
found = 1
for r in rate_list:
rate_key = f"{r.__class__.__name__}: {r.fname}"
if rate_key in allowed_dupe:
found *= 1
else:
found *= 0
if found:
return True

return False
key_set = {f"{r.__class__.__name__}: {r.fname}" for r in rate_list}
return key_set in ALLOWED_DUPLICATES