In [1]:
# Install required packages (if not already installed)
import subprocess
import sys

packages = [
    'torch',
    'torch-geometric',
    'numpy',
    'tqdm',
    'wandb',
]

for pkg in packages:
    try:
        __import__(pkg.replace('-', '_'))
    except ImportError:
        print(f'Installing {pkg}...')
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pkg])

print('All packages ready!')

  cpu = _conversion_method_template(device=torch.device("cpu"))


Installing torch-geometric...
Installing wandb...
All packages ready!


In [2]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool
from torch_geometric.datasets import QM9
from torch_geometric.data import DataLoader, Data
import numpy as np
from copy import deepcopy
from tqdm import tqdm
import os

print('Imports successful')

  from .autonotebook import tqdm as notebook_tqdm


Imports successful


## GCN Model Definition

In [11]:
class GCN(nn.Module):
    """Configurable GCN with optional batchnorm, dropout and mean+max pooled MLP readout."""

    def __init__(
        self,
        num_node_features,
        hidden_channels: int = 128,
        num_layers: int = 3,
        dropout: float = 0.0,
        use_batchnorm: bool = True,
    ):
        super(GCN, self).__init__()

        self.num_layers = num_layers
        self.hidden_channels = hidden_channels
        self.dropout = float(dropout)
        self.use_batchnorm = use_batchnorm

        # Build GCN layers
        self.convs = nn.ModuleList()
        in_channels = num_node_features
        for i in range(num_layers):
            out_channels = hidden_channels
            self.convs.append(GCNConv(in_channels, out_channels))
            in_channels = out_channels

        # Optional batchnorms
        if use_batchnorm and num_layers > 0:
            self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_channels) for _ in range(num_layers)])
        else:
            self.bns = None

        # MLP readout: (mean+max pooled) -> hidden -> 1
        readout_hidden = max(hidden_channels // 2, 16)
        self.readout = nn.Sequential(
            nn.Linear(2 * hidden_channels, readout_hidden),  # 2x because mean+max concat
            nn.ReLU(inplace=True),
            nn.Dropout(self.dropout),
            nn.Linear(readout_hidden, 1),
        )

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # Node embedding through stacked GCN layers
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            # Apply BatchNorm before the non-linearity
            if self.bns is not None:
                x = self.bns[i](x)
            x = F.relu(x)
            if self.dropout > 0:
                x = F.dropout(x, p=self.dropout, training=self.training)

        # Readout: concatenate mean and max pooled node features per graph
        mean_pool = global_mean_pool(x, batch)
        max_pool = global_max_pool(x, batch)
        x = torch.cat([mean_pool, max_pool], dim=1)  # [batch_size, 2*hidden_channels]

        # MLP head -> returns shape [batch_size, 1]
        out = self.readout(x)
        return out

print('GCN model defined')

GCN model defined


## Configuration & Data Setup

In [12]:
# Configuration (hardcoded)
cfg = {
    'seed': 0,
    'device': 'cpu',  # use 'cuda' if GPU available
    'data_dir': './data',
    'target_idx': 2,  # QM9 property index (alpha polarizability)
    
    # Model
    'num_node_features': 11,
    'hidden_channels': 128,
    'num_layers': 3,
    'dropout': 0.0,
    'use_batchnorm': True,
    
    # Dataset
    'batch_size_train': 100,
    'batch_size_inference': 2048,
    'num_workers': 0,
    'splits': [0.72, 0.08, 0.1, 0.1],
    'subset_size': None,  # None = use full dataset
    
    # Trainer
    'total_epochs': 250,
    'validation_interval': 10,
    'early_stopping_patience': 10,
    'lr': 0.0005,
    'weight_decay': 0.005,
    
    # Mean-Teacher
    'lambda_mt': 1.0,
    'ema_decay': 0.99,
    'mt_augment_scale': 0.05,
    'grad_clip_norm': 0.0,
    'use_target_normalization': True,
}

# Set random seed
torch.manual_seed(cfg['seed'])
np.random.seed(cfg['seed'])

device = torch.device(cfg['device'])
print(f'Device: {device}')
print(f'Config: {cfg}')

Device: cpu
Config: {'seed': 0, 'device': 'cpu', 'data_dir': './data', 'target_idx': 2, 'num_node_features': 11, 'hidden_channels': 128, 'num_layers': 3, 'dropout': 0.0, 'use_batchnorm': True, 'batch_size_train': 100, 'batch_size_inference': 2048, 'num_workers': 0, 'splits': [0.72, 0.08, 0.1, 0.1], 'subset_size': None, 'total_epochs': 250, 'validation_interval': 10, 'early_stopping_patience': 10, 'lr': 0.0005, 'weight_decay': 0.005, 'lambda_mt': 1.0, 'ema_decay': 0.99, 'mt_augment_scale': 0.05, 'grad_clip_norm': 0.0, 'use_target_normalization': True}


In [13]:
# Download and load QM9 dataset
print('Downloading QM9 dataset...')
dataset = QM9(root=cfg['data_dir'])

# Normalize edge attributes if needed
if hasattr(dataset.data, 'edge_attr') and dataset.data.edge_attr is not None:
    dataset.data.edge_attr = dataset.data.edge_attr.float()

# Get target property
y = dataset.data.y[:, cfg['target_idx']]
dataset.data.y = y

print(f'Dataset size: {len(dataset)}')
print(f'Node features: {dataset.num_node_features}')
print(f'Edge features: {dataset.num_edge_features}')

Downloading QM9 dataset...
Dataset size: 130831
Node features: 11
Edge features: 4
Dataset size: 130831
Node features: 11
Edge features: 4


  if hasattr(dataset.data, 'edge_attr') and dataset.data.edge_attr is not None:
  dataset.data.edge_attr = dataset.data.edge_attr.float()
  y = dataset.data.y[:, cfg['target_idx']]
  dataset.data.y = y


In [14]:
# Split dataset
n = len(dataset)
splits = cfg['splits']
indices = torch.randperm(n)

train_idx = indices[:int(splits[0] * n)]
val_idx = indices[int(splits[0] * n):int((splits[0] + splits[1]) * n)]
test_idx = indices[int((splits[0] + splits[1]) * n):]

# Split unlabeled (for semi-supervised)
unlabeled_idx = indices[int((splits[0] + splits[1] + splits[2]) * n):]

# Apply subset if configured
if cfg['subset_size'] is not None:
    train_idx = train_idx[:cfg['subset_size']]
    unlabeled_idx = unlabeled_idx[:cfg['subset_size']]

train_data = [dataset[i] for i in train_idx]
val_data = [dataset[i] for i in val_idx]
test_data = [dataset[i] for i in test_idx]
unlabeled_data = [dataset[i] for i in unlabeled_idx]

print(f'Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}, Unlabeled: {len(unlabeled_data)}')

# Create dataloaders
train_loader = DataLoader(train_data, batch_size=cfg['batch_size_train'], shuffle=True, num_workers=cfg['num_workers'])
val_loader = DataLoader(val_data, batch_size=cfg['batch_size_inference'], shuffle=False, num_workers=cfg['num_workers'])
test_loader = DataLoader(test_data, batch_size=cfg['batch_size_inference'], shuffle=False, num_workers=cfg['num_workers'])
unlabeled_loader = DataLoader(unlabeled_data, batch_size=cfg['batch_size_train'], shuffle=True, num_workers=cfg['num_workers'])

print('Dataloaders created')

Train: 94198, Val: 10466, Test: 26167, Unlabeled: 13084
Dataloaders created


  train_loader = DataLoader(train_data, batch_size=cfg['batch_size_train'], shuffle=True, num_workers=cfg['num_workers'])
  val_loader = DataLoader(val_data, batch_size=cfg['batch_size_inference'], shuffle=False, num_workers=cfg['num_workers'])
  test_loader = DataLoader(test_data, batch_size=cfg['batch_size_inference'], shuffle=False, num_workers=cfg['num_workers'])
  unlabeled_loader = DataLoader(unlabeled_data, batch_size=cfg['batch_size_train'], shuffle=True, num_workers=cfg['num_workers'])


## Trainer: Semi-Supervised Ensemble with Mean-Teacher

In [15]:
class SemiSupervisedEnsemble:
    def __init__(
        self,
        supervised_criterion,
        optimizer,
        scheduler,
        device,
        models,
        logger,
        datamodule,
        lambda_mt,
        ema_decay,
        mt_augment_scale,
        grad_clip_norm: float = 0.0,
        use_target_normalization: bool = True,
    ):
        # Semi-supervised hyperparameters (configurable)
        self.device = device
        self.models = models
        for m in self.models:
            m.to(device)
            for p in m.parameters():
                p.requires_grad = True

        self.teacher_models = deepcopy(models)
        for teacher in self.teacher_models:
            teacher.to(device)
            for p in teacher.parameters():
                p.requires_grad = False
            # keep teacher in eval mode so it provides a stable target (no dropout/bn updates)
            teacher.eval()

        for m in self.models:
            m.to(device)
            for p in m.parameters():
                p.requires_grad = True

                
        # set from init args so they can be configured via Hydra
        self.lambda_mt = float(lambda_mt)
        self.lambda_cps = 1.0
        self.ema_decay = float(ema_decay)
        self.mt_augment_scale = float(mt_augment_scale)
        self.grad_clip_norm = float(grad_clip_norm)
        self.use_target_normalization = bool(use_target_normalization)

        # Optim related things
        self.supervised_criterion = supervised_criterion
        all_params = [p for m in self.models for p in m.parameters()]
        self.optimizer = optimizer(params=all_params)
        self.scheduler = scheduler(optimizer=self.optimizer)

        # Dataloader setup
        self.train_dataloader = datamodule.train_dataloader()
        self.val_dataloader = datamodule.val_dataloader()
        self.test_dataloader = datamodule.test_dataloader()
        self.unlabeled_train_dataloader = datamodule.unsupervised_train_dataloader()

        # Logging
        self.logger = logger
        # place to store best model weights found during training
        self._best_state = None
        self._best_epoch = None


    # ---------------------------
    # Mean Teacher EMA update
    # ---------------------------
    def update_teacher(self):
        for teacher, student in zip(self.teacher_models, self.models):
            for tp, sp in zip(teacher.parameters(), student.parameters()):
                tp.data = self.ema_decay * tp.data + (1.0 - self.ema_decay) * sp.data

    # ---------------------------
    # N-CPS consistency
    # ---------------------------
    def noisy_augment(self, data):
        # simple example: gaussian noise using configurable scale
        noisy_x = data.x + float(self.mt_augment_scale) * torch.randn_like(data.x)
        data_aug = deepcopy(data)
        data_aug.x = noisy_x
        return data_aug

    def validate(self):
        for model in self.models:
            model.eval()

        val_losses = []
        
        with torch.no_grad():
            for x, targets in self.val_dataloader:
                x, targets = x.to(self.device), targets.to(self.device)
                
                # Ensemble prediction
                preds = [model(x) for model in self.models]
                # If using target normalization, model outputs are in normalized space;
                # un-normalize before computing validation MSE so it's in original scale.
                if getattr(self, 'use_target_normalization', False):
                    preds = [(p * self.target_std + self.target_mean) for p in preds]
                avg_preds = torch.stack(preds).mean(0)

                val_loss = torch.nn.functional.mse_loss(avg_preds, targets)
                val_losses.append(val_loss.item())
        val_loss = np.mean(val_losses)
        return {"val_MSE": val_loss}

    def test(self):
        """Evaluate models on the test set and return test metrics.

        Returns a dict like {"test_MSE": float}.
        """
        for model in self.models:
            model.eval()

        test_losses = []
        with torch.no_grad():
            for x, targets in self.test_dataloader:
                x, targets = x.to(self.device), targets.to(self.device)

                preds = [model(x) for model in self.models]
                # If using target normalization, un-normalize predictions
                if getattr(self, 'use_target_normalization', False):
                    preds = [(p * self.target_std + self.target_mean) for p in preds]

                avg_preds = torch.stack(preds).mean(0)
                test_loss = torch.nn.functional.mse_loss(avg_preds, targets)
                test_losses.append(test_loss.item())

        test_loss = float(np.mean(test_losses)) if len(test_losses) > 0 else float('nan')
        # log and return
        try:
            self.logger.log_dict({"test_MSE": test_loss})
        except Exception:
            pass
        return {"test_MSE": test_loss}

    def train(self, total_epochs, validation_interval, early_stopping_patience=None, **kwargs):
        final_results = {}
        patience_counter = 0
        # allow overriding patience from config; default to 10 if not provided
        patience = int(early_stopping_patience) if early_stopping_patience is not None else 10
        best_val_loss = float('inf')

        # If target normalization is enabled, compute train target mean/std once
        self.target_mean = 0.0
        self.target_std = 1.0
        if self.use_target_normalization:
            all_targets = []
            for _, t in self.train_dataloader:
                # t may be (batch,1) or (batch,)
                if isinstance(t, (list, tuple)):
                    t = t[0]
                all_targets.append(t.detach().cpu())
            if len(all_targets) > 0:
                all_targets = torch.cat(all_targets, dim=0)
                self.target_mean = float(all_targets.mean())
                self.target_std = float(all_targets.std())
                if self.target_std == 0:
                    self.target_std = 1.0

        unlabeled_iter = iter(self.unlabeled_train_dataloader)

        for epoch in (pbar := tqdm(range(1, total_epochs + 1))):
            for m in self.models:
                m.train()

            supervised_log = []
            mt_log = []

            for x_labeled, targets in self.train_dataloader:
                # Get unlabeled batch
                try:
                    x_unl = next(unlabeled_iter)
                except StopIteration:
                    unlabeled_iter = iter(self.unlabeled_train_dataloader)
                    x_unl = next(unlabeled_iter)

                x_labeled, targets = x_labeled.to(self.device), targets.to(self.device)
                x_unl = x_unl[0].to(self.device)

                # create an augmented view for the student so MT loss is meaningful
                x_unl_student = self.noisy_augment(x_unl)

                self.optimizer.zero_grad()

                # -------------------------
                # 1. Supervised loss (optionally using target normalization)
                # -------------------------
                preds = [m(x_labeled) for m in self.models]

                if self.use_target_normalization:
                    targets_norm = (targets - self.target_mean) / self.target_std
                else:
                    targets_norm = targets

                # loss used for backward (on normalized targets when enabled)
                sup_losses = [self.supervised_criterion(p, targets_norm) for p in preds]
                sup_loss = sum(sup_losses) / len(self.models)

                # For logging, compute un-normalized MSE between ensemble preds and raw targets
                try:
                    with torch.no_grad():
                        preds_un = [(p * self.target_std + self.target_mean) if self.use_target_normalization else p for p in preds]
                        ensemble_un = torch.stack(preds_un).mean(0)
                        supervised_log.append(torch.nn.functional.mse_loss(ensemble_un, targets).item())
                except Exception:
                    supervised_log.append(sup_loss.item())


                # -------------------------
                # 2. Mean Teacher loss
                # -------------------------
                # student sees augmented view, teacher sees original (stable) view
                student_out = [m(x_unl_student) for m in self.models]
                teacher_out = [tm(x_unl).detach() for tm in self.teacher_models]


                mt_loss = 0
                for s, t in zip(student_out, teacher_out):
                    mt_loss += torch.nn.functional.mse_loss(s, t)
                mt_loss = mt_loss / len(self.models)
                mt_log.append(mt_loss.item())


                # -------------------------
                # Total loss
                # -------------------------
                loss = sup_loss + self.lambda_mt * mt_loss
                loss.backward()

                # gradient clipping (if enabled)
                if self.grad_clip_norm and self.grad_clip_norm > 0.0:
                    params = [p for m in self.models for p in m.parameters() if p.grad is not None]
                    torch.nn.utils.clip_grad_norm_(params, self.grad_clip_norm)

                self.optimizer.step()

                # Update EMA teacher
                self.update_teacher()

            self.scheduler.step()

            summary_dict = {
                "supervised_loss": np.mean(supervised_log),
                "mean_teacher_loss": np.mean(mt_log),
            }

            if epoch % validation_interval == 0 or epoch == total_epochs:
                val_metrics = self.validate()
                summary_dict.update(val_metrics)
                pbar.set_postfix(summary_dict)

                # Early stopping
                cur_val = val_metrics["val_MSE"]
                if cur_val < best_val_loss:
                    best_val_loss = cur_val
                    patience_counter = 0
                    # save best model weights (deepcopy state_dicts)
                    try:
                        self._best_state = [ {k: v.cpu().clone() for k, v in m.state_dict().items()} for m in self.models ]
                        self._best_epoch = epoch
                    except Exception:
                        self._best_state = None
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print(f"Early stopping at epoch {epoch}")
                        break

            self.logger.log_dict(summary_dict, step=epoch)
            final_results = summary_dict

        # If we saved a best checkpoint during training, restore it so subsequent
        # testing uses the best-validation weights rather than the final weights.
        if self._best_state is not None:
            try:
                for m, state in zip(self.models, self._best_state):
                    m.load_state_dict(state)
                print(f"Restored best model from epoch {self._best_epoch} for testing/evaluation.")
            except Exception:
                print("Failed to restore best model state; using final weights.")

        return final_results

print('SemiSupervisedEnsemble trainer defined')

SemiSupervisedEnsemble trainer defined


## Training

In [17]:
# Create wrapper to convert PyG Data objects to (x, targets) tuples
class TupleDataLoader:
    def __init__(self, loader, has_labels=True):
        self.loader = loader
        self.has_labels = has_labels
    
    def __iter__(self):
        for batch in self.loader:
            if self.has_labels:
                # Return (batch, targets) where targets is extracted from batch.y
                targets = batch.y.view(-1, 1) if batch.y.dim() == 1 else batch.y
                yield batch, targets
            else:
                # For unlabeled data, return (batch,) as a tuple
                yield (batch,)
    
    def __len__(self):
        return len(self.loader)

# Create a simple datamodule wrapper
class SimpleDataModule:
    def __init__(self, train_loader, val_loader, test_loader, unlabeled_loader):
        self._train_loader = TupleDataLoader(train_loader, has_labels=True)
        self._val_loader = TupleDataLoader(val_loader, has_labels=True)
        self._test_loader = TupleDataLoader(test_loader, has_labels=True)
        self._unlabeled_loader = TupleDataLoader(unlabeled_loader, has_labels=False)
    
    def train_dataloader(self):
        return self._train_loader
    
    def val_dataloader(self):
        return self._val_loader
    
    def test_dataloader(self):
        return self._test_loader
    
    def unsupervised_train_dataloader(self):
        return self._unlabeled_loader

# Create a simple logger
class SimpleLogger:
    def log_dict(self, metrics, step=None):
        # Just print for standalone notebook
        pass

# Create model
model = GCN(
    num_node_features=cfg['num_node_features'],
    hidden_channels=cfg['hidden_channels'],
    num_layers=cfg['num_layers'],
    dropout=cfg['dropout'],
    use_batchnorm=cfg['use_batchnorm'],
)
models = [model]

# Loss
criterion = nn.MSELoss()

# Create datamodule and logger
datamodule = SimpleDataModule(train_loader, val_loader, test_loader, unlabeled_loader)
logger = SimpleLogger()

# Optimizer and scheduler factories (trainer will instantiate them)
def optimizer_factory(params):
    return torch.optim.AdamW(params, lr=cfg['lr'], weight_decay=cfg['weight_decay'])

def scheduler_factory(optimizer):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg['total_epochs'])

# Create trainer
trainer = SemiSupervisedEnsemble(
    supervised_criterion=criterion,
    optimizer=optimizer_factory,
    scheduler=scheduler_factory,
    device=device,
    models=models,
    logger=logger,
    datamodule=datamodule,
    lambda_mt=cfg['lambda_mt'],
    ema_decay=cfg['ema_decay'],
    mt_augment_scale=cfg['mt_augment_scale'],
    grad_clip_norm=cfg['grad_clip_norm'],
    use_target_normalization=cfg['use_target_normalization'],
)

print('Model and trainer ready. Starting training...')
train_results = trainer.train(
    total_epochs=cfg['total_epochs'],
    validation_interval=cfg['validation_interval'],
    early_stopping_patience=cfg['early_stopping_patience']
)
print(f'Training complete. Final results: {train_results}')

Model and trainer ready. Starting training...


  2%|‚ñè         | 4/250 [01:15<1:17:26, 18.89s/it]



KeyboardInterrupt: 

## Testing & Final Evaluation

In [None]:
# Run test
test_results = trainer.test()
print(f'Test results: {test_results}')

# Run final validation
val_results = trainer.validate()
print(f'Validation results: {val_results}')

print(f'\n=== Summary ===')
print(f'Train (supervised + MT): {train_results}')
print(f'Validation MSE: {val_results["val_MSE"]:.6f}')
print(f'Test MSE: {test_results["test_MSE"]:.6f}')