In [10]:
import torch
import torch.nn as nn

In [11]:
class NonSeq(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.fc1 = nn.Linear(100, 50)
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, 50)
        self.fc4 = nn.Linear(50, 30)
        self.fc5 = nn.Linear(30, 2)

    def forward(self, x):
        out1 = self.fc1(x)
        out2 = self.fc2(out1)
        out3 = self.fc3(out2)
        out4 = self.fc4(out3 + out2)
        out5 = self.fc5(out4)
        return out5

model = NonSeq()

In [12]:
dummy_input = torch.randn(1, 100)

In [13]:
trace = torch.jit.trace(model, dummy_input)

In [14]:
trace_output = trace(dummy_input)

In [15]:
print(trace.graph)

graph(%self.1 : __torch__.___torch_mangle_15.NonSeq,
      %x : Float(1, 100, strides=[100, 1], requires_grad=0, device=cpu)):
  %fc5 : __torch__.torch.nn.modules.linear.___torch_mangle_14.Linear = prim::GetAttr[name="fc5"](%self.1)
  %fc4 : __torch__.torch.nn.modules.linear.___torch_mangle_13.Linear = prim::GetAttr[name="fc4"](%self.1)
  %fc3 : __torch__.torch.nn.modules.linear.___torch_mangle_12.Linear = prim::GetAttr[name="fc3"](%self.1)
  %fc2 : __torch__.torch.nn.modules.linear.___torch_mangle_11.Linear = prim::GetAttr[name="fc2"](%self.1)
  %fc1 : __torch__.torch.nn.modules.linear.___torch_mangle_10.Linear = prim::GetAttr[name="fc1"](%self.1)
  %61 : Tensor = prim::CallMethod[name="forward"](%fc1, %x)
  %62 : Tensor = prim::CallMethod[name="forward"](%fc2, %61)
  %63 : Tensor = prim::CallMethod[name="forward"](%fc3, %62)
  %32 : int = prim::Constant[value=1]() # /tmp/ipykernel_16069/242438593.py:14:0
  %input.5 : Float(1, 50, strides=[50, 1], requires_grad=1, device=cpu) = aten::

In [16]:
print(trace.code, type(trace.code))

def forward(self,
    x: Tensor) -> Tensor:
  fc5 = self.fc5
  fc4 = self.fc4
  fc3 = self.fc3
  fc2 = self.fc2
  fc1 = self.fc1
  _0 = (fc2).forward((fc1).forward(x, ), )
  input = torch.add((fc3).forward(_0, ), _0)
  _1 = (fc5).forward((fc4).forward(input, ), )
  return _1
 <class 'str'>


In [19]:
edges = []
for node in trace.graph.nodes():
    for next_node in node.outputs():
        edges.append((node, next_node))
        print(edges[-1])

(%fc5 : __torch__.torch.nn.modules.linear.___torch_mangle_14.Linear = prim::GetAttr[name="fc5"](%self.1)
, fc5 defined in (%fc5 : __torch__.torch.nn.modules.linear.___torch_mangle_14.Linear = prim::GetAttr[name="fc5"](%self.1)
))
(%fc4 : __torch__.torch.nn.modules.linear.___torch_mangle_13.Linear = prim::GetAttr[name="fc4"](%self.1)
, fc4 defined in (%fc4 : __torch__.torch.nn.modules.linear.___torch_mangle_13.Linear = prim::GetAttr[name="fc4"](%self.1)
))
(%fc3 : __torch__.torch.nn.modules.linear.___torch_mangle_12.Linear = prim::GetAttr[name="fc3"](%self.1)
, fc3 defined in (%fc3 : __torch__.torch.nn.modules.linear.___torch_mangle_12.Linear = prim::GetAttr[name="fc3"](%self.1)
))
(%fc2 : __torch__.torch.nn.modules.linear.___torch_mangle_11.Linear = prim::GetAttr[name="fc2"](%self.1)
, fc2 defined in (%fc2 : __torch__.torch.nn.modules.linear.___torch_mangle_11.Linear = prim::GetAttr[name="fc2"](%self.1)
))
(%fc1 : __torch__.torch.nn.modules.linear.___torch_mangle_10.Linear = prim::GetA