In [None]:
import os
import random
import numpy as np

import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Subset

from timm.scheduler import CosineLRScheduler

from utils import convert_to_epd_list, rotate_epds, pixelization, train_one_epoch, test
from load_models import load_model
from gdeep.data.datasets.persistence_diagrams_from_graphs_builder import PersistenceDiagramFromGraphBuilder
from gdeep.data.datasets import PersistenceDiagramFromFiles

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters and configurations
dataname = 'PROTEINS' # Choose from ['IMDB-BINARY', 'IMDB-MULTI', 'MUTAG', 'PROTEINS', 'COX2', 'DHFR'] 
model_name = 'xpert'  # Choose from ['xpert', 'gin', 'gin_assisted_concat', 'gin_assisted_sum']
grid_size = 50
patch_size = 5
embed_dim = 192
depth = 5
epochs = 300
lr = 0.001
warmup_t = 50
batch_size = 64
n_splits = 10
patience = 100
min_epochs = 30

print(f"Dataset: {dataname}, Model: {model_name}, Patch Size: {patch_size}, Embed Dim: {embed_dim}, Depth: {depth}")

def labels_preprocess(labels):
    """Preprocess labels based on dataset name."""
    if dataname in ['IMDB-MULTI', 'PROTEINS']:
        labels = labels - 1
    if any(name in dataname for name in ['MUTAG', 'COX2', 'DHFR']):
        labels = 0.5 * labels + 0.5    
    return labels

# Load the dataset
dataset = TUDataset(root='./data/GraphDatasets/', name=dataname)
num_classes = dataset.num_classes

# Initialize tensor for pixelized persistence diagrams
ppd = torch.zeros((len(dataset), 4, grid_size, grid_size), dtype=torch.float32)

# Create persistence diagrams from graphs
diffusion_parameter = 1.0
pd_creator = PersistenceDiagramFromGraphBuilder(dataname, diffusion_parameter=diffusion_parameter, root='./data')
pd_creator.create()

# Load persistence diagrams
pd_ds = PersistenceDiagramFromFiles(
    os.path.join('./data', f"{dataname}_{diffusion_parameter}_extended_persistence")
)

# Preprocess labels
labels = [pd_ds[i][1] for i in range(len(pd_ds))]
labels = np.array(labels)
labels = labels_preprocess(labels)
print(f'{dataname} labels: {np.unique(labels)}')

# Convert and rotate persistence diagrams
epds, _ = convert_to_epd_list(pd_ds)
repds = rotate_epds(epds)  # List of rotated persistence diagrams

# Pixelize persistence diagrams
for i in range(len(repds)):
    ppd[i] = pixelization(repds[i], grid_size=grid_size, device='cpu')

# Verify that graph labels match persistence diagram labels
graph_labels = [dataset[i].y.item() for i in range(len(dataset))]
num_same_labels = (np.array(graph_labels) == labels).sum()
sanity = (num_same_labels == len(dataset))
print(f"Graph labels are the same as the labels in the persistence diagram dataset: {sanity}")

if not sanity:
    print("Warning: Graph labels do not match persistence diagram labels.")

# Prepare data list
data_list = []
for idx, data in enumerate(dataset):
    data.node_feat = torch.ones((data.num_nodes, 1), dtype=torch.float32)
    data.ppd = ppd[idx]  # ppd[i].shape = (4, grid_size, grid_size)
    data_list.append(data)

# Stratified K-Fold cross-validation
skfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
criterion = torch.nn.CrossEntropyLoss().to(device)
best_test_acc_list = []

# Cross-validation loop
for fold, (train_idx, test_idx) in enumerate(skfold.split(data_list, labels)):
    print(f"Starting Fold {fold + 1}/{n_splits}")
    
    # Load model
    model = load_model(
        model_name, device, num_classes, grid_size, patch_size,
        depth=depth, embed_dim=embed_dim
    )
    
    # Data loaders
    train_loader = DataLoader(Subset(data_list, train_idx), batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(Subset(data_list, test_idx), batch_size=batch_size)
    
    # Optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = CosineLRScheduler(
        optimizer,
        t_initial=epochs,
        cycle_mul=1,
        lr_min=0.05 * lr,
        cycle_decay=1.0,
        warmup_lr_init=0.05 * lr,
        warmup_t=warmup_t,
        cycle_limit=1,
        t_in_epochs=True
    )
    
    best_test_acc = 0.0
    best_epoch = 0
    epochs_no_improve = 0  # Counter for epochs with no improvement
    
    # Training loop
    for epoch in range(epochs):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, scheduler = scheduler)
        
        # Validation
        test_loss, test_acc = test(model, test_loader, criterion, device)
        
        # Early stopping logic
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            best_epoch = epoch
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1  
    
    print(f'Fold {fold + 1}/{n_splits} - Best Test Accuracy: {best_test_acc:.3f}')
    best_test_acc_list.append(best_test_acc)

# Final results
avg_acc = np.mean(best_test_acc_list)
std_acc = np.std(best_test_acc_list)
print(f'Average Best Test Accuracy over {n_splits} folds: {avg_acc:.3f} ± {std_acc:.3f}')




No TPUs...
Dataset: PROTEINS, Model: xpert, Patch Size: 5, Embed Dim: 192, Depth: 5
Dataset PROTEINS already exists! Skipping: dataset will not be created.
PROTEINS labels: [0 1]
Graph labels are the same as the labels in the persistence diagram dataset: True
Starting Fold 1/10
Fold 1/10 - Best Test Accuracy: 0.714
Starting Fold 2/10
Fold 2/10 - Best Test Accuracy: 0.723
Starting Fold 3/10
Fold 3/10 - Best Test Accuracy: 0.777
Starting Fold 4/10
Fold 4/10 - Best Test Accuracy: 0.757
Starting Fold 5/10
