In [1]:
import torch
import time
import os
from tqdm import tqdm
import numpy as np
from src.dataset import core_data_loader
from src.model import DistillModel, StudentLSTMModule
import config as cfg

In [2]:
train_loader, test_loader = core_data_loader(eta=0.01, batch_size=1)

print('train_loader: ', len(train_loader))
print('test_loader: ', len(test_loader))

Loading core data from F:\pyDistilledFDTD\data\cache\core_data_loader\greedy\pca-components-10\batch-size-1-eta-0.01.pkl
train_loader:  600
test_loader:  10000


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
radius_matrix = np.random.rand(10, 10)
radius_matrix[radius_matrix < 0.3] = 0
print('radius_matrix: ', radius_matrix.flatten())

student_model = StudentLSTMModule(input_size=10, hidden_size=128, output_size=10, num_layer=2)
model = DistillModel(radius_matrix, student_model, expand_method='sin')
model = model.to(device)

optimizer = torch.optim.Adam(model.student_model.parameters(), lr=0.001)

radius_matrix:  [0.82165994 0.         0.6864635  0.81904375 0.67695378 0.38863517
 0.84291541 0.         0.62167543 0.41282549 0.98717967 0.84503808
 0.81256787 0.         0.64648234 0.71247146 0.53465788 0.
 0.81691363 0.86312937 0.81148358 0.         0.51422165 0.58486831
 0.36441316 0.94109372 0.         0.67398161 0.78479876 0.39157251
 0.91648447 0.         0.76754852 0.45025886 0.59141299 0.54177986
 0.         0.48506218 0.89734392 0.         0.32146157 0.45711014
 0.44443593 0.65867254 0.41297662 0.         0.         0.42931176
 0.55526522 0.77691304 0.81009626 0.42011165 0.45555794 0.
 0.         0.         0.47847764 0.82664114 0.         0.96371125
 0.90487697 0.         0.52596768 0.         0.47302456 0.50633056
 0.41210205 0.79473886 0.         0.32727209 0.         0.44322516
 0.         0.79650177 0.         0.73142252 0.         0.80810544
 0.53252688 0.7204535  0.         0.50184238 0.4591588  0.66266283
 0.73488166 0.41460809 0.50734342 0.88801483 0.         0.
 0.

In [None]:
epochs = 100
last_loss = 0.0
with tqdm(total=epochs * len(train_loader)) as pbar:
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        with torch.enable_grad():
            for inputs, _ in train_loader:
                inputs = inputs.to(device)
                optimizer.zero_grad()
                loss = model(inputs)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                pbar.update(1)
                pbar.set_postfix({'data_loss':loss.item(), 'loss': last_loss})
        last_loss = running_loss

  0%|          | 0/60000 [00:00<?, ?it/s]

In [None]:
save_path = os.path.join(cfg.PROCESSED_DATA_DIR, f"param-{time.strftime('%Y-%m-%d-%H-%M-%S')}.pth")

torch.save({
    'radius': radius_matrix,
    'model_state_dict': model.student_model.state_dict()
}, save_path)