In [1]:
import torch
import time
import os
from tqdm import tqdm
import numpy as np
from dataset import core_data_loader
from model import DistillModel, StudentSequenceModel

In [2]:
# 设置缓存路径
dir_path = os.getcwd()
cache_path = os.path.join(dir_path, ".cache")
os.makedirs(cache_path, exist_ok=True)

print(f"current working dir: {dir_path}")
print(f"cache path: {cache_path}")

current working dir: /root/workspace
cache path: /root/workspace/.cache


In [3]:
# 加载数据
train_loader, test_loader = core_data_loader(eta=0.01, batch_size=1)

print(f"train data size: {len(train_loader)}")
print(f"test data size: {len(test_loader)}")

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Loading core data from /root/workspace/dataset/.cache/core/greedy/pca-components-10/batch-size-1-eta-0.01.pkl
train data size: 600


In [4]:
# 加载模型、优化器
radius_matrix = np.random.rand(10, 10)
print(radius_matrix)
model = DistillModel(radius_matrix, StudentSequenceModel).to(device)
optimizer = torch.optim.Adam(model.student_model.parameters(), lr=0.01)

[[0.74144352 0.93312757 0.39944224 0.44104345 0.81284844 0.64433732
  0.01969037 0.41854158 0.83051604 0.36676751]
 [0.66537331 0.42724964 0.79199806 0.16269146 0.09730517 0.12815412
  0.21394893 0.53558227 0.31730907 0.13952898]
 [0.59946679 0.70686517 0.52314236 0.35449604 0.53243085 0.78531035
  0.27728216 0.32957445 0.02195541 0.00550055]
 [0.23811725 0.55398471 0.8808384  0.3740212  0.61429861 0.83877077
  0.73148468 0.14524496 0.55022673 0.44976414]
 [0.47046916 0.63228118 0.73088256 0.16831211 0.47753573 0.82318034
  0.49314063 0.99890957 0.31838699 0.70840112]
 [0.44344174 0.64846021 0.50238308 0.1285439  0.50276071 0.51782969
  0.51579816 0.12860355 0.79742273 0.33012636]
 [0.35129276 0.69827958 0.02250548 0.10200456 0.48717925 0.15999549
  0.42714782 0.18456714 0.74224733 0.85790775]
 [0.97355415 0.82533421 0.21060241 0.49877147 0.45632135 0.46725178
  0.27105936 0.39653894 0.41981504 0.19208537]
 [0.29022756 0.11241783 0.95865382 0.80947836 0.89260983 0.78821077
  0.17257296

In [1]:
# 训练student模型
epochs = 10
with tqdm(total=epochs * len(train_loader)) as pbar:
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        with torch.enable_grad():
            for inputs, _ in train_loader:
                inputs = inputs.to(device)
                
                loss = model(inputs)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                pbar.update(1)
                pbar.set_postfix(loss=running_loss)
        print(f"epoch: {epoch}, loss: {running_loss}")

NameError: name 'tqdm' is not defined

In [None]:
# 保存模型

save_path = os.path.join(cache_path, f"student_model_{time.strftime('%Y-%m-%d-%H-%M-%S')}.pth")
torch.save({
    'radius_matrix': radius_matrix,
    'student_model': model.student_model.state_dict(),
}, save_path)

print(f"Saved to : {save_path}")

In [None]:
# 测试模型
model.eval()
model.set_simulation_mode(fdtd=True, lstm=True)
criterion = torch.nn.MSELoss()
total_loss = 0.0
total = 0.0
with torch.no_grad():
    for inputs, _ in test_loader:
        inputs = inputs.to(device)
        fdtd_output, lstm_output = model(inputs)
        loss = criterion(fdtd_output, lstm_output)
        total_loss += loss.item()
        total += fdtd_output.sum(dim=0).item()
        print(f"loss: {loss.item()}")
print(f"average loss: {total_loss / total}")