In [5]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

# 设置随机种子
tf.random.set_seed(42)
np.random.seed(42)

# 1. 数据生成
def generate_data(num_samples=500):
    """生成三类数据（健康/不健康/非植物），每类num_samples个样本"""
    # 特征1: 绿色通道（健康高，不健康中，非植物低）
    green = np.concatenate([
        np.random.normal(0.8, 0.1, num_samples),  # 健康
        np.random.normal(0.5, 0.1, num_samples),  # 不健康
        np.random.normal(0.2, 0.1, num_samples)   # 非植物
    ])
    
    # 特征2: 纹理（健康中，不健康高，非植物低）
    texture = np.concatenate([
        np.random.normal(0.5, 0.1, num_samples),
        np.random.normal(0.8, 0.1, num_samples),
        np.random.normal(0.2, 0.1, num_samples)
    ])
    
    # 标签：0=健康, 1=不健康, 2=非植物
    labels = np.concatenate([
        np.zeros(num_samples),
        np.ones(num_samples),
        np.full(num_samples, 2)
    ])
    
    X = np.stack([green, texture], axis=1)
    y = labels.astype(int)
    return X, y

# 生成数据（旧任务只有健康/不健康植物）
X_old, y_old = generate_data(num_samples=500)
mask = (y_old < 2)  # 筛选标签为0或1的样本
X_old, y_old = X_old[mask], y_old[mask]

# 生成新任务数据（含非植物）
X_new, y_new = generate_data(num_samples=500)

# 2. 构建模型
model = Sequential([
    Dense(32, activation='relu', input_shape=(2,)),
    Dense(32, activation='relu'),
    Dense(3, activation='softmax')
])
optimizer = Adam(learning_rate=0.001)

# 3. 训练旧任务
model.compile(optimizer=optimizer, 
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(X_old, y_old, epochs=10, batch_size=32)

# 保存旧参数和计算Fisher矩阵
prev_weights = [tf.identity(w) for w in model.trainable_variables]

# 4. 计算Fisher矩阵
def compute_fisher(model, X, y):
    fisher = [tf.zeros_like(w) for w in model.trainable_variables]
    for x, true_label in zip(X, y):
        with tf.GradientTape() as tape:
            prob = model(np.expand_dims(x, axis=0))[0, true_label]
            log_prob = tf.math.log(prob)
        grads = tape.gradient(log_prob, model.trainable_variables)
        fisher = [f + g**2 for f, g in zip(fisher, grads)]
    return [f / len(X) for f in fisher]

fisher_matrix = compute_fisher(model, X_old, y_old)

# 5. 用EWC训练新任务
def train_with_ewc(model, X, y, prev_weights, fisher, lambda_ewc=1.0, epochs=10):
    for epoch in range(epochs):
        for x_batch, y_batch in tf.data.Dataset.from_tensor_slices((X, y)).batch(32):
            with tf.GradientTape() as tape:
                # 常规损失
                pred = model(x_batch, training=True)
                loss = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(y_batch, pred)
                )
                # EWC损失
                ewc_loss = sum(
                    tf.reduce_sum(f * (w - pw)**2)
                    for w, pw, f in zip(model.trainable_variables, prev_weights, fisher)
                )
                loss += lambda_ewc * ewc_loss
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

train_with_ewc(model, X_new, y_new, prev_weights, fisher_matrix, lambda_ewc=1.0)

# 6. 评估
def evaluate(model, X, y, name):
    pred = model.predict(X)
    acc = np.mean(np.argmax(pred, axis=1) == y)
    print(f"{name}准确率: {acc:.4f}")

evaluate(model, X_old, y_old, "旧任务（健康/不健康植物）")
evaluate(model, X_new, y_new, "新任务（含非植物）")

import os

# 创建保存目录
save_dir = "ewc_assets"
os.makedirs(save_dir, exist_ok=True)

# 保存旧任务参数和Fisher矩阵
def save_ewc_assets(prev_weights, fisher_matrix, save_dir):
    # 保存旧参数（TensorFlow格式）
    model.save(os.path.join(save_dir, "old_model"))  # 保存整个模型（可选）
    
    # 保存Fisher矩阵（NumPy格式）
    fisher_numpy = [f.numpy() for f in fisher_matrix]  # 转换为NumPy数组
    np.savez(os.path.join(save_dir, "fisher_matrix.npz"), *fisher_numpy)
    
    print(f"EWC参数已保存至 {save_dir}")

# 加载旧参数和Fisher矩阵
def load_ewc_assets(save_dir):
    # 加载旧模型参数
    old_model = tf.keras.models.load_model(os.path.join(save_dir, "old_model"))
    prev_weights = old_model.trainable_variables
    
    # 加载Fisher矩阵
    fisher_data = np.load(os.path.join(save_dir, "fisher_matrix.npz"))
    fisher_matrix = [tf.constant(f) for f in fisher_data.values()]
    
    return prev_weights, fisher_matrix

# 调用保存函数
save_ewc_assets(prev_weights, fisher_matrix, save_dir)

# 后续使用时可以加载
prev_weights_loaded, fisher_matrix_loaded = load_ewc_assets(save_dir)



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
旧任务（健康/不健康植物）准确率: 0.9830
新任务（含非植物）准确率: 0.9833
INFO:tensorflow:Assets written to: ewc_assets/old_model/assets


INFO:tensorflow:Assets written to: ewc_assets/old_model/assets


EWC参数已保存至 ewc_assets
