In [1]:
import time
import hydra
from hydra.utils import to_absolute_path
import torch
import sys
import os 
import numpy as np 
from dgl.dataloading import GraphDataLoader

from omegaconf import DictConfig

from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel

project_path = os.path.abspath(os.path.join(os.path.abspath(''), '..', ''))
sys.path.append(project_path)

from python.create_dgl_dataset import TelemacDataset
from modulus.distributed.manager import DistributedManager
from modulus.launch.logging import (
    PythonLogger,
    RankZeroLoggingWrapper,
    initialize_wandb,
)
from modulus.launch.utils import load_checkpoint, save_checkpoint
from python.CustomMeshGraphNet import MeshGraphNet

In [2]:
class MGNTrainer:
    def __init__(self, cfg: DictConfig, rank_zero_logger: RankZeroLoggingWrapper):
        assert DistributedManager.is_initialized()
        self.dist = DistributedManager()

        self.amp = cfg.amp

        # instantiate dataset
        dataset = TelemacDataset(
            name="telemac_train",
            data_dir=to_absolute_path(cfg.data_dir),
            split="train",
            num_samples=cfg.num_training_samples,
            num_steps=cfg.num_training_time_steps
        )

        # instantiate dataloader
        self.dataloader = GraphDataLoader(
            dataset,
            batch_size=cfg.batch_size,
            shuffle=True,
            drop_last=True,
            pin_memory=True,
            use_ddp=self.dist.world_size > 1,
            num_workers=cfg.num_dataloader_workers,
        )

        # instantiate the model
        self.model = MeshGraphNet(
            cfg.num_input_features,
            cfg.num_edge_features,
            cfg.num_output_features,
            processor_size=4,
            hidden_dim_processor=64,
            hidden_dim_node_encoder=64,
            hidden_dim_edge_encoder=64,
            hidden_dim_node_decoder=64,
            do_concat_trick=cfg.do_concat_trick,
            num_processor_checkpoint_segments=cfg.num_processor_checkpoint_segments,
        )
        if cfg.jit:
            if not self.model.meta.jit:
                raise ValueError("MeshGraphNet is not yet JIT-compatible.")
            self.model = torch.jit.script(self.model).to(self.dist.device)
        else:
            self.model = self.model.to(self.dist.device)
        

        # distributed data parallel for multi-node training
        if self.dist.world_size > 1:
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[self.dist.local_rank],
                output_device=self.dist.device,
                broadcast_buffers=self.dist.broadcast_buffers,
                find_unused_parameters=self.dist.find_unused_parameters,
            )

        # enable train mode
        self.model.train()

        # instantiate loss, optimizer, and scheduler
        self.criterion = torch.nn.MSELoss()

        self.optimizer = None
        try:
            if cfg.use_apex:
                from apex.optimizers import FusedAdam

                self.optimizer = FusedAdam(self.model.parameters(), lr=cfg.lr)
        except ImportError:
            rank_zero_logger.warning(
                "NVIDIA Apex (https://github.com/nvidia/apex) is not installed, "
                "FusedAdam optimizer will not be used."
            )
        if self.optimizer is None:
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr)
        rank_zero_logger.info(f"Using {self.optimizer.__class__.__name__} optimizer")

        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer, lr_lambda=lambda epoch: cfg.lr_decay_rate**epoch
        )
        self.scaler = GradScaler()

        # load checkpoint
        if self.dist.world_size > 1:
            torch.distributed.barrier()
            
        self.epoch_init = load_checkpoint(
            to_absolute_path(cfg.ckpt_path),
            models=self.model,
            optimizer=self.optimizer,
            scheduler=self.scheduler,
            scaler=self.scaler,
            device=self.dist.device,
        )

    def train(self, graph):
        graph = graph.to(self.dist.device)
        self.optimizer.zero_grad()
        loss = self.forward(graph)
        self.backward(loss)
        self.scheduler.step()
        return loss

    #def forward(self, graph):
    #    # forward pass
    #    with autocast(enabled=self.amp):
    #        pred = self.model(graph.ndata["x"], graph.edata["x"], graph)
    #        #print(np.sum((graph.ndata['x'][:, 1] == 1).cpu().detach().numpy()))
    #        #print(np.sum((graph.ndata['x'][:, 0] == 1).cpu().detach().numpy()))
    #        #print(np.sum(((graph.ndata['x'][:, 0] == 1)|(graph.ndata['x'][:, 1] == 1)).cpu().detach().numpy()))
    #        mask_full = (graph.ndata['x'][:, 0] == 1)  # Mask for [1,0,0,0] nodes
    #        mask_partial = (graph.ndata['x'][:, 1] == 1)
    #        loss = self.criterion(pred[mask_full], graph.ndata["y"][mask_full])
    #        loss2 = self.criterion(pred[mask_partial][:,-2:], graph.ndata["y"][mask_partial][:,-2:])
    #        print('loss {}'.format(loss.cpu().detach()))
    #        print('loss2 {}'.format(loss2.cpu().detach()))
    #        return loss
    
    def forward(self, graph):
        # Forward pass
        with autocast(enabled=self.amp):
            pred = self.model(graph.ndata["x"], graph.edata["x"], graph)
            #print("pred shape {}".format(pred.shape))
    
            # Extracting the target labels
            target = graph.ndata['y']
            #print("target shape {}".format(target.shape))
    
            # Create masks for different node types
            mask_full = (graph.ndata['x'][:, 0] == 1)  # Mask for [1,0,0,0] nodes
            mask_partial = (graph.ndata['x'][:, 1] == 1)  # Mask for [0,1,0,0] nodes
            
            coeff_full = mask_full.shape[0]/(mask_full.shape[0]+mask_partial.shape[0])
            
            coeff_partial = mask_partial.shape[0]/(mask_full.shape[0]+mask_partial.shape[0])
            
            #print('mask full shape {}'.format(mask_full.shape))
            #print('mask partial shape {}'.format(mask_partial.shape))
    
            # Initialize loss to zero, ensure it has gradient support if no nodes are selected
            loss = torch.tensor(0.0, device=self.dist.device, requires_grad=True)
    
            # Compute loss for fully considered nodes
            if torch.any(mask_full):
                pred_full = pred[mask_full]
                target_full = target[mask_full]
                
                #print('pred full shape {}'.format(pred_full.shape))
                #print('target full shape {}'.format(target_full.shape))
                
                loss_full = self.criterion(pred_full, target_full)
                #print(loss_full)
                loss = loss + coeff_full*loss_full
    
            # Compute loss for partially considered nodes (last two dimensions)
            if torch.any(mask_partial):
                pred_partial = pred[mask_partial][:, -2:]  # Select only the last two dimensions
                target_partial = target[mask_partial][:, -2:]  # Corresponding targets
                
                #print('pred partial shape {}'.format(pred_partial.shape))
                #print('target partial shape {}'.format(target_partial.shape))
                
                loss_partial = self.criterion(pred_partial, target_partial)
                #print(loss_partial)
                loss = loss + coeff_partial*loss_partial
    
        return loss


    def backward(self, loss):
        # backward pass
        if self.amp:
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            loss.backward()
            self.optimizer.step()

In [None]:
import hydra
from hydra import initialize, compose
from omegaconf import DictConfig, OmegaConf

def train_model(cfg: DictConfig):
    # Initialize the distributed manager
    DistributedManager.initialize()
    dist = DistributedManager()

    logger = PythonLogger("main")
    rank_zero_logger = RankZeroLoggingWrapper(logger, dist)
    rank_zero_logger.file_logging()

    trainer = MGNTrainer(cfg, rank_zero_logger)
    start = time.time()
    rank_zero_logger.info("Training started...")
    for epoch in range(100):
        for graph in trainer.dataloader:
            loss = trainer.train(graph)
        
        if epoch % 10 == 0:
            rank_zero_logger.info(
                f"epoch: {epoch}, loss: {loss:.3e}, time per epoch: {(time.time() - start):.3e}"
            )
        # Save checkpoint
        if dist.world_size > 1:
            torch.distributed.barrier()
        if dist.rank == 0 and epoch % 20 == 0:
            save_checkpoint(
                to_absolute_path(cfg.ckpt_path),
                models=trainer.model,
                optimizer=trainer.optimizer,
                scheduler=trainer.scheduler,
                scaler=trainer.scaler,
                epoch=epoch,
            )
            logger.info(f"Saved model on rank {dist.rank}")
        start = time.time()
    rank_zero_logger.info("Training completed!")

# Initialize Hydra and set the configuration directory
with initialize(config_path="../bin/conf"):
    # Compose the configuration using the config name
    cfg = compose(config_name="config")
    
    # Display the configuration (optional)
    print(OmegaConf.to_yaml(cfg))

    # Now call the training function with the composed config
    train_model(cfg)

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path="../bin/conf"):


data_dir: ./data/toy_one_traj.bin
batch_size: 10
epochs: 1000
num_training_samples: 1
num_training_time_steps: 720
lr: 0.0005
lr_decay_rate: 0.9999991
num_input_features: 9
num_output_features: 3
num_edge_features: 3
use_apex: true
amp: false
jit: false
num_dataloader_workers: 4
do_concat_trick: true
num_processor_checkpoint_segments: 0
recompute_activation: false
ckpt_path: ./data/checkpoints

{'h': tensor([0.0953]), 'u': tensor([-0.0004]), 'v': tensor([0.1140]), 'strickler': tensor([16.2529]), 'z': tensor([-0.5487]), 'delta_h': tensor([0.0002]), 'delta_u': tensor([-7.2438e-07]), 'delta_v': tensor([0.0002]), 'h_std': tensor([0.3001]), 'u_std': tensor([0.0503]), 'v_std': tensor([0.3657]), 'strickler_std': tensor([7.2369]), 'z_std': tensor([0.8599]), 'delta_h_std': tensor([0.0034]), 'delta_u_std': tensor([0.0113]), 'delta_v_std': tensor([0.0140])}


[15:47:12 - main - INFO] [94mUsing FusedAdam optimizer[0m
[15:47:12 - checkpoint - INFO] [92mLoaded model state dictionary /users/daml/vmercier/gnn_modulus_project/data/checkpoints/MeshGraphNet.0.980.mdlus to device cuda:0[0m
[15:47:12 - checkpoint - INFO] [92mLoaded checkpoint file /users/daml/vmercier/gnn_modulus_project/data/checkpoints/checkpoint.0.980.pt to device cuda:0[0m
[15:47:12 - checkpoint - INFO] [92mLoaded optimizer state dictionary[0m
[15:47:12 - checkpoint - INFO] [92mLoaded scheduler state dictionary[0m
[15:47:12 - checkpoint - INFO] [92mLoaded grad scaler state dictionary[0m
[15:47:12 - main - INFO] [94mTraining started...[0m
[15:47:46 - main - INFO] [94mepoch: 0, loss: 1.057e-02, time per epoch: 3.387e+01[0m
[15:47:46 - checkpoint - INFO] [92mSaved model state dictionary: /users/daml/vmercier/gnn_modulus_project/data/checkpoints/MeshGraphNet.0.0.mdlus[0m
[15:47:46 - checkpoint - INFO] [92mSaved training checkpoint: /users/daml/vmercier/gnn_modulus_