In [1]:
from model.ResNet_attention import ResidualAttentionModel_92_32input_update
import torch
from util.util import load_data_percentage, SpecDataset, create_data_loaders
import torch.nn as nn 
import torch
import datetime
import os

%load_ext autoreload
%autoreload 2

In [2]:
now = datetime.datetime.now()
current_time = now.strftime("%Y_%m_%d_%H_%M_%S")
X, Y, Z = load_data_percentage('./data/X.npy', './data/Y.npy', './data/Z.npy', percentage=100) 
config = {
    'model_type': 'resnet_attention',
    'current_time': current_time,
    'epoch': 250,
    'batch_size': 32,
    'cuda_device': torch.device("cuda:5")
}

dataset = SpecDataset(X, Y, Z, config)

batch size is 10, dim is 2, length is 32768


In [3]:
device = config['cuda_device']
num_epochs = config['epoch']


model = ResidualAttentionModel_92_32input_update()
model.to(device)

best_val_loss = float('inf')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

os.makedirs('./log', exist_ok=True)
os.makedirs('./weights', exist_ok=True)
criterion = nn.MSELoss()

train_loader, val_loader, test_loader = create_data_loaders(dataset, batch_size=config['batch_size'])

with open(f"./log/{config['current_time']}_{config['model_type']}_{config['epoch']}.txt", "a") as log_file:
    for epoch in range(num_epochs):

        model.train()
        train_loss = 0
        for batch_idx, (x, _, z) in enumerate(train_loader):
            x, z = x.to(device), z.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, z)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            if (batch_idx + 1) % 10 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
        avg_train_loss = train_loss / len(train_loader)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, _, z in val_loader:
                x, z = x.to(device), z.to(device)
                outputs = model(x)
                loss = criterion(outputs, z)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)

        # Test
        test_loss = 0
        with torch.no_grad():
            for x, _, z in test_loader:
                x, z = x.to(device), z.to(device)
                outputs = model(x)
                loss = criterion(outputs, z)
                test_loss += loss.item()
        
        avg_test_loss = test_loss / len(test_loader)

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Test Loss: {avg_test_loss:.4f}')
        
        # Log the losses of each epoch to the file
        log_file.write(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Test Loss: {avg_test_loss:.4f}\n')
        log_file.flush() 
        
        # Save the best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_name = f"./weights/{config['current_time']}_{config['model_type']}_best.pth"
            torch.save(model.state_dict(), best_model_name)
            print(f"Saved best model with validation loss: {best_val_loss:.4f}")

print("Training finished!")

# Save the last model
last_model_name = f"./weights/{config['current_time']}_{config['model_type']}_last.pth"
torch.save(model.state_dict(), last_model_name)
print(f"Saved last model")

# Final evaluation
model.eval()
with torch.no_grad():
    final_test_loss = 0
    for x, _, z in test_loader:
        x, z = x.to(device), z.to(device)
        outputs = model(x)
        loss = criterion(outputs, z)
        final_test_loss += loss.item()
    
    avg_final_test_loss = final_test_loss / len(test_loader)
    print(f"Final Test Loss: {avg_final_test_loss:.4f}")

Epoch [1/250], Train Loss: 8907.9531, Val Loss: 2507.2625, Test Loss: 17007.9609
Saved best model with validation loss: 2507.2625
Epoch [2/250], Train Loss: 8892.9404, Val Loss: 2504.6936, Test Loss: 17001.3516
Saved best model with validation loss: 2504.6936
Epoch [3/250], Train Loss: 8878.9688, Val Loss: 2502.4429, Test Loss: 16996.3945
Saved best model with validation loss: 2502.4429
Epoch [4/250], Train Loss: 8865.7236, Val Loss: 2500.8550, Test Loss: 16993.6953
Saved best model with validation loss: 2500.8550
Epoch [5/250], Train Loss: 8849.0859, Val Loss: 2500.0754, Test Loss: 16993.7227
Saved best model with validation loss: 2500.0754
Epoch [6/250], Train Loss: 8829.0957, Val Loss: 2500.1338, Test Loss: 16996.5645
Epoch [7/250], Train Loss: 8809.8945, Val Loss: 2500.8984, Test Loss: 17001.8301
Epoch [8/250], Train Loss: 8788.7256, Val Loss: 2502.3777, Test Loss: 17009.6465
Epoch [9/250], Train Loss: 8767.0381, Val Loss: 2504.6909, Test Loss: 17020.1055
Epoch [10/250], Train Loss