In [8]:
from model.ResNet34 import resnet34_1d
import torch
from util.util import load_data_percentage, SpecDataset, create_data_loaders
import torch.nn as nn 
import torch
import datetime
import os

%load_ext autoreload
%autoreload 2


In [9]:
now = datetime.datetime.now()
current_time = now.strftime("%Y_%m_%d_%H_%M_%S")
X, Y, Z = load_data_percentage('./data/X.npy', './data/Y.npy', './data/Z.npy', percentage=100) 
config = {
    'model_type': 'resnet34',
    'current_time': current_time,
    'epoch': 250,
    'batch_size': 32,
    'cuda_device': torch.device("cuda:3")
}

dataset = SpecDataset(X, Y, Z, config)

In [10]:
device = config['cuda_device']
num_epochs = config['epoch']


model = resnet34_1d()
model.to(device)

best_val_loss = float('inf')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

os.makedirs('./log', exist_ok=True)
os.makedirs('./weights', exist_ok=True)
criterion = nn.MSELoss()

train_loader, val_loader, test_loader = create_data_loaders(dataset, batch_size=config['batch_size'])

with open(f"./log/{config['current_time']}_{config['model_type']}_{config['epoch']}.txt", "a") as log_file:
    for epoch in range(num_epochs):

        model.train()
        train_loss = 0
        for batch_idx, (x, _, z) in enumerate(train_loader):
            x, z = x.to(device), z.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, z)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            if (batch_idx + 1) % 10 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
        avg_train_loss = train_loss / len(train_loader)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, _, z in val_loader:
                x, z = x.to(device), z.to(device)
                outputs = model(x)
                loss = criterion(outputs, z)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)

        # Test
        test_loss = 0
        with torch.no_grad():
            for x, _, z in test_loader:
                x, z = x.to(device), z.to(device)
                outputs = model(x)
                loss = criterion(outputs, z)
                test_loss += loss.item()
        
        avg_test_loss = test_loss / len(test_loader)

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Test Loss: {avg_test_loss:.4f}')
        
        # Log the losses of each epoch to the file
        log_file.write(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Test Loss: {avg_test_loss:.4f}\n')
        log_file.flush() 

        # Save the best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_name = f"./weights/{config['current_time']}_{config['model_type']}_best.pth"
            torch.save(model.state_dict(), best_model_name)
            print(f"Saved best model with validation loss: {best_val_loss:.4f}")

print("Training finished!")

# Save the last model
last_model_name = f"./weights/{config['current_time']}_{config['model_type']}_last.pth"
torch.save(model.state_dict(), last_model_name)
print(f"Saved last model")

# Final evaluation
model.eval()
with torch.no_grad():
    final_test_loss = 0
    for x, _, z in test_loader:
        x, z = x.to(device), z.to(device)
        outputs = model(x)
        loss = criterion(outputs, z)
        final_test_loss += loss.item()
    
    avg_final_test_loss = final_test_loss / len(test_loader)
    print(f"Final Test Loss: {avg_final_test_loss:.4f}")

Epoch [1/1], Step [10/405], Loss: 15688.6846
Epoch [1/1], Step [20/405], Loss: 23526.7324
Epoch [1/1], Step [30/405], Loss: 19239.1875
Epoch [1/1], Step [40/405], Loss: 21241.3105
Epoch [1/1], Step [50/405], Loss: 25979.8887
Epoch [1/1], Step [60/405], Loss: 25614.5332
Epoch [1/1], Step [70/405], Loss: 18301.7500
Epoch [1/1], Step [80/405], Loss: 18509.6777
Epoch [1/1], Step [90/405], Loss: 21134.0742
Epoch [1/1], Step [100/405], Loss: 16733.2031
Epoch [1/1], Step [110/405], Loss: 13843.2617
Epoch [1/1], Step [120/405], Loss: 22475.9727
Epoch [1/1], Step [130/405], Loss: 25172.7988
Epoch [1/1], Step [140/405], Loss: 16109.7803
Epoch [1/1], Step [150/405], Loss: 21423.6621
Epoch [1/1], Step [160/405], Loss: 18699.7734
Epoch [1/1], Step [170/405], Loss: 18864.5664
Epoch [1/1], Step [180/405], Loss: 18897.2051
Epoch [1/1], Step [190/405], Loss: 19016.1523
Epoch [1/1], Step [200/405], Loss: 26219.0762
Epoch [1/1], Step [210/405], Loss: 19252.7227
Epoch [1/1], Step [220/405], Loss: 20864.59