In [6]:
import torch
from torch import nn
import time
import numpy as np

# Sequential的混合式编程

In [7]:
def get_net():
  net = nn.Sequential(nn.Linear(512,256),
                      nn.ReLU(),
                      nn.Linear(256,128),
                      nn.ReLU(),
                      nn.Linear(128,2))
  return net

x = torch.randn(size=(1,512))
net = get_net()
net(x)

tensor([[-0.0099, -0.0587]], grad_fn=<AddmmBackward0>)

使用torch.jit.script来转换模型，将Sequential转换为HybirdSequential

In [8]:
net = torch.jit.script(net)
net(x)

tensor([[-0.0099, -0.0587]], grad_fn=<DifferentiableGraphBackward>)

In [9]:
class Timer:  #@save
    """记录多次运行时间"""
    def __init__(self):
        self.times = []
        self.start()

    def start(self):
        """启动计时器"""
        self.tik = time.time()

    def stop(self):
        """停止计时器并将时间记录在列表中"""
        self.times.append(time.time() - self.tik)
        return self.times[-1]

    def avg(self):
        """返回平均时间"""
        return sum(self.times) / len(self.times)

    def sum(self):
        """返回时间总和"""
        return sum(self.times)

    def cumsum(self):
        """返回累计时间"""
        return np.array(self.times).cumsum().tolist()

In [10]:
class Benchmark:
  """用于测量运行的时间"""
  def __init__(self, description='Done'):
    self.description = description
  
  def __enter__(self):
    self.timer = Timer()
    return self
  
  def __exit__(self, *args):
    print(f'{self.description}: {self.timer.stop():.4f} sec')

In [11]:
net = get_net()
with Benchmark('无torchscript'):
  for i in range(1000): net(x)

net = torch.jit.script(net)
with Benchmark('有torchscript'):
  for i in range(1000): net(x)

无torchscript: 32.3858 sec
有torchscript: 27.7141 sec


In [13]:
net.save('my_mlp1')