Skip to content

Commit

Permalink
Get the number of steps more reliably.
Browse files Browse the repository at this point in the history
* it can now be retrieved from deeply nested nodes.
* fix error when it's a string
This only matters for prompt editing.
  • Loading branch information
shiimizu committed Sep 22, 2023
1 parent 518b040 commit f2d27af
Showing 1 changed file with 22 additions and 25 deletions.
47 changes: 22 additions & 25 deletions smZNodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,38 +732,34 @@ def prompt_handler(json_data):
data=json_data['prompt']
def tmp():
nonlocal data
current_clip_id = None
def find_nearest_ksampler(clip_id):
"""Find the nearest KSampler node that references the given CLIPTextEncode id."""
for ksampler_id, node in data.items():
if "Sampler" in node["class_type"] or "sampler" in node["class_type"]:
# Check if this KSampler node directly or indirectly references the given CLIPTextEncode node
if check_link_to_clip(ksampler_id, clip_id):
ksampler_steps_value = node["inputs"].get("steps", 1)
# If the steps value is a list, get the referenced node's steps
if isinstance(ksampler_steps_value, list):
referenced_node = data[ksampler_steps_value[0]]
# print('referenced_node',referenced_node['inputs'])
vs=ksampler_steps_value=list(referenced_node["inputs"].values())
if len(vs) == 1 and isinstance(vs[0], int):
return vs[0]
else:
return ksampler_steps_value
return get_steps(data, ksampler_id)
return None
def fd(node_id, steps_id, visited):
visited = set()

node = data[node_id]
def get_steps(graph, node_id):
node = graph.get(str(node_id), {})
steps_input_value = node.get("inputs", {}).get("steps", None)

if node_id in visited:
return False
visited.add(node_id)
for value in node["inputs"].values():
if isinstance(value, list):
return fd(node_id, visited)
elif isinstance(value, int):
return value
while(True):
# Base case: it's a direct value
if isinstance(steps_input_value, (int, float, str)):
return min(max(1, int(steps_input_value)), 10000)

# Loop case: it's a reference to another node
elif isinstance(steps_input_value, list):
ref_node_id, ref_input_index = steps_input_value
ref_node = graph.get(str(ref_node_id), {})
keys = list(ref_node.get("inputs", {}).keys())
ref_input_key = keys[ref_input_index % len(keys)]
steps_input_value = ref_node.get("inputs", {}).get(ref_input_key)
else:
raise NotImplementedError()
return None

def check_link_to_clip(node_id, clip_id, visited=None):
"""Check if a given node links directly or indirectly to a CLIPTextEncode node."""
Expand All @@ -788,11 +784,12 @@ def check_link_to_clip(node_id, clip_id, visited=None):
# Update each CLIPTextEncode node's steps with the steps from its nearest referencing KSampler node
for clip_id, node in data.items():
if node["class_type"] == "smZ CLIPTextEncode":
current_clip_id = clip_id
steps = find_nearest_ksampler(clip_id)
if opts.debug:
print(f'[smZNodes] id: {clip_id} | find_nearest_ksampler {steps}')
if steps is not None:
node["inputs"]["steps"] = steps
if opts.debug:
print(f'[smZNodes] id: {current_clip_id} | steps: {steps}')
tmp()
return json_data
PromptServer.instance.add_on_prompt_handler(prompt_handler)
Expand Down Expand Up @@ -850,7 +847,7 @@ def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, m
model_options['transformer_options'] = {}
model_options['transformer_options']['from_smZ'] = True

if not opts.use_CFGDenoiser:
if not opts.use_CFGDenoiser or not model_options['transformer_options'].get('from_smZ', False):
out = self.orig.apply_model(x, timestep, cc, uu, cond_scale, cond_concat, model_options, seed)
else:
# Only supports one cond
Expand Down

0 comments on commit f2d27af

Please sign in to comment.