# Graph Optimization with torch.fx
---
This notebook demonstrates how to implement custom optimization passes in PyTorch using `torch.fx`.
We will:
- Fuse `Linear → BatchNorm → ReLU` into a single module
- Remove redundant operations (e.g., cancelling consecutive transposes)
- Rewrite control flow into more graph-friendly operations
- Compare correctness and performance (execution time, memory)


In [1]:
#SETUP
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fx as fx
import time

print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("CUDA available:", torch.cuda.is_available())
print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")


# Select device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Quick PyTorch check
x_cpu = torch.randn(2000, 2000)
x_gpu = x_cpu.to(device)


# CPU matmul
t0 = time.time()
_ = x_cpu @ x_cpu
print("CPU matmul time:", time.time() - t0)


# GPU matmul
if device == "cuda":
    torch.cuda.synchronize()
    t0 = time.time()
    _ = x_gpu @ x_gpu
    torch.cuda.synchronize()
    print("GPU matmul time:", time.time() - t0)


PyTorch version: 2.6.0+cu124
CUDA version: 12.4
CUDA available: True
GPU name: NVIDIA GeForce RTX 4050 Laptop GPU
Using device: cuda
CPU matmul time: 0.03013134002685547
GPU matmul time: 0.06492304801940918


## Define a toy model with `Linear → BatchNorm → ReLU`

We begin by defining a lightweight feed-forward neural network (`ToyModel`) that captures the canonical `Linear → BatchNorm → ReLU → Linear` computation pattern.  
This serves as a controlled testbed for our custom optimization passes. The intermediate composition of linear, normalization, and activation layers is a typical target for graph-level fusion, as it introduces multiple operator boundaries that can be collapsed into a single composite operation without altering functional semantics.  

By setting the model to evaluation mode (`.eval()`), we ensure that batch normalization uses its stored statistics rather than minibatch statistics, yielding deterministic behavior suitable for graph transformations and benchmarking.


In [2]:
class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(32, 64)
        self.bn = nn.BatchNorm1d(64)
        self.relu = nn.ReLU()
        self.head = nn.Linear(64, 10)

    def forward(self, x):
        x = self.fc(x)
        x = self.bn(x)
        x = self.relu(x)
        return self.head(x)
model = ToyModel().to(device).eval()  
x = torch.randn(16, 32).to(device)
with torch.no_grad():
    out = model(x)
print("Output shape:", out.shape)

Output shape: torch.Size([16, 10])


In [3]:
def assert_bn_eval(model: nn.Module):
    """Ensure all BatchNorm layers are in eval mode before fusion."""
    for name, m in model.named_modules():
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            if m.training:
                raise RuntimeError(
                    f"BatchNorm layer '{name}' must be in eval() mode before fusion."
                )
    print("All BatchNorm layers confirmed to be in eval mode.")
assert_bn_eval(model)

All BatchNorm layers confirmed to be in eval mode.


## Helper function to fold BatchNorm into Linear

This utility function implements **batch normalization folding** into a preceding linear layer’s parameters, a common graph-level optimization for inference.  

Given a linear layer (`Linear`) and a batch normalization layer (`BatchNorm`), we analytically compute the equivalent weight (`W_fold`) and bias (`b_fold`) that integrate the normalization and affine transformation directly into the linear layer. This eliminates the need for a separate `BatchNorm` during forward execution, reducing operator overhead and memory footprint.  

Key steps include:  
1. Extracting the linear layer’s original weights and biases.  
2. Accounting for the batch normalization’s affine parameters (`gamma`, `beta`) if they exist.  
3. Computing the scaling factor based on the running variance and epsilon.  
4. Applying the folding transformation to both weights and bias to preserve functional equivalence.  

This approach preserves the model’s output while enabling subsequent graph-level fusion and simplification.


In [4]:
class FusedLinearReLU(nn.Module):
    def __init__(self, linear: nn.Linear, with_relu=True):
        super().__init__()
        self.linear = linear
        self.act = nn.ReLU() if with_relu else nn.Identity()

    def forward(self, x):
        return self.act(self.linear(x))


def fold_bn_into_linear_params(linear: nn.Linear, bn: nn.BatchNorm1d):
    """Fold BatchNorm1d parameters into a preceding Linear layer."""
    W = linear.weight
    b = linear.bias if linear.bias is not None else torch.zeros(W.size(0), device=W.device)

    gamma = bn.weight if bn.affine else torch.ones(W.size(0), device=W.device)
    beta = bn.bias if bn.affine else torch.zeros(W.size(0), device=W.device)
    mean = bn.running_mean
    var = bn.running_var
    eps = bn.eps

    scale = gamma / torch.sqrt(var + eps)
    W_fused = W * scale.unsqueeze(1)
    b_fused = beta + (b - mean) * scale

    return W_fused, b_fused

## Function for Custom FX Fusion Pass: Linear → BatchNorm → ReLU

This cell defines a transformation pass on a PyTorch fx.GraphModule that detects the pattern: Linear → BatchNorm (eval mode) → [optional ReLU] and replaces it with a single fused operator.

**STEP 1: Pattern Matching:**
Iterate through graph nodes and check for a Linear module.
Verify it is followed by a BatchNorm module in eval() mode.
Optionally detect a ReLU (either as a module or function call) immediately after BatchNorm.

**STEP 2: Parameter Folding:**
Use fold_bn_into_linear_params to mathematically fold BatchNorm parameters into the Linear layer’s weights and biases.
This eliminates the runtime BatchNorm computation.

**STEP 3: Fused Module Creation:**
Construct a new FusedLinearReLU module.
If a ReLU is present, attach it inside the fused module; otherwise, use an identity mapping.

**STEP 4: Graph Rewriting:**
Insert the fused module node into the graph.
Redirect all uses of the old BatchNorm/ReLU outputs to the fused node.
Erase the original Linear, BatchNorm, and (if present) ReLU nodes.

**STEP 5: Graph Validation and Recompile:**
Run graph.lint() to ensure graph consistency.
Call gm.recompile() so the updated module reflects the optimized graph.

This pass enables static graph optimizations that reduce operator count, improve inference latency, and simplify downstream compilation or deployment.

In [5]:
def fuse_linear_bn_relu_fx(gm: fx.GraphModule) -> fx.GraphModule:
    graph = gm.graph
    counter = 0

    for node in list(graph.nodes):
        if node.op != "call_module":
            continue

        # Check Linear
        try:
            linear_mod = gm.get_submodule(node.target)
        except Exception:
            continue
        if not isinstance(linear_mod, nn.Linear):
            continue

        linear_users = list(node.users)
        if len(linear_users) != 1:
            continue

        # Check BatchNorm
        bn_node = linear_users[0]
        if bn_node.op != "call_module":
            continue
        try:
            bn_mod = gm.get_submodule(bn_node.target)
        except Exception:
            continue
        if not isinstance(bn_mod, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            continue
        if getattr(bn_mod, "training", False):
            continue

        # Optional ReLU
        bn_users = list(bn_node.users)
        relu_node, is_relu = None, False
        if len(bn_users) == 1:
            candidate = bn_users[0]
            if candidate.op == "call_module":
                try:
                    relu_mod = gm.get_submodule(candidate.target)
                    if isinstance(relu_mod, nn.ReLU):
                        relu_node, is_relu = candidate, True
                except Exception:
                    pass
            elif candidate.op == "call_function" and candidate.target in (torch.relu, F.relu):
                relu_node, is_relu = candidate, True

        # Fold BN into Linear
        Wf, bf = fold_bn_into_linear_params(linear_mod, bn_mod)
        fused_linear = nn.Linear(linear_mod.in_features, linear_mod.out_features, bias=True)
        fused_linear.weight.data.copy_(Wf)
        fused_linear.bias.data.copy_(bf)

        fused_mod = FusedLinearReLU(fused_linear, with_relu=is_relu)
        fused_name = f"fused_{counter}_{linear_mod._get_name()}_{bn_mod._get_name()}"
        counter += 1
        gm.add_submodule(fused_name, fused_mod)

        # Insert fused node
        with graph.inserting_before(node):
            fused_node = graph.call_module(fused_name, args=(node.args[0],))

        # Replace old nodes
        last_node = relu_node if relu_node is not None else bn_node
        last_node.replace_all_uses_with(fused_node)

        # Erase old nodes
        if relu_node is not None:
            graph.erase_node(relu_node)
        graph.erase_node(bn_node)
        graph.erase_node(node)

        print(f"Fused nodes: {node.name} -> {bn_node.name}" +
              (f" -> {relu_node.name}" if is_relu else "") +
              f" as {fused_name}")

    graph.lint()
    gm.recompile()
    return gm

## Compare correctness between original and fused models
This cell performs functional correctness verification of the fused graph transformation.

compare_models evaluates two modules (a and b) on the same input tensor x and checks element-wise equivalence within specified absolute (atol) and relative (rtol) tolerances. It reports the maximum absolute difference to provide insight into numerical deviations introduced by the fusion process.

fx.symbolic_trace generates a GraphModule representation of the original ToyModel, which is then passed to fuse_linear_bn_relu_fx to produce a fused variant.

Finally, compare_models confirms that the fused model produces outputs numerically equivalent to the original, ensuring that the Linear → BatchNorm → ReLU fusion preserves the model’s functional semantics.

This step is critical in validating that graph-level optimizations maintain correctness before any performance benchmarking.

In [6]:
def compare_models(a: nn.Module, b: nn.Module, x: torch.Tensor, atol=1e-6, rtol=1e-5):
    a.eval()
    b.eval()
    with torch.no_grad():
        out_a = a(x)
        out_b = b(x)
    max_diff = (out_a - out_b).abs().max().item()
    print(f"max_abs_diff: {max_diff:.6e}")
    assert torch.allclose(out_a, out_b, atol=atol, rtol=rtol), "Mismatch detected!"
    return True

In [7]:
#Test fusion by tracing a model, applying fusion, visualizing graphs before/after, printing fused FX code, and verifying correctness.

def test_fusion(model: nn.Module, input_shape=(4, 32)):
    # Ensure BN layers are in eval
    assert_bn_eval(model)

    # FX trace
    traced = fx.symbolic_trace(model.eval())

    # Nodes before fusion
    print("\nNodes BEFORE fusion:")
    for n in traced.graph.nodes:
        print(f"{n.name:20} | {n.op:12} | {n.target}")

    # Fusion
    fused = fuse_linear_bn_relu_fx(traced)

    # Nodes after fusion
    print("\nNodes AFTER fusion:")
    for n in fused.graph.nodes:
        print(f"{n.name:20} | {n.op:12} | {n.target}")

    # Fused FX code
    print("\nFused Graph Code:\n", fused.code)

    # Test correctness
    x = torch.randn(*input_shape)
    compare_models(model, fused, x)
    print("Fusion correctness verified\n")

In [8]:
if __name__ == "__main__":
    model = ToyModel().eval()
    test_fusion(model)

All BatchNorm layers confirmed to be in eval mode.

Nodes BEFORE fusion:
x                    | placeholder  | x
fc                   | call_module  | fc
bn                   | call_module  | bn
relu                 | call_module  | relu
head                 | call_module  | head
output               | output       | output
Fused nodes: fc -> bn -> relu as fused_0_Linear_BatchNorm1d

Nodes AFTER fusion:
x                    | placeholder  | x
fused_0_linear_batch_norm1d | call_module  | fused_0_Linear_BatchNorm1d
head                 | call_module  | head
output               | output       | output

Fused Graph Code:
 


def forward(self, x):
    fused_0_linear_batch_norm1d = self.fused_0_Linear_BatchNorm1d(x);  x = None
    head = self.head(fused_0_linear_batch_norm1d);  fused_0_linear_batch_norm1d = None
    return head
    
max_abs_diff: 1.005828e-07
Fusion correctness verified



## Benchmarking Performance

This cell implements a **performance benchmarking routine** to quantify the impact of graph-level fusion on inference efficiency.  

`benchmark_model` measures both **execution latency** and **peak memory usage** for a given model and input shape:  
1. Warm-up iterations stabilize runtime performance and account for JIT or caching effects.  
2. For CUDA devices, `torch.cuda.Event` is used to record high-precision timing, and `torch.cuda.max_memory_allocated` captures peak memory consumption.  
3. For CPU execution, `time.perf_counter` provides wall-clock timing, though memory profiling is not included.  

By comparing the original and fused models, we can evaluate how operator fusion reduces runtime overhead and memory footprint, validating the practical benefits of the custom `Linear → BatchNorm → ReLU` fusion pass.


In [9]:
import time

def benchmark_model(model: nn.Module, input_shape=(1024, 32), device=None, warmup=10, iters=100):
    """
    Benchmark a model's inference performance.

    Args:
        model (nn.Module): The model to benchmark (should already be in eval mode).
        input_shape (tuple): Shape of a single input batch.
        device (torch.device or None): Target device (defaults to 'cuda' if available, else 'cpu').
        warmup (int): Number of warm-up iterations.
        iters (int): Number of timed iterations.

    Returns:
        dict with 'latency_ms' (per iteration) and 'peak_mem_mb' (if CUDA).
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)
    x = torch.randn(*input_shape, device=device)

    # Ensure determinism
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    # Warm-up phase
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(x)

    # CUDA benchmark
    if device.type == "cuda":
        torch.cuda.reset_peak_memory_stats(device)
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        torch.cuda.synchronize(device)
        start_event.record()
        with torch.no_grad():
            for _ in range(iters):
                _ = model(x)
        end_event.record()
        torch.cuda.synchronize(device)

        elapsed_ms = start_event.elapsed_time(end_event) / iters
        peak_mem_mb = torch.cuda.max_memory_allocated(device) / (1024 ** 2)

        result = {"latency_ms": elapsed_ms, "peak_mem_mb": peak_mem_mb}

    # CPU benchmark
    else:
        start = time.perf_counter()
        with torch.no_grad():
            for _ in range(iters):
                _ = model(x)
        end = time.perf_counter()
        elapsed_ms = (end - start) * 1000 / iters

        result = {"latency_ms": elapsed_ms, "peak_mem_mb": None}

    print(f"Benchmark on {device}: {result}")
    return result


In [10]:
if __name__ == "__main__":
    model = ToyModel().eval()
    fused_model = fuse_linear_bn_relu_fx(fx.symbolic_trace(model))

    # Benchmark both
    print("\n--- Benchmark Original ---")
    bench_orig = benchmark_model(model, input_shape=(1024, 32))

    print("\n--- Benchmark Fused ---")
    bench_fused = benchmark_model(fused_model, input_shape=(1024, 32))


Fused nodes: fc -> bn -> relu as fused_0_Linear_BatchNorm1d

--- Benchmark Original ---
Benchmark on cuda: {'latency_ms': 0.047531838417053225, 'peak_mem_mb': 41.5556640625}

--- Benchmark Fused ---
Benchmark on cuda: {'latency_ms': 0.0357478404045105, 'peak_mem_mb': 41.56396484375}


In [11]:
def measure_peak_memory(model: nn.Module, input_shape=(1024, 32), device=None):
    """
    Measure peak GPU memory usage (MB) for a single forward pass.
    Returns None if not running on CUDA.
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if device.type != "cuda":
        print("Peak memory measurement only supported on CUDA.")
        return None

    model = model.to(device).eval()
    x = torch.randn(*input_shape, device=device)

    torch.cuda.reset_peak_memory_stats(device)
    with torch.no_grad():
        _ = model(x)
    torch.cuda.synchronize(device)

    peak_mem_mb = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
    return peak_mem_mb


def compare_peak_memory(orig_model: nn.Module, fused_model: nn.Module, input_shape=(1024, 32)):
    """
    Compare peak GPU memory usage between original and fused models.
    On CPU, reports N/A.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    mem_orig = measure_peak_memory(orig_model, input_shape, device)
    mem_fused = measure_peak_memory(fused_model, input_shape, device)

    print("\nPeak GPU Memory Usage")
    if device.type == "cuda":
        print(f"Original: {mem_orig:.2f} MB")
        print(f"Fused   : {mem_fused:.2f} MB")
        reduction = (mem_orig - mem_fused) / mem_orig * 100 if mem_orig else 0.0
        print(f"Reduction: {reduction:.2f}%")
    else:
        print("CUDA not available — memory usage measurement skipped.")

    return {"orig_mb": mem_orig, "fused_mb": mem_fused}


In [12]:
if __name__ == "__main__":
    model = ToyModel().eval()
    traced = fx.symbolic_trace(model)
    fused_model = fuse_linear_bn_relu_fx(traced)

    compare_peak_memory(model, fused_model, input_shape=(1024, 32))


Fused nodes: fc -> bn -> relu as fused_0_Linear_BatchNorm1d

Peak GPU Memory Usage
Original: 41.58 MB
Fused   : 41.59 MB
Reduction: -0.02%


## Redundant Operation Removal (Transpose Cancellation)

This cell implements a **graph-level optimization pass to eliminate redundant transpositions**.  

The function `remove_redundant_transposes` traverses the FX `GraphModule` and identifies consecutive `transpose` operations that **cancel each other out** (i.e., transposing the same two dimensions twice). When such patterns are detected:  
1. The output of the second transpose is replaced directly with the input of the first, effectively bypassing both operations.  
2. Both transpose nodes are removed from the graph, reducing computational overhead.  

By performing this transformation, the graph becomes **more efficient and streamlined**, avoiding unnecessary memory permutations while preserving functional equivalence. The call to `graph.lint()` and `gm.recompile()` ensures the modified FX graph is consistent and executable.


In [13]:

def remove_redundant_transposes(gm: fx.GraphModule, verbose: bool = True) -> fx.GraphModule:
    """
    Graph-level optimization pass to remove redundant transpose ops.
    Looks for consecutive transpose calls with identical (dim0, dim1) that cancel out.
    """
    graph = gm.graph
    removed = 0

    for node in list(graph.nodes):
        if node.op != "call_function" or node.target != torch.transpose:
            continue

        users = list(node.users)
        if len(users) != 1:
            continue
        next_node = users[0]

        if next_node.op == "call_function" and next_node.target == torch.transpose:
            # both should have the same dims
            if (len(node.args) >= 3 and len(next_node.args) >= 3):
                dim0, dim1 = node.args[1], node.args[2]
                dim0_next, dim1_next = next_node.args[1], next_node.args[2]

                if dim0 == dim0_next and dim1 == dim1_next:
                    # cancel: replace all uses of next_node with the original input to first transpose
                    orig_input = node.args[0]
                    next_node.replace_all_uses_with(orig_input)

                    # Erase the two transposes
                    graph.erase_node(next_node)
                    graph.erase_node(node)

                    removed += 2
                    if verbose:
                        print(f"Removed redundant transpose pair on dims ({dim0}, {dim1})")

    if removed > 0:
        graph.lint()
        gm.recompile()
    else:
        if verbose:
            print("No redundant transposes found.")

    return gm



In [14]:
class TransposeToy(nn.Module):
    def forward(self, x):
        # Redundant double transpose
        return x.transpose(1, 2).transpose(1, 2)

model = TransposeToy()
traced = fx.symbolic_trace(model)

print("Before optimization:")
print(traced.graph)

optimized = remove_redundant_transposes(traced)

print("\nAfter optimization:")
print(optimized.graph)


Before optimization:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %transpose : [num_users=1] = call_method[target=transpose](args = (%x, 1, 2), kwargs = {})
    %transpose_1 : [num_users=1] = call_method[target=transpose](args = (%transpose, 1, 2), kwargs = {})
    return transpose_1
No redundant transposes found.

After optimization:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %transpose : [num_users=1] = call_method[target=transpose](args = (%x, 1, 2), kwargs = {})
    %transpose_1 : [num_users=1] = call_method[target=transpose](args = (%transpose, 1, 2), kwargs = {})
    return transpose_1


In [15]:
def test_fusion_and_optimization(model: nn.Module, input_shape=(4, 32)):
    """
    Run fusion (Linear+BN+ReLU) and redundant transpose elimination,
    then compare against the original model for correctness.
    """
    # Ensure BatchNorm layers are frozen
    assert_bn_eval(model)

    # Trace
    traced = fx.symbolic_trace(model.eval())

    # --- Step 1: Fusion ---
    fused = fuse_linear_bn_relu_fx(traced)

    # --- Step 2: Remove redundant transposes ---
    optimized = remove_redundant_transposes(fused)

    # --- Step 3: Print final graph ---
    print("\nOptimized Graph Code:\n", optimized.code)

    # --- Step 4: Correctness test ---
    x = torch.randn(*input_shape)
    compare_models(model, optimized, x)
    print("Fusion + Optimization correctness verified\n")

    return optimized


In [16]:
if __name__ == "__main__":
    model = ToyModel().eval()
    optimized_model = test_fusion_and_optimization(model, input_shape=(4, 32))


All BatchNorm layers confirmed to be in eval mode.
Fused nodes: fc -> bn -> relu as fused_0_Linear_BatchNorm1d
No redundant transposes found.

Optimized Graph Code:
 


def forward(self, x):
    fused_0_linear_batch_norm1d = self.fused_0_Linear_BatchNorm1d(x);  x = None
    head = self.head(fused_0_linear_batch_norm1d);  fused_0_linear_batch_norm1d = None
    return head
    
max_abs_diff: 7.450581e-08
Fusion + Optimization correctness verified



## Control Flow Rewriting
Masked assignments (e.g., y[mask] = value) are not compatible with torch.fx tracing because in-place item assignment on Proxy objects is unsupported. To make the model graph-friendly:\
Replace in-place masked assignments with torch.where.
Use torch.full_like to create the replacement tensor safely, preserving shape, dtype, and device, even when x is a Proxy.
Compatible with symbolic tracing (torch.fx.symbolic_trace).
Preserves functional correctness: outputs match the original in-place masked assignment.
Enables downstream graph-level optimizations, fusions, or export.
Functional equivalence is tested by comparing the output of the original module against the rewritten version.
The FX graph confirms that the in-place assignment has been replaced with a single torch.where operation.

In [17]:
class OriginalMaskedAssign(nn.Module):
    def forward(self, x, mask):
        # In-place masked assignment (not FX-traceable)
        y = x.clone()
        y[mask] = 0.0
        return y


class MaskedAssignWhere(nn.Module):
    def forward(self, x, mask):
        # Graph-friendly version: use torch.where
        replacement = torch.full_like(x, 0.0)
        return torch.where(mask, replacement, x)


def test_masked_assign_equivalence():
    x = torch.randn(4, 5)
    mask = x > 0  # boolean mask

    m_orig = OriginalMaskedAssign()
    m_rewritten = MaskedAssignWhere()

    with torch.no_grad():
        y1 = m_orig(x, mask)
        y2 = m_rewritten(x, mask)

    # Functional correctness check
    max_diff = (y1 - y2).abs().max().item()
    print(f"max_abs_diff: {max_diff:.6e}")
    assert torch.allclose(y1, y2), "Mismatch detected!"
    print("Outputs match between original and rewritten version")

    # FX trace of the rewritten version
    traced = fx.symbolic_trace(m_rewritten)
    print("\nFX Graph of rewritten module:")
    print(traced.graph)


# --- Run---
if __name__ == "__main__":
    test_masked_assign_equivalence()


max_abs_diff: 0.000000e+00
Outputs match between original and rewritten version

FX Graph of rewritten module:
graph():
    %x : [num_users=2] = placeholder[target=x]
    %mask : [num_users=1] = placeholder[target=mask]
    %full_like : [num_users=1] = call_function[target=torch.full_like](args = (%x, 0.0), kwargs = {})
    %where : [num_users=1] = call_function[target=torch.where](args = (%mask, %full_like, %x), kwargs = {})
    return where
