In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from src.dataset import pca_data_loader
from tqdm import tqdm

In [2]:
# 加载数据
train_loader, test_loader = pca_data_loader()

# 设置设备（如果有 GPU 可用则使用 GPU）
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Loading PCA data from F:\pyDistilledFDTD\dataset\.cache\pca\pca-components-10\batch-size-64.pkl


In [3]:
# 定义LSTM模型
class CustomLSTM(nn.Module):
    def __init__(self, input_size=10, hidden_size=50, num_layers=2, output_size=10):
        super(CustomLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # 定义LSTM层
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        
        # 定义全连接层，将LSTM的输出转换为最终的分类输出
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # 初始化LSTM的隐藏状态和细胞状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).double()
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).double()

        # LSTM前向传播
        out, _ = self.lstm(x, (h0, c0))
        
        # 取最后一个时间步的输出
        out = out.mean(dim=1)
        
        # 全连接层
        out = self.fc(out)
        return out

In [4]:
# 使用扩展函数处理数据
def data_expand(data, time_step=20, method="repeat"):
    """
    将数据扩展到指定的时间步长。
    """
    if method == "repeat":
        return data.unsqueeze(1).expand(-1, time_step, -1)
    elif method == "gaussian":
        mean = (time_step - 1) / 2.0
        std_dev = (time_step - 1) / 6.0

        t = torch.arange(time_step, device=data.device, dtype=data.dtype)

        gaussian_curve = torch.exp(-((t - mean) ** 2) / (2 * std_dev ** 2))
        gaussian_curve = gaussian_curve / gaussian_curve.sum()  # Shape: (time_step,)
        expanded_value = gaussian_curve.view(1, -1, 1) * data.unsqueeze(1)  # Shape: (batch_size, time_step, ports)
        return expanded_value
    elif method == "sin":
        WAVELENGTH = 1  # 假设波长
        SPEED_LIGHT = 3e8  # 假设光速

        period = WAVELENGTH / SPEED_LIGHT  # Calculate the period
        omega = 2 * torch.pi / period      # Angular frequency

        t = torch.arange(time_step, device=data.device, dtype=data.dtype).view(1, time_step, 1)  # Shape: (1, time_step, 1)
        phase_shift = data * torch.pi  # Scale data to range [0, π], Shape: (batch_size, ports)
        phase_shift = phase_shift.unsqueeze(1)  # Shape: (batch_size, 1, ports)

        sin_wave = torch.sin(omega * t + phase_shift)  # Shape: (batch_size, time_step, ports)
        return sin_wave
    else:
        raise ValueError("Invalid method")

In [5]:
# 初始化模型、损失函数和优化器
model = CustomLSTM(input_size=10, hidden_size=50, num_layers=5, output_size=10).to(device).double()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [6]:
# 训练模型
epochs = 10
with tqdm(total=epochs * len(train_loader), desc="Training Progress", unit='batch') as pbar:
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # 扩展输入数据
            inputs = data_expand(inputs, time_step=20, method="repeat")  # 使用指定方法扩展输入

            # 前向传播
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.update(1)

Training Progress: 100%|██████████| 9380/9380 [03:07<00:00, 49.98batch/s]


In [7]:
# 在测试集上评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # 扩展输入数据
        inputs = data_expand(inputs, time_step=20, method="repeat")

        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Test Accuracy: 90.00%
