# 追踪

In [1]:
import numpy as np
import tvm
from tvm import te, relay
from tqdm.asyncio import tqdm
from testing.relay.utils.tag_span import _create_span, _set_span, _verify_structural_equal_with_span


def list_ops(expr):
    """list_ops"""

    class OpLister(tvm.relay.ExprVisitor):
        """OpLister inherits from ExprVisitor"""

        def visit_op(self, op):
            if op not in self.node_set:
                self.node_list.append(op)
            return super().visit_op(op)

        def list_nodes(self, expr):
            self.node_set = {}
            self.node_list = []
            self.visit(expr)
            return self.node_list

    return OpLister().list_nodes(expr)
   

def gen_ir_module(model, inputs, use_parser_friendly_name=False):
    """Helper function to generate IRModule with meaningful source information"""

    trace = torch.jit.trace(model, inputs)
    input_names = ["input{}".format(idx) for idx, _ in enumerate(inputs)]
    input_shapes = list(zip(input_names, [inp.shape for inp in inputs]))
    mod, _ = relay.frontend.from_pytorch(
        trace,
        input_shapes,
        use_parser_friendly_name=use_parser_friendly_name,
    )
    return mod

def assert_shapes_match(tru, est):
    """Verfiy whether the shapes are equal"""
    if tru.shape != est.shape:
        msg = "Output shapes {} and {} don't match"
        raise AssertionError(msg.format(tru.shape, est.shape))

In [2]:
import torch
from torch import nn

torch.set_grad_enabled(False)

input_shape = [10]

class Add1(nn.Module):
    def forward(self, *args):
        return args[0] + args[0]

class Add2(nn.Module):
    def forward(self, *args):
        return args[0] + 1

class Add3(nn.Module):
    def forward(self, *args):
        ones = torch.ones(input_shape, dtype=torch.float)
        if torch.cuda.is_available():
            ones = ones.cuda()
        return args[0] + ones

class Add4(nn.Module):
    def forward(self, *args):
        ones = torch.ones([], dtype=torch.float)
        if torch.cuda.is_available():
            ones = ones.cuda()
        return args[0] + ones
input_data = torch.rand(input_shape).float()

In [3]:
baseline_model = Add1().float().eval()
baseline_input = [input_data]
with torch.no_grad():
    baseline_outputs = baseline_model(*[input.clone() for input in baseline_input])
if isinstance(baseline_outputs, tuple):
    baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs)
else:
    baseline_outputs = (baseline_outputs.cpu().numpy(),)
trace = torch.jit.trace(baseline_model, [input.clone() for input in baseline_input])
trace = trace.float().eval()
input_names = [f"input{idx}" for idx, _ in enumerate(baseline_input)]
input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input]))

In [4]:
input_names = [f"input{idx}" for idx, _ in enumerate(baseline_input)]
input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input]))
mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map=None)
print(mod["main"])
for arg in mod["main"].params[: len(input_names)]:
    assert arg.name_hint in input_names
compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in baseline_input]))

fn (%input0: Tensor[(10), float32] /* span=aten::add_0.input0:0:0 */) {
  add(%input0, %input0) /* span=aten::add_0:0:0 */
}


In [5]:
kind = "graph"
targets = ["llvm"]
# targets = ["llvm", "cuda"]
check_correctness = True
rtol = 1e-5
atol = 1e-5
expected_ops = []
for target in targets:
    if not tvm.runtime.enabled(target):
        continue
    dev = tvm.device(target, 0)
    exe = relay.create_executor(
        kind, mod=mod, params=params, device=dev, target=target
    ).evaluate()
    result = exe(**compiled_input)
    if not isinstance(result, list):
        result = [result]

    for i, baseline_output in tqdm(enumerate(baseline_outputs)):
        output = result[i].numpy()
        assert_shapes_match(baseline_output, output)
        if check_correctness:
            np.testing.assert_allclose(baseline_output, output, rtol=rtol, atol=atol)
    def visit(op):
        if isinstance(op, tvm.ir.op.Op):
            if op.name in expected_ops:
                expected_ops.remove(op.name)

    tvm.relay.analysis.post_order_visit(mod["main"].body, visit)

    if expected_ops:
        msg = "TVM Relay do not contain expected ops {}"
        raise AssertionError(msg.format(expected_ops))

1it [00:00, 1993.49it/s]
