In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# Highway 网络
class Highway(nn.Module):
    def __init__(self, size, num_layers=2):
        super(Highway, self).__init__()
        self.num_layers = num_layers
        self.linears = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.gates = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])

    def forward(self, x):
        for linear, gate in zip(self.linears, self.gates):
            gate_output = torch.sigmoid(gate(x))
            nonlinear_output = F.relu(linear(x))
            x = gate_output * nonlinear_output + (1 - gate_output) * x  # Gated connection
        return x

In [None]:
# 模型架构
class BiLSTMWithResidual(nn.Module):
    def __init__(self, input_dim, hidden_dim=4096, proj_dim=512, num_layers=2, vocab_size=100, char_embed_dim=64):
        super(BiLSTMWithResidual, self).__init__()

        # 字符嵌入
        self.char_embedding = nn.Embedding(vocab_size, char_embed_dim)
        
        # 卷积层用于字符 n-gram
        self.char_conv = nn.Conv1d(in_channels=char_embed_dim, out_channels=2048, kernel_size=3, padding=1)
        
        # 两层 Highway 网络
        self.highway = Highway(size=2048)
        
        # 线性投影到 512 维
        self.proj = nn.Linear(2048, proj_dim)
        
        # 双向 LSTM
        self.lstm = nn.LSTM(input_size=proj_dim, hidden_size=hidden_dim, num_layers=num_layers, 
                            bidirectional=True, batch_first=True)
        
        # 线性投影将 LSTM 输出从 4096 维降维到 512 维
        self.proj_lstm = nn.Linear(2 * hidden_dim, proj_dim)
    
    def forward(self, x):
        # 假设输入 x 是字符的索引
        # 字符嵌入
        char_embeds = self.char_embedding(x)  # (batch_size, seq_len, char_embed_dim)
        
        # 转换为卷积输入格式: (batch_size, char_embed_dim, seq_len)
        char_embeds = char_embeds.permute(0, 2, 1)
        
        # 卷积操作
        conv_output = self.char_conv(char_embeds)  # (batch_size, 2048, seq_len)
        conv_output = F.relu(conv_output)
        
        # 转置回原来的形状: (batch_size, seq_len, 2048)
        conv_output = conv_output.permute(0, 2, 1)
        
        # 通过 Highway 网络
        highway_output = self.highway(conv_output)  # (batch_size, seq_len, 2048)
        
        # 线性投影到 512 维
        proj_output = self.proj(highway_output)  # (batch_size, seq_len, 512)
        
        # 通过双向 LSTM
        lstm_output, _ = self.lstm(proj_output)  # (batch_size, seq_len, 2 * hidden_dim)
        
        # 添加残差连接（从输入到输出的连接）
        residual_output = proj_output + lstm_output  # (batch_size, seq_len, 2 * hidden_dim)
        
        # 最终投影到 512 维
        final_output = self.proj_lstm(residual_output)  # (batch_size, seq_len, 512)
        
        return final_output

In [None]:
# 测试模型
batch_size = 32
seq_len = 50
vocab_size = 100  # 假设字符表大小为100
input_dim = 64  # 假设字符嵌入维度为64

# 随机生成字符索引输入
x = torch.randint(0, vocab_size, (batch_size, seq_len))

# 初始化模型
model = BiLSTMWithResidual(input_dim=input_dim, vocab_size=vocab_size)

# 前向传播
output = model(x)
print(output.shape)  # 输出维度应该是 (batch_size, seq_len, 512)