In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt

# 生成伪数据
num_tasks = 4
num_classes = 10
samples = 14000
# test_samples = 50
input_shape = (1024, 2)

task_datasets = []
for _ in range(num_tasks):
    data = torch.randn((samples, *input_shape)).double()
    
    # 数据转置
    data = data.transpose(1, 2)
    labels = torch.randint(num_classes, size=(samples,))
    dataset = torch.utils.data.TensorDataset(data, labels)
    
    # 划分训练集和测试集
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    task_datasets.append((train_dataset, test_dataset))

In [None]:
# 定义WDCNN模型
class WDCNN(nn.Module):
    def __init__(self, num_classes):
        super(WDCNN, self).__init__()
        self.conv1 = nn.Conv1d(2, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 256, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        return x


# 定义EWC训练函数
def train_ewc(model, fisher_matrix, task_datasets, num_epochs=10, lr=0.001, batch_size=32):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    task_accuracies = []
    task_losses = []

    for task_idx, (train_dataset, test_dataset) in enumerate(task_datasets):
        print(f"Training Task {task_idx+1}")
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        # 计算当前任务的Fisher信息矩阵
        fisher_matrix_task = []
        for param in model.parameters():
            fisher_matrix_task.append(torch.zeros_like(param.data))

        for epoch in range(num_epochs):
            model.train()
            for inputs, labels in train_loader:
                inputs = inputs.to(device, dtype=torch.double)
                labels = labels.to(device, dtype=torch.long)

                inputs = inputs.to(device)  # 移动到设备上
                labels = labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()

                for i, param in enumerate(model.parameters()):
                    fisher_matrix_task[i] += param.grad.data ** 2

                optimizer.step()

        # 更新Fisher信息矩阵
        if task_idx == 0:
            fisher_matrix = fisher_matrix_task
        else:
            for i in range(len(fisher_matrix)):
                fisher_matrix[i] += fisher_matrix_task[i]

        # 在当前任务上进行测试
        model.eval()
        correct = 0
        total = 0
        task_loss = 0.0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = inputs.to(device, dtype=torch.double)
                labels = labels.to(device, dtype=torch.long)

                inputs = inputs.to(device)  # 移动到设备上
                labels = labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                loss = criterion(outputs, labels)
                task_loss += loss.item()

        accuracy = 100 * correct / total
        loss_avg = task_loss / len(test_loader)
        task_accuracies.append(accuracy)
        task_losses.append(loss_avg)

        print(f"Task {task_idx+1} Accuracy: {accuracy:.2f}%")
        print(f"Task {task_idx+1} Loss: {loss_avg:.4f}")

    return task_accuracies, task_losses


# 创建WDCNN模型实例
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = WDCNN(num_classes).to(device)

# 计算初始Fisher信息矩阵
fisher_matrix = []
for param in model.parameters():
    fisher_matrix.append(torch.zeros_like(param.data))

# 使用EWC训练模型
model.double()
task_accuracies, task_losses = train_ewc(model, fisher_matrix, task_datasets, num_epochs=10)

# 可视化结果
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, num_tasks + 1), task_accuracies, marker='o')
plt.xlabel('Task')
plt.ylabel('Accuracy (%)')
plt.title('Task Accuracies')

plt.subplot(1, 2, 2)
plt.plot(range(1, num_tasks + 1), task_losses, marker='o')
plt.xlabel('Task')
plt.ylabel('Loss')
plt.title('Task Losses')

plt.tight_layout()
plt.show()

