In [1]:
import torch

In [5]:
@torch.jit.script
def foo(a, b):
    c = 2 * b
    a += 1
    if a.max() > 4:
        r = a[0]
    else:
        r = b[0]
    return c, r

In [6]:
a = torch.zeros(2, 3)
b = torch.ones(2, 3)

In [7]:
import torch

def foo(x, y):
    return 2 * x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

@torch.jit.script
def bar(x):
    return traced_foo(x, x)

print(traced_foo.graph)

graph(%x : Float(3),
      %y : Float(3)):
  %2 : Long() = prim::Constant[value={2}]() # <ipython-input-7-3c10302655ac>:4:0
  %3 : Float(3) = aten::mul(%x, %2) # <ipython-input-7-3c10302655ac>:4:0
  %4 : int = prim::Constant[value=1]() # <ipython-input-7-3c10302655ac>:4:0
  %5 : Float(3) = aten::add(%3, %y, %4) # <ipython-input-7-3c10302655ac>:4:0
  return (%5)



In [8]:
@torch.jit.script
def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
    gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh
    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = torch.sigmoid(outgate)
    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * torch.tanh(cy)
    return hy, cy

In [None]:
traced_foo = torch.jit.trace(LSTMCellS, (torch.rand(3), torch.rand(3)))

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # torch.jit.trace produces a ScriptModule's conv1 and conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    def forward(self, input):
      input = F.relu(self.conv1(input))
      input = F.relu(self.conv2(input))
      return input

scripted_module = torch.jit.script(MyModule())

In [10]:
scripted_module.graph

graph(%self : __torch__.MyModule,
      %input.1 : Tensor):
  %12 : Function = prim::Constant[name="relu"]()
  %6 : Function = prim::Constant[name="relu"]()
  %5 : bool = prim::Constant[value=0]()
  %2 : __torch__.torch.nn.modules.module.Module = prim::GetAttr[name="conv1"](%self)
  %4 : Tensor = prim::CallMethod[name="forward"](%2, %input.1) # <ipython-input-9-0766fbb23a95>:13:21
  %input.3 : Tensor = prim::CallFunction(%6, %4, %5) # <ipython-input-9-0766fbb23a95>:13:14
  %8 : __torch__.torch.nn.modules.module.___torch_mangle_3.Module = prim::GetAttr[name="conv2"](%self)
  %10 : Tensor = prim::CallMethod[name="forward"](%8, %input.3) # <ipython-input-9-0766fbb23a95>:14:21
  %input.5 : Tensor = prim::CallFunction(%12, %10, %5) # <ipython-input-9-0766fbb23a95>:14:14
  return (%input.5)

In [12]:
# Defining the network (LeNet-5)  
class LeNet5(torch.nn.Module):
     
    def __init__(self):   
        super(LeNet5, self).__init__()
        # Convolution (In LeNet-5, 32x32 images are given as input. Hence padding of 2 is done below)
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2, bias=True)
        # Max-pooling
        self.max_pool_1 = torch.nn.MaxPool2d(kernel_size=2)
        # Convolution
        self.conv2 = torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0, bias=True)
        # Max-pooling
        self.max_pool_2 = torch.nn.MaxPool2d(kernel_size=2)
        # Fully connected layer
        self.fc1 = torch.nn.Linear(16*5*5, 120)   # convert matrix with 16*5*5 (= 400) features to a matrix of 120 features (columns)
        self.fc2 = torch.nn.Linear(120, 84)       # convert matrix with 120 features to a matrix of 84 features (columns)
        self.fc3 = torch.nn.Linear(84, 10)        # convert matrix with 84 features to a matrix of 10 features (columns)
        
    def forward(self, x):
        # convolve, then perform ReLU non-linearity
        x = torch.nn.functional.relu(self.conv1(x))  
        # max-pooling with 2x2 grid
        x = self.max_pool_1(x)
        # convolve, then perform ReLU non-linearity
        x = torch.nn.functional.relu(self.conv2(x))
        # max-pooling with 2x2 grid
        x = self.max_pool_2(x)
        # first flatten 'max_pool_2_out' to contain 16*5*5 columns
        # read through https://stackoverflow.com/a/42482819/7551231
        x = x.view(-1, 16*5*5)
        # FC-1, then perform ReLU non-linearity
        x = torch.nn.functional.relu(self.fc1(x))
        # FC-2, then perform ReLU non-linearity
        x = torch.nn.functional.relu(self.fc2(x))
        # FC-3
        x = self.fc3(x)
        
        return x
    
    
with torch.jit.optimized_execution(3):
#     module = torch.jit.trace(n.forward, example_forward_input)
    scripted_module = torch.jit.script(LeNet5())

In [17]:
scripted_module.graph

graph(%self : __torch__.LeNet5,
      %x.1 : Tensor):
  %40 : Function = prim::Constant[name="relu"]()
  %34 : Function = prim::Constant[name="relu"]()
  %23 : int = prim::Constant[value=-1]() # <ipython-input-12-90096b50665d>:30:19
  %15 : Function = prim::Constant[name="relu"]()
  %6 : Function = prim::Constant[name="relu"]()
  %5 : bool = prim::Constant[value=0]()
  %24 : int = prim::Constant[value=16]() # <ipython-input-12-90096b50665d>:30:23
  %25 : int = prim::Constant[value=5]() # <ipython-input-12-90096b50665d>:30:26
  %2 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv1"](%self)
  %4 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # <ipython-input-12-90096b50665d>:21:37
  %x.3 : Tensor = prim::CallFunction(%6, %4, %5) # <ipython-input-12-90096b50665d>:21:12
  %8 : __torch__.torch.nn.modules.pooling.MaxPool2d = prim::GetAttr[name="max_pool_1"](%self)
  %x.5 : Tensor = prim::CallMethod[name="forward"](%8, %x.3) # <ipython-input-12-90096b50665d>:23:12
  %

In [18]:
scripted_module.code

'def forward(self,\n    x: Tensor) -> Tensor:\n  _0 = __torch__.torch.nn.functional.___torch_mangle_17.relu\n  _1 = __torch__.torch.nn.functional.___torch_mangle_18.relu\n  _2 = __torch__.torch.nn.functional.___torch_mangle_19.relu\n  _3 = __torch__.torch.nn.functional.___torch_mangle_20.relu\n  x0 = _0((self.conv1).forward(x, ), False, )\n  x1 = (self.max_pool_1).forward(x0, )\n  x2 = _1((self.conv2).forward(x1, ), False, )\n  x3 = (self.max_pool_2).forward(x2, )\n  x4 = torch.view(x3, [-1, torch.mul(torch.mul(16, 5), 5)])\n  x5 = _2((self.fc1).forward(x4, ), False, )\n  x6 = _3((self.fc2).forward(x5, ), False, )\n  return (self.fc3).forward(x6, )\n'

In [26]:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
with torch.jit.optimized_execution(3):
    module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
# module = torch.jit.trace(n, example_forward_input)

In [24]:
module.

graph(%self.1 : __torch__.torch.nn.modules.module.___torch_mangle_34.Module,
      %input : Float(1, 1, 3, 3)):
  %28 : __torch__.torch.nn.modules.module.___torch_mangle_33.Module = prim::GetAttr[name="conv"](%self.1)
  %30 : Tensor = prim::CallMethod[name="forward"](%28, %input)
  return (%30)

In [27]:
module.graph

graph(%self.1 : __torch__.torch.nn.modules.module.___torch_mangle_38.Module,
      %input : Float(1, 1, 3, 3)):
  %28 : __torch__.torch.nn.modules.module.___torch_mangle_37.Module = prim::GetAttr[name="conv"](%self.1)
  %30 : Tensor = prim::CallMethod[name="forward"](%28, %input)
  return (%30)