In [1]:
import torch

In [2]:
# Small JIT Example

In [3]:
@torch.jit.script
def foo(a, b):
    c = 2 * b
    a += 1
    if a.max() > 4:
        r = a[0]
    else:
        r = b[0]
    return c, r

In [4]:
a = torch.zeros(2, 3)
b = torch.ones(2, 3)

In [5]:
import torch

def foo(x, y):
    return 2 * x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

@torch.jit.script
def bar(x):
    return traced_foo(x, x)

print(traced_foo.graph)

graph(%x : Float(3),
      %y : Float(3)):
  %2 : Long() = prim::Constant[value={2}]() # <ipython-input-5-3c10302655ac>:4:0
  %3 : Float(3) = aten::mul(%x, %2) # <ipython-input-5-3c10302655ac>:4:0
  %4 : int = prim::Constant[value=1]() # <ipython-input-5-3c10302655ac>:4:0
  %5 : Float(3) = aten::add(%3, %y, %4) # <ipython-input-5-3c10302655ac>:4:0
  return (%5)



# LeNet Model

In [6]:
# Defining the network (LeNet-5)  
class LeNet5(torch.nn.Module):
     
    def __init__(self):   
        super(LeNet5, self).__init__()
        # Convolution (In LeNet-5, 32x32 images are given as input. Hence padding of 2 is done below)
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2, bias=True)
        # Max-pooling
        self.max_pool_1 = torch.nn.MaxPool2d(kernel_size=2)
        # Convolution
        self.conv2 = torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0, bias=True)
        # Max-pooling
        self.max_pool_2 = torch.nn.MaxPool2d(kernel_size=2)
        # Fully connected layer
        self.fc1 = torch.nn.Linear(16*5*5, 120)   # convert matrix with 16*5*5 (= 400) features to a matrix of 120 features (columns)
        self.fc2 = torch.nn.Linear(120, 84)       # convert matrix with 120 features to a matrix of 84 features (columns)
        self.fc3 = torch.nn.Linear(84, 10)        # convert matrix with 84 features to a matrix of 10 features (columns)
        
    def forward(self, x):
        # convolve, then perform ReLU non-linearity
        x = torch.nn.functional.relu(self.conv1(x))  
        # max-pooling with 2x2 grid
        x = self.max_pool_1(x)
        # convolve, then perform ReLU non-linearity
        x = torch.nn.functional.relu(self.conv2(x))
        # max-pooling with 2x2 grid
        x = self.max_pool_2(x)
        # first flatten 'max_pool_2_out' to contain 16*5*5 columns
        # read through https://stackoverflow.com/a/42482819/7551231
        x = x.view(-1, 16*5*5)
        # FC-1, then perform ReLU non-linearity
        x = torch.nn.functional.relu(self.fc1(x))
        # FC-2, then perform ReLU non-linearity
        x = torch.nn.functional.relu(self.fc2(x))
        # FC-3
        x = self.fc3(x)
        
        return x
    
    


# Unoptimized JIT TRACE of LeNet

In [7]:
scripted_module = torch.jit.trace(LeNet5(),torch.rand(1,1,28,28))
scripted_module.graph # unpoptimized

graph(%self.1 : __torch__.torch.nn.modules.module.___torch_mangle_8.Module,
      %input.1 : Float(1, 1, 28, 28)):
  %134 : __torch__.torch.nn.modules.module.___torch_mangle_7.Module = prim::GetAttr[name="fc3"](%self.1)
  %131 : __torch__.torch.nn.modules.module.___torch_mangle_6.Module = prim::GetAttr[name="fc2"](%self.1)
  %128 : __torch__.torch.nn.modules.module.___torch_mangle_5.Module = prim::GetAttr[name="fc1"](%self.1)
  %125 : __torch__.torch.nn.modules.module.___torch_mangle_4.Module = prim::GetAttr[name="max_pool_2"](%self.1)
  %124 : __torch__.torch.nn.modules.module.___torch_mangle_3.Module = prim::GetAttr[name="conv2"](%self.1)
  %121 : __torch__.torch.nn.modules.module.___torch_mangle_2.Module = prim::GetAttr[name="max_pool_1"](%self.1)
  %120 : __torch__.torch.nn.modules.module.Module = prim::GetAttr[name="conv1"](%self.1)
  %142 : Tensor = prim::CallMethod[name="forward"](%120, %input.1)
  %input.3 : Float(1, 6, 28, 28) = aten::relu(%142) # /home/ritesh/miniconda3/envs/

# Optimized JIT TRACE of LeNet

In [8]:
with torch.jit.optimized_execution(3):
    scripted_module = torch.jit.trace(LeNet5(),torch.rand(1,1,28,28))
scripted_module.graph  # for opt 3

graph(%self.1 : __torch__.torch.nn.modules.module.___torch_mangle_24.Module,
      %input.1 : Float(1, 1, 28, 28)):
  %134 : __torch__.torch.nn.modules.module.___torch_mangle_23.Module = prim::GetAttr[name="fc3"](%self.1)
  %131 : __torch__.torch.nn.modules.module.___torch_mangle_22.Module = prim::GetAttr[name="fc2"](%self.1)
  %128 : __torch__.torch.nn.modules.module.___torch_mangle_21.Module = prim::GetAttr[name="fc1"](%self.1)
  %125 : __torch__.torch.nn.modules.module.___torch_mangle_20.Module = prim::GetAttr[name="max_pool_2"](%self.1)
  %124 : __torch__.torch.nn.modules.module.___torch_mangle_19.Module = prim::GetAttr[name="conv2"](%self.1)
  %121 : __torch__.torch.nn.modules.module.___torch_mangle_18.Module = prim::GetAttr[name="max_pool_1"](%self.1)
  %120 : __torch__.torch.nn.modules.module.___torch_mangle_17.Module = prim::GetAttr[name="conv1"](%self.1)
  %142 : Tensor = prim::CallMethod[name="forward"](%120, %input.1)
  %input.3 : Float(1, 6, 28, 28) = aten::relu(%142) # /ho

# Unpotimized JIT Script of LeNet

In [9]:
scripted_module = torch.jit.script(LeNet5(),torch.rand(1,1,28,28))
scripted_module.graph # unpoptimized



graph(%self : __torch__.LeNet5,
      %x.1 : Tensor):
  %40 : Function = prim::Constant[name="relu"]()
  %34 : Function = prim::Constant[name="relu"]()
  %23 : int = prim::Constant[value=-1]() # <ipython-input-6-5b68f572a232>:30:19
  %15 : Function = prim::Constant[name="relu"]()
  %6 : Function = prim::Constant[name="relu"]()
  %5 : bool = prim::Constant[value=0]()
  %24 : int = prim::Constant[value=16]() # <ipython-input-6-5b68f572a232>:30:23
  %25 : int = prim::Constant[value=5]() # <ipython-input-6-5b68f572a232>:30:26
  %2 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv1"](%self)
  %4 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # <ipython-input-6-5b68f572a232>:21:37
  %x.3 : Tensor = prim::CallFunction(%6, %4, %5) # <ipython-input-6-5b68f572a232>:21:12
  %8 : __torch__.torch.nn.modules.pooling.MaxPool2d = prim::GetAttr[name="max_pool_1"](%self)
  %x.5 : Tensor = prim::CallMethod[name="forward"](%8, %x.3) # <ipython-input-6-5b68f572a232>:23:12
  %11 : _

# Optimized JIT Script of LeNet

In [10]:
with torch.jit.optimized_execution(3):
    scripted_module = torch.jit.script(LeNet5(),torch.rand(1,1,28,28))
scripted_module.graph  # fo

graph(%self : __torch__.LeNet5,
      %x.1 : Tensor):
  %40 : Function = prim::Constant[name="relu"]()
  %34 : Function = prim::Constant[name="relu"]()
  %23 : int = prim::Constant[value=-1]() # <ipython-input-6-5b68f572a232>:30:19
  %15 : Function = prim::Constant[name="relu"]()
  %6 : Function = prim::Constant[name="relu"]()
  %5 : bool = prim::Constant[value=0]()
  %24 : int = prim::Constant[value=16]() # <ipython-input-6-5b68f572a232>:30:23
  %25 : int = prim::Constant[value=5]() # <ipython-input-6-5b68f572a232>:30:26
  %2 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv1"](%self)
  %4 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # <ipython-input-6-5b68f572a232>:21:37
  %x.3 : Tensor = prim::CallFunction(%6, %4, %5) # <ipython-input-6-5b68f572a232>:21:12
  %8 : __torch__.torch.nn.modules.pooling.MaxPool2d = prim::GetAttr[name="max_pool_1"](%self)
  %x.5 : Tensor = prim::CallMethod[name="forward"](%8, %x.3) # <ipython-input-6-5b68f572a232>:23:12
  %11 : _

In [11]:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
with torch.jit.optimized_execution(3):
    module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
# module = torch.jit.trace(n, example_forward_input)

In [12]:
module.graph

graph(%self.1 : __torch__.torch.nn.modules.module.___torch_mangle_48.Module,
      %input : Float(1, 1, 3, 3)):
  %28 : __torch__.torch.nn.modules.module.___torch_mangle_47.Module = prim::GetAttr[name="conv"](%self.1)
  %30 : Tensor = prim::CallMethod[name="forward"](%28, %input)
  return (%30)