In [1]:
import torch
import time
import os
from tqdm import tqdm
import numpy as np
from src.dataset import core_data_loader
from src.model import DistillModel, StudentSequenceModel

In [2]:
# 设置缓存路径
dir_path = os.getcwd()
cache_path = os.path.join(dir_path, ".cache")
os.makedirs(cache_path, exist_ok=True)

print(f"current working dir: {dir_path}")
print(f"cache path: {cache_path}")

current working dir: /root/pyDistilledFDTD
cache path: /root/pyDistilledFDTD/.cache


In [3]:
# 加载数据
train_loader, test_loader = core_data_loader(eta=0.01, batch_size=1)

print(f"train data size: {len(train_loader)}")
print(f"test data size: {len(test_loader)}")

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Loading core data from /root/pyDistilledFDTD/dataset/.cache/core/greedy/pca-components-10/batch-size-1-eta-0.01.pkl
train data size: 600
test data size: 10000


In [4]:
# 加载模型、优化器
radius_matrix = np.random.rand(10, 10)
print(radius_matrix)
model = DistillModel(radius_matrix, StudentSequenceModel).to(device)
optimizer = torch.optim.Adam(model.student_model.parameters(), lr=0.01)

[[0.0841148  0.31256194 0.23945658 0.13212715 0.07424489 0.9855788
  0.85097928 0.51481715 0.12865044 0.56303219]
 [0.43497856 0.82494826 0.86788573 0.10375143 0.09723539 0.18044094
  0.09320979 0.92111445 0.36111127 0.65310328]
 [0.82586661 0.96721154 0.41039585 0.35956671 0.27274564 0.66721724
  0.99516399 0.93233188 0.29265336 0.81817654]
 [0.80969515 0.15622216 0.54194143 0.49552312 0.65293209 0.34732713
  0.47953365 0.03801998 0.86377679 0.17622188]
 [0.74031416 0.35443611 0.98661807 0.94338242 0.20615526 0.45194494
  0.00926944 0.816122   0.39067166 0.12424215]
 [0.82505376 0.02891013 0.8704457  0.31029998 0.73462796 0.82354464
  0.07291546 0.15312064 0.46156672 0.01916062]
 [0.38615059 0.73439155 0.76119317 0.29808613 0.15549912 0.10429426
  0.35488086 0.48662795 0.17148327 0.46923418]
 [0.70432705 0.59026092 0.31907682 0.20005936 0.20742479 0.37161586
  0.69760461 0.1079499  0.72966881 0.45409293]
 [0.87408946 0.10738391 0.03695502 0.80038825 0.29943694 0.1779864
  0.04275086 0

In [5]:
# 训练student模型
epochs = 10
with tqdm(total=epochs * len(train_loader)) as pbar:
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        with torch.enable_grad():
            for inputs, _ in train_loader:
                inputs = inputs.to(device)
                
                loss = model(inputs)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                pbar.update(1)
                pbar.set_postfix(loss=running_loss)
        print(f"epoch: {epoch}, loss: {running_loss}")

 10%|█         | 602/6000 [3:19:31<16:37:18, 11.09s/it, loss=1.46e-7]

epoch: 0, loss: 0.02212213845563494


 20%|██        | 1203/6000 [3:20:12<05:21, 14.93it/s, loss=2.09e-7]  

epoch: 1, loss: 6.565220685679181e-05


 30%|███       | 1802/6000 [3:20:52<04:47, 14.59it/s, loss=2.77e-7]

epoch: 2, loss: 7.010195126013484e-05


 40%|████      | 2403/6000 [3:21:33<03:58, 15.08it/s, loss=3.26e-7] 

epoch: 3, loss: 0.00010183143679172813


 50%|█████     | 3002/6000 [3:22:14<03:27, 14.44it/s, loss=4.53e-7] 

epoch: 4, loss: 0.0001086599355312437


 60%|██████    | 3603/6000 [3:22:55<02:42, 14.78it/s, loss=7.11e-7] 

epoch: 5, loss: 0.00011136552088588767


 70%|███████   | 4203/6000 [3:23:36<02:01, 14.79it/s, loss=7.06e-7] 

epoch: 6, loss: 0.00012844397995794043


 80%|████████  | 4802/6000 [3:24:17<01:27, 13.77it/s, loss=1.04e-6] 

epoch: 7, loss: 0.00012277293219905202


 90%|█████████ | 5402/6000 [3:24:58<00:40, 14.87it/s, loss=1.69e-7] 

epoch: 8, loss: 0.00011635792250668723


100%|██████████| 6000/6000 [3:25:38<00:00,  2.06s/it, loss=0.000113]

epoch: 9, loss: 0.00011258437696693945





In [6]:
# 保存模型

save_path = os.path.join(cache_path, f"student_model_{time.strftime('%Y-%m-%d-%H-%M-%S')}.pth")
torch.save({
    'radius_matrix': radius_matrix,
    'student_model': model.student_model.state_dict(),
}, save_path)

print(f"Saved to : {save_path}")

Saved to : /root/pyDistilledFDTD/.cache/student_model_2024-10-08-18-04-11.pth


In [8]:
# 测试模型
model.eval()
model.set_simulation_mode(fdtd=True, lstm=True)
criterion = torch.nn.MSELoss()
total_loss = 0.0
total = 0.0
with torch.no_grad():
    for inputs, _ in test_loader:
        inputs = inputs.to(device)
        fdtd_output, lstm_output = model(inputs)
        loss = criterion(fdtd_output, lstm_output)
        total_loss += loss.item()
        total += fdtd_output.sum().item()
        print(f"loss: {loss.item()}")
print(f"average loss: {total_loss / total}")

loss: 0.12154486958248022
loss: 0.11878986641578358
loss: 0.12319986217161474
loss: 0.1133796394792897
loss: 0.11723643475887094
loss: 0.12315040310169963
loss: 0.12444731747123434
loss: 0.11585749357702375
loss: 0.12100962278551859
loss: 0.11612048275957143
loss: 0.1266953444100639
loss: 0.12054311755218068
loss: 0.11533282352613448
loss: 0.12235594493172851
loss: 0.12146662343681879
loss: 0.12774961703148144
loss: 0.12016510502750762
loss: 0.1200572242207406
loss: 0.11896133818759394
loss: 0.11551049090727816
loss: 0.11781998277548039
loss: 0.11329768501671966
loss: 0.11815767659167593
loss: 0.1149401695309988
loss: 0.12065743503152232
loss: 0.11500962190833582
loss: 0.12122953550462141
loss: 0.11249078668172977
loss: 0.12254717930094108
loss: 0.12346451278118581
loss: 0.12407964753889625
loss: 0.12422780364958537
loss: 0.12326301103203266
loss: 0.11862041000356956
loss: 0.1194855605359813
loss: 0.11664039042756112
loss: 0.12281664913341142
loss: 0.12268863417628395
loss: 0.125587742