In [None]:
import numpy as np
import math
from check import check

In [None]:
# 是否首字母大写
def word_is_capitalized(seq, pos, current_tag, prev_tag):
    word = seq[pos]
    return 1 if word[0].isupper() else 0

# 检查前一个词是否是'the'
def prev_word_is_the(seq, pos, current_tag, prev_tag):
    if pos == 0:
        return 0
    return 1 if seq[pos-1].lower() == 'the' else 0

feature_functions = [
    word_is_capitalized,
    prev_word_is_the,
]

In [None]:
class CRF:
    # n_tags: 标签数量
    # feature_functions: 特征函数列表
    def __init__(self, n_tags, feature_functions):
        self.n_tags = n_tags
        self.feature_functions = feature_functions
        self.weights = np.zeros(len(feature_functions)) # 权重向量，每个特征函数对应一个权重
    
    # 前向算法
    # alpha[t, tag]: 给定观测序列的前t+1个词时，以tag作为第t个词的标签的所有路径的“分数”之和
    def forward(self, seq):
        T = len(seq)
        alpha = np.zeros((T, self.n_tags))

        # 初始化
        for tag in range(self.n_tags):
            alpha[0, tag] = self.state_score(seq, 0, tag, None)

        # 递推
        for t in range(1, T):
            for tag in range(self.n_tags):
                score = 0
                for prev_tag in range(self.n_tags):
                    score += alpha[t-1, prev_tag] * self.transition_score(seq, t, tag, prev_tag)
                alpha[t, tag] = score * self.state_score(seq, t, tag, None)

        return alpha

    # 后向算法
    # beta[t, tag]: 在给定观测序列的第t个词处，已知该词的标签为tag，从t到序列末尾的所有可能标签路径的“分数”之和。
    def backward(self, seq):
        T = len(seq)
        beta = np.zeros((T, self.n_tags))

        # 初始化
        beta[T-1, :] = 1

        # 递推
        for t in range(T-2, -1, -1):
            for tag in range(self.n_tags):
                score = 0
                for next_tag in range(self.n_tags):
                    score += beta[t+1, next_tag] * \
                             self.transition_score(seq, t+1, next_tag, tag) * \
                             self.state_score(seq, t+1, next_tag, tag)
                beta[t, tag] = score

        return beta

    # 计算边缘概率
    # marginals[t, tag]: 在给定观测序列x的条件下，第t个词的标签为tag的概率
    def compute_marginals(self, seq):
        alpha = self.forward(seq)
        beta = self.backward(seq)
        Z = np.sum(alpha[-1, :])

        marginals = np.zeros((len(seq), self.n_tags))
        for t in range(len(seq)):
            for tag in range(self.n_tags):
                marginals[t, tag] = (alpha[t, tag] * beta[t, tag]) / Z

        return marginals
    
    # 维特比解码
    def viterbi_decode(self, seq):
        T = len(seq)
        viterbi = np.zeros((T, self.n_tags))
        backptrs = np.zeros((T, self.n_tags), dtype=int)    # 回溯指针

        # 初始化
        for tag in range(self.n_tags):
            viterbi[0, tag] = self.state_score(seq, 0, tag, None)

        # 递推
        for t in range(1, T):
            for tag in range(self.n_tags):
                max_score = -float('inf')
                best_prev_tag = 0
                for prev_tag in range(self.n_tags):
                    score = viterbi[t-1, prev_tag] * self.transition_score(seq, t, tag, prev_tag)
                    if score > max_score:
                        max_score = score
                        best_prev_tag = prev_tag
                viterbi[t, tag] = max_score * self.state_score(seq, t, tag, best_prev_tag)
                backptrs[t, tag] = best_prev_tag

        # 回溯
        best_path = []
        best_last_tag = np.argmax(viterbi[-1, :])
        best_path.append(best_last_tag)

        for t in range(T-1, 0, -1):
            best_last_tag = backptrs[t, best_last_tag]
            best_path.insert(0, best_last_tag)

        return best_path
    

    # 计算状态特征得分
    def state_score(self, sequence, position, current_tag, prev_tag):
        score = 0
        for i, func in enumerate(self.feature_functions):
            score += self.weights[i] * func(sequence, position, current_tag, prev_tag)
        return math.exp(score)

    # 计算转移特征得分
    def transition_score(self, sequence, position, current_tag, prev_tag):
        # 这里简化处理，实际应用中可能需要更复杂的转移特征
        return 1.0  # 简化处理，实际应用中需要学习转移概率


    # 训练CRF模型
    def train(self, sequences, tags, lr=0.01, max_iter=100):
        for iteration in range(max_iter):
            total_loss = 0
            grad = np.zeros(len(self.weights))

            for seq, tag_seq in zip(sequences, tags):
                # 计算模型预测的边缘概率
                marginals = self.compute_marginals(seq)

                empirical_expectation = np.zeros(len(self.weights))
                model_expectation = np.zeros(len(self.weights))

                # 计算经验特征期望
                # empirical_expectation[i]: 第i个特征函数在真实标签下的总和
                for t in range(len(seq)):
                    current_tag = tag_seq[t]
                    prev_tag = tag_seq[t-1] if t > 0 else None
                    for i, func in enumerate(self.feature_functions):
                        empirical_expectation[i] += func(seq, t, current_tag, prev_tag)

                # 计算模型特征期望
                # model_expectation[i]: 第i个特征函数在模型预测分布下的期望值
                for t in range(len(seq)):
                    for current_tag in range(self.n_tags):
                        for prev_tag in range(self.n_tags):
                            if t == 0:
                                prev_tag = None
                            for i, func in enumerate(self.feature_functions):
                                model_expectation[i] += marginals[t, current_tag] * \
                                                       func(seq, t, current_tag, prev_tag)

                # 更新梯度
                grad += empirical_expectation - model_expectation

                # 计算对数似然损失
                Z = np.sum(self.forward(seq)[-1, :])
                sequence_score = 0
                for t in range(len(seq)):
                    current_tag = tag_seq[t]
                    prev_tag = tag_seq[t-1] if t > 0 else None
                    for i, func in enumerate(self.feature_functions):
                        sequence_score += self.weights[i] * func(seq, t, current_tag, prev_tag)
                total_loss += sequence_score - math.log(Z)

            # 更新权重
            self.weights += lr * grad

            print(f"Iteration {iteration + 1}, Loss: {total_loss}")
        