# Simulating Complex Physics with Graph Networks: step by step

## Overview

• By Peng Chen, Shiyu Li, Haochen Shi as part of Stanford CS224W course project. 

• This tutorial provides a step-by-step guide for how to build a Graph Network to simulate complex physics.

**Before we get started:**
- This Colab includes a concise PyG implementation of paper ***Learning to Simulate Complex Physics with Graph Networks*.
- We adapted our code from open-source tensorflow implementation by DeepMind.
    - Link to pdf of this paper: https://arxiv.org/abs/2002.09405
    - Link to Deepmind's implementation: https://github.com/deepmind/deepmind-research/tree/master/learning_to_simulate
    - Link to video site by DeepMind: https://sites.google.com/view/learning-to-simulate
- Run **sequentially run all cells in each section**, so intermediate variables / packages will carry over to next cell.


## Device

We recommend using a GPU for this Colab. Click `Runtime` then `Change runtime type`. Then set `hardware accelerator` to **GPU**.

## Setup

installation of PyG on Colab can be a little bit tricky. Before we get started, let's check which version of PyTorch you are running.

In [1]:
#!pip install rectpack

In [2]:
# import re

# def identify_text(text):
#     # Define a regex pattern to extract the text between the timestamp and newline
#     pattern = r"\[\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\] - (.*?)\n"
    
#     # Search for the pattern in the text
#     match = re.search(pattern, text)
    
#     if match:
#         # Extract the text between the timestamp and newline
#         extracted_text = match.group(1)
#         print(extracted_text)
#     else:
#         print("Pattern not found")

# # Sample texts
# text1 = "[2024-08-31 15:38:16] - rollout_path\ntemp/rollouts/WaterDrop"
# text2 = "[2024-08-31 15:38:17] - self.metadata - OneStepDataset\nLength:9\nType:<class 'dict'>..."

# # Test the function
# identify_text(text1)  # Should print "rollout_path"
# identify_text(text2)  # Should print "self.metadata - OneStepDataset"

In [3]:
# Dataset Source #1:
# https://drive.google.com/file/d/1ZmiKpsQVLFxPOIff-LfFkZwe5ZYG1FEb/view?usp=drive_link

# Dataset Source #2:
# https://drive.google.com/drive/mobile/folders/11uuYl0peqPg2DQno64YPYMODPu8fjDXU?usp=sharing

In [4]:
#!pip install torch

In [5]:
!export LD_LIBRARY_PATH=/home/admin1/anaconda3/envs/GNN/lib:$LD_LIBRARY_PATH
!export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH    

In [6]:
import random
import os
from PIL import Image, ImageDraw, ImageFont, ImageOps
imageindex = 0

In [7]:
import re

def remove_timestamp(log_entry):
    # Use regex to match the timestamp pattern and remove it
    cleaned_entry = re.sub(r'\[\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\] - ', '', log_entry)
    return cleaned_entry

def text_to_image_function(text, font_size, output_file, selected_font_name):
    # Remove timestamp
    text = remove_timestamp(text)

    # Define the initial image size and other properties
    initial_max_width = 640
    initial_max_height = 640
    background_color = "white"
    text_color = "black"
    border_color = "black"
    padding = 20
    border_width = 1

    # Initialize variables for the actual size
    required_width = initial_max_width
    required_height = initial_max_height

    # Load the font
    font = ImageFont.truetype(selected_font_name, font_size)

    # Create a temporary image to measure the text size
    temp_image = Image.new("RGB", (initial_max_width, initial_max_height), background_color)
    draw = ImageDraw.Draw(temp_image)
    
    # Measure the text size
    text_size = draw.textsize(text, font=font)
    #text_size = draw.textbbox((0, 0), text, font=font)

    # Calculate the required size based on the measured text size
    required_width = text_size[0] + 2 * (padding + border_width)
    required_height = text_size[1] + 2 * (padding + border_width)

    # Ensure the image is not smaller than the initial size
    required_width = max(required_width, initial_max_width)
    required_height = max(required_height, initial_max_height)

    # Create the final image with the calculated size
    image = Image.new("RGB", (required_width, required_height), background_color)
    draw = ImageDraw.Draw(image)

    # Draw the text on the image
    text_position = (padding, padding)
    draw.text(text_position, text, fill=text_color, font=font)

    # Draw a border around the text
    border_rectangle = [
        padding - border_width, 
        padding - border_width, 
        padding + text_size[0] + border_width, 
        padding + text_size[1] + border_width
    ]
    draw.rectangle(border_rectangle, outline=border_color, width=border_width)

    border_rectangle1 = [
        padding - border_width, 
        padding - border_width, 
        padding + text_size[0] + border_width + 1, 
        padding + text_size[1] + border_width + 1
    ]
    
    # Crop the image to the size of the border
    cropped_image = image.crop(border_rectangle1)

    # Save the cropped image
    cropped_image.save(output_file, "PNG")

    # Optionally, copy the cropped image to the clipboard (requires `pyperclip` and `Pillow` integration)
    # pyperclip.copy(cropped_image)  # Not directly supported; requires custom implementation
    
    

In [8]:
import os
import torch
print(f"PyTorch has version {torch.__version__} with cuda {torch.version.cuda}")

PyTorch has version 1.12.0+cu102 with cuda 10.2


• Download necessary packages for PyG. 

• ensure your version of torch matches output from cell above. 

• In case of any issues, more information may be found on [PyG's installation page](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html)

!pip3 install torch==1.12.1+cu102 torchvision==0.13.1+cu102 torchaudio==0.12.1 torchtext --extra-index-url https://download.pytorch.org/whl/cu102

!pip install https://data.pyg.org/whl/torch-1.12.0%2Bcu102/torch_cluster-1.6.0%2Bpt112cu102-cp37-cp37m-linux_x86_64.whl

!pip install https://data.pyg.org/whl/torch-1.12.0%2Bcu102/torch_scatter-2.1.0%2Bpt112cu102-cp37-cp37m-linux_x86_64.whl

!pip install https://data.pyg.org/whl/torch-1.12.0%2Bcu102/torch_sparse-0.6.16%2Bpt112cu102-cp37-cp37m-linux_x86_64.whl

!pip install torch-geometric

!pip install matplotlib

!pip install networkx


# Dataset Preparation
!cd /home/admin1/Desktop/gnndataset/datasets/WaterDrop/

# metadata.json
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1o6cKxgbnfUUFPTX1JngBzB928w2bUIwk' -O metadata.json

# test_offset.json
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1vr4JiVliKCQNWVV4kziyusxNVUvQuAYL' -O test_offset.json

# test_particle_type.dat
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1Z_r9ivdKqKZzVJG80gb2uY6JDVRd0wAt' -O test_particle_type.dat

# test_position.dat
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1wCeBz1pZ5hxmlqWw4eylajg6pzFgQjIJ' -O test_position.dat

# train_offset.json
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=160wnp9PEc1HuzsBi7kO0ryMu3tnon2tI' -O train_offset.json

# train_particle_type.dat
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1LVtGLld7assF4sPk0mF2Bz2F7FBaxU0O' -O train_particle_type.dat

# train_position.dat
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1YCXcir_fmJZLvXkbPjchsrr8VuuWugH0' -O train_position.dat

# valid_offset.json
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1tiDP5uHMJQDTNxyRNSb6sEZCWAADPu8a' -O valid_offset.json

# valid_particle_type.dat
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1fXIw9RWM0xzfK2sGn1H0DaAOxzm59ZEd' -O valid_particle_type.dat

# valid_position.dat
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1U9QuV3Ra0E1tDD1HgXYCYyn4SeLKXQGs' -O valid_position.dat


## Dataset

• Dataset WaterDropSmall includes 100 videos of dropping water to ground rendered in a particle-based physics simulator. 

• It is a cropped version of WaterDrop dataset by Deepmind. 

• will download this dataset from Google Cloud stoarge to folder `temp/datasets` in file system. 

• may inspect downloaded files on **Files** menu on left of this Colab.

`metadata.json` file in dataset includes following information:
1. sequence length of each video data point
2. dimensionality, 2d or 3d
3. box bounds - specify bounding box for scene
4. default connectivity radius - defines size of each particle's neighborhood
5. statistics for normalization e.g. velocity mean and standard deviation and acceleration of particles


Each data point in dataset includes following information:
1. Particle type, such as water
2. particle positions at each frame in video

In [9]:
from datetime import datetime
import inspect
import os

# Global flags to enable/disable debugging and verbosity
DEBUG_ENABLED = True
VERBOSE_ENABLED = False

# Global dictionary to store logged headers and their counts
logged_header_counts = {}
folders_created = []  # Initialize an empty list

def debug_log(theVariable, functionName=None, ShowShape=False, ShowLength=False, ShowType=False, ExplicitVariableName=None):
    
    print("#################", flush=True)
    print("## theVariable ##", flush=True)
    print("#################", flush=True)
    print(theVariable, flush=True)
    print("", flush=True)    
    
    global logged_header_counts  # Access the global dictionary
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    frame = inspect.currentframe().f_back
    variable_names = [name for name, val in frame.f_locals.items() if val is theVariable]
    theVariableName = variable_names[0] if variable_names else ExplicitVariableName

    print("#####################", flush=True)
    print("## theVariableName ##", flush=True)
    print("#####################", flush=True)
    print(theVariableName, flush=True)
    print("", flush=True)    
    
    thefilename = ''
    if functionName is not None:
        functionName = functionName.replace("\\", "_")

    if DEBUG_ENABLED:
        
        if ExplicitVariableName is None:
            # INCLUDE functionName
            if functionName:
                header = f"{theVariableName} - {functionName}"
                log_message = f"{timestamp} {header}\n"
                thefilename = header
            else:
                # EXCLUDE functionName
                header = f"{theVariableName}"
                log_message = f"{timestamp} {header}\n"
                thefilename = header
        else:
                header = ExplicitVariableName
                log_message = f"{timestamp} {header}\n"
                thefilename = header


        print("############", flush=True)
        print("## header ##", flush=True)
        print("############", flush=True)
        print(header, flush=True)
        print("", flush=True)

            
        # Check if the header has been logged less than 2 times
        if header in logged_header_counts:
            if logged_header_counts[header] >= 1:
            # if logged_header_counts[header] >= 2:
                return  # Skip logging if the header has been logged twice
            else:
                logged_header_counts[header] += 1  # Increment the count
        else:
            logged_header_counts[header] = 1  # Add new header to the dictionary with a count of 1

        #if ShowShape:
        #            log_message += "Shape:" + str(theVariable.shape) + "\n"
        if ShowShape:
            if hasattr(theVariable, 'shape'):
                log_message += "Shape:" + str(theVariable.shape) + "\n"
            else:
                log_message += "Shape: Not applicable\n"

        
        # if ShowLength:
        #     if isinstance(theVariable, torch.Tensor):
        #         if theVariable.dim() == 1:
        #             log_message += "Length:" + str(len(theVariable)) + "\n"
        #         else:
        #             length = theVariable.numel()
        #             log_message += "Length:" + str(length) + "\n"
        #     else:
        #         log_message += "Length:" + str(len(theVariable)) + "\n"

        if ShowLength:
            try:
                if isinstance(theVariable, torch.Tensor):
                    if theVariable.dim() == 1:
                        log_message += "Length:" + str(len(theVariable)) + "\n"
                    else:
                        length = theVariable.numel()  # Total number of elements
                        log_message += "Length:" + str(length) + "\n"
                else:
                    log_message += "Length:" + str(len(theVariable)) + "\n"
            except (TypeError, AttributeError):
                log_message += "Length: Not applicable\n"

        
        # if ShowType:
        #     if isinstance(theVariable, torch.Tensor):
        #         log_message += "Type:" + str(theVariable.dtype) + "\n"
        #     else:
        #         log_message += "Type:" + str(type(theVariable)) + "\n"

        if ShowType:
            if isinstance(theVariable, torch.Tensor):
                if hasattr(theVariable, 'dtype'):
                    log_message += "Type:" + str(theVariable.dtype) + "\n"
                else:
                    log_message += "Type: No dtype attribute\n"
            else:
                log_message += "Type:" + str(type(theVariable)) + "\n"

        
        # VARIABLE CONTENTS
        log_message += str(theVariable) + "\n"

        # Create an image
        global imageindex
        imageindex = imageindex + 1
        thefilename = thefilename.replace("\\", "_")
        text = log_message
        # Create the folder if it does not exist
        print("##################", flush=True)
        print("## functionName ##", flush=True)
        print("##################", flush=True)
        print(functionName, flush=True)
        print("", flush=True)

        print("#################", flush=True)
        print("## thefilename ##", flush=True)
        print("#################", flush=True)
        print(thefilename, flush=True)
        print("", flush=True)

        # Create Folder if it does nto exist
        os.makedirs("outputpng", exist_ok=True)
        os.makedirs(r"outputpng/" + str(functionName), exist_ok=True)
        
        folder_path = str(functionName)
        # Append the string only if it doesn't already exist in the list
        if folder_path not in folders_created:
            folders_created.append(folder_path)
            
        output_file = os.path.join(r"outputpng", str(functionName), f"{imageindex:07d} {thefilename}.png")
        # output_file = os.path.join("outputpng\\" + str(functionName), f"{imageindex:07d}{thefilename}.png")
        text_to_image_function(text, 16, output_file, "/usr/share/fonts/truetype/freefont/Arial.ttf")

        log_message += "---------------------------------------------------------" + "\n"

        # Get the current date and time
        current_date = datetime.now().strftime('%Y-%m-%d')

        # Write to log file
        with open(f'debugGNN_{current_date}.txt', 'a') as file:
            file.write(log_message)

    if VERBOSE_ENABLED:
        print(timestamp)
        
        if ShowShape:
            print("Shape:", theVariable.shape)
        if ShowLength:
            if isinstance(theVariable, torch.Tensor):
                if theVariable.dim() == 1:
                    print("Length:", str(len(theVariable)), flush=True)
                else:
                    length = theVariable.numel()
                    print("Length:", str(length), flush=True)
            else:
                print("Length:", str(len(theVariable)), flush=True)                    
        if ShowType:
            if isinstance(theVariable, torch.Tensor):
                print("Type:", str(theVariable.dtype), flush=True)  
            else:
                print("Type:", str(type(theVariable)), flush=True) 

        # VARIABLE CONTENTS                
        if functionName:
            print('#' * len("## " + theVariableName + ' ## ' + functionName + " ##"), flush=True)
            print("## " + theVariableName + ' ## ' + functionName + " ##", flush=True)
            print('#' * len("## " + theVariableName + ' ## ' + functionName + " ##"), flush=True)            
            print(str(theVariable), flush=True)
        else:
            print('#' * len("## " + theVariableName + " ##"), flush=True)
            print("## " + theVariableName + " ##", flush=True)
            print('#' * len("## " + theVariableName + " ##"), flush=True)            
            print(str(theVariable), flush=True)

print(logged_header_counts, flush=True)

{}


In [10]:
from datetime import datetime
import inspect
# Global flags to enable/disable debugging and verbosity
DEBUG_ENABLED = True
VERBOSE_ENABLED = False

def debug_log_old(theVariable, functionName=None, ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName=None):
    
    print("#################")
    print("## theVariable ##")
    print("#################")
    print(theVariable)
    print("")

    
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    frame = inspect.currentframe().f_back
    variable_names = [name for name, val in frame.f_locals.items() if val is theVariable]
    # theVariableName = variable_names[0]
    theVariableName = variable_names[0] if variable_names else ExplicitVariableName

    print("#####################")
    print("## theVariableName ##")
    print("#####################")
    print(theVariableName)
    print("")

    
#     if theVariableName == "unknown_variable":
#         frame_info = traceback.extract_stack(limit=2)[0]
#         log_message = f"[{timestamp}] - Variable name unknown in {frame_info.filename} at line {frame_info.lineno}\n"
#     else:
#         log_message = f"[{timestamp}] - {theVariableName}\n"
    
    thefilename=''
    if functionName is not None:
        functionName=functionName.replace("\\", "_")
    
    if DEBUG_ENABLED:
        # INCLUDE functionName
        if functionName:
            log_message = f"[{timestamp}] - {theVariableName} - {functionName}\n"
            thefilename = f"{theVariableName} - {functionName}"
        else:
            # EXCLUDE functionName
            log_message = f"[{timestamp}] - {theVariableName}\n"
            thefilename = f"{theVariableName}"
            
        if ShowShape:
            log_message += "Shape:" + str(theVariable.shape) + "\n"
        if ShowLength:
            if isinstance(theVariable, torch.Tensor):
                if theVariable.dim() == 1:
                    log_message += "Length:" + str(len(theVariable)) + "\n"
                else:
                    length = theVariable.numel()
                    log_message += "Length:" + str(length) + "\n"
            else:
                log_message += "Length:" + str(len(theVariable)) + "\n"
        if ShowType:
            if isinstance(theVariable, torch.Tensor):
                log_message += "Type:" + str(theVariable.dtype) + "\n"
            else:
                log_message += "Type:" + str(type(theVariable)) + "\n"

        # VARIABLE CONTENTS
        log_message += str(theVariable) + "\n"

        # Create an image
        global imageindex
        imageindex = imageindex + 1
        thefilename = thefilename.replace("\\", "_")
        text = log_message
        output_file = f"{imageindex:07d} {thefilename}.png"
        # text_to_image_function(text, 12, output_file, "/usr/share/fonts/truetype/freefont/FreeSans.ttf")
        text_to_image_function(text, 16, output_file, "/usr/share/fonts/truetype/freefont/Arial.ttf")

        
        log_message += "---------------------------------------------------------" + "\n"
            
        # Get the current date and time
        current_date = datetime.now().strftime('%Y-%m-%d')
        
        # with open('debugGNN.txt', 'a') as file:
        with open(f'debugGNN_{current_date}.txt', 'a') as file:
            file.write(log_message)

    
    if VERBOSE_ENABLED:
        print(timestamp)
        
        if ShowShape:
            print("Shape:", theVariable.shape)
        if ShowLength:
            if isinstance(theVariable, torch.Tensor):
                if theVariable.dim() == 1:
                    print("Length:", str(len(theVariable)))
                else:
                    length = theVariable.numel()
                    print("Length:", str(length))
            else:
                print("Length:", str(len(theVariable)))                    
        if ShowType:
            if isinstance(theVariable, torch.Tensor):
                print("Type:", str(theVariable.dtype))  
            else:
                print("Type:", str(type(theVariable))) 

        # VARIABLE CONTENTS                
        if functionName:
            print('#' * len("## " + theVariableName + ' ## ' + functionName + " ##"))
            print("## " + theVariableName + ' ## ' + functionName + " ##")
            print('#' * len("## " + theVariableName + ' ## ' + functionName + " ##"))            
            print(str(theVariable))
        else:
            print('#' * len("## " + theVariableName + " ##"))
            print("## " + theVariableName + " ##")
            print('#' * len("## " + theVariableName + " ##"))            
            print(str(theVariable))

In [11]:
# Example Usage:
abc = 123
debug_log("abc1", "NoFunctionaabbcc", ShowShape=True,ShowLength=True,ShowType=True)
debug_log("abc2", "NoFunctionaabbcc", ShowShape=True,ShowLength=True,ShowType=True)
debug_log("abc3", "NoFunctionaabbcc", ShowShape=True,ShowLength=True,ShowType=True)

#################
## theVariable ##
#################
abc1

#####################
## theVariableName ##
#####################
None

############
## header ##
############
None - NoFunctionaabbcc

##################
## functionName ##
##################
NoFunctionaabbcc

#################
## thefilename ##
#################
None - NoFunctionaabbcc

#################
## theVariable ##
#################
abc2

#####################
## theVariableName ##
#####################
None

############
## header ##
############
None - NoFunctionaabbcc

#################
## theVariable ##
#################
abc3

#####################
## theVariableName ##
#####################
None

############
## header ##
############
None - NoFunctionaabbcc



  text_size = draw.textsize(text, font=font)


In [12]:
print(logged_header_counts)

{'None - NoFunctionaabbcc': 1}


In [13]:
def debug_log_special(var):
    # Use inspect to find the variable name in the caller's frame
    frame = inspect.currentframe()
    try:
        caller_locals = frame.f_back.f_locals
        var_name = [name for name, value in caller_locals.items() if value is var]
        var_name = var_name[0] if var_name else "unknown"
    finally:
        del frame  # Clean up the frame to avoid reference cycles

    # Print the variable name and its content
    print(f"{var_name}: {var}")    
    
    with open('debugGNN1.txt', 'a') as file:
        file.write(f"{var_name}: {var}")


In [14]:
import os
import torch
print(f"PyTorch has version {torch.__version__} with cuda {torch.version.cuda}")

DATASET_NAME = "WaterDrop"
OUTPUT_DIR = os.path.join("/home/admin1/Desktop/GNN/gnndataset/datasets/WaterDrop")

debug_log(DATASET_NAME, ShowShape=True, ShowLength=True, ShowType=True)

debug_log(OUTPUT_DIR, ShowShape=True, ShowLength=True, ShowType=True)

# BASE_URL = f"https://storage.googleapis.com/cs224w_course_project_dataset/{DATASET_NAME}"

# !mkdir -p "$OUTPUT_DIR"

# META_DATA_PATH = f"{OUTPUT_DIR}/metadata.json"
# CLOUD_PATH = f"{BASE_URL}/metadata.json"
# !wget -O "$META_DATA_PATH" "$CLOUD_PATH"
# for split in ["test", "train", "valid"]:
#   for suffix in ["offset.json", "particle_type.dat", "position.dat"]:
#       DATA_PATH = f"{OUTPUT_DIR}/{split}_{suffix}"
#       CLOUD_PATH = f"{BASE_URL}/{split}_{suffix}"
#       !wget -O "$DATA_PATH" "$CLOUD_PATH"

PyTorch has version 1.12.0+cu102 with cuda 10.2
#################
## theVariable ##
#################
WaterDrop

#####################
## theVariableName ##
#####################
DATASET_NAME

############
## header ##
############
DATASET_NAME

##################
## functionName ##
##################
None

#################
## thefilename ##
#################
DATASET_NAME

#################
## theVariable ##
#################
/home/admin1/Desktop/GNN/gnndataset/datasets/WaterDrop

#####################
## theVariableName ##
#####################
OUTPUT_DIR

############
## header ##
############
OUTPUT_DIR

##################
## functionName ##
##################
None

#################
## thefilename ##
#################
OUTPUT_DIR



  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)


## Data Preprocessing

• Cannot apply raw data in dataset to train GNN model directly, so must perform below steps to convert raw data into graphs with descriptive node features and edge features:
1. Apply noise to trajectory to have more diverse training examples
1. Construct graph based on distance between particles
1. Extract node-level features: particle velocities and their distance to boundary
1. Extract edge-level features: displacement and distance between particles

In [15]:
import json
import numpy as np
import torch_geometric as pyg

def generate_noise(position_seq, noise_std):
    """Generate noise for a trajectory"""
    velocity_seq = position_seq[:, 1:] - position_seq[:, :-1]
    debug_log(velocity_seq, "generate_noise", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[1]] generate_noise--velocity_seq")


    time_steps = velocity_seq.size(1)
    debug_log(time_steps, "generate_noise", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[2]] generate_noise--time_steps")
    
    velocity_noise = torch.randn_like(velocity_seq) * (noise_std / time_steps ** 0.5)
    debug_log(velocity_noise, "generate_noise", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[3]] generate_noise--velocity-noise1")
    
    velocity_noise = velocity_noise.cumsum(dim=1)
    debug_log(velocity_noise, "generate_noise", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[4]] generate_noise--velocity-noise2")    
    
    position_noise = velocity_noise.cumsum(dim=1)
    debug_log(position_noise, "generate_noise", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[5]] generate_noise--position-noise1")
       
    position_noise = torch.cat((torch.zeros_like(position_noise)[:, 0:1], position_noise), dim=1)
    debug_log(position_noise, "generate_noise", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[6]] generate_noise--position-noise2")
    
    return position_noise


def preprocess(particle_type, position_seq, target_position, metadata, noise_std):
    """Preprocess a trajectory and construct graph"""
    # apply noise to trajectory
    position_noise = generate_noise(position_seq, noise_std)
    debug_log(position_noise, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[7]] preprocess--position_noise")
    
    position_seq = position_seq + position_noise
    debug_log(position_seq, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[8]] preprocess--position_seq")

    # calculate velocities of particles
    recent_position = position_seq[:, -1]
    debug_log(recent_position, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[9]] preprocess--recent_position")
    
    
    velocity_seq = position_seq[:, 1:] - position_seq[:, :-1]
    debug_log(velocity_seq, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[10]] preprocess--velocity_seq")
    
    
    # construct graph based on distances between particles
    n_particle = recent_position.size(0)
    debug_log(n_particle, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[11]] preprocess--n_particle")
    
    
    edge_index = pyg.nn.radius_graph(recent_position, metadata["default_connectivity_radius"], loop=True, max_num_neighbors=n_particle)
    debug_log(edge_index, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[12]] preprocess--edge_index")

    
    # node-level features: velocity, distance to boundary
    normal_velocity_seq = (velocity_seq - torch.tensor(metadata["vel_mean"])) / torch.sqrt(torch.tensor(metadata["vel_std"]) ** 2 + noise_std ** 2)
    debug_log(normal_velocity_seq, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[13]] preprocess--normal_velocity_seq")
    
    boundary = torch.tensor(metadata["bounds"])
    debug_log(boundary, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[14]] preprocess--boundary")
        
    distance_to_lower_boundary = recent_position - boundary[:, 0]
    debug_log(distance_to_lower_boundary, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[15]] preprocess--distance_to_lower_boundary")
        
    distance_to_upper_boundary = boundary[:, 1] - recent_position
    debug_log(distance_to_upper_boundary, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[16]] preprocess--distance_to_upper_boundary")
    
    distance_to_boundary = torch.cat((distance_to_lower_boundary, distance_to_upper_boundary), dim=-1)
    debug_log(distance_to_boundary, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[17]] preprocess--distance_to_boundary1")    
    
    
    distance_to_boundary = torch.clip(distance_to_boundary / metadata["default_connectivity_radius"], -1.0, 1.0)
    debug_log(distance_to_boundary, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[18]] preprocess--distance_to_boundary2")
    
    

    # edge-level features: displacement, distance
    dim = recent_position.size(-1)
    debug_log(dim, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[19]] dim-preprocess")
    
    edge_displacement = (torch.gather(recent_position, dim=0, index=edge_index[0].unsqueeze(-1).expand(-1, dim)) - torch.gather(recent_position, dim=0, index=edge_index[1].unsqueeze(-1).expand(-1, dim)))
    debug_log(edge_displacement, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[20]] preprocess--edge_displacement1")    
    
    
    edge_displacement /= metadata["default_connectivity_radius"]
    debug_log(edge_displacement, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[21]] preprocess--edge_displacement2")
    
    
    edge_distance = torch.norm(edge_displacement, dim=-1, keepdim=True)
    debug_log(edge_distance, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[22]] preprocess--edge_distance")    


    
    # ground truth for training
    if target_position is not None:
        last_velocity = velocity_seq[:, -1]
        debug_log(last_velocity, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[23]] preprocess--last_velocity")        
       
        
        next_velocity = target_position + position_noise[:, -1] - recent_position
        debug_log(next_velocity, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[24]] preprocess--next_velocity")        
        
        
        
        acceleration = next_velocity - last_velocity
        debug_log(acceleration, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[25]] preprocess--acceleration1")
        
        
        
        acceleration = (acceleration - torch.tensor(metadata["acc_mean"])) / torch.sqrt(torch.tensor(metadata["acc_std"]) ** 2 + noise_std ** 2)
        debug_log(acceleration, "preprocess", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[26]] preprocess--acceleration2")
        
        
        
    else:
        acceleration = None

    # return graph with features
    graph = pyg.data.Data(
        x=particle_type,
        edge_index=edge_index,
        edge_attr=torch.cat((edge_displacement, edge_distance), dim=-1),
        y=acceleration,
        pos=torch.cat((velocity_seq.reshape(velocity_seq.size(0), -1), distance_to_boundary), dim=-1)
    )
    return graph

  import scipy.cluster


### One Step Dataset

• Each datapoint in this dataset contains trajectories sliced to short time windows. 

• We use this dataset in training phase because history of particles' states are necessary for model to make predictions. 

• But in meantime, since long-horizon prediction is inaccurate and time-consuming, sliced trajectories to short time windows to improve perfomance of model.

In [16]:
class OneStepDataset(pyg.data.Dataset):
    def __init__(self, data_path, split, window_length=7, noise_std=0.0, return_pos=False):
        super().__init__()

        debug_log(data_path, "OneStepDataset", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[27]] OneStepDataset--_init_--data_path")

        
        # load dataset from disk
        with open(os.path.join(data_path, "metadata.json")) as f:
            self.metadata = json.load(f)
            debug_log(self.metadata, "OneStepDataset", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[28]] OneStepDataset--_init_--self.metadata")
            
        with open(os.path.join(data_path, f"{split}_offset.json")) as f:
            self.offset = json.load(f)
            # debug_log(self.offset, "OneStepDataset 1", ShowShape=True, ShowLength=True, ShowType=True)        
            debug_log(self.offset, "OneStepDataset 1", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[29]] OneStepDataset--_init_--self.offset1")
            
            
        self.offset = {int(k): v for k, v in self.offset.items()}
        # debug_log(self.offset, "OneStepDataset 2", ShowShape=True, ShowLength=True, ShowType=True)        
        debug_log(self.offset, "OneStepDataset 2", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[30]] OneStepDataset--_init_--self.offset2")
        
        
        self.window_length = window_length
        debug_log(window_length, "OneStepDataset", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[31]] OneStepDataset--_init_--window_length")
        
        
        
        self.noise_std = noise_std
        debug_log(noise_std, "OneStepDataset", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[32]] OneStepDataset--_init_--noise_std")

        
        self.return_pos = return_pos
        debug_log(return_pos, "OneStepDataset", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[33]] OneStepDataset--_init_--return_pos")
        
        

        self.particle_type = np.memmap(os.path.join(data_path, f"{split}_particle_type.dat"), dtype=np.int64, mode="r")
        # debug_log(self.particle_type, "OneStepDataset", ShowShape=True, ShowLength=True, ShowType=True)
        debug_log(self.particle_type, "OneStepDataset", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[34]] OneStepDataset--_init_--self.particle_type")
        
        self.position = np.memmap(os.path.join(data_path, f"{split}_position.dat"), dtype=np.float32, mode="r")
        # debug_log(self.position, "OneStepDataset", ShowShape=True, ShowLength=True, ShowType=True)
        debug_log(self.position, "OneStepDataset", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[35]] OneStepDataset--_init_--self.position")
        
        for traj in self.offset.values():
            self.dim = traj["position"]["shape"][2]
            # debug_log(self.dim, "OneStepDataset", ShowShape=True, ShowLength=True, ShowType=True)
            debug_log(self.dim, "OneStepDataset", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[36]] OneStepDataset--_init_--self.dim")
            
            break

        # cut particle trajectories according to time slices
        self.windows = []
        for traj in self.offset.values():
            size = traj["position"]["shape"][1]
            debug_log(size, "OneStepDataset", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[37]] OneStepDataset--traj--size")
            
            length = traj["position"]["shape"][0] - window_length + 1
            debug_log(length, "OneStepDataset", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[38]] OneStepDataset--traj--length")
            
            
            
            for i in range(length):
                desc = {
                    "size": size,
                    "type": traj["particle_type"]["offset"],
                    "pos": traj["position"]["offset"] + i * size * self.dim,
                }
                self.windows.append(desc)

    def len(self):
        return len(self.windows)

    def get(self, idx):
        # load corresponding data for this time slice
        window = self.windows[idx]
        debug_log(window, "get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[39]] OneStepDataset--get--window")
        
        
        size = window["size"]
        debug_log(size, "get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[40]] OneStepDataset--get--size")
        
        
        particle_type = self.particle_type[window["type"]: window["type"] + size].copy()
        debug_log(particle_type, "get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[41]] OneStepDataset--get--particle_type1")
                
        
        particle_type = torch.from_numpy(particle_type)
        debug_log(particle_type, "get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[42]] OneStepDataset--get--particle_type2")
        
        
        position_seq = self.position[window["pos"]: window["pos"] + self.window_length * size * self.dim].copy()
        debug_log(position_seq, "get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[43]] OneStepDataset--get--position_seq1")
        
        position_seq.resize(self.window_length, size, self.dim)
        debug_log(position_seq, "get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[44]] OneStepDataset--get--position_seq2")
        
        position_seq = position_seq.transpose(1, 0, 2)
        debug_log(position_seq, "get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[45]] OneStepDataset--get--position_seq3")
        
        target_position = position_seq[:, -1]
        debug_log(target_position, "get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[46]] OneStepDataset--get--target_position1")
        
        
        position_seq = position_seq[:, :-1]
        debug_log(position_seq, "get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[47]] OneStepDataset--get--position_seq4")
        
        target_position = torch.from_numpy(target_position)
        debug_log(target_position, "get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[48]] OneStepDataset--get--target_position2")
        
        
        position_seq = torch.from_numpy(position_seq)
        debug_log(position_seq, "get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[49]] OneStepDataset--get--position_seq5")
        

        # construct graph
        with torch.no_grad():
            graph = preprocess(particle_type, position_seq, target_position, self.metadata, self.noise_std)
        if self.return_pos:
            return graph, position_seq[:, -1]
        return graph

### Rollout Dataset

• Each datapoint in this dataset contains trajectories of particles over 1000 time frames. 

• This dataset used in evaluation phase to measure model's ability to make long-horizon predictions.

In [17]:
class RolloutDataset(pyg.data.Dataset):
    def __init__(self, data_path, split, window_length=7):
        super().__init__()

        # load data from disk
        with open(os.path.join(data_path, "metadata.json")) as f:
            self.metadata = json.load(f)
            # debug_log(self.metadata, "RolloutDataset\_init_", ShowShape=True, ShowLength=True, ShowType=True)
            debug_log(self.metadata, "RolloutDataset\_init_", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[50]] RolloutDataset--_init_--self.metadata")
            
        with open(os.path.join(data_path, f"{split}_offset.json")) as f:
            self.offset = json.load(f)
            # debug_log(self.offset, "RolloutDataset\_init_", ShowShape=True, ShowLength=True, ShowType=True)
            debug_log(self.offset, "RolloutDataset\_init_", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[51]] RolloutDataset1--_init_--self.offset")
            
        self.offset = {int(k): v for k, v in self.offset.items()}
        # debug_log(self.offset, "RolloutDataset\_init_", ShowShape=True, ShowLength=True, ShowType=True)
        debug_log(self.offset, "RolloutDataset\_init_", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[52]] RolloutDataset2--_init_--self.offset")
        
        self.window_length = window_length
        debug_log(window_length, "RolloutDataset\_init_", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[53]] RolloutDataset--_init_--window_length")
        
        

        self.particle_type = np.memmap(os.path.join(data_path, f"{split}_particle_type.dat"), dtype=np.int64, mode="r")
        # debug_log(self.particle_type, "RolloutDataset\_init_", ShowShape=True, ShowLength=True, ShowType=True)
        debug_log(self.particle_type, "RolloutDataset\_init_", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[54]] RolloutDataset--_init_--self.particle_type")
        
        self.position = np.memmap(os.path.join(data_path, f"{split}_position.dat"), dtype=np.float32, mode="r")
        # debug_log(self.position, "RolloutDataset\_init_", ShowShape=True, ShowLength=True, ShowType=True)
        debug_log(self.position, "RolloutDataset\_init_", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[55]] RolloutDataset--_init_--self.position")
        
        for traj in self.offset.values():
            self.dim = traj["position"]["shape"][2]
            break

    def len(self):
        return len(self.offset)

    def get(self, idx):
        traj = self.offset[idx]
        debug_log(traj, "RolloutDataset\get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[56]] RolloutDataset--get--traj")
        
        size = traj["position"]["shape"][1]
        debug_log(size, "RolloutDataset\get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[57]] RolloutDataset--get--size")
        
        
        time_step = traj["position"]["shape"][0]
        debug_log(time_step, "RolloutDataset\get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[58]] RolloutDataset--get--time_step")
        
        
        
        particle_type = self.particle_type[traj["particle_type"]["offset"]: traj["particle_type"]["offset"] + size].copy()
        debug_log(particle_type, "RolloutDataset\get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[59]] RolloutDataset1--get--particle_type")
        
        
        
        particle_type = torch.from_numpy(particle_type)
        debug_log(particle_type, "RolloutDataset\get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[60]] RolloutDataset2--get--particle_type")
        
        
        position = self.position[traj["position"]["offset"]: traj["position"]["offset"] + time_step * size * self.dim].copy()
        debug_log(position, "RolloutDataset\get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[61]] RolloutDataset1--get--position")
                
        
        position.resize(traj["position"]["shape"])
        debug_log(position, "RolloutDataset\get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[62]] RolloutDataset2--get--position")
        
        position = torch.from_numpy(position)
        debug_log(position, "RolloutDataset\get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[63]] RolloutDataset3--get--position")
        
        data = {"particle_type": particle_type, "position": position}
        debug_log(data, "RolloutDataset\get", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[64]] RolloutDataset--get--data")
        
        
        
        return data

### Visualize a graph in dataset

• Each data point in dataset is a `pyg.data.Data` object which describes a graph. 

• explain contents of first data point, visualize graph.

In [18]:
#!pip install numpy==1.23


## GNN Model

We will walk through implementation of GNN model in this section!

### Helper class

• first define a class for Multi-Layer Perceptron (MLP). 

• This class generates an MLP given width and depth of it. 

• Because MLPs are used in several places of GNN, this helper class will make code cleaner.

In [19]:
import math
import torch_scatter

class MLP(torch.nn.Module):
    """Multi-Layer perceptron"""
    def __init__(self, input_size, hidden_size, output_size, layers, layernorm=True):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        # debug_log(self.layers, "MLP\_init_", ShowShape=True, ShowLength=True, ShowType=True)
        debug_log(self.layers, "MLP\_init_", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[65]] MLP--_init_--self.layers")
        
        
        for i in range(layers):
            self.layers.append(torch.nn.Linear(
                input_size if i == 0 else hidden_size,
                output_size if i == layers - 1 else hidden_size,
            ))
            
            
            if i != layers - 1:
                self.layers.append(torch.nn.ReLU())
                # debug_log(self.layers, "MLP\_init_\i", ShowShape=True, ShowLength=True, ShowType=True)
                debug_log(self.layers, "MLP\_init_\i", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[66]] MLP--_init_--i--self.layers")
                
                
        if layernorm:
            self.layers.append(torch.nn.LayerNorm(output_size))
            # debug_log(self.layers, "MLP", ShowShape=True, ShowLength=True, ShowType=True)
            debug_log(self.layers, "MLP", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[67]] MLP--self.layers")
            
            
            
        self.reset_parameters()

    def reset_parameters(self):
        
        
        for layer in self.layers:
            debug_log(layer, "MLP\reset_parameters", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[68]] MLP1--reset_parameters--layer")
            
            
            if isinstance(layer, torch.nn.Linear):
                layer.weight.data.normal_(0, 1 / math.sqrt(layer.in_features))
                debug_log(layer, "MLP\reset_parameters", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[69]] MLP2--reset_parameters--layer")
                
                layer.bias.data.fill_(0)
                debug_log(layer, "MLP\reset_parameters", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[70]] MLP3--reset_parameters--layer")
                

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

### GNN layers

In following code block, we implement one type of GNN layer named `InteractionNetwork` (IN), which is proposed by paper *Interaction Networks for Learning about Objects,
Relations and Physics*.

• For a graph $G$, let feature of node $i$ be $v_i$, feature of edge $(i, j)$ be $e_{i, j}$. 

• three stages for IN to generate new features of nodes and edges.

1. **Message generation.**

• If there is an edge pointing from node $i$ to node $j$, node $i$ sends a message to node $j$. 

• message carries information of edge and its two nodes, so it is generated by following equation $\mathrm{Msg}_{i,j} = \mathrm{MLP}(v_i, v_j, e_{i,j})$.

2. **Message aggregation.**

• In this stage, each node of graph aggregates all messages it received to a fixed-sized representation. 

• In IN, aggregation means summing all messages up, i.e., $\mathrm{Agg}_i=\sum_{(j,i)\in G}\mathrm{Msg}_{i,j}$.

3. **Update.**

• update features of nodes and edges with results of previous stages. 

• For each edge, its new feature is sum of its old feature and correspond message, i.e., $e'_{i,j}=e_{i,j}+\mathrm{Msg}_{i,j}$. 

• For each node, new feature is determined by its old feature and aggregated message, i.e., $v'_i=v_i+\mathrm{MLP}(v_i, \mathrm{Agg}_i)$.

• In PyG, GNN layers are implemented as subclass of `MessagePassing`. 

• must override three critical functions to implement `InteractionNetwork` GNN layer. 

• Each function corresponds to one stage of GNN layer.

1. `message()` -> message generation

• This function controls how a message is generated on each edge of graph. 

• It takes three arguments:

• (1) `x_i`, features of source nodes; 

• (2) `x_j`, features of target nodes; 

• (3) `edge_feature`, features of edges themselves. 

• In IN, concatenate all these features and generate messages with an MLP.

1. `aggregate()` -> message aggregation

• This function aggregates messages for nodes. 

• It depends on two arguments:

• (1) `inputs`, messages; 

• (2) `index`, graph structure. 

• handle over task of message aggregation to function `torch_scatter.scatter` and specifies in argument `reduce` that want to sum messages up. 

• Because want to retain messages themselves to update edge features, return both messages and aggregated messages.

1. `forward()` -> update

• This function puts everything together. 

• `x` is node features, `edge_index` is graph structure and `edge_feature` is edge features. 

• function`MessagePassing.propagate` invokes functions `message` and `aggregate` for us. 

• Then, update node features and edge features and return them.

In [20]:
class InteractionNetwork(pyg.nn.MessagePassing):
    """Interaction Network as proposed in this paper:
    https://proceedings.neurips.cc/paper/2016/hash/3147da8ab4a0437c15ef51a5cc7f2dc4-Abstract.html"""
    def __init__(self, hidden_size, layers):
        super().__init__()
        self.lin_edge = MLP(hidden_size * 3, hidden_size, hidden_size, layers)
        self.lin_node = MLP(hidden_size * 2, hidden_size, hidden_size, layers)

    def forward(self, x, edge_index, edge_feature):
        edge_out, aggr = self.propagate(edge_index, x=(x, x), edge_feature=edge_feature)
        debug_log(edge_out, r"InteractionNetwork\forward", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[71]] InteractionNetwork--forward--edge_out1")
        
        node_out = self.lin_node(torch.cat((x, aggr), dim=-1))
        debug_log(node_out, r"InteractionNetwork\forward", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[72]] InteractionNetwork--forward--node_out1")

        
        
        edge_out = edge_feature + edge_out
        debug_log(edge_out, r"InteractionNetwork\forward", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[73]] InteractionNetwork--forward--edge_out2")
        
        
        
        node_out = x + node_out
        debug_log(node_out, r"InteractionNetwork\forward", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[74]] InteractionNetwork--forward--node_out2")
        
        
        
        return node_out, edge_out

    def message(self, x_i, x_j, edge_feature):
        x = torch.cat((x_i, x_j, edge_feature), dim=-1)
        debug_log(x, "InteractionNetwork\message", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[75]] InteractionNetwork--message--x1")
        
        
        
        x = self.lin_edge(x)
        debug_log(x, "InteractionNetwork\message", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[76]] InteractionNetwork--message--x2")
        
        
        
        return x

    def aggregate(self, inputs, index, dim_size=None):
        out = torch_scatter.scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum")
        debug_log(out, "InteractionNetwork\aggregate", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[77]] InteractionNetwork--aggregate--out")

        
        return (inputs, out)

### GNN

• Now its time to stack GNN layers to a GNN. 

• Besides GNN layers, pre-processing and post-processing blocks in GNN. 

• Before GNN layers, input features are transformed by MLP so expressiveness of GNN is improved without increasing GNN layers. 

• After GNN layers, final outputs (accelerations of particles in case) are extracted from features generated by GNN layers to meet requirement of task.

In [21]:
class LearnedSimulator(torch.nn.Module):
    """Graph Network-based Simulators(GNS)"""
    def __init__(
        self,
        hidden_size=128,
        n_mp_layers=10, # number of GNN layers
        num_particle_types=9,
        particle_type_dim=16, # embedding dimension of particle types
        dim=2, # dimension of world, typical 2D or 3D
        window_size=5, # model looks into W frames before frame to be predicted
    ):
        super().__init__()
        self.window_size = window_size
        debug_log(window_size, "LearnedSimulator\_init_", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[78]] LearnedSimulator--_init_--window_size")
        
        self.embed_type = torch.nn.Embedding(num_particle_types, particle_type_dim)
        # debug_log(self.embed_type, "LearnedSimulator\_init_", ShowShape=True, ShowLength=True, ShowType=True)
        debug_log(self.embed_type, "LearnedSimulator\_init_", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[79]] LearnedSimulator--_init_--self.embed_type")
        
        
        self.node_in = MLP(particle_type_dim + dim * (window_size + 2), hidden_size, hidden_size, 3)
        # debug_log(self.node_in, "LearnedSimulator\_init_", ShowShape=True, ShowLength=True, ShowType=True)
        debug_log(self.node_in, "LearnedSimulator\_init_", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[80]] LearnedSimulator--_init_--self.node_in")
        
        
        self.edge_in = MLP(dim + 1, hidden_size, hidden_size, 3)
        # debug_log(self.node_in, "LearnedSimulator\_init_", ShowShape=True, ShowLength=True, ShowType=True)
        debug_log(self.edge_in, "LearnedSimulator\_init_", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[81]] LearnedSimulator--_init_--self.edge_in")
        
        
        self.node_out = MLP(hidden_size, hidden_size, dim, 3, layernorm=False)
        # debug_log(self.node_out, "LearnedSimulator\_init_", ShowShape=True, ShowLength=True, ShowType=True)
        debug_log(self.node_out, "LearnedSimulator\_init_", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[82]] LearnedSimulator--_init_--self.node_out")
        
        
        self.n_mp_layers = n_mp_layers
        debug_log(n_mp_layers, "LearnedSimulator\_init_", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[83]] LearnedSimulator--_init_--n_mp_layers")

        
        
        self.layers = torch.nn.ModuleList([InteractionNetwork(
            hidden_size, 3
        ) for _ in range(n_mp_layers)])

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.embed_type.weight)

    def forward(self, data):
        # pre-processing
        # node feature: combine categorial feature data.x and contiguous feature data.pos.
        node_feature = torch.cat((self.embed_type(data.x), data.pos), dim=-1)
        debug_log(node_feature, r"LearnedSimulator\forward", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[84]] LearnedSimulator--forward--node_feature1")
        
        
        
        node_feature = self.node_in(node_feature)
        debug_log(node_feature, r"LearnedSimulator\forward", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[85]] LearnedSimulator--forward--node_feature2")
        
        
        
        edge_feature = self.edge_in(data.edge_attr)
        debug_log(edge_feature, r"LearnedSimulator\forward", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[86]] LearnedSimulator--forward--edge_feature")

        
        
        # stack of GNN layers
        for i in range(self.n_mp_layers):
            node_feature, edge_feature = self.layers[i](node_feature, data.edge_index, edge_feature=edge_feature)
            debug_log(node_feature, r"LearnedSimulator\forward\i", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[87]] LearnedSimulator--forward--i--node_feature")
            debug_log(edge_feature, r"LearnedSimulator\forward\i", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[88]] LearnedSimulator--forward--i--edge_feature")
            
        # post-processing
        out = self.node_out(node_feature)
        
        
        return out

## Training

• Before start training model, let's configure hyperparameters! 

• Since accessible computaion power is limited in Colab, will only run 1 epoch of training, which takes about 1.5 hour. 

• won't produce as accurate results as shown in original paper in this Colab. 

• provide a checkpoint of training model on entire WaterDrop dataset for 5 epochs, which takes about 14 hours with a GeForce RTX 3080 Ti.

In [22]:
data_path = OUTPUT_DIR
debug_log(data_path, ShowShape=True, ShowLength=True, ShowType=True)


model_path = os.path.join("temp", "models", DATASET_NAME)
debug_log(model_path, ShowShape=True, ShowLength=True, ShowType=True)


rollout_path = os.path.join("temp", "rollouts", DATASET_NAME)
debug_log(rollout_path, ShowShape=True, ShowLength=True, ShowType=True)


!mkdir -p "$model_path"
!mkdir -p "$rollout_path"

params = {
    "epoch": 1,
    #"epoch": 20,
    "batch_size": 4,
    "lr": 1e-4,
    "noise": 3e-4,
    "save_interval": 1000,
    "eval_interval": 1000,
    "rollout_interval": 200000,
}

#################
## theVariable ##
#################
/home/admin1/Desktop/GNN/gnndataset/datasets/WaterDrop

#####################
## theVariableName ##
#####################
OUTPUT_DIR

############
## header ##
############
OUTPUT_DIR

#################
## theVariable ##
#################
temp/models/WaterDrop

#####################
## theVariableName ##
#####################
model_path

############
## header ##
############
model_path

##################
## functionName ##
##################
None

#################
## thefilename ##
#################
model_path

#################
## theVariable ##
#################
temp/rollouts/WaterDrop

#####################
## theVariableName ##
#####################
rollout_path

############
## header ##
############
rollout_path

##################
## functionName ##
##################
None

#################
## thefilename ##
#################
rollout_path



  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)


Below are some helper functions for evaluation.

In [23]:
def rollout(model, data, metadata, noise_std):
    device = next(model.parameters()).device
    debug_log(device, "rollout", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[89]] rollout--device")
    
    model.eval()
    
    window_size = model.window_size + 1
    debug_log(window_size, "rollout", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[90]] rollout--window_size")
   
    total_time = data["position"].size(0)
    debug_log(total_time, "rollout", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[91]] rollout--total_time")

    
    traj = data["position"][:window_size]
    debug_log(traj, "rollout", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[92]] rollout--traj1")
    
    
    traj = traj.permute(1, 0, 2)
    debug_log(traj, "rollout", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[93]] rollout--traj2")
    
    
    particle_type = data["particle_type"]
    debug_log(particle_type, "rollout", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[94]] rollout--particle_type")
    

    for time in range(total_time - window_size):
        with torch.no_grad():
            graph = preprocess(particle_type, traj[:, -window_size:], None, metadata, 0.0)
            debug_log(graph, "rollout\time", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[95]] rollout--time--graph1")

            
            
            graph = graph.to(device)
            debug_log(graph, "rollout\time", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[96]] rollout--time--graph2")            

            
            
            acceleration = model(graph).cpu()
            debug_log(acceleration, "rollout\time", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[97]] rollout--time--acceleration1")

            
            
            acceleration = acceleration * torch.sqrt(torch.tensor(metadata["acc_std"]) ** 2 + noise_std ** 2) + torch.tensor(metadata["acc_mean"])
            debug_log(acceleration, "rollout\time", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[98]] rollout--time--acceleration2")

                        

            recent_position = traj[:, -1]
            debug_log(recent_position, "rollout\time", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[99]] rollout--time--recent_position")
            
            
            
            recent_velocity = recent_position - traj[:, -2]
            debug_log(recent_velocity, "rollout\time", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[100]] rollout--time--recent_velocity")
            
            
            
            new_velocity = recent_velocity + acceleration
            debug_log(new_velocity, "rollout\time", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[101]] rollout--time--new_velocity1")

            
            
            new_position = recent_position + new_velocity
            debug_log(new_velocity, "rollout\time", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[102]] rollout--time--new_velocity2")            

            
            
            traj = torch.cat((traj, new_position.unsqueeze(1)), dim=1)
            debug_log(traj, "rollout\time", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[103]] rollout--time--traj")            

            

    return traj


def oneStepMSE(simulator, dataloader, metadata, noise):
    """Returns two values, loss and MSE"""
    total_loss = 0.0
    total_mse = 0.0
    batch_count = 0
    simulator.eval()
    with torch.no_grad():
        scale = torch.sqrt(torch.tensor(metadata["acc_std"]) ** 2 + noise ** 2).cuda()
        for data in valid_loader:
            data = data.cuda()
            debug_log(data, "oneStepMSE\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[104]] oneStepMSE--data--data")
            
            
            
            pred = simulator(data)
            debug_log(pred, "oneStepMSE\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[105]] oneStepMSE--data--pred")
            
            
            
            mse = ((pred - data.y) * scale) ** 2
            debug_log(mse, "oneStepMSE\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[106]] oneStepMSE--data--mse1")
            
            
            
            mse = mse.sum(dim=-1).mean()
            debug_log(mse, "oneStepMSE\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[107]] oneStepMSE--data--mse2")

            
            
            loss = ((pred - data.y) ** 2).mean()
            debug_log(loss, "oneStepMSE\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[108]] oneStepMSE--data--loss")
            
            
            
            total_mse += mse.item()
            debug_log(total_mse, "oneStepMSE\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[109]] oneStepMSE--data--total_mse")

            
            
            total_loss += loss.item()
            debug_log(total_loss, "oneStepMSE\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[110]] oneStepMSE--data--total_loss")
            
           
            
            batch_count += 1
            debug_log(batch_count, "oneStepMSE\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[111]] oneStepMSE--data--batch_count")
            
            
            
    return total_loss / batch_count, total_mse / batch_count


def rolloutMSE(simulator, dataset, noise):
    total_loss = 0.0
    batch_count = 0
    simulator.eval()
    with torch.no_grad():
        for rollout_data in dataset:
            debug_log(rollout_data, "rolloutMSE\rollout_data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[112]] rolloutMSE--rollout_data--rollout_data")

            
            rollout_out = rollout(simulator, rollout_data, dataset.metadata, noise)
            debug_log(rollout_out, "rolloutMSE\rollout_data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[113]] rolloutMSE1--rollout_data--rollout_out")

            
            
            rollout_out = rollout_out.permute(1, 0, 2)
            debug_log(rollout_out, "rolloutMSE\rollout_data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[114]] rolloutMSE2--rollout_data--rollout_out")
            
            
            
            loss = (rollout_out - rollout_data["position"]) ** 2
            debug_log(loss, "rolloutMSE\rollout_data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[115]] rolloutMSE1--rollout_data--loss")
            
            
            loss = loss.sum(dim=-1).mean()
            debug_log(loss, "rolloutMSE\rollout_data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[116]] rolloutMSE2--rollout_data--loss")

            
            
            total_loss += loss.item()
            debug_log(total_loss, "rolloutMSE\rollout_data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[117]] rolloutMSE--rollout_data--total_loss")

            
            
            batch_count += 1
            debug_log(batch_count, "rolloutMSE\rollout_data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[118]] rolloutMSE--rollout_data--batch_count")

            
            
    return total_loss / batch_count

Here is main training loop!

In [24]:
from tqdm import tqdm

def train(params, simulator, train_loader, valid_loader, valid_rollout_dataset):
    loss_fn = torch.nn.MSELoss()
    debug_log(loss_fn, "train", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[119]] train--loss_fn")
    
    
    
    optimizer = torch.optim.Adam(simulator.parameters(), lr=params["lr"])
    debug_log(optimizer, "train", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[120]] train--optimizer")
    
       
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1 ** (1 / 5e6))
    debug_log(scheduler, "train", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[121]] train--scheduler")
    
    
    

    # recording loss curve
    train_loss_list = []
    eval_loss_list = []
    onestep_mse_list = []
    rollout_mse_list = []
    total_step = 0

    for i in range(params["epoch"]):
        simulator.train()
        
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {i}")
        debug_log(progress_bar, "train\i", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[122]] train--i--progress_bar")
        
        
        
        total_loss = 0
        debug_log(total_loss, "train\i", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[123]] train--i--total_loss")

        
        
        batch_count = 0
        debug_log(batch_count, "train\i", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[124]] train--i--batch_count")
        
        
        
        
        for data in progress_bar:
            optimizer.zero_grad()
            debug_log(optimizer, "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[125]] train--i--data--optimizer")
            
            
            
            data = data.cuda()
            debug_log(data, "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[126]] train--i--data--data")
            
            
            
            pred = simulator(data)
            debug_log(pred, "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[127]] train--i--data--pred")
            
            
            
            loss = loss_fn(pred, data.y)
            debug_log(loss, "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[128]] train--i--data--loss")

            
            
            loss.backward()


            optimizer.step()


            scheduler.step()


            total_loss += loss.item()
            debug_log(total_loss, "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[129]] train--i--data--total_loss")
            


            batch_count += 1
            debug_log(batch_count, "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[130]] train--i--data--batch_count")
            


            progress_bar.set_postfix({"loss": loss.item(), "avg_loss": total_loss / batch_count, "lr": optimizer.param_groups[0]["lr"]})


            total_step += 1
            debug_log(total_step, "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[131]] train--i--data--total_step")
            


            train_loss_list.append((total_step, loss.item()))
            debug_log(train_loss_list, "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[132]] train--i--data--train_loss_list")
            



            # evaluation
            if total_step % params["eval_interval"] == 0:
                simulator.eval()
                eval_loss, onestep_mse = oneStepMSE(simulator, valid_loader, valid_dataset.metadata, params["noise"])
                debug_log(eval_loss, "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[133]] train--i--data--eval_loss")


                eval_loss_list.append((total_step, eval_loss))
                debug_log(eval_loss_list, "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[134]] train--i--data--eval_loss_list")
                


                onestep_mse_list.append((total_step, onestep_mse))
                debug_log(onestep_mse_list, "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[135]] train--i--data--onestep_mse_list")
                



                tqdm.write(f"\nEval: Loss: {eval_loss}, One Step MSE: {onestep_mse}")
                simulator.train()

            # do rollout on valid set
            if total_step % params["rollout_interval"] == 0:
                simulator.eval()
                rollout_mse = rolloutMSE(simulator, valid_rollout_dataset, params["noise"])
                debug_log(rollout_mse, "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[136]] train--i--data--rollout_mse")


                rollout_mse_list.append((total_step, rollout_mse))
                debug_log(rollout_mse_list, "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[137]] train--i--data--rollout_mse_list")
                


                tqdm.write(f"\nEval: Rollout MSE: {rollout_mse}")
                simulator.train()

            # save model
            if total_step % params["save_interval"] == 0:
                debug_log(total_step, "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[138]] train--i--data--total_step")
                debug_log(params[save_interval], "train\i\data", ShowShape=True, ShowLength=True, ShowType=True, ExplicitVariableName = "[[139]] train--i--data--params[save_interval]")

                
                torch.save(
                    {
                        "model": simulator.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict(),
                    },
                    os.path.join(model_path, f"checkpoint_{total_step}.pt")
                )
    return train_loss_list, eval_loss_list, onestep_mse_list, rollout_mse_list

• let's load dataset and train model! 

• It takes roughly 1.5 hour to run this block on Colab with default parameters. 

• **If you are impatient, highly recommend you to skip next 2 blocks and load checkpoint provided to save some time;**

• **otherwise, make a cup of tea/coffee and come back later to see results of training!**

In [1]:
#######################################################################
## Just an illustation of how long strings of number are being used. ##
#######################################################################
import numpy as np
import torch

# 1. Load particle types (simulated data for example)
particle_types = np.array([4532172933455552486, 4532560202066482713, 4531245598476515001], dtype=np.int64)

# 2. Convert to PyTorch tensor
particle_type_tensor = torch.from_numpy(particle_types)

# 3. Create embedding layer
num_particle_types = 9  # Number of unique particle types
embedding_dim = 16      # Size of embedding vector
embed_layer = torch.nn.Embedding(num_particle_types, embedding_dim)

# 4. Get embedded representations
embedded_particles = embed_layer(particle_type_tensor)

print("Original particle type:", particle_types[0])
print("Embedded representation shape:", embedded_particles.shape)
print("First particle embedding:", embedded_particles[0])

IndexError: index out of range in self

In [None]:
# Training model is time-consuming. We highly recommend you to skip this block and load checkpoint in next block.

# load dataset
train_dataset = OneStepDataset(data_path, "train", noise_std=params["noise"])
debug_log(train_dataset, ShowShape=True, ShowLength=True, ShowType=True)

valid_dataset = OneStepDataset(data_path, "valid", noise_std=params["noise"])
debug_log(valid_dataset, ShowShape=True, ShowLength=True, ShowType=True)

train_loader = pyg.loader.DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True, pin_memory=True, num_workers=2)
debug_log(train_loader, ShowShape=True, ShowLength=True, ShowType=True)

valid_loader = pyg.loader.DataLoader(valid_dataset, batch_size=params["batch_size"], shuffle=False, pin_memory=True, num_workers=2)
debug_log(valid_loader, ShowShape=True, ShowLength=True, ShowType=True)

valid_rollout_dataset = RolloutDataset(data_path, "valid")
debug_log(valid_rollout_dataset, ShowShape=True, ShowLength=True, ShowType=True)

# build model
simulator = LearnedSimulator()

simulator = simulator.cuda()

# train model
train_loss_list, eval_loss_list, onestep_mse_list, rollout_mse_list = train(params, simulator, train_loader, valid_loader, valid_rollout_dataset)



#################
## theVariable ##
#################
/home/admin1/Desktop/GNN/gnndataset/datasets/WaterDrop

#####################
## theVariableName ##
#####################
data_path

############
## header ##
############
[[27]] OneStepDataset--_init_--data_path

##################
## functionName ##
##################
OneStepDataset

#################
## thefilename ##
#################
[[27]] OneStepDataset--_init_--data_path

#################
## theVariable ##
#################
{'bounds': [[0.1, 0.9], [0.1, 0.9]], 'sequence_length': 1000, 'default_connectivity_radius': 0.015, 'dim': 2, 'dt': 0.0025, 'vel_mean': [-3.964619574176163e-05, -0.00026272129664401046], 'vel_std': [0.0013722809722366911, 0.0013119977252142715], 'acc_mean': [2.602686518497945e-08, 1.0721623948191945e-07], 'acc_std': [6.742962470925277e-05, 8.700719180424815e-05]}

#####################
## theVariableName ##
#####################
[[28]] OneStepDataset--_init_--self.metadata

############
## header ##
####

  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
{0: {'particle_type': {'offset': 0, 'shape': [678]}, 'position': {'offset': 0, 'shape': [1001, 678, 2]}}, 1: {'particle_type': {'offset': 678, 'shape': [355]}, 'position': {'offset': 1357356, 'shape': [1001, 355, 2]}}, 2: {'particle_type': {'offset': 1033, 'shape': [461]}, 'position': {'offset': 2068066, 'shape': [1001, 461, 2]}}, 3: {'particle_type': {'offset': 1494, 'shape': [307]}, 'position': {'offset': 2990988, 'shape': [1001, 307, 2]}}, 4: {'particle_type': {'offset': 1801, 'shape': [300]}, 'position': {'offset': 3605602, 'shape': [1001, 300, 2]}}, 5: {'particle_type': {'offset': 2101, 'shape': [398]}, 'position': {'offset': 4206202, 'shape': [1001, 398, 2]}}, 6: {'particle_type': {'offset': 2499, 'shape': [362]}, 'position': {'offset': 5002998, 'shape': [1001, 362, 2]}}, 7: {'particle_type': {'offset': 2861, 'shape': [317]}, 'position': {'offset': 5727722, 'shape': [1001, 317, 2]}}, 8: {'particle_type': {'offset': 3178, 'shap

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
7

#####################
## theVariableName ##
#####################
window_length

############
## header ##
############
[[31]] OneStepDataset--_init_--window_length

##################
## functionName ##
##################
OneStepDataset

#################
## thefilename ##
#################
[[31]] OneStepDataset--_init_--window_length

#################
## theVariable ##
#################
0.0003

#####################
## theVariableName ##
#####################
noise_std

############
## header ##
############
[[32]] OneStepDataset--_init_--noise_std

##################
## functionName ##
##################
OneStepDataset

#################
## thefilename ##
#################
[[32]] OneStepDataset--_init_--noise_std

#################
## theVariable ##
#################
False

#####################
## theVariableName ##
#####################
return_pos

############
## header ##
############
[[33]] OneStepDataset--_init_--return

  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)


#####################
size

############
## header ##
############
[[37]] OneStepDataset--traj--size

#################
## theVariable ##
#################
995

#####################
## theVariableName ##
#####################
length

############
## header ##
############
[[38]] OneStepDataset--traj--length

#################
## theVariable ##
#################
317

#####################
## theVariableName ##
#####################
size

############
## header ##
############
[[37]] OneStepDataset--traj--size

#################
## theVariable ##
#################
995

#####################
## theVariableName ##
#####################
length

############
## header ##
############
[[38]] OneStepDataset--traj--length

#################
## theVariable ##
#################
451

#####################
## theVariableName ##
#####################
size

############
## header ##
############
[[37]] OneStepDataset--traj--size

#################
## theVariable ##
#################
995

###########

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



#####################
## theVariableName ##
#####################
layer

############
## header ##
############
[[69]] MLP2--reset_parameters--layer

#################
## theVariable ##
#################
Linear(in_features=128, out_features=128, bias=True)

#####################
## theVariableName ##
#####################
layer

############
## header ##
############
[[70]] MLP3--reset_parameters--layer

#################
## theVariable ##
#################
LayerNorm((128,), eps=1e-05, elementwise_affine=True)

#####################
## theVariableName ##
#####################
layer

############
## header ##
############
[[68]] MLP1--reset_parameters--layer

#################
## theVariable ##
#################
ModuleList()

#####################
## theVariableName ##
#####################
[[65]] MLP--_init_--self.layers

############
## header ##
############
[[65]] MLP--_init_--self.layers

#################
## theVariable ##
#################
ModuleList(
  (0): Linear(in_features=25

  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
Epoch 0:   0%|                                        | 0/24875 [00:00<?, ?it/s]

#################
## theVariable ##
#################
Epoch 0:   0%|                                        | 0/24875 [00:00<?, ?it/s]

#####################
## theVariableName ##
#####################
progress_bar

############
## header ##
############
[[122]] train--i--progress_bar

##################
## functionName ##
##################
train_i

#################
## thefilename ##
#################
[[122]] train--i--progress_bar

#################
## theVariable ##
#################
0

#####################
## theVariableName ##
#####################
total_step

############
## header ##
############
[[123]] train--i--total_loss

##################
## functionName ##
##################
train_i

#################
## thefilename ##
#################
[[123]] train--i--total_loss

#################
## theVariable ##
#################
0

#####################
## theVariableName ##
#####################
total_step

############
## header ##
############
[[124]] train--i--batch_count

##

  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)


##################################

## theVariable #### theVariable ##

##################################

{'size': 298, 'type': 8523, 'pos': 17235290}{'size': 834, 'type': 49954, 'pos': 100138012}



##########################################

## theVariableName #### theVariableName ##

##########################################

windowwindow



########################

## header #### header ##

########################

[[39]] OneStepDataset--get--window[[39]] OneStepDataset--get--window



####################################

## functionName #### functionName ##

####################################

getget



##################################

## thefilename #### thefilename ##

##################################

[[39]] OneStepDataset--get--window[[39]] OneStepDataset--get--window





  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
298

######################################

## theVariable #### theVariableName ##

######################################

834size



#################################

## theVariableName #### header ##

#################################

size[[40]] OneStepDataset--get--size



##############################

## header #### functionName ##

##############################

[[40]] OneStepDataset--get--sizeget



###################################

## functionName #### thefilename ##

###################################

[[40]] OneStepDataset--get--sizeget



#################

  text_size = draw.textsize(text, font=font)



## thefilename ##
#################
[[40]] OneStepDataset--get--size



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
[5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5]

#####################
## theVariableName ##
######################################

particle_type## theVariable ##


#############################

## header ##
############
[[41]] OneStepDataset--get--particle_type1[5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5

  text_size = draw.textsize(text, font=font)


##################
## functionName ##
##################
get

#################
## thefilename ##
#################
[[41]] OneStepDataset--get--particle_type1



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
[0.8883967  0.3199687  0.8897685  ... 0.11657852 0.34995112 0.14074798]

#####################
## theVariableName ##
#####################
position_seq

############
## header ##
############
[[43]] OneStepDataset--get--position_seq1

##################
## functionName ##
##################
get

#################
## thefilename ##
#################
[[43]] OneStepDataset--get--position_seq1



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
[[[0.8883967  0.3199687 ]
  [0.8897685  0.3059608 ]
  [0.886135   0.3469879 ]
  ...
  [0.36658156 0.12412727]
  [0.43722335 0.11601578]
  [0.33163002 0.14118138]]

 [[0.8884005  0.31769165]
  [0.88977116 0.30363458]
  [0.8861363  0.34479195]
  ...
  [0.3690738  0.12412269]
  [0.43908584 0.11613952]
  [0.33464614 0.14110887]]

 [[0.8884044  0.3153525 ]
  [0.8897761  0.30124477]
  [0.88613796 0.34253412]
  ...
  [0.37158018 0.12412111]
  [0.4409715  0.11625104]
  [0.33767855 0.14103411]]

 ...

 [[0.88841367 0.3104892 ]
  [0.8897833  0.29628107]
  [0.8861407  0.33783388]
  ...
  [0.3766522  0.12412423]
  [0.44481936 0.11643182]
  [0.34377918 0.14089642]]

 [[0.8884175  0.30796513]
  [0.8897865  0.29370707]
  [0.886142   0.33539152]
  ...
  [0.3791826  0.12412355]
  [0.44679564 0.11650595]
  [0.3468595  0.14082374]]

 [[0.8884219  0.30537885]
  [0.8897927  0.29106918]
  [0.8861442  0.3328874 ]
  ...
  [0.3817084  0.12412481]
  [0.44880

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
[0.6068777  0.20707317 0.5910884  ... 0.30910864 0.6791412  0.3018587 ]

#####################
## theVariableName ##
#####################
position_seq

############
## header ##
############
[[43]] OneStepDataset--get--position_seq1

##################
## functionName ##
##################
get

#################
## thefilename ##
#################
[[43]] OneStepDataset--get--position_seq1



  text_size = draw.textsize(text, font=font)


##################################

## theVariable ##
## theVariable ##
##################################

[[[0.8883967  0.3199687 ]
  [0.8884005  0.31769165]
  [0.8884044  0.3153525 ]
  ...
  [0.88841367 0.3104892 ]
  [0.8884175  0.30796513]
  [0.8884219  0.30537885]]

 [[0.8897685  0.3059608 ]
  [0.88977116 0.30363458]
  [0.8897761  0.30124477]
  ...
  [0.8897833  0.29628107]
  [0.8897865  0.29370707]
  [0.8897927  0.29106918]]

 [[0.886135   0.3469879 ]
  [0.8861363  0.34479195]
  [0.88613796 0.34253412]
  ...
  [0.8861407  0.33783388]
  [0.886142   0.33539152]
  [0.8861442  0.3328874 ]]

 ...

 [[0.36658156 0.12412727]
  [0.3690738  0.12412269]
  [0.37158018 0.12412111]
  ...
  [0.3766522  0.12412423]
  [0.3791826  0.12412355]
  [0.3817084  0.12412481]]

 [[0.43722335 0.11601578]
  [0.43908584 0.11613952]
  [0.4409715  0.11625104]
  ...
  [0.44481936 0.11643182]
  [0.44679564 0.11650595]
  [0.4488083  0.11657852]]

 [[0.33163002 0.14118138]
  [0.33464614 0.14110887]
  [0.33767855 

  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)


#################
################### theVariable ##

## theVariable ###################

#################
[[[0.6068777  0.20707317]
  [0.60752547 0.20417434]
  [0.6081577  0.20130847]
  ...
  [0.6094693  0.19565801]
  [0.6102009  0.19300392]
  [0.61096805 0.1905966 ]]

 [[0.5910884  0.20960028]
  [0.5914843  0.20661514]
  [0.5918747  0.20368075]
  ...
  [0.5927195  0.19795752]
  [0.5931793  0.19530351]
  [0.59369004 0.19286266]]

 [[0.5912196  0.20252024]
  [0.591632   0.19967951]
  [0.59204483 0.19687004]
  ...
  [0.59297293 0.19145095]
  [0.5934924  0.1889229 ]
  [0.59404975 0.18659285]]

 ...

 [[0.46208322 0.3144311 ]
  [0.46135092 0.30970663]
  [0.4608904  0.30498713]
  ...
  [0.46080825 0.2953966 ]
  [0.46096754 0.29043543]
  [0.4610456  0.28541487]]

 [[0.49122664 0.3378954 ]
  [0.49067688 0.33355263]
  [0.49046072 0.3292398 ]
  ...
  [0.4908704  0.31954634]
  [0.49117857 0.3142921 ]
  [0.49141696 0.30910864]]

 [[0.67831975 0.3316923 ]
  [0.6789196  0.32691386]
  [0.6792179  

  text_size = draw.textsize(text, font=font)



#################
## thefilename ##
#################
[[46]] OneStepDataset--get--target_position1



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
[[0.61096805 0.1905966 ]
 [0.59369004 0.19286266]
 [0.59404975 0.18659285]
 ...
 [0.4610456  0.28541487]
 [0.49141696 0.30910864]
 [0.6791412  0.3018587 ]]

#####################
## theVariableName ##
#####################
target_position

############
## header ##
############
[[46]] OneStepDataset--get--target_position1

##################
## functionName ##
##################
get

#################
## thefilename ##
#################
[[46]] OneStepDataset--get--target_position1



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
[[[0.6068777  0.20707317]
  [0.60752547 0.20417434]
  [0.6081577  0.20130847]
  [0.60880166 0.19845098]
  [0.6094693  0.19565801]
  [0.6102009  0.19300392]]

 [[0.5910884  0.20960028]
  [0.5914843  0.20661514]
  [0.5918747  0.20368075]
  [0.59229076 0.20077623]
  [0.5927195  0.19795752]
  [0.5931793  0.19530351]]

 [[0.5912196  0.20252024]
  [0.591632   0.19967951]
  [0.59204483 0.19687004]
  [0.59249514 0.19411397]
  [0.59297293 0.19145095]
  [0.5934924  0.1889229 ]]

 ...

 [[0.46208322 0.3144311 ]
  [0.46135092 0.30970663]
  [0.4608904  0.30498713]
  [0.4607243  0.3002418 ]
  [0.46080825 0.2953966 ]
  [0.46096754 0.29043543]]

 [[0.49122664 0.3378954 ]
  [0.49067688 0.33355263]
  [0.49046072 0.3292398 ]
  [0.49057606 0.32460326]
  [0.4908704  0.31954634]
  [0.49117857 0.3142921 ]]

 [[0.67831975 0.3316923 ]
  [0.6789196  0.32691386]
  [0.6792179  0.3220813 ]
  [0.6792846  0.3171421 ]
  [0.6792262  0.31211042]
  [0.6791521  0.3070

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[0.6110, 0.1906],
        [0.5937, 0.1929],
        [0.5940, 0.1866],
        ...,
        [0.4610, 0.2854],
        [0.4914, 0.3091],
        [0.6791, 0.3019]])

#####################
## theVariableName ##
#####################
target_position

############
## header ##
############
[[48]] OneStepDataset--get--target_position2

##################
## functionName ##
##################
get

#################
## thefilename ##
#################
[[48]] OneStepDataset--get--target_position2



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[[0.6069, 0.2071],
         [0.6075, 0.2042],
         [0.6082, 0.2013],
         [0.6088, 0.1985],
         [0.6095, 0.1957],
         [0.6102, 0.1930]],

        [[0.5911, 0.2096],
         [0.5915, 0.2066],
         [0.5919, 0.2037],
         [0.5923, 0.2008],
         [0.5927, 0.1980],
         [0.5932, 0.1953]],

        [[0.5912, 0.2025],
         [0.5916, 0.1997],
         [0.5920, 0.1969],
         [0.5925, 0.1941],
         [0.5930, 0.1915],
         [0.5935, 0.1889]],

        ...,

        [[0.4621, 0.3144],
         [0.4614, 0.3097],
         [0.4609, 0.3050],
         [0.4607, 0.3002],
         [0.4608, 0.2954],
         [0.4610, 0.2904]],

        [[0.4912, 0.3379],
         [0.4907, 0.3336],
         [0.4905, 0.3292],
         [0.4906, 0.3246],
         [0.4909, 0.3195],
         [0.4912, 0.3143]],

        [[0.6783, 0.3317],
         [0.6789, 0.3269],
         [0.6792, 0.3221],
         [0.6793, 0.3171],
    

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[[ 6.4778e-04, -2.8988e-03],
         [ 6.3223e-04, -2.8659e-03],
         [ 6.4397e-04, -2.8575e-03],
         [ 6.6763e-04, -2.7930e-03],
         [ 7.3159e-04, -2.6541e-03]],

        [[ 3.9589e-04, -2.9851e-03],
         [ 3.9041e-04, -2.9344e-03],
         [ 4.1604e-04, -2.9045e-03],
         [ 4.2874e-04, -2.8187e-03],
         [ 4.5979e-04, -2.6540e-03]],

        [[ 4.1240e-04, -2.8407e-03],
         [ 4.1282e-04, -2.8095e-03],
         [ 4.5031e-04, -2.7561e-03],
         [ 4.7779e-04, -2.6630e-03],
         [ 5.1945e-04, -2.5281e-03]],

        ...,

        [[-7.3230e-04, -4.7245e-03],
         [-4.6051e-04, -4.7195e-03],
         [-1.6612e-04, -4.7453e-03],
         [ 8.3953e-05, -4.8452e-03],
         [ 1.5929e-04, -4.9612e-03]],

        [[-5.4976e-04, -4.3428e-03],
         [-2.1616e-04, -4.3128e-03],
         [ 1.1533e-04, -4.6365e-03],
         [ 2.9433e-04, -5.0569e-03],
         [ 3.0819e-04, -5.2542e-03]]

  text_size = draw.textsize(text, font=font)


#################
################### theVariable ##

## theVariable ###################

#################
[[[0.8883967  0.3199687 ]
  [0.8884005  0.31769165]
  [0.8884044  0.3153525 ]
  [0.88840985 0.3129515 ]
  [0.88841367 0.3104892 ]
  [0.8884175  0.30796513]]

 [[0.8897685  0.3059608 ]
  [0.88977116 0.30363458]
  [0.8897761  0.30124477]
  [0.8897807  0.29879367]
  [0.8897833  0.29628107]
  [0.8897865  0.29370707]]

 [[0.886135   0.3469879 ]
  [0.8861363  0.34479195]
  [0.88613796 0.34253412]
  [0.88613945 0.34021473]
  [0.8861407  0.33783388]
  [0.886142   0.33539152]]

 ...

 [[0.36658156 0.12412727]
  [0.3690738  0.12412269]
  [0.37158018 0.12412111]
  [0.37411395 0.12412371]
  [0.3766522  0.12412423]
  [0.3791826  0.12412355]]

 [[0.43722335 0.11601578]
  [0.43908584 0.11613952]
  [0.4409715  0.11625104]
  [0.44288096 0.11634777]
  [0.44481936 0.11643182]
  [0.44679564 0.11650595]]

 [[0.33163002 0.14118138]
  [0.33464614 0.14110887]
  [0.33767855 0.14103411]
  [0.340721   0.14

  text_size = draw.textsize(text, font=font)






  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[[-1.1910e-04,  2.3559e-05],
         [-1.5191e-04,  8.5017e-05],
         [ 2.4898e-05, -2.9097e-05],
         [-7.4941e-05,  4.6808e-05],
         [-3.2453e-05, -8.7752e-05]],

        [[-2.2465e-05,  7.6370e-05],
         [-3.4197e-05,  7.2648e-05],
         [-5.4867e-05,  4.3461e-05],
         [ 1.0264e-05, -9.7250e-05],
         [ 1.3289e-04, -1.2871e-04]],

        [[-2.4693e-04, -2.1784e-04],
         [ 9.1019e-05, -8.2735e-05],
         [ 1.3491e-04,  1.0490e-04],
         [-6.2018e-05,  1.6419e-04],
         [ 1.9304e-04,  1.0882e-04]],

        ...,

        [[-5.3923e-05,  2.2973e-04],
         [-5.0115e-05, -2.1199e-04],
         [-1.9191e-05,  5.4799e-05],
         [ 5.1583e-05, -7.9330e-05],
         [ 1.0420e-04, -2.9990e-04]],

        [[ 1.0627e-04, -8.0945e-05],
         [ 9.3317e-05, -8.8124e-05],
         [-3.3442e-05, -1.4162e-04],
         [ 2.4614e-05,  9.6071e-05],
         [ 1.5872e-04,  6.3743e-05]]

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[0.8884, 0.3054],
        [0.8898, 0.2911],
        [0.8861, 0.3329],
        [0.8914, 0.3138],
        [0.8889, 0.3447],
        [0.8921, 0.2867],
        [0.8913, 0.1737],
        [0.8902, 0.1752],
        [0.8708, 0.1294],
        [0.8863, 0.1479],
        [0.8777, 0.1470],
        [0.8832, 0.2903],
        [0.8893, 0.2087],
        [0.8908, 0.2355],
        [0.8932, 0.2295],
        [0.8930, 0.1665],
        [0.8767, 0.1108],
        [0.8623, 0.1055],
        [0.8719, 0.1447],
        [0.8856, 0.2997],
        [0.8866, 0.2535],
        [0.8885, 0.2550],
        [0.8877, 0.1775],
        [0.8803, 0.1599],
        [0.8934, 0.1767],
        [0.8794, 0.1427],
        [0.8767, 0.1075],
        [0.8358, 0.1056],
        [0.4281, 0.1025],
        [0.7665, 0.1018],
        [0.6460, 0.1003],
        [0.4176, 0.1032],
        [0.8700, 0.1225],
        [0.8876, 0.2191],
        [0.8939, 0.2531],
        [0.8964, 0.2130],
        [0

  text_size = draw.textsize(text, font=font)


## header ##
############
[[4]] generate_noise--velocity-noise2

##################
## functionName ##
##################
generate_noise

#################
## thefilename ##
#################
[[4]] generate_noise--velocity-noise2



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[[-1.1910e-04,  2.3559e-05],
         [-3.9011e-04,  1.3214e-04],
         [-6.3622e-04,  2.1161e-04],
         [-9.5727e-04,  3.3790e-04],
         [-1.3108e-03,  3.7644e-04]],

        [[-2.2465e-05,  7.6370e-05],
         [-7.9126e-05,  2.2539e-04],
         [-1.9066e-04,  4.1787e-04],
         [-2.9192e-04,  5.1310e-04],
         [-2.6030e-04,  4.7962e-04]],

        [[-2.4693e-04, -2.1784e-04],
         [-4.0284e-04, -5.1841e-04],
         [-4.2385e-04, -7.1408e-04],
         [-5.0687e-04, -7.4556e-04],
         [-3.9685e-04, -6.6823e-04]],

        ...,

        [[-5.3923e-05,  2.2973e-04],
         [-1.5796e-04,  2.4746e-04],
         [-2.8119e-04,  3.1999e-04],
         [-3.5284e-04,  3.1319e-04],
         [-3.2028e-04,  6.4881e-06]],

        [[ 1.0627e-04, -8.0945e-05],
         [ 3.0586e-04, -2.5001e-04],
         [ 4.7200e-04, -5.6070e-04],
         [ 6.6276e-04, -7.7532e-04],
         [ 1.0122e-03, -9.2620e-04]]

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[[ 0.0000e+00,  0.0000e+00],
         [-1.1910e-04,  2.3559e-05],
         [-3.9011e-04,  1.3214e-04],
         [-6.3622e-04,  2.1161e-04],
         [-9.5727e-04,  3.3790e-04],
         [-1.3108e-03,  3.7644e-04]],

        [[ 0.0000e+00,  0.0000e+00],
         [-2.2465e-05,  7.6370e-05],
         [-7.9126e-05,  2.2539e-04],
         [-1.9066e-04,  4.1787e-04],
         [-2.9192e-04,  5.1310e-04],
         [-2.6030e-04,  4.7962e-04]],

        [[ 0.0000e+00,  0.0000e+00],
         [-2.4693e-04, -2.1784e-04],
         [-4.0284e-04, -5.1841e-04],
         [-4.2385e-04, -7.1408e-04],
         [-5.0687e-04, -7.4556e-04],
         [-3.9685e-04, -6.6823e-04]],

        ...,

        [[ 0.0000e+00,  0.0000e+00],
         [-5.3923e-05,  2.2973e-04],
         [-1.5796e-04,  2.4746e-04],
         [-2.8119e-04,  3.1999e-04],
         [-3.5284e-04,  3.1319e-04],
         [-3.2028e-04,  6.4881e-06]],

        [[ 0.0000e+00,  0.0000e+00],

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[[ 0.0000e+00,  0.0000e+00],
         [-1.1910e-04,  2.3559e-05],
         [-3.9011e-04,  1.3214e-04],
         [-6.3622e-04,  2.1161e-04],
         [-9.5727e-04,  3.3790e-04],
         [-1.3108e-03,  3.7644e-04]],

        [[ 0.0000e+00,  0.0000e+00],
         [-2.2465e-05,  7.6370e-05],
         [-7.9126e-05,  2.2539e-04],
         [-1.9066e-04,  4.1787e-04],
         [-2.9192e-04,  5.1310e-04],
         [-2.6030e-04,  4.7962e-04]],

        [[ 0.0000e+00,  0.0000e+00],
         [-2.4693e-04, -2.1784e-04],
         [-4.0284e-04, -5.1841e-04],
         [-4.2385e-04, -7.1408e-04],
         [-5.0687e-04, -7.4556e-04],
         [-3.9685e-04, -6.6823e-04]],

        ...,

        [[ 0.0000e+00,  0.0000e+00],
         [-5.3923e-05,  2.2973e-04],
         [-1.5796e-04,  2.4746e-04],
         [-2.8119e-04,  3.1999e-04],
         [-3.5284e-04,  3.1319e-04],
         [-3.2028e-04,  6.4881e-06]],

        [[ 0.0000e+00,  0.0000e+00],

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[[0.6069, 0.2071],
         [0.6074, 0.2042],
         [0.6078, 0.2014],
         [0.6082, 0.1987],
         [0.6085, 0.1960],
         [0.6089, 0.1934]],

        [[0.5911, 0.2096],
         [0.5915, 0.2067],
         [0.5918, 0.2039],
         [0.5921, 0.2012],
         [0.5924, 0.1985],
         [0.5929, 0.1958]],

        [[0.5912, 0.2025],
         [0.5914, 0.1995],
         [0.5916, 0.1964],
         [0.5921, 0.1934],
         [0.5925, 0.1907],
         [0.5931, 0.1883]],

        ...,

        [[0.4621, 0.3144],
         [0.4613, 0.3099],
         [0.4607, 0.3052],
         [0.4604, 0.3006],
         [0.4605, 0.2957],
         [0.4606, 0.2904]],

        [[0.4912, 0.3379],
         [0.4908, 0.3335],
         [0.4908, 0.3290],
         [0.4910, 0.3240],
         [0.4915, 0.3188],
         [0.4922, 0.3134]],

        [[0.6783, 0.3317],
         [0.6788, 0.3268],
         [0.6791, 0.3219],
         [0.6790, 0.3167],
    

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[0.6089, 0.1934],
        [0.5929, 0.1958],
        [0.5931, 0.1883],
        ...,
        [0.4606, 0.2904],
        [0.4922, 0.3134],
        [0.6782, 0.3067]])

#####################
## theVariableName ##
#####################
recent_position

############
## header ##
############
[[9]] preprocess--recent_position

##################
## functionName ##
##################
preprocess

#################
## thefilename ##
#################
[[9]] preprocess--recent_position



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[[ 5.2869e-04, -2.8753e-03],
         [ 3.6120e-04, -2.7573e-03],
         [ 3.9786e-04, -2.7780e-03],
         [ 3.4660e-04, -2.6667e-03],
         [ 3.7807e-04, -2.6156e-03]],

        [[ 3.7342e-04, -2.9088e-03],
         [ 3.3373e-04, -2.7854e-03],
         [ 3.0452e-04, -2.7120e-03],
         [ 3.2747e-04, -2.7235e-03],
         [ 4.9144e-04, -2.6875e-03]],

        [[ 1.6546e-04, -3.0586e-03],
         [ 2.5690e-04, -3.1100e-03],
         [ 4.2933e-04, -2.9517e-03],
         [ 3.9476e-04, -2.6945e-03],
         [ 6.2948e-04, -2.4507e-03]],

        ...,

        [[-7.8622e-04, -4.4948e-03],
         [-5.6455e-04, -4.7018e-03],
         [-2.8935e-04, -4.6728e-03],
         [ 1.2308e-05, -4.8520e-03],
         [ 1.9184e-04, -5.2679e-03]],

        [[-4.4349e-04, -4.4237e-03],
         [-1.6570e-05, -4.4819e-03],
         [ 2.8148e-04, -4.9472e-03],
         [ 4.8506e-04, -5.2716e-03],
         [ 6.5768e-04, -5.4051e-03]]

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[[0.8884, 0.3200],
         [0.8884, 0.3177],
         [0.8884, 0.3154],
         [0.8884, 0.3130],
         [0.8884, 0.3105],
         [0.8884, 0.3080]],

        [[0.8898, 0.3060],
         [0.8898, 0.3036],
         [0.8898, 0.3012],
         [0.8898, 0.2988],
         [0.8898, 0.2963],
         [0.8898, 0.2937]],

        [[0.8861, 0.3470],
         [0.8861, 0.3448],
         [0.8861, 0.3425],
         [0.8861, 0.3402],
         [0.8861, 0.3378],
         [0.8861, 0.3354]],

        ...,

        [[0.3666, 0.1241],
         [0.3691, 0.1241],
         [0.3716, 0.1241],
         [0.3741, 0.1241],
         [0.3767, 0.1241],
         [0.3792, 0.1241]],

        [[0.4372, 0.1160],
         [0.4391, 0.1161],
         [0.4410, 0.1163],
         [0.4429, 0.1163],
         [0.4448, 0.1164],
         [0.4468, 0.1165]],

        [[0.3316, 0.1412],
         [0.3346, 0.1411],
         [0.3377, 0.1410],
         [0.3407, 0.1410],
    

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
834

#####################
## theVariableName ##
#####################
n_particle

############
## header ##
############
[[11]] preprocess--n_particle

##################
## functionName ##
##################
preprocess

#################
## thefilename ##
#################
[[11]] preprocess--n_particle



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[[ 3.8147e-06, -2.2770e-03],
         [ 3.9339e-06, -2.3392e-03],
         [ 5.4240e-06, -2.4010e-03],
         [ 3.8147e-06, -2.4623e-03],
         [ 3.8147e-06, -2.5241e-03]],

        [[ 2.6822e-06, -2.3262e-03],
         [ 4.9472e-06, -2.3898e-03],
         [ 4.5896e-06, -2.4511e-03],
         [ 2.6226e-06, -2.5126e-03],
         [ 3.1590e-06, -2.5740e-03]],

        [[ 1.3113e-06, -2.1960e-03],
         [ 1.6689e-06, -2.2578e-03],
         [ 1.4901e-06, -2.3194e-03],
         [ 1.2517e-06, -2.3808e-03],
         [ 1.3113e-06, -2.4424e-03]],

        ...,

        [[ 2.4922e-03, -4.5747e-06],
         [ 2.5064e-03, -1.5870e-06],
         [ 2.5338e-03,  2.6077e-06],
         [ 2.5383e-03,  5.1409e-07],
         [ 2.5304e-03, -6.7800e-07]],

        [[ 1.8625e-03,  1.2375e-04],
         [ 1.8857e-03,  1.1152e-04],
         [ 1.9095e-03,  9.6731e-05],
         [ 1.9384e-03,  8.4050e-05],
         [ 1.9763e-03,  7.4126e-05]]

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[  3,   0,   5,  ..., 397, 734, 620],
        [  0,   0,   0,  ..., 833, 833, 833]])

#####################
## theVariableName ##
#####################
edge_index

############
## header ##
############
[[12]] preprocess--edge_index

##################
## functionName ##
##################
preprocess

#################
## thefilename ##
#################
[[12]] preprocess--edge_index



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
5

#####################
################### theVariableName ##

## theVariable #######################

#################time_steps


############
tensor([[[ 0.4046, -1.9412],
         [ 0.2854, -1.8535],
         [ 0.3115, -1.8689],
         [ 0.2750, -1.7862],
         [ 0.2974, -1.7482]],

        [[ 0.2941, -1.9661],
         [ 0.2658, -1.8744],
         [ 0.2450, -1.8199],
         [ 0.2613, -1.8284],
         [ 0.3781, -1.8016]],

        [[ 0.1460, -2.0774],
         [ 0.2111, -2.1156],
         [ 0.3339, -1.9980],
         [ 0.3093, -1.8069],
         [ 0.4764, -1.6257]],

        ...,

        [[-0.5315, -3.1445],
         [-0.3737, -3.2983],
         [-0.1778, -3.2768],
         [ 0.0370, -3.4099],
         [ 0.1648, -3.7189]],

        [[-0.2875, -3.0917],
         [ 0.0164, -3.1349],
         [ 0.2286, -3.4807],
         [ 0.3735, -3.7217],
         [ 0.4964, -3.8209]],

        [[ 0.3549, -3.4258],
         [ 0.2257, -

  text_size = draw.textsize(text, font=font)



##################
preprocess

#################
## thefilename ##
#################
[[13]] preprocess--normal_velocity_seq



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[[ 1.4786e-05,  1.3114e-04],
         [ 9.1346e-05,  7.6718e-06],
         [ 1.2528e-04, -9.7611e-05],
         [ 2.1357e-04, -7.5648e-05],
         [-1.8323e-04,  5.1911e-05]],

        [[ 4.1247e-05,  7.7859e-05],
         [ 9.4982e-05, -1.4180e-05],
         [ 2.1288e-05, -1.2925e-04],
         [ 4.6868e-05, -3.3948e-04],
         [ 1.2214e-04, -2.2966e-06]],

        [[-2.2177e-05, -1.3857e-04],
         [-1.0623e-04, -1.9847e-04],
         [-5.3886e-05, -8.2772e-05],
         [-5.1427e-05, -9.3546e-05],
         [ 4.3043e-05, -5.4113e-05]],

        ...,

        [[-6.4798e-05, -1.0856e-04],
         [-1.3388e-04, -5.9050e-05],
         [-2.0405e-04,  1.5249e-05],
         [-1.5141e-04,  3.4267e-04],
         [ 6.9576e-05,  4.0609e-04]],

        [[ 1.2758e-04,  5.4823e-05],
         [ 1.8896e-04,  8.5986e-05],
         [ 1.9130e-04,  1.1565e-06],
         [ 8.1164e-05, -2.9405e-05],
         [ 4.2326e-06,  2.1377e-05]]

  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[0.1000, 0.9000],
        [0.1000, 0.9000]])

#####################
## theVariableName ##
#####################
boundary

############
## header ##
############
[[14]] preprocess--boundary

##################
## functionName ##
##################
preprocess

#################
## thefilename ##
#################
[[14]] preprocess--boundary



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[0.5089, 0.0934],
        [0.4929, 0.0958],
        [0.4931, 0.0883],
        ...,
        [0.3606, 0.1904],
        [0.3922, 0.2134],
        [0.5782, 0.2067]])

#####################
################### theVariableName ##

####################### theVariable ##

distance_to_lower_boundary
#################

############
## header ##
############
tensor([[[ 1.4786e-05,  1.3114e-04],
         [ 1.0613e-04,  1.3881e-04],
         [ 2.3142e-04,  4.1196e-05],
         [ 4.4498e-04, -3.4452e-05],
         [ 2.6175e-04,  1.7459e-05]],

        [[ 4.1247e-05,  7.7859e-05],
         [ 1.3623e-04,  6.3679e-05],
         [ 1.5752e-04, -6.5572e-05],
         [ 2.0439e-04, -4.0505e-04],
         [ 3.2653e-04, -4.0735e-04]],

        [[-2.2177e-05, -1.3857e-04],
         [-1.2841e-04, -3.3703e-04],
         [-1.8229e-04, -4.1981e-04],
         [-2.3372e-04, -5.1335e-04],
         [-1.9068e-04, -5.6747e-04]],

        ...,

        [[-6.

  text_size = draw.textsize(text, font=font)



## functionName ##
##################
generate_noise

#################
## thefilename ##
#################
[[4]] generate_noise--velocity-noise2



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[0.2911, 0.7066],
        [0.3071, 0.7042],
        [0.3069, 0.7117],
        ...,
        [0.4394, 0.6096],
        [0.4078, 0.5866],
        [0.2218, 0.5933]])

#####################
## theVariableName ##
#####################
distance_to_upper_boundary

############
## header ##
############
[[16]] preprocess--distance_to_upper_boundary

##################
## functionName ##
##################
preprocess

#################
## thefilename ##
#################
[[16]] preprocess--distance_to_upper_boundary



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[[ 1.4786e-05,  1.3114e-04],
         [ 1.2092e-04,  2.6994e-04],
         [ 3.5233e-04,  3.1114e-04],
         [ 7.9732e-04,  2.7669e-04],
         [ 1.0591e-03,  2.9415e-04]],

        [[ 4.1247e-05,  7.7859e-05],
         [ 1.7748e-04,  1.4154e-04],
         [ 3.3499e-04,  7.5967e-05],
         [ 5.3938e-04, -3.2909e-04],
         [ 8.6591e-04, -7.3643e-04]],

        [[-2.2177e-05, -1.3857e-04],
         [-1.5059e-04, -4.7560e-04],
         [-3.3288e-04, -8.9541e-04],
         [-5.6660e-04, -1.4088e-03],
         [-7.5728e-04, -1.9762e-03]],

        ...,

        [[-6.4798e-05, -1.0856e-04],
         [-2.6348e-04, -2.7617e-04],
         [-6.6621e-04, -4.2853e-04],
         [-1.2203e-03, -2.3822e-04],
         [-1.7049e-03,  3.5818e-04]],

        [[ 1.2758e-04,  5.4823e-05],
         [ 4.4412e-04,  1.9563e-04],
         [ 9.5196e-04,  3.3760e-04],
         [ 1.5410e-03,  4.5016e-04],
         [ 2.1342e-03,  5.8409e-04]]

  text_size = draw.textsize(text, font=font)



[[17]] preprocess--distance_to_boundary1



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        ...,
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])

#####################
## theVariableName ##
#####################
distance_to_boundary

############
## header ##
############
[[18]] preprocess--distance_to_boundary2

##################
## functionName ##
##################
preprocess

#################
## thefilename ##
#################
[[18]] preprocess--distance_to_boundary2



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[[ 0.0000e+00,  0.0000e+00],
         [ 1.4786e-05,  1.3114e-04],
         [ 1.2092e-04,  2.6994e-04],
         [ 3.5233e-04,  3.1114e-04],
         [ 7.9732e-04,  2.7669e-04],
         [ 1.0591e-03,  2.9415e-04]],

        [[ 0.0000e+00,  0.0000e+00],
         [ 4.1247e-05,  7.7859e-05],
         [ 1.7748e-04,  1.4154e-04],
         [ 3.3499e-04,  7.5967e-05],
         [ 5.3938e-04, -3.2909e-04],
         [ 8.6591e-04, -7.3643e-04]],

        [[ 0.0000e+00,  0.0000e+00],
         [-2.2177e-05, -1.3857e-04],
         [-1.5059e-04, -4.7560e-04],
         [-3.3288e-04, -8.9541e-04],
         [-5.6660e-04, -1.4088e-03],
         [-7.5728e-04, -1.9762e-03]],

        ...,

        [[ 0.0000e+00,  0.0000e+00],
         [-6.4798e-05, -1.0856e-04],
         [-2.6348e-04, -2.7617e-04],
         [-6.6621e-04, -4.2853e-04],
         [-1.2203e-03, -2.3822e-04],
         [-1.7049e-03,  3.5818e-04]],

        [[ 0.0000e+00,  0.0000e+00],

  text_size = draw.textsize(text, font=font)




#################
## thefilename ##
#################
[[19]] dim-preprocess



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[-0.0043,  0.0096],
        [ 0.0000,  0.0000],
        [ 0.0083, -0.0109],
        ...,
        [ 0.0038,  0.0078],
        [-0.0119,  0.0087],
        [-0.0067,  0.0133]])

#####################
## theVariableName ##
#####################
edge_displacement

############
## header ##
############
[[20]] preprocess--edge_displacement1

##################
## functionName ##
##################
preprocess

#################
## thefilename ##
#################
[[20]] preprocess--edge_displacement1



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[-0.2856,  0.6395],
        [ 0.0000,  0.0000],
        [ 0.5504, -0.7261],
        ...,
        [ 0.2503,  0.5181],
        [-0.7958,  0.5788],
        [-0.4436,  0.8884]])

#####################
## theVariableName ##
#####################
edge_displacement
#################

############## theVariable ##

## header ###################

############
[[21]] preprocess--edge_displacement2

tensor([[[ 0.0000e+00,  0.0000e+00],
         [ 1.4786e-05,  1.3114e-04],
         [ 1.2092e-04,  2.6994e-04],
         [ 3.5233e-04,  3.1114e-04],
         [ 7.9732e-04,  2.7669e-04],
         [ 1.0591e-03,  2.9415e-04]],

        [[ 0.0000e+00,  0.0000e+00],
         [ 4.1247e-05,  7.7859e-05],
         [ 1.7748e-04,  1.4154e-04],
         [ 3.3499e-04,  7.5967e-05],
         [ 5.3938e-04, -3.2909e-04],
         [ 8.6591e-04, -7.3643e-04]],

        [[ 0.0000e+00,  0.0000e+00],
         [-2.2177e-05, -1.3857e-04],
         [-1.5059e-04, -

  text_size = draw.textsize(text, font=font)


[[7]] preprocess--position_noise

##################
## functionName ##
##################
preprocess

##################################
## theVariable ##

## thefilename ###################

#################tensor([[0.7003],
        [0.0000],
        [0.9111],
        ...,
        [0.5754],
        [0.9840],
        [0.9930]])

#####################
## theVariableName ##
#####################
edge_distance

############
## header ##
############
[[22]] preprocess--edge_distance


##################[[7]] preprocess--position_noise

## functionName ##

##################


  text_size = draw.textsize(text, font=font)


preprocess

#################
## thefilename ##
#################
[[22]] preprocess--edge_distance



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[ 0.0004, -0.0026],
        [ 0.0005, -0.0027],
        [ 0.0006, -0.0025],
        ...,
        [ 0.0002, -0.0053],
        [ 0.0007, -0.0054],
        [-0.0005, -0.0051]])

#####################
## theVariableName ##
#####################
last_velocity

############
## header ##
############
[[23]] preprocess--last_velocity

##################
## functionName ##
##################
preprocess

#################
## thefilename ##
#################
[[23]] preprocess--last_velocity



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[ 7.6717e-04, -2.4073e-03],
        [ 5.1075e-04, -2.4409e-03],
        [ 5.5736e-04, -2.3300e-03],
        ...,
        [ 7.8052e-05, -5.0206e-03],
        [ 2.3839e-04, -5.1835e-03],
        [-1.0848e-05, -5.1648e-03]])

#####################
## theVariableName ##
#####################
next_velocity

############
## header ##
############
[[24]] preprocess--next_velocity

##################
## functionName ##
##################
#################preprocess

## theVariable ##

##################################

## thefilename ##
#################tensor([[[0.8884, 0.3200],
         [0.8884, 0.3178],
         [0.8885, 0.3156],
         [0.8888, 0.3133],
         [0.8892, 0.3108],
         [0.8895, 0.3083]],

        [[0.8898, 0.3060],
         [0.8898, 0.3037],
         [0.8900, 0.3014],
         [0.8901, 0.2989],
         [0.8903, 0.2960],
         [0.8907, 0.2930]],

        [[0.8861, 0.3470],
         [0.8861, 0.3447],
   

  text_size = draw.textsize(text, font=font)


#####################
position_seq

############
## header ##
############
[[8]] preprocess--position_seq

##################
## functionName ##
##################
preprocess

#################
## thefilename ##
#################
[[8]] preprocess--position_seq



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
tensor([[ 3.8910e-04,  2.0823e-04],
        [ 1.9312e-05,  2.4661e-04],
        [-7.2122e-05,  1.2067e-04],
        ...,
        [-1.1379e-04,  2.4730e-04],
        [-4.1929e-04,  2.2164e-04],
        [ 4.8721e-04, -7.4744e-05]])

#####################
## theVariableName ##
#####################
acceleration

############
## header ##
############
[[25]] preprocess--acceleration1

##################
## functionName ##
##################
preprocess

#################
## thefilename ##
#################
[[25]] preprocess--acceleration1



  text_size = draw.textsize(text, font=font)


##################################

## theVariable #### theVariable ##

#################
#################
tensor([[ 1.2653,  0.6663],
        [ 0.0627,  0.7892],
        [-0.2346,  0.3860],
        ...,
        [-0.3701,  0.7914],
        [-1.3637,  0.7092],
        [ 1.5844, -0.2396]])tensor([[0.8895, 0.3083],
        [0.8907, 0.2930],
        [0.8854, 0.3334],
        [0.8912, 0.3165],
        [0.8891, 0.3474],
        [0.8923, 0.2890],
        [0.8899, 0.1777],
        [0.8907, 0.1762],
        [0.8711, 0.1290],
        [0.8863, 0.1490],
        [0.8792, 0.1491],
        [0.8838, 0.2919],
        [0.8905, 0.2111],
        [0.8927, 0.2373],
        [0.8925, 0.2331],
        [0.8935, 0.1685],
        [0.8778, 0.1094],
        [0.8627, 0.1053],
        [0.8738, 0.1470],
        [0.8874, 0.3021],
        [0.8871, 0.2579],
        [0.8901, 0.2575],
        [0.8898, 0.1803],
        [0.8821, 0.1616],
        [0.8924, 0.1785],
        [0.8791, 0.1438],
        [0.8775, 0.1074],
        [

  text_size = draw.textsize(text, font=font)


##################
## functionName ##
##################
preprocess

#################
## thefilename ##
#################
[[9]] preprocess--recent_position



  text_size = draw.textsize(text, font=font)


#################
## theVariable ##
#################
{'size': 378, 'type': 18903, 'pos': 37964766}

#####################
## theVariableName ##
#####################
window

############
## header ##
############
[[39]] OneStepDataset--get--window

#################
## theVariable ##
#################
378

#####################
## theVariableName ##
#####################
size

############
## header ##
############
[[40]] OneStepDataset--get--size

#################
## theVariable ##
#################
[5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[[10]] preprocess--velocity_seq

#################
## theVariable ##
#################
540

#####################
## theVariableName ##
#####################
n_particle

############
## header ##
############
[[11]] preprocess--n_particle

#################
## theVariable ##
#################
tensor([[ 70, 236,   0,  ..., 495, 463, 539],
        [  0,   0,   0,  ..., 539, 539, 539]])

#####################
## theVariableName ##
#####################
edge_index

############
## header ##
############
[[12]] preprocess--edge_index

#################
## theVariable ##
#################
tensor([[[ 1.9511e-01,  4.1043e+00],
         [ 1.7746e-01,  4.0459e+00],
         [ 1.9822e-02,  3.9373e+00],
         [-1.5991e-02,  3.9774e+00],
         [ 4.1378e-02,  3.8338e+00]],

        [[ 9.9161e-01,  2.3748e+00],
         [ 9.7192e-01,  2.6049e+00],
         [ 8.3678e-01,  2.7173e+00],
         [ 7.4283e-01,  2.5557e+00],
         [ 5.9551e-01,  2.6225e+00]],

        [[ 5.8002e-01,  2.8010e+00],

  text_size = draw.textsize(text, font=font)


############
[[2]] generate_noise--time_steps

#################
## theVariable ##
#################
tensor([[[ 9.0958e-05,  2.5717e-04],
         [ 1.5207e-04, -6.7838e-05],
         [ 8.2535e-05,  2.5835e-04],
         [ 1.4766e-04, -5.2654e-05],
         [-1.6054e-04,  1.3672e-04]],

        [[-7.6841e-05,  3.5972e-05],
         [ 8.1135e-05,  2.4659e-04],
         [-4.1790e-05,  4.3251e-05],
         [ 7.7678e-05,  9.2969e-05],
         [ 6.1541e-05,  6.4312e-05]],

        [[-7.7594e-05,  1.9398e-04],
         [ 7.1957e-05, -1.6674e-04],
         [-9.7380e-05,  1.4615e-05],
         [ 6.5730e-05, -3.1698e-04],
         [ 1.0832e-04, -1.4472e-04]],

        ...,

        [[ 1.3235e-04,  7.7824e-05],
         [ 1.0396e-05,  1.4078e-04],
         [-1.4538e-04, -8.4839e-05],
         [-4.8348e-05, -1.5104e-04],
         [-2.2140e-04,  1.6969e-04]],

        [[ 1.3709e-04, -2.7550e-05],
         [ 1.2144e-04,  1.9734e-05],
         [ 3.1758e-04, -2.1651e-04],
         [ 4.4721e-05,  1.

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)




[[44]] OneStepDataset--get--position_seq2

#################
## theVariable ##
#################
[[[0.6489846  0.1165855 ]
  [0.651161   0.11644513]
  [0.65333444 0.11630179]
  ...
  [0.6577065  0.11600553]
  [0.65989965 0.11586868]
  [0.66208565 0.11574224]]

 [[0.5720598  0.11664462]
  [0.57366675 0.11651514]
  [0.57527345 0.1163831 ]
  ...
  [0.57847667 0.11612815]
  [0.5800851  0.11600711]
  [0.58170706 0.11589061]]

 [[0.5778363  0.11624688]
  [0.57949096 0.11613748]
  [0.58114487 0.11602958]
  ...
  [0.58445114 0.11581698]
  [0.5861125  0.11571129]
  [0.58778685 0.11560898]]

 ...

 [[0.104117   0.10814847]
  [0.10415406 0.10790288]
  [0.10419458 0.10765178]
  ...
  [0.10428331 0.1071392 ]
  [0.10433215 0.10687867]
  [0.10438326 0.1066163 ]]

 [[0.312965   0.13451943]
  [0.31254798 0.13433726]
  [0.31213775 0.13416076]
  ...
  [0.31134215 0.13384034]
  [0.31095853 0.1336992 ]
  [0.31058493 0.13357061]]

 [[0.15806788 0.22495027]
  [0.15851454 0.224853  ]
  [0.15897936 0.22472036

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



tensor([[3.5757e-01, 7.8391e-01],
        [3.4502e-01, 7.8661e-01],
        [2.5702e-01, 7.9558e-01],
        [4.2978e-01, 7.8118e-01],
        [4.4368e-01, 7.8201e-01],
        [3.6260e-01, 7.9082e-01],
        [1.5267e-01, 7.9046e-01],
        [3.8190e-01, 7.8406e-01],
        [3.6213e-01, 7.8583e-01],
        [4.8556e-01, 7.8776e-01],
        [5.0327e-01, 7.8976e-01],
        [3.7898e-01, 7.8125e-01],
        [3.5099e-01, 7.9393e-01],
        [5.1231e-01, 7.8513e-01],
        [4.6496e-01, 7.9197e-01],
        [4.3465e-01, 7.9263e-01],
        [4.8265e-01, 7.8788e-01],
        [4.4527e-01, 7.8407e-01],
        [4.7398e-01, 7.8832e-01],
        [4.3214e-01, 7.8773e-01],
        [4.1263e-01, 7.8168e-01],
        [4.1349e-01, 7.9154e-01],
        [4.3003e-01, 7.8642e-01],
        [5.3933e-01, 7.9612e-01],
        [5.3805e-01, 7.9466e-01],
        [5.2970e-01, 7.9366e-01],
        [3.8273e-01, 7.9101e-01],
        [4.1318e-01, 7.9006e-01],
        [2.5869e-01, 7.9707e-01],
        [4.041

  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)



## theVariableName ##
#####################
distance_to_boundary

tensor([[ 1.5298,  0.5248,  0.6871,  ..., -0.1136, -1.0352, -0.7545],
        [ 1.5298,  0.5248,  0.6871,  ...,  1.1272, -1.4832, -1.2998],
        [ 1.5298,  0.5248,  0.6871,  ...,  1.5172, -1.2316, -1.2507],
        ...,
        [ 0.7198,  1.7564, -0.4489,  ...,  1.3925, -1.1657, -1.1484],
        [ 0.7198,  1.7564, -0.4489,  ...,  1.7560, -1.3068, -1.3833],
        [ 0.7198,  1.7564, -0.4489,  ..., -0.1026, -0.5318, -0.7208]],
       device='cuda:0', grad_fn=<CatBackward0>)
############

#####################
## theVariableName ##
## header #######################

x

########################

## header ##
############
[[18]] preprocess--distance_to_boundary2[[75]] InteractionNetwork--message--x1



#################
## theVariable ###################
## theVariable ##

#################
#################
2

#####################
## theVariableName ##
#####################
dimtensor([[ 0.9767,  1.0508,  0.4894,  ...,

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)




############
[[74]] InteractionNetwork--forward--node_out2
velocity_noise

#################
## theVariable ##

#################
############
## header ##
############
[[3]] generate_noise--velocity-noise1

#################
## theVariable ##
#################
tensor([[ 1.5410, -2.1603,  2.4003,  ...,  6.9187, -2.8205,  1.7909],
        [ 1.8032, -1.7260,  2.2857,  ...,  7.4635, -3.0452,  1.6968],
        [ 1.2507, -1.5180,  2.1233,  ...,  6.5638, -2.8096,  2.1182],
        ...,
        [ 1.2497, -1.9512,  1.8434,  ...,  6.7896, -3.3808,  1.6599],
        [ 1.7878, -2.3987,  2.1432,  ...,  7.1810, -3.4267,  2.0000],
        [ 2.0596, -1.3140,  0.9780,  ...,  6.7081, -1.7179,  2.0261]],
       device='cuda:0', grad_fn=<AddBackward0>)

tensor([[[ 1.5270e-05,  2.6688e-04],
         [-9.2950e-05,  3.6848e-04],
         [-7.0711e-05,  2.8998e-04],
         [-1.2911e-04,  4.2283e-04],
         [-3.4193e-04,  3.8514e-04]],

        [[-4.4846e-05,  2.0355e-04],
         [-4.8840e-07,  1.2225

  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
Epoch 0:   0%|   | 0/24875 [00:07<?, ?it/s, loss=3.33, avg_loss=3.33, lr=0.0001]


[[3]] generate_noise--velocity-noise1#################


## theVariable ##
#################
#################
## theVariable ##
1
#################

#####################
## theVariableName ##
#####################
tensor([[[ 1.3893e-04,  4.5594e-05],
         [ 5.0974e-05, -1.2769e-04],
         [ 1.9151e-04,  2.6238e-04],
         [ 1.2286e-04,  2.6550e-04],
         [ 8.1659e-05,  3.1011e-04]],

        [[ 9.3464e-05, -9.4696e-05],
         [ 2.3110e-05, -1.1285e-04],
         [ 4.6341e-05, -1.2963e-04],
         [ 9.7379e-05, -3.6434e-04],
         [ 2.9983e-04, -3.1934e-04]],

        [[ 9.7919e-05, -2.1390e-04],
         [ 1.0846e-04, -1.9615e-04],
         [ 3.9862e-05, -3.6081e-04],
         [ 2.6158e-05, -2.8955e-04],
         [-3.8164e-05, -2.2467e-04]],

        ...,

        [[ 1.6127e-04,  1.8907e-04],
         [ 1.7530e-04,  1.0848e-04],
         [-3.8174e-05,  1.3798e-04],
         [-2.0555e-04, -1.4584e-05],
         [-3.8267e-04, -6.4578e-06]],

        [[-7.5526e-05

  text_size = draw.textsize(text, font=font)
  text_size = draw.textsize(text, font=font)
Epoch 0:   0%| | 1/24875 [00:07<53:20:20,  7.72s/it, loss=3.33, avg_loss=3.33, l


#################
#################
## theVariable ##
#################
tensor([[[ 0.0000e+00,  0.0000e+00],
         [ 1.3893e-04,  4.5594e-05],
         [ 1.8991e-04, -8.2092e-05],
         [ 3.8142e-04,  1.8029e-04],
         [ 5.0429e-04,  4.4579e-04],
         [ 5.8595e-04,  7.5590e-04]],

        [[ 0.0000e+00,  0.0000e+00],
         [ 9.3464e-05, -9.4696e-05],
         [ 1.1657e-04, -2.0754e-04],
         [ 1.6291e-04, -3.3717e-04],
         [ 2.6029e-04, -7.0151e-04],
         [ 5.6013e-04, -1.0209e-03]],

        [[ 0.0000e+00,  0.0000e+00],
         [ 9.7919e-05, -2.1390e-04],
         [ 2.0638e-04, -4.1005e-04],
         [ 2.4624e-04, -7.7086e-04],
         [ 2.7240e-04, -1.0604e-03],
         [ 2.3424e-04, -1.2851e-03]],

        ...,

        [[ 0.0000e+00,  0.0000e+00],
         [ 1.6127e-04,  1.8907e-04],
         [ 3.3657e-04,  2.9755e-04],
         [ 2.9839e-04,  4.3554e-04],
         [ 9.2848e-05,  4.2095e-04],
         [-2.8982e-04,  4.1449e-04]],

        [[ 0.0000

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [None]:
# Save trained model
model_save_path = "simulator_model_justoneepoch.pth"


torch.save(simulator.state_dict(), model_save_path)


print(f"Model saved to {model_save_path}")

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

# visualize loss curve
plt.figure()
plt.plot(*zip(*train_loss_list), label="train")
plt.plot(*zip(*eval_loss_list), label="valid")
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title('Loss')
plt.legend()
plt.show()

• Load checkpoint trained by us. 

• Do **not** run this block if you have trained your model in previous block.

In [None]:
################
## LOAD MODEL ##
################
simulator = LearnedSimulator()


simulator = simulator.cuda()


#!wget -O temp/models/WaterDrop_checkpoint.pt https://storage.googleapis.com/cs224w_course_project_dataset/Checkpoints/WaterDrop_checkpoint.pt
# checkpoint = torch.load("simulator_model_20epoch.pth")
# simulator.load_state_dict(checkpoint["model"])
# model_save_path = "simulator_model_20epoch.pth"
model_save_path = "simulator_model_justoneepoch.pth"

simulator.load_state_dict(torch.load(model_save_path))

## Visualization

Since video is 1000 frames long, it might take a few minutes to rollout.

In [None]:
rollout_dataset = RolloutDataset(data_path, "valid")


simulator.eval()


rollout_data = rollout_dataset[0]


rollout_out = rollout(simulator, rollout_data, rollout_dataset.metadata, params["noise"])


rollout_out = rollout_out.permute(1, 0, 2)



In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

TYPE_TO_COLOR = {
    3: "black",
    0: "green",
    7: "magenta",
    6: "gold",
    5: "blue",
}


def visualize_prepare(ax, particle_type, position, metadata):
    bounds = metadata["bounds"]
    ax.set_xlim(bounds[0][0], bounds[0][1])
    ax.set_ylim(bounds[1][0], bounds[1][1])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect(1.0)
    points = {type_: ax.plot([], [], "o", ms=2, color=color)[0] for type_, color in TYPE_TO_COLOR.items()}
    return ax, position, points


def visualize_pair(particle_type, position_pred, position_gt, metadata):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    plot_info = [
        visualize_prepare(axes[0], particle_type, position_gt, metadata),
        visualize_prepare(axes[1], particle_type, position_pred, metadata),
    ]
    axes[0].set_title("Ground truth")
    axes[1].set_title("Prediction")

    plt.close()

    def update(step_i):
        outputs = []


        for _, position, points in plot_info:


            for type_, line in points.items():
                mask = particle_type == type_


                line.set_data(position[step_i, mask, 0], position[step_i, mask, 1])


            outputs.append(line)
        return outputs

    return animation.FuncAnimation(fig, update, frames=np.arange(0, position_gt.size(0)), interval=10, blit=True)

anim = visualize_pair(rollout_data["particle_type"], rollout_out, rollout_data["position"], rollout_dataset.metadata)
HTML(anim.to_html5_video())

## Conclusion

• Hope this Colab is helpful for you to understand how to apply GNN in a real-world application like simulating complex physics! 

• If you're interested in technical details, read [medium post](https://) or see [original paper](https://arxiv.org/abs/2002.09405) by DeepMind. 

• Thanks for spending your time with us!

In [None]:
# ####################
# ## IMAGE STITCHER ## 
# ####################
# import os
# from PIL import Image

# def stitch_images_bfdh(image_folder, output_path):
#     # Load all PNG images from the folder
#     images = []
#     for filename in os.listdir(image_folder):
#         if filename.endswith('.png'):
#             img_path = os.path.join(image_folder, filename)
#             img = Image.open(img_path)
#             images.append(img)

#     # Sort images by width in descending order for better packing
#     images.sort(key=lambda img: img.width, reverse=True)

#     # Initialize variables for row-based packing
#     rows = []
#     current_row = []
#     current_width = 0
#     max_height_in_row = 0
#     max_canvas_width = 0

#     max_width = max(img.width for img in images)  # Maximum width of any image
#     max_total_height = 0

#     # Group images into rows based on width (bin-packing approach)
#     for img in images:
#         if current_width + img.width <= max_width:
#             # Add image to the current row
#             current_row.append(img)
#             current_width += img.width
#             max_height_in_row = max(max_height_in_row, img.height)
#         else:
#             # Move to the next row
#             rows.append((current_row, current_width, max_height_in_row))
#             max_canvas_width = max(max_canvas_width, current_width)
#             max_total_height += max_height_in_row
            
#             # Reset for the new row
#             current_row = [img]
#             current_width = img.width
#             max_height_in_row = img.height

#     # Add the last row
#     if current_row:
#         rows.append((current_row, current_width, max_height_in_row))
#         max_canvas_width = max(max_canvas_width, current_width)
#         max_total_height += max_height_in_row

#     # Create a new blank canvas large enough to hold all rows
#     stitched_image = Image.new('RGBA', (max_canvas_width, max_total_height))

#     # Variable to keep track of current y-position (vertical stacking)
#     y_offset = 0

#     # Stitch images row by row
#     for row, row_width, row_height in rows:
#         x_offset = 0
#         for img in row:
#             # Paste each image into its row
#             stitched_image.paste(img, (x_offset, y_offset))
#             x_offset += img.width  # Move to the right for the next image
#         y_offset += row_height  # Move down for the next row

#     # Save the stitched image
#     stitched_image.save(output_path)
#     print(f"Stitched image saved as {output_path}")
    

# # Run stitch_images_bfdh for each item in the folders_created list
# for folder in folders_created:
#     image_folder = os.path.expanduser(os.path.join('~/Desktop/GNN/outputpng/', folder))
#     output_path = os.path.expanduser(os.path.join('~/Desktop/GNN/outputpng/', folder, f"{folder}.png"))
#     stitch_images_bfdh(image_folder, output_path)  


In [None]:
import os
from PIL import Image

def stitch_images_bfdh(image_folder, output_path, spacing=0):
    # Load all PNG images from the folder
    images = []
    for filename in os.listdir(image_folder):
        if filename.endswith('.png'):
            img_path = os.path.join(image_folder, filename)
            img = Image.open(img_path)
            images.append(img)

    # Sort images by width in descending order for better packing
    images.sort(key=lambda img: img.width, reverse=True)

    # Initialize variables for row-based packing
    rows = []
    current_row = []
    current_width = 0
    max_height_in_row = 0
    max_canvas_width = 0

    max_width = max(img.width for img in images)  # Maximum width of any image
    max_total_height = 0

    # Group images into rows based on width (bin-packing approach)
    for img in images:
        if current_width + img.width + (len(current_row) * spacing) <= max_width:
            # Add image to the current row
            current_row.append(img)
            current_width += img.width
            max_height_in_row = max(max_height_in_row, img.height)
        else:
            # Move to the next row
            rows.append((current_row, current_width, max_height_in_row))
            max_canvas_width = max(max_canvas_width, current_width + (len(current_row) - 1) * spacing)
            max_total_height += max_height_in_row + spacing

            # Reset for the new row
            current_row = [img]
            current_width = img.width
            max_height_in_row = img.height

    # Add the last row
    if current_row:
        rows.append((current_row, current_width, max_height_in_row))
        max_canvas_width = max(max_canvas_width, current_width + (len(current_row) - 1) * spacing)
        max_total_height += max_height_in_row  # Don't add spacing after the last row

    # Create a new blank canvas large enough to hold all rows
    stitched_image = Image.new('RGBA', (max_canvas_width, max_total_height))

    # Variable to keep track of current y-position (vertical stacking)
    y_offset = 0

    # Stitch images row by row
    for row, row_width, row_height in rows:
        x_offset = 0
        for img in row:
            # Paste each image into its row with spacing
            stitched_image.paste(img, (x_offset, y_offset))
            x_offset += img.width + spacing  # Move to the right for the next image
        y_offset += row_height + spacing  # Move down for the next row with spacing

    # Save the stitched image
    stitched_image.save(output_path)
    print(f"Stitched image saved as {output_path}")

# Run stitch_images_bfdh for each item in the folders_created list
for folder in folders_created:
    image_folder = os.path.expanduser(os.path.join('~/Desktop/GNN/outputpng/', folder))
    output_path = os.path.expanduser(os.path.join('~/Desktop/GNN/outputpng/', folder, f"{folder}.png"))
    stitch_images_bfdh(image_folder, output_path, spacing=10)

# outputfile = r"C:\Users\Admin\Desktop\GNNoutputpng\OneStepDataset\outputpng.png"
# # if outputfile exists then delete it
# if os.path.exists(outputfile):
#     os.remove(outputfile)  # Remove the file
#     print(f"File '{outputfile}' has been removed.")
# else:
#     print(f"File '{outputfile}' does not exist.")

# image_folder = r'C:\Users\Admin\Desktop\GNNoutputpng\OneStepDataset'
# stitch_images_bfdh(image_folder, outputfile, spacing=10)  # Example with 10px spacing
