In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from avalanche.benchmarks import SplitMNIST
from avalanche.logging import InteractiveLogger, TextLogger, TensorboardLogger
from avalanche.training import Naive
from avalanche.training.plugins.ewc import EWCPlugin

# --------------------
# 1. 定义简单 CNN 模型
# --------------------
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5, padding=2)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 256)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu3(self.fc1(x))
        return self.fc2(x)

# --------------------
# 2. 创建 SplitMNIST 基准
# --------------------
benchmark = SplitMNIST(
    n_experiences=5,
    return_task_id=False
)

# --------------------
# 3. 设置模型、优化器、损失和 EWC 插件
# --------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN(num_classes=10).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# Elastic Weight Consolidation 插件
ewc_plugin = EWCPlugin(
    ewc_lambda=1000,      # 正则化强度，可根据验证结果调节
    mode='online',        # 使用 online 模式可累积 Fisher 信息
    decay_factor=0.5,     # Fisher 信息随任务递减的比例
    keep_importance_data=False
)

# --------------------
# 4. 定义 Naive 策略并添加 EWC
# --------------------
strategy = Naive(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    train_mb_size=32,
    train_epochs=3,
    eval_mb_size=64,
    device=device,
    plugins=[ewc_plugin]
)

# 日志记录
interactive_logger = InteractiveLogger()
# text_logger = TextLogger(open('ewc_splitmnist.log', 'w'))
# tb_logger = TensorboardLogger('ewc_tb_logs')
strategy.evaluator.loggers = [interactive_logger]

# --------------------
# 5. 训练与评估
# --------------------
task_accuracies = []
print("Starting SplitMNIST + EWC...")
for experience in benchmark.train_stream:
    print(f"\n--- Training experience {experience.current_experience} ---")
    strategy.train(experience)
    print("Evaluation:")
    resutls = strategy.eval(benchmark.test_stream)
    task_accuracies.append(resutls)

print("Training completed!")


Starting SplitMNIST + EWC...

--- Training experience 0 ---
-- >> Start of training phase << --
 81%|████████  | 321/397 [00:05<00:01, 69.44it/s]

KeyboardInterrupt: 

In [None]:
import os
# Define the file path
file_path = "/home/yangz2/code/quantum_cl/results/list/splitminist_EWC.pkl"

# Create directories if they don't exist
os.makedirs(os.path.dirname(file_path), exist_ok=True)  # <-- Add this line   

In [None]:
import pickle
# 存储到文件
with open(file_path, "wb") as f:
    pickle.dump([task_accuracies], f)  