In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

: 

In [None]:
# 自注意力机制
class GPT2SelfAttention(nn.Module):
    def __init__(self, embedding_size, num_heads, dropout=0.1):
        super(GPT2SelfAttention, self).__init__()
        self.embedding_size = embedding_size
        self.num_heads = num_heads
        self.head_dim = embedding_size // num_heads

        assert (
            self.head_dim * num_heads == embedding_size
        ), "Embedding size needs to be divisible by num_heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(num_heads * self.head_dim, embedding_size)

        self.dropout = nn.Dropout(dropout)

    def forward(self, values, keys, queries):
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        # 分割嵌入向量为多个头
        values = values.view(N, value_len, self.num_heads, self.head_dim)
        keys = keys.view(N, key_len, self.num_heads, self.head_dim)
        queries = queries.view(N, query_len, self.num_heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # 计算注意力分数
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) / (self.head_dim ** 0.5)

        mask = torch.tril(torch.ones((query_len, key_len))).expand(N, self.num_heads, query_len, key_len)
        energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy, dim=-1)
        attention = self.dropout(attention)

        out = torch.einsum("nhqk,nkhd->nqhd", [attention, values]).view(
            N, query_len, self.num_heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

In [None]:
# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, embedding_size, hidden_dim, dropout=0.1):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embedding_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embedding_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out

In [None]:
# GPT-2基础块定义
class GPT2Block(nn.Module):
    def __init__(self, embedding_size, num_heads, hidden_dim, dropout=0.1):
        super(GPT2Block, self).__init__()
        self.attention = GPT2SelfAttention(embedding_size, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embedding_size)
        self.norm2 = nn.LayerNorm(embedding_size)
        self.feed_forward = FeedForward(embedding_size, hidden_dim, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # 自注意力子层
        attn_output = self.attention(x, x, x)
        attn_output = self.dropout(attn_output)
        norm1_out = self.norm1(x + attn_output)

        # 前馈神经网络子层
        ffn_output = self.feed_forward(norm1_out)
        ffn_output = self.dropout(ffn_output)
        norm2_out = self.norm2(norm1_out + ffn_output)
        return norm2_out

In [None]:
# 定义GPT-2模型的基础生成模块
class GPT2TextGenerator(nn.Module):
    def __init__(self, vocab_size, embedding_size, num_layers, num_heads, hidden_dim, max_length=50, dropout=0.1):
        super(GPT2TextGenerator, self).__init__()
        self.embedding_size = embedding_size
        self.token_embedding = nn.Embedding(vocab_size, embedding_size)
        self.position_embedding = nn.Embedding(max_length, embedding_size)
        self.layers = nn.ModuleList(
            [GPT2Block(embedding_size, num_heads, hidden_dim, dropout) for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embedding_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.max_length = max_length

    def forward(self, input_ids):
        N, seq_length = input_ids.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(input_ids.device)
        x = self.token_embedding(input_ids) + self.position_embedding(positions)
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x)

        logits = self.fc_out(x)
        return logits
    
    def generate(self, start_token, max_len, temperature=1.0):
        generated = start_token
        for _ in range(max_len - len(start_token)):
            x = torch.tensor(generated).unsqueeze(0).to(next(self.parameters()).device)
            logits = self.forward(x)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.argmax(probs, dim=-1).item()
            generated.append(next_token)
            if next_token == 0:  # 假设0是结束标记
                break
        return generated

    def greedy_search(self, start_token):
        generated = start_token
        for _ in range(self.max_length - len(start_token)):
            x = torch.tensor(generated).unsqueeze(0).to(next(self.parameters()).device)
            logits = self.forward(x)
            next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
            generated.append(next_token)
            if next_token == 0:  # 假设0是结束标记
                break
        return generated

    def beam_search(self, start_token, beam_width=3):
        sequences = [(start_token, 0)]
        
        for _ in range(self.max_length - len(start_token)):
            all_candidates = []
            for seq, score in sequences:
                x = torch.tensor(seq).unsqueeze(0).to(next(self.parameters()).device)
                logits = self.forward(x)
                next_token_logits = logits[:, -1, :]
                probs = F.log_softmax(next_token_logits, dim=-1)
                
                topk_probs, topk_indices = torch.topk(probs, beam_width, dim=-1)
                
                for k in range(beam_width):
                    candidate = seq + [topk_indices[0][k].item()]
                    candidate_score = score + topk_probs[0][k].item()
                    all_candidates.append((candidate, candidate_score))
            # 选择得分最高的beam_width个序列
            ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
            sequences = ordered[:beam_width]
            # 检查是否所有序列都以结束标记结尾
            if any(seq[-1] == 0 for seq, _ in sequences):
                break
        # 选择得分最高的序列
        return max(sequences, key=lambda tup: tup[1])[0]

In [None]:
# 超参数设置
vocab_size = 50527
embedding_size = 768
num_layers = 12
heads = 12
hidden_dim = 3072
max_len = 20
dropout = 0.1

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化模型
gpt2_model = GPT2TextGenerator(vocab_size, embedding_size, num_layers, num_heads=heads, hidden_dim=hidden_dim, max_length=max_len, dropout=dropout)
gpt2_model = gpt2_model.to(device)

In [None]:
# 模拟提示次输入
start_token = [1, 345, 876]
start_token = torch.tensor(start_token).to(device)

In [None]:
# Greedy Search生成文本
greedy_generated_sequence = gpt2_model.greedy_search(start_token.tolist())
print("Greedy Search生成的文本序列:", greedy_generated_sequence)

In [None]:
# Beam Search生成的文本
beam_generated_sequence = gpt2_model.beam_search(start_token.tolist())
print("Beam Search生成的文本序列:", beam_generated_sequence)

In [None]:
# 模拟解码生成的文本
def decode_sequence(sequence):
    return " ".join([f"<token_{idx}>" for idx in sequence])

greedy_decoded_text = decode_sequence(greedy_generated_sequence)
beam_decoded_text = decode_sequence(beam_generated_sequence)

print("Greedy Search解码生成的文本序列:", greedy_decoded_text)
print("Beam Search解码生成的文本序列:", beam_decoded_text)