# 🔥Causal Graph Neural Networks for Wildfire Danger Prediction🔥
Re-implementation of original work by Zhao et al.(2024) (https://arxiv.org/abs/2403.08414)

IDL S25 Group 23: Wenting Yue, Wenyu Liu, Youyou Huang (Group 23)

## Retrieve files from github repository
If `only notebook` is downloaded locally

In [1]:
import os
import sys

# Get the current working directory
print("Current working directory:", os.getcwd())
# repo = "https://github.com/youyouh511/11785_IDL_S25_Final-Project.git"
# !git clone {repo}
!git pull

Current working directory: /home/ubuntu/11785_IDL_S25_Final-Project
Already up to date.


# Set up

## Environment

Environment setup
```bash
conda env create -f env.yml
```

Activate environment and check device
```bash
conda activate idl
nvidia-smi
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
```

# Imports

In [2]:
from data import (
    JsonFireDataset
)
from model import (
    AdjacencyMatrix,
    TemporalLSTM,
    CausalGNN
)
# from train import (
    
# )
# from utils import (
# )


import numpy as np
import tqdm
import matplotlib.pyplot as plt
import json
import zipfile
import torch
import requests
import xarray as xr
import yaml
from torchinfo import summary
import shutil
import wandb
import time
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


Using device: cuda


# Config

In [3]:
%%writefile config.yaml

###### Dataset
    root                    : "./data"
    train_json_path         : "train.json"
    val_json_path           : "val.json"
    test_json_path          : "test.json"
    matrix_json_path        : "matrix.json"
    subset                  : 1.0
    batch_size              : 128
    NUM_WORKERS             : 4

    ### Target threshold
    fire_threshold          : 10


###### Model
    ### Adjacency matrix
    local_var_lag           : 8
    oci_var_lag             : 31
    max_lag                 : 312
    independence_test       : "ParCorr"
    tau_max                 : 186
    pc_alpha                : 0.05
    mask_target             : True

    ### Temporal LSTM
    lstm_layer              : 1
    hidden_dim              : 256

    ### GNN
    gnn_nodes               : 7


###### Training
    epochs                  : 30

    lr                      : 1.0e-5
    min_lr                  : 1.0e-9
    
    optimizer               : "Adam"
    betas                   : [0.9, 0.999]
    eps                     : 1.0e-8
    weight_decay            : 5.0e-6

    lr_scheduler            : "CosineAnnealingLR"
    patience                : 10
    early_stop              : True

    save_model              : True
    save_model_path         : "./checkpoints"
    load_model              : False
    load_model_path         : "./checkpoints/best.pth"
    wandb_log               : True
    wandb_project           : "IDL_Final"
    wandb_run_id            : None

Overwriting config.yaml


In [4]:
with open('config.yaml', 'r') as file:
    config = yaml.safe_load(file)

config

{'root': './data',
 'train_json_path': 'train.json',
 'val_json_path': 'val.json',
 'test_json_path': 'test.json',
 'matrix_json_path': 'matrix.json',
 'subset': 1.0,
 'batch_size': 128,
 'NUM_WORKERS': 4,
 'fire_threshold': 10,
 'local_var_lag': 8,
 'oci_var_lag': 31,
 'max_lag': 312,
 'independence_test': 'ParCorr',
 'tau_max': 186,
 'pc_alpha': 0.05,
 'mask_target': True,
 'lstm_layer': 1,
 'hidden_dim': 256,
 'gnn_nodes': 7,
 'epochs': 30,
 'lr': 1e-05,
 'min_lr': 1e-09,
 'optimizer': 'Adam',
 'betas': [0.9, 0.999],
 'eps': 1e-08,
 'weight_decay': 5e-06,
 'lr_scheduler': 'CosineAnnealingLR',
 'patience': 10,
 'early_stop': True,
 'save_model': True,
 'save_model_path': './checkpoints',
 'load_model': False,
 'load_model_path': './checkpoints/best.pth',
 'wandb_log': True,
 'wandb_project': 'IDL_Final',
 'wandb_run_id': 'None'}

# Data Retrieval & Pre-process
Refer to data_preprocessing.ipynb

# Datasets

In [5]:
local_keys  = ["T2M_MEAN","TP","VPD"]
oci_keys    = ["OCI_NAO", "OCI_NINA34_ANOM", "OCI_AO"]

train_ds = JsonFireDataset(
    json_path   = "data/train.json",
    local_keys  = local_keys,
    oci_keys    = oci_keys
)
val_ds = JsonFireDataset(
    json_path   = "data/val.json",
    local_keys  = local_keys,
    oci_keys    = oci_keys
)
test_ds = JsonFireDataset(
    json_path   = "data/test.json",
    local_keys  = local_keys,
    oci_keys    = oci_keys
)

# Model

## Adjacency Matrix

In [6]:
## Filter a subset of matrix samples from matrix.json
subset_frac         = 0.1
rng_seed            = 11785
input_matrix_file   = "data/matrix.json"
subset_matrix_file  = f"data/matrix_{subset_frac}_{rng_seed}.json"

AdjacencyMatrix.sample_json_file(
    subset_frac  = subset_frac,
    rng_seed     = rng_seed,
    input_path   = input_matrix_file,
    output_path  = subset_matrix_file,
)

In [None]:
matrix_builder = AdjacencyMatrix(subset_matrix_file, independence_test="ParCorr",tau_max=23)
output, varlist = matrix_builder.gen_adj_matrix("val", "mean", True, "target", True)
print(varlist)


### Panel 0 last 5 rows ###
     OCI_AO  OCI_NAO  OCI_NINA34_ANOM    T2M_MEAN         TP       VPD  target
961   0.099    -0.33            -0.94  256.633453   3.656824  0.403680     0.0
962   0.201     0.18            -1.06  256.120178   1.920475  0.314687     0.0
963   0.201     0.18            -1.06  267.182312  12.704536  0.497526     0.0
964   0.201     0.18            -1.06  261.647827   4.729585  0.375239     0.0
965   0.201     0.18            -1.06  259.047424   4.493174  0.267562     0.0

### Panel 1 last 5 rows ###
     OCI_AO  OCI_NAO  OCI_NINA34_ANOM    T2M_MEAN         TP       VPD  target
961   0.099    -0.33            -0.94  256.733704   5.492703  0.360529     0.0
962   0.201     0.18            -1.06  256.949310   1.383027  0.333360     0.0
963   0.201     0.18            -1.06  267.576721  14.041216  0.580773     0.0
964   0.201     0.18            -1.06  262.697357   4.590917  0.516119     0.0
965   0.201     0.18            -1.06  260.237701   1.244266  0.346518   

## Causal GNN

In [None]:
model = CausalGNN(
    adj_matrix=norm_matrix,
    num_nodes=config['gnn_nodes'],
    hidden_dim=config['hidden_dim']
).to(device)

model_stats = summary(model)