In [1]:
import torch
import pickle
import random
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch_geometric.nn as gnn

from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data

In [2]:
from net import PSOGNN
from pso import PSO 
from base_function import *

In [3]:
splited_path = r'A:\Code\deepso\splited_data_no.pkl'
with open(splited_path, 'rb') as f:
    dataset = pickle.load(f)


In [4]:
train_set = dataset['train']
# val_set = dataset['validation']
test_set = dataset['test']
print(train_set[0])


{'dim': 1, 'func_type': 'ackley', 'params': [16.98135401917914, 0.28507902333427554, 5.580495503135101]}


In [5]:
def create_batches(dataset, batch_size): 
    random.shuffle(dataset)
    for i in range(0, len(dataset), batch_size): 
        batch = dataset[i:i + batch_size]
        yield batch

In [6]:
import torch
print(torch.cuda.is_available())  # Should return True


True


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = torch.cuda.is_available() 

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

batch_size = 16
mini_batch_size = 4
num_epochs = 10
UB = 1
LB = 0
num_particles = 100

def save_model(model, optimizer, epoch, path="model_checkpoint.pth"):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)
    print(f"Model saved at epoch {epoch} to {path}")

def load_model(model, optimizer, path="model_checkpoint.pth"):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    print(f"Model loaded from {path}, resuming from epoch {epoch}")
    return epoch


def train(dataloader, accumulate_step, save_path="model_checkpoint.pth"):
    epoch_losses = []
    batch_losses = {}
    function_losses = {}
    function_loss_ranges = {}

    patience = 5  
    min_loss = float('inf')
    epochs_no_improve = 0

    max_norm = 1.0  

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        total_loss = 0
        batch_count = 0
        function_loader = create_batches(dataloader, batch_size)

        for batch in function_loader:
            for i in range(0, len(batch), mini_batch_size):
                mini_batch = batch[i : i + mini_batch_size]
                for func in mini_batch:
                    dim = func['dim']
                    func_type = func['func_type']
                    params = func['params']
                    model = PSOGNN(node_input_dim=dim).to(device)
                    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
                    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)
                    scaler = torch.amp.GradScaler('cuda') if use_amp else None
                    def function(x):
                        function_instance = Function.get_function(func_type, x, params)
                        return function_instance.evaluate_function()

                    X = torch.rand(num_particles, dim).to(device)

                    Pso = PSO(X, function, model, LB, UB)
                    position_best, best, mean_fitness = Pso.run()

                    if use_amp:
                        with torch.amp.autocast('cuda'): 
                            loss = mean_fitness / len(mini_batch) 
                            loss = loss / accumulate_step  
                            scaler.scale(loss).backward()  
                    else:
                        loss = mean_fitness
                        loss = loss / accumulate_step
                        loss.backward()


                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

                    if (i + 1) % accumulate_step == 0:
                        if use_amp:
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            optimizer.step()
                        optimizer.zero_grad()

                    if func_type not in function_losses:
                        function_losses[func_type] = []
                    function_losses[func_type].append(loss.item())

                    if epoch not in batch_losses:
                        batch_losses[epoch] = []
                    batch_losses[epoch].append(loss.item())

            total_loss += loss.item()
            batch_count += 1

        avg_loss = total_loss / batch_count if batch_count > 0 else 0
        epoch_losses.append(avg_loss)

        scheduler.step(avg_loss)  

        if avg_loss < min_loss:
            min_loss = avg_loss
            epochs_no_improve = 0 
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

        save_model(model, optimizer, epoch+1, save_path)

    print("Training complete!")
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Average Loss')
    plt.title('Average Loss over Epochs')
    plt.show()

    for func_type, losses in function_losses.items():
        min_loss = min(losses)
        max_loss = max(losses)
        function_loss_ranges[func_type] = (min_loss, max_loss)
        print(f"Function Type: {func_type}, Loss Range: Min={min_loss:.4f}, Max={max_loss:.4f}")

    return function_loss_ranges



def test(dataloader, model_path="model_checkpoint.pth"):

    test_loss = 0
    batch_count = 0
    all_test_losses = []

    function_loader = create_batches(dataloader, batch_size)
    
    with torch.no_grad():
        for batch in function_loader:
            for mini_batch in batch:
                dim = mini_batch['dim']
                func_type = mini_batch['func_type']
                params = mini_batch['params']
                model = PSOGNN(node_input_dim=dim).to(device)
                optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
                load_model(model, optimizer, model_path)
                model.eval()
                
                def function(x):
                    function_instance = Function.get_function(func_type, x, params)
                    return function_instance.evaluate_function()

                X = torch.rand(num_particles, dim).to(device)
                
                Pso = PSO(X, function, model, LB, UB)
                position_best, best, mean_fitness = Pso.run()

                loss = mean_fitness
                all_test_losses.append(loss.item())
                
                test_loss += loss.item()
                batch_count += 1

    avg_test_loss = test_loss / batch_count if batch_count > 0 else 0
    print(f"Average Test Loss: {avg_test_loss:.4f}")

    # Plot test losses
    plt.figure(figsize=(10, 6))
    plt.plot(all_test_losses, marker='o')
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.title('Loss per Batch on Test Set')
    plt.show()

    return avg_test_loss


In [None]:
accumulate_step = 4 

print("Starting Training...")
train_loss_ranges = train(train_set, accumulate_step, save_path="model_checkpoint.pth")


Starting Training...
Epoch 1/10




In [10]:

print("\nStarting Testing...")
avg_test_loss = test(test_set, model_path="model_checkpoint.pth")


Starting Testing...


UnboundLocalError: cannot access local variable 'dim' where it is not associated with a value

: 

In [None]:
# Example usage:
# Assuming 'train_dataloader' and 'test_dataloader' are provided
accumulate_step = 4  # Example accumulate step

print("Starting Training...")
train_loss_ranges = train(train_set, accumulate_step)

print("\nStarting Testing...")
avg_test_loss = test(test_set)

In [None]:
# Example: Start fresh or load from a saved model
# model, optimizer, start_epoch, epoch_losses, val_losses = load_checkpoint(MODEL_PATH)

# Call to train (either fresh or continue from checkpoint)
train(train_set, val_set, accumulate_step=4, start_epoch=0)

In [None]:
# batch_size = 16
# mini_batch_size = 4
# num_epochs = 10

# UB = 1
# LB = 0
# num_particles = 100

# import matplotlib.pyplot as plt

# def train(dataloader, accumulate_step):
#     epoch_losses = []
#     batch_losses = {}
#     function_losses = {}
#     function_loss_ranges = {}

#     scaler = torch.cuda.amp.GradScaler()  # Initialize the GradScaler for mixed precision

#     for epoch in range(num_epochs):
#         print(f"Epoch {epoch+1}/{num_epochs}")
#         total_loss = 0
#         batch_count = 0
#         function_loader = create_batches(dataloader, batch_size)

#         for batch in function_loader:
#             for i in range(0, len(batch), mini_batch_size):
#                 mini_batch = batch[i : i + mini_batch_size]
#                 for func in mini_batch:
#                     dim = func['dim']
#                     func_type = func['func_type']
#                     params = func['params']

#                     def function(x):
#                         function_instance = Function.get_function(func_type, x, params)
#                         return function_instance.evaluate_function()

#                     X = torch.rand(num_particles, dim).to(device)
#                     model = PSOGNN(node_input_dim=dim).to(device)
#                     optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

#                     Pso = PSO(X, function, model, LB, UB)
#                     position_best, best, mean_fitness = Pso.run()

#                     with torch.cuda.amp.autocast():  # Mixed precision training
#                         loss = mean_fitness / len(mini_batch)  # Normalize by batch size
#                         loss = loss / accumulate_step  # Scale by accumulate step
#                         scaler.scale(loss).backward()  # Scale and backpropagate

#                     if (i + 1) % accumulate_step == 0:
#                         scaler.step(optimizer)  # Apply optimizer step
#                         scaler.update()
#                         optimizer.zero_grad()

#                     if func_type not in function_losses:
#                         function_losses[func_type] = []
#                     function_losses[func_type].append(loss.item())

#                     if epoch not in batch_losses:
#                         batch_losses[epoch] = []
#                     batch_losses[epoch].append(loss.item())

#                     if torch.isinf(loss).any() or torch.isnan(loss).any():
#                         print(f"Warning: Loss is {'-inf' if torch.isinf(loss).any() else 'NaN'} at Epoch {epoch+1}, Step {i+1}")
#                         break

#             total_loss += loss.item()
#             if torch.isinf(torch.tensor(total_loss)) or torch.isnan(torch.tensor(total_loss)):
#                 print(f"Warning: total_loss is {'-inf' if torch.isinf(torch.tensor(total_loss)) else 'NaN'} at Epoch {epoch+1}, Step {i+1}")
#                 break
#             batch_count += 1

#         avg_loss = total_loss / batch_count if batch_count > 0 else 0
#         epoch_losses.append(avg_loss)

#         if torch.isinf(torch.tensor(avg_loss)) or torch.isnan(torch.tensor(avg_loss)):
#             print(f"Warning: avg_loss is {'-inf' if torch.isinf(torch.tensor(avg_loss)) else 'NaN'} at Epoch {epoch+1}")

#         print(f"Avg Loss Epoch {epoch+1}: {avg_loss:.4f}")

#     print("Training complete!")
#     plt.figure(figsize=(10, 6))
#     plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o')
#     plt.xlabel('Epoch')
#     plt.ylabel('Average Loss')
#     plt.title('Average Loss over Epochs')
#     plt.show()
    
#     # Print the min-max range for each function type
#     for func_type, losses in function_losses.items():
#         min_loss = min(losses)
#         max_loss = max(losses)
#         function_loss_ranges[func_type] = (min_loss, max_loss)
#         print(f"Function Type: {func_type}, Loss Range: Min={min_loss:.4f}, Max={max_loss:.4f}")

#     return function_loss_ranges


In [7]:
# import torch
# import matplotlib.pyplot as plt
# import os

# # Initialize parameters
# batch_size = 16
# mini_batch_size = 4
# num_epochs = 15
# UB = 1
# LB = 0
# num_particles = 100

# # Path to save the model
# MODEL_PATH = r'A:\Code\deepso\check_point.pth'

# # Define the evaluation function
# def evaluate(dataloader, device):
#     total_loss = 0
#     batch_count = 0

#     with torch.no_grad():  # Disable gradient calculation
#         for batch in create_batches(dataloader, batch_size):
#             for i in range(0, len(batch), mini_batch_size):
#                 mini_batch = batch[i: i + mini_batch_size]
#                 for func in mini_batch:
#                     dim = func['dim']  # Get the dimension from the function data
#                     func_type = func['func_type']
#                     params = func['params']

#                     # Define the function
#                     def function(x):
#                         function_instance = Function.get_function(func_type, x, params)
#                         return function_instance.evaluate_function()
                    
#                     # Re-create the model for each 'func' based on the 'dim'
#                     model = PSOGNN(node_input_dim=dim).to(device)

#                     # Initialize PSO with this model
#                     X = torch.rand(num_particles, dim).to(device)
#                     Pso = PSO(X, function, model, LB, UB)
#                     _, _, mean_fitness = Pso.run()

#                     loss = mean_fitness
#                     total_loss += loss.item()
#                     batch_count += 1

#     avg_loss = total_loss / batch_count if batch_count > 0 else 0
#     print(f"Evaluation Avg Loss: {avg_loss:.4f}")
#     return avg_loss


# # Define the training function with model saving
# def train(dataloader, validation_loader, accumulate_step, start_epoch = 0):
#     epoch_losses = []  # To store average loss for each epoch
#     val_losses = []    # To store validation loss for each epoch
#     function_losses = {}  # To store loss per function type
#     batch_losses = {}  # To store batch loss per epoch

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     for epoch in range(start_epoch, num_epochs):
#         print(f"Epoch {epoch+1}/{num_epochs}")
#         total_loss = 0
#         batch_count = 0
#         batch_losses[epoch] = []  # Initialize batch loss tracking for this epoch

#         for batch in create_batches(dataloader, batch_size):
#             for i in range(0, len(batch), mini_batch_size):
#                 mini_batch = batch[i: i + mini_batch_size]
#                 for func in mini_batch:
#                     dim = func['dim']  # Get the dimension from the function data
#                     func_type = func['func_type']
#                     params = func['params']

#                     # Define the function
#                     def function(x):
#                         function_instance = Function.get_function(func_type, x, params)
#                         return function_instance.evaluate_function()

#                     # Re-create the model for each 'func' based on the 'dim'
#                     model = PSOGNN(node_input_dim=dim).to(device)
#                     optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

#                     # Initialize PSO with this model
#                     X = torch.rand(num_particles, dim).to(device)
#                     Pso = PSO(X, function, model, LB, UB)
#                     position_best, best, mean_fitness = Pso.run()

#                     loss = mean_fitness / accumulate_step
#                     loss.backward()

#                     if func_type not in function_losses:
#                         function_losses[func_type] = []  # Initialize list if func_type is not tracked
#                     function_losses[func_type].append(loss.item())  # Track loss per function type
                    
#                     batch_losses[epoch].append(loss.item())  # Track loss per batch for this epoch

#                     if (i + 1) % accumulate_step == 0:
#                         optimizer.step()
#                         optimizer.zero_grad()

#                     # Check for invalid loss values
#                     if torch.isinf(loss).any() or torch.isnan(loss).any():
#                         print(f"Warning: Loss is {'-inf' if torch.isinf(loss).any() else 'NaN'} at Epoch {epoch+1}, Step {i+1}")
#                         break

#             total_loss += loss.item()
#             batch_count += 1
        
#         avg_loss = total_loss / batch_count if batch_count > 0 else 0
#         epoch_losses.append(avg_loss)  # Track average loss for this epoch
        
#         if torch.isinf(torch.tensor(avg_loss)) or torch.isnan(torch.tensor(avg_loss)):
#             print(f"Warning: avg_loss is {'-inf' if torch.isinf(torch.tensor(avg_loss)) else 'NaN'} at Epoch {epoch+1}")
        
#         print(f"Avg Loss Epoch {epoch+1}: {avg_loss:.4f}")

        
#         # Evaluate on the validation set after each epoch
#         val_loss = evaluate(validation_loader, device)
#         val_losses.append(val_loss)

#         # Save model checkpoint after each epoch
#         save_checkpoint(model, optimizer, epoch, epoch_losses, val_losses, MODEL_PATH)

#     print("Training complete!")

#     # Plot average loss over epochs
#     plt.figure(figsize=(10, 6))
#     plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o', label='Training Loss')
#     plt.plot(range(1, num_epochs + 1), val_losses, marker='o', label='Validation Loss', color='red')
#     plt.xlabel('Epoch')
#     plt.ylabel('Average Loss')
#     plt.title('Average Loss over Epochs')
#     plt.legend()
#     plt.show()

#     # Plot loss per function type across batches
#     for func_type, losses in function_losses.items():
#         plt.figure(figsize=(10, 6))
#         plt.plot(losses, marker='o')
#         plt.xlabel('Batch')
#         plt.ylabel('Loss')
#         plt.title(f'Loss per Batch for {func_type}')
#         plt.show()

#     # Plot batch loss per epoch
#     for epoch, losses in batch_losses.items():
#         plt.figure(figsize=(10, 6))
#         plt.plot(losses, marker='o')
#         plt.xlabel('Batch')
#         plt.ylabel('Loss')
#         plt.title(f'Loss per Batch for Epoch {epoch+1}')
#         plt.show()

# # Save checkpoint function
# def save_checkpoint(model, optimizer, epoch, epoch_losses, val_losses, path):
#     torch.save({
#         'epoch': epoch,
#         'model_state_dict': model.state_dict(),
#         'optimizer_state_dict': optimizer.state_dict(),
#         'epoch_losses': epoch_losses,
#         'val_losses': val_losses
#     }, path)
#     print(f"Model saved at {path}")

# # Load checkpoint function
# def load_checkpoint(path):
#     if os.path.exists(path):
#         checkpoint = torch.load(path)
#         model = PSOGNN(node_input_dim=dim).to(device)
#         optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#         model.load_state_dict(checkpoint['model_state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#         start_epoch = checkpoint['epoch'] + 1
#         epoch_losses = checkpoint['epoch_losses']
#         val_losses = checkpoint['val_losses']
#         print(f"Loaded model from {path}, starting from epoch {start_epoch}")
#         return model, optimizer, start_epoch, epoch_losses, val_losses
#     else:
#         print(f"No checkpoint found at {path}")
#         return None, None, 0, [], []




In [None]:

# train(train_set, accumulate_step=4)

In [None]:
# import torch
# import matplotlib.pyplot as plt

# batch_size = 16
# mini_batch_size = 4
# num_particles = 100
# UB = 1
# LB = 0

# def test(test_loader):
#     total_loss = 0
#     batch_count = 0
#     batch_losses = []  # Track loss per batch
#     function_losses = {}  # Track loss per function type

#     with torch.no_grad():  # Disable gradient computation for evaluation
#         for batch in create_batches(test_loader, batch_size):
#             for i in range(0, len(batch), mini_batch_size):
#                 mini_batch = batch[i: i + mini_batch_size]
#                 for func in mini_batch:
#                     dim = func['dim']
#                     func_type = func['func_type']
#                     params = func['params']

#                     def function(x):
#                         function_instance = Function.get_function(func_type, x, params)
#                         return function_instance.evaluate_function()

#                     X = torch.rand(num_particles, dim).to(device)

#                     # Khởi tạo mô hình dựa trên kích thước đầu vào của hàm func
#                     model = PSOGNN(node_input_dim=dim).to(device)

#                     # Không cần optimizer ở đây vì không có cập nhật trọng số
#                     Pso = PSO(X, function, model, LB, UB)
#                     position_best, best, mean_fitness = Pso.run()

#                     loss = mean_fitness

#                     if func_type not in function_losses:
#                         function_losses[func_type] = []  # Initialize if not tracked
#                     function_losses[func_type].append(loss.item())  # Track loss per function type

#                     batch_losses.append(loss.item())  # Track batch losses

#                     total_loss += loss.item()
#                     batch_count += 1

#     avg_loss = total_loss / batch_count if batch_count > 0 else 0
#     print(f"Test Avg Loss: {avg_loss:.4f}")

#     # Plot loss per function type across batches
#     for func_type, losses in function_losses.items():
#         plt.figure(figsize=(10, 6))
#         plt.plot(losses, marker='o')
#         plt.xlabel('Batch')
#         plt.ylabel('Loss')
#         plt.title(f'Test Loss per Batch for {func_type}')
#         plt.show()

#     # Plot batch loss for the whole test set
#     plt.figure(figsize=(10, 6))
#     plt.plot(batch_losses, marker='o')
#     plt.xlabel('Batch')
#     plt.ylabel('Loss')
#     plt.title('Test Loss per Batch')
#     plt.show()

#     return avg_loss

# # Example call to test
# test_set = dataset['test']
# test(test_set)



In [None]:
# import torch
# import matplotlib.pyplot as plt
# import os

# # Initialize parameters
# batch_size = 16
# mini_batch_size = 4
# num_epochs = 15
# UB = 1
# LB = 0
# num_particles = 100

# # Path to save the model
# MODEL_PATH = r'A:\Code\deepso\check_point.pth'

# # Define the evaluation function
# def evaluate(dataloader, device):
#     total_loss = 0
#     batch_count = 0

#     with torch.no_grad():  # Disable gradient calculation
#         for batch in create_batches(dataloader, batch_size):
#             for i in range(0, len(batch), mini_batch_size):
#                 mini_batch = batch[i: i + mini_batch_size]
#                 for func in mini_batch:
#                     dim = func['dim']  # Get the dimension from the function data
#                     func_type = func['func_type']
#                     params = func['params']

#                     # Define the function
#                     def function(x):
#                         function_instance = Function.get_function(func_type, x, params)
#                         return function_instance.evaluate_function()
                    
#                     # Re-create the model for each 'func' based on the 'dim'
#                     model = PSOGNN(node_input_dim=dim).to(device)

#                     # Initialize PSO with this model
#                     X = torch.rand(num_particles, dim).to(device)
#                     Pso = PSO(X, function, model, LB, UB)
#                     _, _, mean_fitness = Pso.run()

#                     loss = mean_fitness
#                     total_loss += loss.item()
#                     batch_count += 1

#     avg_loss = total_loss / batch_count if batch_count > 0 else 0
#     print(f"Evaluation Avg Loss: {avg_loss:.4f}")
#     return avg_loss


# # Define the training function with model saving
# def train(dataloader, validation_loader, accumulate_step, start_epoch = 0):
#     epoch_losses = []  # To store average loss for each epoch
#     val_losses = []    # To store validation loss for each epoch
#     function_losses = {}  # To store loss per function type
#     batch_losses = {}  # To store batch loss per epoch

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     for epoch in range(start_epoch, num_epochs):
#         print(f"Epoch {epoch+1}/{num_epochs}")
#         total_loss = 0
#         batch_count = 0
#         batch_losses[epoch] = []  # Initialize batch loss tracking for this epoch

#         for batch in create_batches(dataloader, batch_size):
#             for i in range(0, len(batch), mini_batch_size):
#                 mini_batch = batch[i: i + mini_batch_size]
#                 for func in mini_batch:
#                     dim = func['dim']  # Get the dimension from the function data
#                     func_type = func['func_type']
#                     params = func['params']

#                     # Define the function
#                     def function(x):
#                         function_instance = Function.get_function(func_type, x, params)
#                         return function_instance.evaluate_function()

#                     # Re-create the model for each 'func' based on the 'dim'
#                     model = PSOGNN(node_input_dim=dim).to(device)
#                     optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

#                     # Initialize PSO with this model
#                     X = torch.rand(num_particles, dim).to(device)
#                     Pso = PSO(X, function, model, LB, UB)
#                     position_best, best, mean_fitness = Pso.run()

#                     loss = mean_fitness / accumulate_step
#                     loss.backward()

#                     if func_type not in function_losses:
#                         function_losses[func_type] = []  # Initialize list if func_type is not tracked
#                     function_losses[func_type].append(loss.item())  # Track loss per function type
                    
#                     batch_losses[epoch].append(loss.item())  # Track loss per batch for this epoch

#                     if (i + 1) % accumulate_step == 0:
#                         optimizer.step()
#                         optimizer.zero_grad()

#                     # Check for invalid loss values
#                     if torch.isinf(loss).any() or torch.isnan(loss).any():
#                         print(f"Warning: Loss is {'-inf' if torch.isinf(loss).any() else 'NaN'} at Epoch {epoch+1}, Step {i+1}")
#                         break

#             total_loss += loss.item()
#             batch_count += 1
        
#         avg_loss = total_loss / batch_count if batch_count > 0 else 0
#         epoch_losses.append(avg_loss)  # Track average loss for this epoch
        
#         if torch.isinf(torch.tensor(avg_loss)) or torch.isnan(torch.tensor(avg_loss)):
#             print(f"Warning: avg_loss is {'-inf' if torch.isinf(torch.tensor(avg_loss)) else 'NaN'} at Epoch {epoch+1}")
        
#         print(f"Avg Loss Epoch {epoch+1}: {avg_loss:.4f}")

        
#         # Evaluate on the validation set after each epoch
#         val_loss = evaluate(validation_loader, device)
#         val_losses.append(val_loss)

#         # Save model checkpoint after each epoch
#         save_checkpoint(model, optimizer, epoch, epoch_losses, val_losses, MODEL_PATH)

#     print("Training complete!")

#     # Plot average loss over epochs
#     plt.figure(figsize=(10, 6))
#     plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o', label='Training Loss')
#     plt.plot(range(1, num_epochs + 1), val_losses, marker='o', label='Validation Loss', color='red')
#     plt.xlabel('Epoch')
#     plt.ylabel('Average Loss')
#     plt.title('Average Loss over Epochs')
#     plt.legend()
#     plt.show()

#     # Plot loss per function type across batches
#     for func_type, losses in function_losses.items():
#         plt.figure(figsize=(10, 6))
#         plt.plot(losses, marker='o')
#         plt.xlabel('Batch')
#         plt.ylabel('Loss')
#         plt.title(f'Loss per Batch for {func_type}')
#         plt.show()

#     # Plot batch loss per epoch
#     for epoch, losses in batch_losses.items():
#         plt.figure(figsize=(10, 6))
#         plt.plot(losses, marker='o')
#         plt.xlabel('Batch')
#         plt.ylabel('Loss')
#         plt.title(f'Loss per Batch for Epoch {epoch+1}')
#         plt.show()

# # Save checkpoint function
# def save_checkpoint(model, optimizer, epoch, epoch_losses, val_losses, path):
#     torch.save({
#         'epoch': epoch,
#         'model_state_dict': model.state_dict(),
#         'optimizer_state_dict': optimizer.state_dict(),
#         'epoch_losses': epoch_losses,
#         'val_losses': val_losses
#     }, path)
#     print(f"Model saved at {path}")

# # Load checkpoint function
# def load_checkpoint(path):
#     if os.path.exists(path):
#         checkpoint = torch.load(path)
#         model = PSOGNN(node_input_dim=dim).to(device)
#         optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#         model.load_state_dict(checkpoint['model_state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#         start_epoch = checkpoint['epoch'] + 1
#         epoch_losses = checkpoint['epoch_losses']
#         val_losses = checkpoint['val_losses']
#         print(f"Loaded model from {path}, starting from epoch {start_epoch}")
#         return model, optimizer, start_epoch, epoch_losses, val_losses
#     else:
#         print(f"No checkpoint found at {path}")
#         return None, None, 0, [], []


