In [1]:
from torch import tensor, nn
import numpy as np
import torch

In [2]:
a = tensor(range(10), dtype=torch.float32).reshape(2, 5)
b = tensor(range(10, 20), dtype=torch.float32).reshape(2, 5)
c = tensor(range(20, 30), dtype=torch.float32).reshape(2, 5)

In [3]:
stack_ = torch.stack([a, b, c], 0)

In [4]:
stack_

tensor([[[ 0.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  9.]],

        [[10., 11., 12., 13., 14.],
         [15., 16., 17., 18., 19.]],

        [[20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29.]]])

In [5]:
stack_.split([2, 3], -1)

(tensor([[[ 0.,  1.],
          [ 5.,  6.]],
 
         [[10., 11.],
          [15., 16.]],
 
         [[20., 21.],
          [25., 26.]]]),
 tensor([[[ 2.,  3.,  4.],
          [ 7.,  8.,  9.]],
 
         [[12., 13., 14.],
          [17., 18., 19.]],
 
         [[22., 23., 24.],
          [27., 28., 29.]]]))

In [6]:
stack_.chunk(2, 1)

(tensor([[[ 0.,  1.,  2.,  3.,  4.]],
 
         [[10., 11., 12., 13., 14.]],
 
         [[20., 21., 22., 23., 24.]]]),
 tensor([[[ 5.,  6.,  7.,  8.,  9.]],
 
         [[15., 16., 17., 18., 19.]],
 
         [[25., 26., 27., 28., 29.]]]))

## 自动微分

实例:
$$
f(\mathbf{x})=2\mathbf{x}+1, g(y)=\mathbf{y^2}+5, z=mean(\mathbf{g(y)})
$$
求$\frac{dz}{dx}$

In [7]:
def f(x:tensor):
    return 2*x + 1

def g(x:tensor):
    return x**2 + 5

def mean(x:tensor):
    return torch.mean(x)

In [8]:
def dz_dx(x:tensor):  # 实际上的导数
    return (8*x + 4) / x.numel()

In [9]:
x = torch.randint(1, 10, size=(2, 5), dtype=torch.float32, requires_grad=True)

In [13]:
x.requires_grad
# x.requires_grad_(True)  # 如果为False, 可以追加

False

In [12]:
x.requires_grad_(False)

tensor([[6., 9., 1., 9., 1.],
        [6., 5., 9., 5., 7.]])

torch自动求导结果

In [178]:
z = mean(g(f(x)))
z.backward()  # 反向传播，自动求微分
x.grad  # dz/dx

tensor([[2.0000, 6.8000, 2.0000, 3.6000, 1.2000],
        [4.4000, 6.0000, 5.2000, 7.6000, 1.2000]])

解析求导结果

In [179]:
dz_dx(x)  # dz/dx

tensor([[2.0000, 6.8000, 2.0000, 3.6000, 1.2000],
        [4.4000, 6.0000, 5.2000, 7.6000, 1.2000]], grad_fn=<DivBackward0>)

In [175]:
z.grad_fn

<MeanBackward1 at 0x20f7cf2af98>

### Bass模型拟合

In [188]:
def bass(p:tensor, q:tensor, m:tensor, T:int): # 如果要使用其它模型，可以重新定义
    t_tensor = torch.arange(1, T + 1, dtype=torch.float32)
    a = 1 - torch.exp(- (p + q) * t_tensor)
    b = 1 + q / p * torch.exp(- (p + q) * t_tensor)
    diffu_cont = m * a / b

    adopt_cont = torch.zeros_like(diffu_cont)
    adopt_cont[0] = diffu_cont[0]
    for t in range(1, T):
        adopt_cont[t] = diffu_cont[t] - diffu_cont[t - 1]
    return adopt_cont

In [208]:
def meanSquaredLoss(p, q, m, y):  # 平均平方误差
    T = y.numel()
    hat_y = bass(p, q, m, T)
    return torch.mean((hat_y - y)**2)

In [213]:
def r_2(p, q, m, y):  # R2
    T = y.numel()
    hat_y = bass(p, q, m, T)
    tse = torch.sum((y - hat_y)**2)
    ssl = torch.sum((y - torch.mean(y))**2)
    R_2 = (ssl-tse)/ssl
    return R_2

In [184]:
y = tensor([96, 195, 238, 380, 1045, 1230, 1267, 1828, 1586, 1673, 1800, 1580, 1500], dtype=torch.float32)

In [191]:
r = meanSquaredLoss(p, q, m, y)
r.backward()

In [235]:
torch.sign(tensor(-1))

tensor(-1)

In [238]:
p = tensor(0.0001, requires_grad=True)
q = tensor(0.3, requires_grad=True)
m = tensor(20000.0, requires_grad=True)
for i in range(200):
    r = meanSquaredLoss(p, q, m, y)
    r.backward()
    p.data.sub_(0.0001 * torch.sign(p.grad.data))
    q.data.sub_(0.0001 * torch.sign(q.grad.data))
    m.data.sub_(10 * torch.sign(m.grad.data))
    with torch.no_grad():
        r2 = r_2(p, q, m, y)
    
    print(f"第{i+1}轮, r2={r2.detach().numpy():.4f}\n    p:{p.detach().numpy():.4f}, q:{q.detach().numpy():.4f}, m:{m.detach().numpy():.1f}")
    p.grad.data.zero_()  # 清空梯度，否则会累加
    q.grad.data.zero_()
    m.grad.data.zero_()

第1轮, r2=-2.7695
    p:0.0002, q:0.3001, m:20010.0
第2轮, r2=-2.6051
    p:0.0003, q:0.3002, m:20020.0
第3轮, r2=-2.4500
    p:0.0004, q:0.3003, m:20030.0
第4轮, r2=-2.3036
    p:0.0005, q:0.3004, m:20040.0
第5轮, r2=-2.1652
    p:0.0006, q:0.3005, m:20050.0
第6轮, r2=-2.0343
    p:0.0007, q:0.3006, m:20060.0
第7轮, r2=-1.9103
    p:0.0008, q:0.3007, m:20070.0
第8轮, r2=-1.7928
    p:0.0009, q:0.3008, m:20080.0
第9轮, r2=-1.6814
    p:0.0010, q:0.3009, m:20090.0
第10轮, r2=-1.5757
    p:0.0011, q:0.3010, m:20100.0
第11轮, r2=-1.4752
    p:0.0012, q:0.3011, m:20110.0
第12轮, r2=-1.3796
    p:0.0013, q:0.3012, m:20120.0
第13轮, r2=-1.2887
    p:0.0014, q:0.3013, m:20130.0
第14轮, r2=-1.2021
    p:0.0015, q:0.3014, m:20140.0
第15轮, r2=-1.1195
    p:0.0016, q:0.3015, m:20150.0
第16轮, r2=-1.0408
    p:0.0017, q:0.3016, m:20160.0
第17轮, r2=-0.9657
    p:0.0018, q:0.3017, m:20170.0
第18轮, r2=-0.8939
    p:0.0019, q:0.3018, m:20180.0
第19轮, r2=-0.8253
    p:0.0020, q:0.3019, m:20190.0
第20轮, r2=-0.7597
    p:0.0021, q:0.3020,