In [1]:
import time
import hydra
from hydra.utils import to_absolute_path
import torch
import sys
import os
import torch.nn as nn

import argparse

from dgl.dataloading import GraphDataLoader
import dgl
from omegaconf import DictConfig

from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel

#project_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', ''))
#sys.path.append(project_path)
project_path = os.path.abspath(os.path.join(os.getcwd(), '..', ''))
sys.path.append(project_path)

from python.create_dgl_dataset import TelemacDataset
from modulus.distributed.manager import DistributedManager
from modulus.launch.logging import (
    PythonLogger,
    RankZeroLoggingWrapper,
    initialize_wandb,
)
from modulus.launch.utils import load_checkpoint, save_checkpoint
from python.CustomMeshGraphNet import MeshGraphNet

In [4]:

# Ensure that DGL and other dependencies are installed
# !pip install dgl

# Define your collate function
def collate_fn(batch):
    # batch is a list of sequences
    # Each sequence is a list of graphs (of length sequence_length)
    # We want to batch the graphs at each time step across sequences

    sequence_length = len(batch[0])  # Assuming all sequences have the same length

    batched_graphs = []
    for t in range(sequence_length):
        graphs_at_t = [sequence[t] for sequence in batch]
        batched_graph = dgl.batch(graphs_at_t)
        batched_graphs.append(batched_graph)

    return batched_graphs



# Define the trainer class
class MGNTrainer:
    def __init__(self):
        self.sequence_length = 10  # Adjust as needed
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        # Instantiate the dataset and dataloader
        dataset = TelemacDataset(
            name="telemac_train",
            data_dir="/work/m24046/m24046mrcr/results_data_30min/Multimesh_2_32_True.bin",
            dynamic_data_files=['/work/m24046/m24046mrcr/results_data_30min/Group_3_peak_1800_Group_3_peak_1800_0_0-80.pkl'],
            split="train",
            ckpt_path="/work/m24046/m24046mrcr//work/m24046/m24046mrcr/new_tests_Group3/config1",
            normalize=True,
            sequence_length=self.sequence_length,
        )
        self.dataloader = GraphDataLoader(
            dataset,
            batch_size=1,
            shuffle=False,
            drop_last=True,
            pin_memory=True,
            use_ddp=False,
            num_workers=1,
            collate_fn=collate_fn,
        )

        # Instantiate the model
        # Instantiate the model
        self.model = MeshGraphNet(
            9,
            3,
            3,
            processor_size=10,
            hidden_dim_processor=64,
            hidden_dim_node_encoder=64,
            hidden_dim_edge_encoder=64,
            hidden_dim_node_decoder=64,
            do_concat_trick=True,
            num_processor_checkpoint_segments=1,
        )
        if 0 :
            self.model = torch.jit.script(self.model).to(self.device)
        else:
            self.model = self.model.to(self.device)
            
        self.model.train()
        
        self.criterion = torch.nn.MSELoss()

        self.optimizer = None
        try:
            if True:
                from apex.optimizers import FusedAdam

                self.optimizer = FusedAdam(self.model.parameters(), lr=0.0001)
        except ImportError:
            print(
                "NVIDIA Apex (https://github.com/nvidia/apex) is not installed, "
                "FusedAdam optimizer will not be used."
            )
        if self.optimizer is None:
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001)

        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer, lr_lambda=lambda epoch: 0.999999**epoch
        )
        self.scaler = GradScaler()

            
        self.epoch_init = load_checkpoint(
            to_absolute_path("/work/m24046/m24046mrcr//work/m24046/m24046mrcr/new_tests_Group3/config1"),
            models=self.model,
            optimizer=self.optimizer,
            scheduler=self.scheduler,
            scaler=self.scaler,
            device=self.device,
        )

        # Get node statistics from the dataset
        node_stats = dataset.node_stats
        device = torch.device('cuda')  # Use 'cuda' if GPU is available

        self.h_u_v_i_0_mean = torch.tensor([
            node_stats['h'],
            node_stats['u'],
            node_stats['v']
        ], dtype=torch.float32).to(device)

        self.h_u_v_i_0_std = torch.tensor([
            node_stats['h_std'],
            node_stats['u_std'],
            node_stats['v_std']
        ], dtype=torch.float32).to(device)

        self.delta_h_u_v_i_diff_mean = torch.tensor([
            node_stats['delta_h'],
            node_stats['delta_u'],
            node_stats['delta_v']
        ], dtype=torch.float32).to(device)

        self.delta_h_u_v_i_diff_std = torch.tensor([
            node_stats['delta_h_std'],
            node_stats['delta_u_std'],
            node_stats['delta_v_std']
        ], dtype=torch.float32).to(device)

    def _denormalize_data(self, tensor, mean, std):
        return tensor * std + mean

    def _normalize_data(self, tensor, mean, std):
        return (tensor - mean) / std

    def modify_graph(self, graph_0, pred_0, graph_1):
        # Clone graph_0 to avoid modifying the original
        predicted_modified_graph = graph_0.clone()

        pred_0_detached = pred_0

        # Update node features in predicted_modified_graph
        node_features_mod = predicted_modified_graph.ndata['x'].clone()
        one_hot_vectors = node_features_mod[:, :4]

        # Nodes with one-hot [0, 0, 1, 0]: replace x values with those from graph_1
        mask_replace = (one_hot_vectors == torch.tensor([0, 0, 1, 0], device=one_hot_vectors.device)).all(dim=1)
        node_features_mod[mask_replace] = graph_1.ndata['x'][mask_replace]
        # Nodes with one-hot [0, 1, 0, 0]: replace 7th feature with that from graph_1
        mask_7th = (one_hot_vectors == torch.tensor([0, 1, 0, 0], device=one_hot_vectors.device)).all(dim=1)
        node_features_mod[mask_7th, 6] = graph_1.ndata['x'][mask_7th, 6]
        # For other nodes, update 'h', 'u', 'v' features with pred_0
        mask_other = ~(mask_replace | mask_7th)

        # Denormalize current 'h', 'u', 'v' features
        node_features_to_update = node_features_mod[mask_other][:, 6:9]
        unnormalized_current = self._denormalize_data(
            node_features_to_update,
            self.h_u_v_i_0_mean,
            self.h_u_v_i_0_std
        )

        # Denormalize predicted increments
        unnormalized_increment = self._denormalize_data(
            pred_0_detached[mask_other],
            self.delta_h_u_v_i_diff_mean,
            self.delta_h_u_v_i_diff_std
        )

        # Update features
        unnormalized_updated = unnormalized_current + unnormalized_increment

        # Normalize updated features
        normalized_updated = self._normalize_data(
            unnormalized_updated,
            self.h_u_v_i_0_mean,
            self.h_u_v_i_0_std
        )
        
        # Update the node features in predicted_modified_graph
        node_features_mod[mask_other][:, 6:9] = normalized_updated

        predicted_modified_graph.ndata['x'] = node_features_mod

        return predicted_modified_graph

    def compute_loss(self, pred, target, graph):
        # Masks and loss computation as before
        one_hot_vectors = graph.ndata['x'][:, :4]
        mask_exclude = (one_hot_vectors == torch.tensor([0, 0, 1, 0], device=self.device)).all(dim=1)
        mask_include = ~mask_exclude
        mask_specific = (one_hot_vectors == torch.tensor([0, 1, 0, 0], device=self.device)).all(dim=1)
        mask_specific = mask_specific & mask_include
        mask_other = mask_include & ~mask_specific

        # Nodes with one-hot [0, 1, 0, 0]: predict only y[:,1:3]
        pred_specific = pred[mask_specific][:, 1:3]
        target_specific = target[mask_specific][:, 1:3]

        # Other nodes: predict all features
        pred_other = pred[mask_other]
        target_other = target[mask_other]

        # Compute losses
        loss_specific = self.criterion(pred_specific, target_specific)
        loss_other = self.criterion(pred_other, target_other)

        # Combine losses
        total_nodes = mask_include.sum()
        weight_specific = mask_specific.sum().float() / total_nodes
        weight_other = mask_other.sum().float() / total_nodes

        loss = weight_specific * loss_specific + weight_other * loss_other
        return loss

    def forward(self, batch):
        losses = []
        preds = []
        num_steps = self.sequence_length

        # Initialize the first graph
        graph = batch[0].to(self.device)
        node_features = graph.ndata["x"]
        edge_features = graph.edata.get("x", None)
        target = graph.ndata["y"]
        with autocast(enabled=True):
            pred = self.model(node_features, edge_features, graph)
            loss = self.compute_loss(pred, target, graph)
        losses.append(loss)
        preds.append(pred)

        # Loop through the sequence
        for t in range(1, num_steps):
            next_graph = batch[t].to(self.device)
            # Modify the current graph using the previous prediction
            graph = self.modify_graph(graph, preds[-1], next_graph)
            node_features = graph.ndata["x"]
            edge_features = graph.edata.get("x", None)
            target = next_graph.ndata["y"]
            with autocast(enabled=True):
                pred = self.model(node_features, edge_features, graph)
                loss = self.compute_loss(pred, target, graph)
            losses.append(loss)
            preds.append(pred)
        print(losses)
        #losses[0] = 0.0
        # Combine losses
        total_loss = sum(losses)
        return total_loss

    def train(self, batch):
        self.optimizer.zero_grad()

        # Enable gradient tracking for initial node features
        #batch[0].ndata['x'].requires_grad_(True)

        # Forward pass
        total_loss = self.forward(batch)

        # Backward pass
        self.scaler.scale(total_loss)
        self.scaler.scale(total_loss).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()
        # Check gradients of initial node features
        #initial_node_features_grad = batch[0].ndata['x'].grad
        #if initial_node_features_grad is not None:
        #    print("Gradient w.r.t initial node features exists.")
        #    print("Gradient norm:", initial_node_features_grad.norm().item())
        #else:
        #    print("No gradient w.r.t initial node features.")

        # Check if gradients are flowing through the model parameters
        #for name, param in self.model.named_parameters():
        #    if param.grad is not None:
        #        print(f"Gradient for {name}: {param.grad.norm().item()}")
        #    else:
        #        print(f"No gradient for {name}")

        

        return total_loss


# Instantiate the trainer
trainer = MGNTrainer()

# Run the training step
#for batch in trainer.dataloader:
#    loss = trainer.train(batch)
#    print(f"Training loss: {loss}")
#    break  # Only run one batch for the test


Normalizing data...


[16:57:36 - checkpoint - INFO] [92mLoaded model state dictionary /work/m24046/m24046mrcr/work/m24046/m24046mrcr/new_tests_Group3/config1/MeshGraphNet.0.1310.mdlus to device cuda[0m
[16:57:36 - checkpoint - INFO] [92mLoaded checkpoint file /work/m24046/m24046mrcr/work/m24046/m24046mrcr/new_tests_Group3/config1/checkpoint.0.1310.pt to device cuda[0m
[16:57:36 - checkpoint - INFO] [92mLoaded optimizer state dictionary[0m
[16:57:36 - checkpoint - INFO] [92mLoaded scheduler state dictionary[0m
[16:57:36 - checkpoint - INFO] [92mLoaded grad scaler state dictionary[0m


In [5]:
 for epoch in range(100):
    total_loss = 0 
    for batch in trainer.dataloader:
        loss = trainer.train(batch)
        total_loss += loss.item()
        print(f"batch loss: {loss.item()}")
                
    total_loss = total_loss / len(trainer.dataloader)
    print(f"Training loss: {total_loss}")


[tensor(0.2707, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.4414, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.2742, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.3573, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.2783, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.0559, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.4228, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.2855, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.2077, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.2994, device='cuda:0', grad_fn=<AddBackward0>)]
batch loss: 2.8932418823242188
[tensor(0.1902, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.1089, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.1142, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.0834, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.2730, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.3146, device='cuda:0', grad_fn=<AddBackward0>), tensor(0.2531, device='cuda:0', grad_fn=<AddBackward0>)


KeyboardInterrupt

