<a href="https://colab.research.google.com/github/steve859/traffic_flow_prediction/blob/sensor_dataset_Duy/model/notebooks/01_pretrain_STGCN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Set Up

In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
if os.path.exists("/content/traffic_flow_prediction"):
    !rm -rf /content/traffic_flow_prediction

!git clone https://github.com/steve859/traffic_flow_prediction.git
%cd traffic_flow_prediction
!git checkout sensor_dataset_Duy
# Setup dataset paths
METR_H5 = "/content/drive/MyDrive/Project Data/Dataset/metr-la.h5"
ADJ_PKL = "/content/drive/MyDrive/Project Data/Dataset/adj_mx.pkl"

print("SETUP DONE")


Mounted at /content/drive
Cloning into 'traffic_flow_prediction'...
remote: Enumerating objects: 21, done.[K
remote: Counting objects: 100% (21/21), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 21 (delta 3), reused 12 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (21/21), 7.46 KiB | 7.46 MiB/s, done.
Resolving deltas: 100% (3/3), done.
/content/traffic_flow_prediction
Branch 'sensor_dataset_Duy' set up to track remote branch 'sensor_dataset_Duy' from 'origin'.
Switched to a new branch 'sensor_dataset_Duy'
SETUP DONE


In [3]:
import os, h5py, numpy as np, pprint

METR_H5 = "/content/drive/MyDrive/Project Data/Dataset/metr-la.h5"  # ensure path correct

print("Check file exists:", os.path.exists(METR_H5), METR_H5)
if not os.path.exists(METR_H5):
    raise FileNotFoundError(f"File not found: {METR_H5}. Kiểm tra đường dẫn trong Drive.")

# Helper: recursively list HDF5 content
def print_h5_tree(f):
    def visitor(name, obj):
        try:
            typ = 'Group' if isinstance(obj, h5py.Group) else 'Dataset'
            shape = getattr(obj, 'shape', None)
            dtype = getattr(obj, 'dtype', None)
            print(f"{name}  |  {typ}  |  shape={shape}  |  dtype={dtype}")
        except Exception as e:
            print(f"{name}  |  <error getting info: {e}>")
    f.visititems(visitor)

with h5py.File(METR_H5, 'r') as f:
    print("HDF5 file keys / tree:")
    print_h5_tree(f)
    # gather candidate datasets
    candidates = []
    def collect(name, obj):
        if isinstance(obj, h5py.Dataset):
            shape = getattr(obj, 'shape', ())
            if len(shape) >= 2:
                candidates.append((name, shape, obj.dtype))
    f.visititems(collect)

    if not candidates:
        print("\nKhông tìm thấy dataset 2D trong file. In root keys:")
        pprint.pprint(list(f.keys()))
    else:
        print("\nCandidate numeric datasets (name, shape, dtype):")
        pprint.pprint(candidates)

        # Heuristic: prefer dataset named 'speed' or with shape[1] between 100..500 (METR has 207 nodes)
        chosen = None
        for name, shape, dtype in candidates:
            lower_name = name.lower()
            if 'speed' in lower_name or 'data' in lower_name or 'traffic' in lower_name:
                chosen = name
                break
        if chosen is None:
            # pick dataset with second-dim between 100 and 500 if possible
            for name, shape, dtype in candidates:
                if len(shape) >= 2 and 100 <= shape[1] <= 500:
                    chosen = name
                    break
        if chosen is None:
            # fallback to first candidate
            chosen = candidates[0][0]

        print(f"\n--> Loading dataset: {chosen}")
        data = f[chosen][:]
        print("Loaded data shape:", data.shape, "dtype:", data.dtype)




Check file exists: True /content/drive/MyDrive/Project Data/Dataset/metr-la.h5
HDF5 file keys / tree:
df  |  Group  |  shape=None  |  dtype=None
df/axis0  |  Dataset  |  shape=(207,)  |  dtype=|S6
df/axis1  |  Dataset  |  shape=(34272,)  |  dtype=int64
df/block0_items  |  Dataset  |  shape=(207,)  |  dtype=|S6
df/block0_values  |  Dataset  |  shape=(34272, 207)  |  dtype=float64

Candidate numeric datasets (name, shape, dtype):
[('df/block0_values', (34272, 207), dtype('<f8'))]

--> Loading dataset: df/block0_values
Loaded data shape: (34272, 207) dtype: float64


In [4]:
data = data.astype(np.float32)
print("Final data shape:", data.shape)
print("Min:", data.min(), "Max:", data.max())


Final data shape: (34272, 207)
Min: 0.0 Max: 70.0


CLEAN MISSING VALUES

In [5]:
import pandas as pd
import numpy as np

df = pd.DataFrame(data)

# 1) Chuyển 0 thành NaN (METR-LA coi 0 = missing)
df.replace(0, np.nan, inplace=True)

# 2) Điền bằng forward-fill -> backward-fill
df = df.fillna(method='ffill').fillna(method='bfill')

# 3) Interpolate cho mượt
df = df.interpolate()

clean_data = df.values.astype(np.float32)

print("Cleaned data shape:", clean_data.shape)
print("Remaining NaN:", np.isnan(clean_data).sum())


  df = df.fillna(method='ffill').fillna(method='bfill')


Cleaned data shape: (34272, 207)
Remaining NaN: 0
