In [None]:
import numpy as np
from check import check

1. 初始状态概率（π）

$$
\pi_i = \frac{\text{Count}(q_1 = s_i)}{\text{总序列数}}
$$

2. 状态转移概率（A）

$$
a_{ij} = \frac{\text{Count}(s_i \rightarrow s_j)}{\text{Count}(s_i \rightarrow \text{任意状态})}
$$

3. 观测发射概率（B）

$$
b_j(k) = \frac{\text{Count}(s_j \text{生成} o_k)}{\text{Count}(s_j)}
$$

符号说明：
- $s_i$ ：隐藏状态
- $v_k$ ：观测符号
- $q_t$ ：t时刻的状态
- $o_t$ ：t时刻的观测
- $T$ ：序列长度


In [28]:
class HMM:
    def __init__(self, states, observations, pi, A, B):
        self.states = states
        self.observations = observations
        self.pi = pi
        self.A = A
        self.B = B

    # 维特比算法
    def viterbi(self, obs_seq):
        T = len(obs_seq)
        N = len(self.states)
        delta = np.zeros((T, N))
        psi = np.zeros((T, N), dtype=int)  # 回溯路径

        # 初始化
        for i, q in enumerate(self.states):
            emit_prob = self.B[q].get(obs_seq[0], 1e-8)  # 未知观测给极小概率
            delta[0][i] = self.pi[q] * emit_prob

        # 递推
        for t in range(1, T):
            for j, q_j in enumerate(self.states):
                max_val = -1
                max_idx = -1
                for i, q_i in enumerate(self.states):
                    emit_prob = self.B[q_j].get(obs_seq[t], 1e-8)
                    val = delta[t-1][i] * self.A[q_i][q_j] * emit_prob
                    if val > max_val:
                        max_val = val
                        max_idx = i
                delta[t][j] = max_val
                psi[t][j] = max_idx

        # 回溯
        path = [np.argmax(delta[T-1])]
        for t in range(T-1, 0, -1):
            path.insert(0, psi[t][path[0]])

        return [self.states[idx] for idx in path]
    
    # 监督学习
    # obs_seqs: 观测序列列表
    # state_seqs: 对应的状态序列列表
    def supervised_learning(self, state_seqs, obs_seqs):
        # 初始化计数
        pi_counts = {q: 0 for q in self.states}
        A_counts = {q_i: {q_j: 0 for q_j in self.states} for q_i in self.states}
        B_counts = {q: {o: 0 for o in self.observations} for q in self.states}

        # 统计初始状态
        for seq in state_seqs:
            first_state = seq[0]
            pi_counts[first_state] += 1

        # 统计转移和观测
        for state_seq, obs_seq in zip(state_seqs, obs_seqs):
            for t in range(len(state_seq)-1):
                current_state = state_seq[t]
                next_state = state_seq[t+1]
                current_obs = obs_seq[t]

                A_counts[current_state][next_state] += 1
                B_counts[current_state][current_obs] += 1

            # 处理最后一个观测
            last_state = state_seq[-1]
            last_obs = obs_seq[-1]
            B_counts[last_state][last_obs] += 1

        # 计算初始概率 pi
        total_seqs = len(state_seqs)
        for q in self.states:
            self.pi[q] = pi_counts[q] / total_seqs

        # 计算转移概率A
        for q_i in self.states:
            total_trans = sum(A_counts[q_i].values())
            if total_trans > 0:
                for q_j in self.states:
                    self.A[q_i][q_j] = A_counts[q_i][q_j] / total_trans
            else:
                # 如果没有观察到转移，均匀分布
                for q_j in self.states:
                    self.A[q_i][q_j] = 1.0 / len(self.states)

        # 计算发射概率B
        for q in self.states:
            total_obs = sum(B_counts[q].values())
            if total_obs > 0:
                for o in self.observations:
                    self.B[q][o] = B_counts[q][o] / total_obs
            else:
                # 如果没有观察到发射，均匀分布
                for o in self.observations:
                    self.B[q][o] = 1.0 / len(self.observations)

In [29]:
def process_data(file_path):
    # 初始化数据结构
    states = set()
    observations = set()
    state_seqs = []
    obs_seqs = []
    
    current_state_seq = []
    current_obs_seq = []
    
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:  # 空行表示句子结束
                if current_state_seq and current_obs_seq:
                    state_seqs.append(current_state_seq)
                    obs_seqs.append(current_obs_seq)
                    current_state_seq = []
                    current_obs_seq = []
                continue
                
            parts = line.split()
            if len(parts) >= 2:  # 确保有词和标签
                word = parts[0]
                tag = parts[-1]  # 假设标签在最后
                
                # 更新状态和观测集合
                states.add(tag)
                observations.add(word)
                
                # 添加到当前序列
                current_state_seq.append(tag)
                current_obs_seq.append(word)
    
    # 处理最后一个句子（如果文件不以空行结尾）
    if current_state_seq and current_obs_seq:
        state_seqs.append(current_state_seq)
        obs_seqs.append(current_obs_seq)
    
    # 转换为列表并排序（为了确定性）
    states = sorted(states)
    observations = sorted(observations)
    
    return {
        'states': states,
        'observations': observations,
        'state_seqs': state_seqs,
        'obs_seqs': obs_seqs
    }

In [30]:
# 加载训练数据
train_data_path = "./NER/English/train.txt"
# train_data_path = "./NER/Chinese/train.txt"
train_data = process_data(train_data_path)

states = train_data['states']
observations = train_data['observations']
state_seqs = train_data['state_seqs']
obs_seqs = train_data['obs_seqs']

In [31]:
# 初始化参数
pi = {q: 0 for q in states}
A = {q_i: {q_j: 0 for q_j in states} for q_i in states}
B = {q: {o: 0 for o in observations} for q in states}

# 创建HMM模型
hmm = HMM(states, observations, pi, A, B)

# 进行监督学习
hmm.supervised_learning(state_seqs, obs_seqs)

In [32]:
valid_data_path = "./NER/English/validation.txt"
# valid_data_path = "./NER/Chinese/validation.txt"
valid_data = process_data(valid_data_path)

output_path = "./hmm_validation_output.txt"
with open(output_path, "w", encoding="utf-8") as fout:
    for obs_seq, pred_states in zip(valid_data['obs_seqs'], [hmm.viterbi(seq) for seq in valid_data['obs_seqs']]):
        for word, tag in zip(obs_seq, pred_states):
            fout.write(f"{word} {tag}\n")
        fout.write("\n")  # 句子间空行

In [None]:
check(language = "English", gold_path="NER/English/validation.txt", my_path="hmm_validation_output.txt")
#check(language = "Chinese", gold_path="NER/Chinese/validation.txt", my_path="hmm_validation_output.txt")