In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import QM9

from dataset import create_QM9_pyg_datasets
from models import GCNModel, GATModel, GINModel, MPNNModel, SchNetModel, DimeNetPlusPlusModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- Data Loading ---
# train_graphs, val_graphs, test_graphs are lists of PyG Data objects
# y_train, y_val, y_test are tensors of the selected targets, useful for knowing num_targets
train_graphs, val_graphs, test_graphs = create_QM9_pyg_datasets(
    subset_size=1000,  # Use a small subset for quick testing
    random_seed=42
)

# Create DataLoaders
# The Data objects in train_graphs etc. still have original data.y [1,19]
train_loader = DataLoader(train_graphs, batch_size=32, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=32, shuffle=False)
test_loader = DataLoader(test_graphs, batch_size=32, shuffle=False)

# --- Determine Model Input/Output Sizes ---
# Get number of node and edge features from the dataset
# Load a temporary full dataset instance just for these properties if needed,
# or get from the first graph if the subset is representative.
if train_graphs:
    num_node_features = train_graphs[0].num_node_features
    num_edge_features = train_graphs[0].num_edge_features
else: # Fallback if subset is empty, though create_QM9_pyg_datasets should handle this
    temp_dataset = QM9(root='./data/QM9')
    num_node_features = temp_dataset.num_node_features
    num_edge_features = temp_dataset.num_edge_features
    del temp_dataset

# For HOMO and LUMO (indices 2 and 3 in QM9's 19 targets)
target_idx = [2, 3] # HOMO at index 2, LUMO at index 3
num_targets = len(target_idx)

print(f"Number of node features: {num_node_features}")
print(f"Number of edge features: {num_edge_features}") # Relevant for MPNN
print(f"Number of targets: {num_targets}")

# --- Model Initialization (Choose one model to train) ---
# Example: GCNModel
selected_model_name = "GAT" # Change this to "GAT", "GIN", "MPNN", "SchNet", "DimeNet"

if selected_model_name == "GCN":
    model = GCNModel(num_node_features=num_node_features, hidden_channels=64, num_targets=num_targets, num_layers=3).to(device)
elif selected_model_name == "GAT":
    model = GATModel(num_node_features=num_node_features, hidden_channels=64, num_targets=num_targets, heads=4, num_layers=3).to(device)
elif selected_model_name == "GIN":
    model = GINModel(num_node_features=num_node_features, hidden_channels=64, num_targets=num_targets, num_layers=3).to(device)
elif selected_model_name == "MPNN":
    model = MPNNModel(num_node_features=num_node_features, num_edge_features=num_edge_features, hidden_channels=64, num_targets=num_targets, num_layers=3).to(device)
elif selected_model_name == "SchNet":
    # SchNet uses atomic numbers (z) and positions (pos), not num_node_features directly for its core.
    model = SchNetModel(num_targets=num_targets, hidden_channels=128).to(device)
elif selected_model_name == "DimeNet":
    model = DimeNetPlusPlusModel(num_targets=num_targets, hidden_channels=128).to(device)
else:
    raise ValueError(f"Unknown model name: {selected_model_name}")

print(f"\nSelected model: {selected_model_name}")
print(model)

# --- Optimizer and Loss Function ---
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss() # Mean Squared Error for regression

ImportError: cannot import name 'GCNModel' from 'models' (c:\msc_1\GraphDL\Graph-DL-HW\models.py)

In [None]:
def train_epoch(model, loader, optimizer, criterion, device, target_indices_list, model_type):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()

        # Forward pass depends on the model type
        if model_type in ["GCN", "GAT", "GIN"]:
            out = model(data.x, data.edge_index, data.batch)
        elif model_type == "MPNN":
            out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        elif model_type in ["SchNet", "DimeNet"]:
            out = model(data.z, data.pos, data.batch)
        else:
            raise ValueError(f"Unsupported model type for training: {model_type}")

        # Select the correct targets from data.y
        # data.y is typically [batch_size, 1, num_all_qm9_targets]
        # We need to make it [batch_size, num_selected_targets]
        actual_y = data.y.squeeze(1)[:, target_indices_list] # Squeeze out the middle dim, then select
        
        loss = criterion(out, actual_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

@torch.no_grad()
def evaluate_epoch(model, loader, criterion, device, target_indices_list, model_type):
    model.eval()
    total_loss = 0
    for data in loader:
        data = data.to(device)

        if model_type in ["GCN", "GAT", "GIN"]:
            out = model(data.x, data.edge_index, data.batch)
        elif model_type == "MPNN":
            out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        elif model_type in ["SchNet", "DimeNet"]:
            out = model(data.z, data.pos, data.batch)
        else:
            raise ValueError(f"Unsupported model type for evaluation: {model_type}")
            
        actual_y = data.y.squeeze(1)[:, target_indices_list]
        loss = criterion(out, actual_y)
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

# --- Training ---
num_epochs = 20 # Adjust as needed
train_losses = []
val_losses = []

print(f"\nStarting training for {selected_model_name}...")
for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device, target_idx, selected_model_name)
    val_loss = evaluate_epoch(model, val_loader, criterion, device, target_idx, selected_model_name)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    print(f'Epoch {epoch:02d}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

print("\nTraining finished.")

# --- Final Evaluation on Test Set ---
test_loss = evaluate_epoch(model, test_loader, criterion, device, target_idx, selected_model_name)
print(f'Final Test Loss for {selected_model_name}: {test_loss:.4f}')