In [1]:
import pandas as pd
import numpy as np
import json
import urllib

In [2]:
# 일반적인 모듈 
import numpy as np
import matplotlib.pyplot as plt 
import networkx as nx 
from tqdm import tqdm 

# 파이토치 관련 
import torch
import torch.nn.functional as F

# PyG 관련 
from torch_geometric.data import Data ## Data: 그래프자료형을 만드는 클래스

# STGCN 관련 
import torch_geometric_temporal
from torch_geometric_temporal.nn.recurrent import GConvGRU
from torch_geometric_temporal.signal import temporal_signal_split 

In [3]:
url = "https://graphmining.ai/temporal_datasets/windmill_output.json"
data_dict = json.loads(urllib.request.urlopen(url).read())
data_dict.keys()

dict_keys(['block', 'time_periods', 'weights', 'edges'])

In [7]:
type(data_dict['weights']) # list 형태로 들어가있음.

list

In [10]:
type(data_dict['block']) # list 형태로 들어가 있음.

list

In [24]:
np.array(data_dict['block'])

array([[0.1287, 0.1167, 0.0812, ..., 0.027 , 0.0201, 0.0228],
       [0.0817, 0.1078, 0.1054, ..., 0.0439, 0.0262, 0.021 ],
       [0.9418, 0.9589, 0.9447, ..., 0.7815, 0.8621, 0.2498],
       ...,
       [0.1391, 0.1829, 0.1383, ..., 0.0359, 0.0335, 0.0219],
       [0.5972, 0.6057, 0.6123, ..., 0.2606, 0.4203, 0.1954],
       [0.1298, 0.1504, 0.1442, ..., 0.0256, 0.093 , 0.0158]])

In [22]:
np.array(data_dict['block']).shape

(17472, 319)

In [16]:
type(data_dict['block'])

list

In [21]:
np.array(data_dict['weights']).shape

(101761,)

In [18]:
# data_dict['block']

In [25]:
# data_dict['edges']

In [None]:
import json
import urllib
import numpy as np
from ..signal import StaticGraphTemporalSignal


class WindmillOutputLargeDatasetLoader(object):
    """Hourly energy output of windmills from a European country
    for more than 2 years. Vertices represent 319 windmills and
    weighted edges describe the strength of relationships. The target
    variable allows for regression tasks.
    """

    def __init__(self):
        self._read_web_data()

    def _read_web_data(self):
        url = "https://graphmining.ai/temporal_datasets/windmill_output.json"
        self._dataset = json.loads(urllib.request.urlopen(url).read().decode())

    def _get_edges(self):
        self._edges = np.array(self._dataset["edges"]).T

    def _get_edge_weights(self):
        self._edge_weights = np.array(self._dataset["weights"]).T

    def _get_targets_and_features(self):
        stacked_target = np.stack(self._dataset["block"])
        standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / (
            np.std(stacked_target, axis=0) + 10 ** -10
        )
        self.features = [
            standardized_target[i : i + self.lags, :].T
            for i in range(standardized_target.shape[0] - self.lags)
        ]
        self.targets = [
            standardized_target[i + self.lags, :].T
            for i in range(standardized_target.shape[0] - self.lags)
        ]

    def get_dataset(self, lags: int = 8) -> StaticGraphTemporalSignal:
        """Returning the Windmill Output data iterator.

        Args types:
            * **lags** *(int)* - The number of time lags.
        Return types:
            * **dataset** *(StaticGraphTemporalSignal)* - The Windmill Output dataset.
        """
        self.lags = lags
        self._get_edges()
        self._get_edge_weights()
        self._get_targets_and_features()
        dataset = StaticGraphTemporalSignal(
            self._edges, self._edge_weights, self.features, self.targets
        )
        return dataset

In [37]:
torch_geometric_temporal.signal.StaticGraphTemporalSignal 

In [45]:
import json
import urllib
import numpy as np
from torch_geometric_temporal.signal.static_graph_temporal_signal import StaticGraphTemporalSignal

In [38]:
class SolarDatasetLoader(object):
    """Hourly energy output of windmills from a European country
    for more than 2 years. Vertices represent 319 windmills and
    weighted edges describe the strength of relationships. The target
    variable allows for regression tasks.
    """

    def __init__(self):
        self._read_web_data()

    def _read_web_data(self):
        url = "https://raw.githubusercontent.com/pinkocto/noteda/main/posts/SOLAR/data/solar.json"
        self._dataset = json.loads(urllib.request.urlopen(url).read().decode())

    def _get_edges(self):
        self._edges = np.array(self._dataset["edges"]).T

    def _get_edge_weights(self):
        self._edge_weights = np.array(self._dataset["weights"]).T

    def _get_targets_and_features(self):
        stacked_target = np.stack(self._dataset["FX"])
        self.features = [
            stacked_target[i : i + self.lags, :].T
            for i in range(stacked_target.shape[0] - self.lags)
        ]
        self.targets = [
            stacked_target[i + self.lags, :].T
            for i in range(stacked_target.shape[0] - self.lags)
        ]

    def get_dataset(self, lags: int = 4) -> StaticGraphTemporalSignal:
        """Returning the Windmill Output data iterator.

        Args types:
            * **lags** *(int)* - The number of time lags.
        Return types:
            * **dataset** *(StaticGraphTemporalSignal)* - The Windmill Output dataset.
        """
        self.lags = lags
        self._get_edges()
        self._get_edge_weights()
        self._get_targets_and_features()
        dataset = StaticGraphTemporalSignal(
            self._edges, self._edge_weights, self.features, self.targets
        )
        return dataset

In [41]:
loader = SolarDatasetLoader()
dataset = loader.get_dataset(lags=4)
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)

In [44]:
dataset

<torch_geometric_temporal.signal.static_graph_temporal_signal.StaticGraphTemporalSignal at 0x7f82258d5f40>

In [None]:
model = RecurrentGCN(node_features=4, filters=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()

for epoch in tqdm(range(50)):
    for t, snapshot in enumerate(train_dataset):
        yt_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        cost = torch.mean((yt_hat-snapshot.y)**2)
        cost.backward()
        optimizer.step()
        optimizer.zero_grad()

In [46]:
import mysolar