In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
# 自注意力实现
class GPT2SelfAttention(nn.Module):
    def __init__(self, embedding_size, heads, dropout=0.1):
        super(GPT2SelfAttention, self).__init__()
        self.embedding_size = embedding_size
        self.heads = heads
        self.head_dim = embedding_size // heads

        assert (self.head_dim * heads == embedding_size), "Embedding size needs to be divisible by heads"

        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.fc_out = nn.Linear(self.heads * self.head_dim, embedding_size)

    def forward(self, values, keys, queries):
        N = queries.shape[0]
        values_len, keys_len, queries_len = values.shape[1], keys.shape[1], queries.shape[1]

        # 生成多头矩阵
        values = values.view(N, values_len, self.heads, self.head_dim)
        keys = keys.view(N, keys_len, self.heads, self.head_dim)
        queries = queries.view(N, queries_len, self.heads, self.head_dim)

        Q = self.queries(queries)
        K = self.keys(keys)
        V = self.values(values)

        energy = torch.einsum("nqhd,nkhd->nhqk", [Q, K]) / (self.head_dim ** 0.5)  # Q和K的点积
        # 应用单向掩码
        mask = torch.tril(torch.ones(queries_len, keys_len)).expand(N, self.heads, queries_len, keys_len).to(energy.device)  # 下三角掩码
        energy = energy.masked_fill(mask == 0, float('-1e20'))
        # 计算注意力权重
        attn_weights = F.softmax(energy, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context = torch.einsum("nhqk,nkhd->nqhd", [attn_weights, V]).reshape(N, queries_len, self.heads * self.head_dim)
        out = self.fc_out(context)
        return out

In [4]:
# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, embedding_size, hidden_dim, dropout=0.1):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embedding_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embedding_size)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out

In [5]:
# 定义GPT-2模型基础块
class GPT2Block(nn.Module):
    def __init__(self, embedding_size, heads, hidden_dim, dropout=0.1):
        super(GPT2Block, self).__init__()
        self.attention = GPT2SelfAttention(embedding_size, heads, dropout)
        self.layer_norm1 = nn.LayerNorm(embedding_size)
        self.layer_norm2 = nn.LayerNorm(embedding_size)
        self.feed_forward = FeedForward(embedding_size, hidden_dim, dropout)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # 自注意力机制
        attn_output = self.attention(x, x, x)
        norm1_out = self.layer_norm1(x + attn_output)
        # 前馈神经网络
        ffn_output = self.feed_forward(norm1_out)
        norm2_out = self.layer_norm2(norm1_out + ffn_output)
        return self.dropout(norm2_out)

In [6]:
# GPT-2模型的核心生成模块
class GPT2TextGenerator(nn.Module):
    def __init__(self, vocab_size, embedding_size, num_layers, heads, hidden_dim, max_len, dropout=0.1):
        super(GPT2TextGenerator, self).__init__()
        self.embedding_size = embedding_size
        self.token_embedding = nn.Embedding(vocab_size, embedding_size)
        self.position_embedding = nn.Embedding(max_len, embedding_size)  # 假设最大序列长度为max_len
        self.layers = nn.ModuleList(
            [GPT2Block(embedding_size, heads, hidden_dim, dropout) for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embedding_size, vocab_size)

    def forward(self, x):
        batch_size, seq_length = x.size()
        positions = torch.arange(0, seq_length).expand(batch_size, seq_length).to(x.device)

        x = self.token_embedding(x) + self.position_embedding(positions)
        for layer in self.layers:
            x = layer(x)
        logits = self.fc_out(x)
        
        return logits
    
    def generate(self, start_token, max_len, temperature=1.0):
        generted = start_token
        for _ in range(max_len - len(start_token)):
            x = torch.tensor(generted).unsqueeze(0).to(next(self.parameters()).device)
            logits = self.forward(x)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.argmax(probs, dim=-1).item()
            generted.append(next_token)
            if next_token == 0:  # 假设0是结束标记
                break
        return generted

In [7]:
# 超参数
vocab_size = 10000
embedding_size = 128
num_layers = 4
heads = 8
hidden_dim = 512
max_len = 50
dropout = 0.1

In [8]:
# 初始化模型
gpt2_model = GPT2TextGenerator(vocab_size, embedding_size, num_layers, heads, hidden_dim, max_len, dropout)
# 模拟提示词输入
start_token = [1, 5, 23, 67]  # 假设这些是提示词的token ID
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpt2_model.to(device)
start_token_tensor = torch.tensor(start_token).to(device)

# 生成文本
generated_sequence = gpt2_model.generate(start_token, max_len=max_len, temperature=1.0)
print("生成的文本序列（token ID）：", generated_sequence)

生成的文本序列（token ID）： [1, 5, 23, 67, 5862, 7848, 7616, 8325, 4800, 2124, 5675, 3472, 2857, 8936, 3392, 8376, 8529, 4373, 4797, 1435, 2181, 2999, 4079, 5864, 9716, 9418, 7600, 3505, 2545, 3668, 661, 5486, 4424, 3750, 7071, 7905, 8075, 861, 9934, 7897, 9760, 8688, 4282, 4917, 7785, 2173, 6474, 6709, 4608, 5641]


In [9]:
# 模拟输出解码
def decode_tokens(token_ids):
    return " ".join([f"<token_{token_id}>" for token_id in token_ids])

decoded_text = decode_tokens(generated_sequence)
print("生成的文本序列（解码后）：", decoded_text)

生成的文本序列（解码后）： <token_1> <token_5> <token_23> <token_67> <token_5862> <token_7848> <token_7616> <token_8325> <token_4800> <token_2124> <token_5675> <token_3472> <token_2857> <token_8936> <token_3392> <token_8376> <token_8529> <token_4373> <token_4797> <token_1435> <token_2181> <token_2999> <token_4079> <token_5864> <token_9716> <token_9418> <token_7600> <token_3505> <token_2545> <token_3668> <token_661> <token_5486> <token_4424> <token_3750> <token_7071> <token_7905> <token_8075> <token_861> <token_9934> <token_7897> <token_9760> <token_8688> <token_4282> <token_4917> <token_7785> <token_2173> <token_6474> <token_6709> <token_4608> <token_5641>
