In [1]:
import os
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from dataset import load_core_set_data, load_pca_data
from simulation.simulator import FDTDSimulator
from simulation.student import LSTMPredictor

# 加载数据集

In [2]:
class FDTDDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.FloatTensor(data)
        self.labels = torch.FloatTensor(labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


train_data, train_labels, test_data, test_labels = load_core_set_data()

train_dataset = FDTDDataset(train_data, train_labels)
test_dataset = FDTDDataset(test_data, test_labels)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)


In [3]:
radius_matrix = torch.rand(10, 10) * 10

print(radius_matrix)

tensor([[9.8255, 7.3699, 2.8815, 1.5214, 6.0255, 8.2805, 4.9387, 7.2811, 2.4591,
         3.9089],
        [2.0647, 2.0806, 6.3558, 2.2175, 8.8540, 2.2342, 6.0093, 8.0540, 1.7990,
         0.4360],
        [8.0274, 4.3012, 2.9524, 9.2272, 8.8699, 8.8332, 9.2471, 5.4805, 2.6575,
         6.1031],
        [9.7515, 8.3529, 1.8866, 3.3360, 2.4874, 2.4479, 2.3108, 1.9156, 4.2194,
         7.2804],
        [2.8730, 0.7302, 2.6532, 6.2523, 7.1793, 7.8691, 4.6819, 4.6316, 2.2732,
         8.7599],
        [5.4082, 3.4759, 3.0826, 2.8230, 6.7669, 2.7832, 7.1524, 9.9375, 8.7169,
         4.4255],
        [3.3696, 9.6699, 0.1420, 0.8477, 7.8621, 7.3240, 7.9770, 4.6472, 8.7734,
         0.1351],
        [4.6822, 4.4650, 5.7432, 3.8443, 6.9914, 6.0555, 1.5298, 9.7675, 5.2538,
         9.0771],
        [6.5631, 5.8581, 9.5429, 6.5055, 7.0407, 0.5696, 5.2386, 3.3710, 6.9758,
         9.2995],
        [9.5105, 2.2587, 4.7564, 8.3978, 5.8785, 7.3033, 5.8868, 6.6483, 4.4559,
         1.6768]])


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = LSTMPredictor(
    input_size=10,
    hidden_size=128,
    output_size=10,
    num_layers=2,
    dropout=0.1,
    device=device
)
simulator = FDTDSimulator(
    radius_matrix=radius_matrix
)

num_epochs = 100
learning_rate = 1e-3

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

In [5]:
model.train()

total_steps = num_epochs * len(train_loader)
progress_bar = tqdm(total=total_steps, desc="Training Progress")

for epoch in range(num_epochs):
    epoch_loss = 0
    
    for inputs, _ in train_loader:
        inputs = inputs.to(device)
        optimizer.zero_grad()

        with torch.no_grad():
            target_seq = simulator(inputs).detach()

        with torch.enable_grad():
            _, pred_seq = model.get_sequence_output(inputs)

            loss = criterion(pred_seq, target_seq)
            loss.backward()
            optimizer.step()
        
        epoch_loss += loss.item()
        # 更新进度条
        progress_bar.update(1)
        progress_bar.set_description(f"Loss: {loss.item():.6f}")
    
    avg_loss = epoch_loss / len(train_loader)
    tqdm.write(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.6f}")

progress_bar.close()

Loss: 0.049515:   1%|          | 300/30000 [47:53<82:22:25,  9.98s/it] 

Epoch 1/100, Average Loss: 0.040788


Loss: 0.036160:   1%|▏         | 438/30000 [1:09:20<78:23:42,  9.55s/it]

KeyboardInterrupt: 

In [None]:
save_dir = "data/model/lstm_model.pth"
os.makedirs(os.path.dirname(save_dir), exist_ok=True)

torch.save({
    "model_state_dict": model.state_dict(),
    "radius_matrix": radius_matrix
}, save_dir)