In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class Highway(nn.Module):
    def __init__(self, size, num_layers=2):
        super(Highway, self).__init__()
        self.num_layers = num_layers
        self.linears = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.gates = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])

    def forward(self, x):
        for linear, gate in zip(self.linears, self.gates):
            # 计算Transform gate的输出
            gate_output = torch.sigmoid(gate(x))
            # 计算非线性变换后的输出
            nonlinear_output = F.relu(linear(x))
            # 结合Transform gate 和 Carry gate
            x = gate_output * nonlinear_output + (1 - gate_output) * x
        return x

In [None]:
# 示例：假设输入的向量维度是512
highway_net = Highway(size=512, num_layers=2)
input_tensor = torch.randn(64, 512)  # 假设输入为64个样本，每个样本是512维向量
output_tensor = highway_net(input_tensor)
print(output_tensor.shape)  # 输出应为 (64, 512)