In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from dataset import load_core_set_data, load_pca_data
from simulation.simulator import FDTDSimulator
from simulation.student import LSTMPredictor

# 加载数据集

In [None]:
class FDTDDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.FloatTensor(data)
        self.labels = torch.FloatTensor(labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


train_data, train_labels, test_data, test_labels = load_core_set_data()

train_dataset = FDTDDataset(train_data, train_labels)
test_dataset = FDTDDataset(test_data, test_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [None]:
radius_matrix = torch.rand(10, 10)  # 生成10x10的随机矩阵，范围在0~1之间

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = LSTMPredictor(
    input_size=10,
    hidden_size=64,
    num_layers=2,
    dropout=0.1,
    device=device
)
simulator = FDTDSimulator(
    radius_matrix=radius_matrix
)

num_epochs = 100
learning_rate = 1e-3

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

In [None]:
for epoch in range(num_epochs):
    for inputs, _ in train_loader:
        inputs = inputs.to(device)
        optimizer.zero_grad()

        _, pred_seq = model.get_sequence_output(inputs)

        with torch.no_grad():
            target_seq = simulator(inputs)

        loss = criterion(pred_seq, target_seq)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

In [None]:
save_dir = "data/model/lstm_model.pth"
os.makedirs(os.path.dirname(save_dir), exist_ok=True)

torch.save({
    "model_state_dict": model.state_dict(),
    "radius_matrix": radius_matrix
}, save_dir)