In [5]:
import torch
from torch import nn
import math
from utils.activate import Swish

In [2]:
#Swish是一个激活函数
class Swish(nn.Module):
    
    def __init__(self,beta=1.0):
        super().__init__()
        self.beta=beta
    
    def forward(self,x):
        return x*torch.sigmoid(self.beta*x)
    
    

In [19]:
class TimeEmbedding(nn.Module):
    """
    该模块把整型数t，按照Transformer函数式的编码方式映射成向量，向量的形状为(batch,time_channel)，这个time_channel其实就是想把t映射成向量的维度
    """
    def __init__(self,n_channels:int):
        """
        
        :param n_channels: n_channels就是time_channels
        """
        super().__init__()
        self.n_channels=n_channels
        self.lin1=nn.Linear(self.n_channels//4,self.n_channels)
        self.act=Swish()
        self.lin2 = nn.Linear(self.n_channels, self.n_channels)

    def forward(self, t: torch.Tensor):
        """
        Params:
            t: 维度（batch_size），整型时刻t
        """
        # 以下转换方法和Transformer的位置编码一致
        half_dim = self.n_channels // 8
        emb = math.log(10_000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=1)

        # Transform with the MLP
        emb = self.act(self.lin1(emb))
        emb = self.lin2(emb)

        # 输出维度(batch_size, time_channels)
        return emb

t.shape torch.Size([5])
t[:,None].shape torch.Size([5, 1])
emb.shape torch.Size([4])
emb[None,:] torch.Size([1, 4])
emb.shape torch.Size([5, 4])
emb.shape torch.Size([5, 8])
emb.shape torch.Size([5, 32])


tensor([[-0.0715, -0.1873, -0.2228,  0.0173,  0.0673,  0.0914, -0.2218,  0.2959,
          0.1476, -0.0188, -0.1146, -0.3305, -0.1300,  0.2299, -0.1995,  0.0468,
          0.3782, -0.0090,  0.0759, -0.3204,  0.2322,  0.1280,  0.0982,  0.1014,
          0.0041,  0.2335,  0.1778, -0.0114,  0.1533,  0.1659,  0.0922,  0.0412],
        [ 0.0282, -0.3186, -0.0710, -0.0284,  0.0559,  0.1017, -0.2001,  0.2207,
          0.1051,  0.0107, -0.1226, -0.2305, -0.1056,  0.2142, -0.2897,  0.0847,
          0.2580, -0.0158, -0.0205, -0.3133,  0.2667,  0.1834,  0.0462,  0.0989,
          0.0120,  0.1591,  0.0761, -0.0761,  0.1562,  0.0580,  0.0889,  0.1325],
        [ 0.1652, -0.3248,  0.0217,  0.0526,  0.0475,  0.0485, -0.1060,  0.1598,
          0.0114, -0.0076, -0.0923, -0.1456, -0.0532,  0.1368, -0.3094,  0.1204,
          0.2065,  0.0156, -0.0945, -0.3069,  0.3100,  0.2385, -0.0463,  0.1196,
          0.0191,  0.1236,  0.0485, -0.1265,  0.0622,  0.0290,  0.1054,  0.1305],
        [ 0.2021, -0.1955