# Create dataset for STGCN Ver1

## ref : [***STGCN튜토리얼***](https://miruetoto.github.io/yechan3/posts/3_Researches/ITSTGCN/2022-12-29-STGCN-tutorial.html)

In [1]:
# !pip install torch-geometric
# !pip install torch-geometric-temporal

## import

`-` 필요한 패키지 임포트

In [1]:
import pandas as pd
import numpy as np
import json
import urllib

In [2]:
# 일반적인 모듈 
import numpy as np
import matplotlib.pyplot as plt 
import networkx as nx 
from tqdm import tqdm 

# 파이토치 관련 
import torch
import torch.nn.functional as F

# PyG 관련 
from torch_geometric.data import Data

# STGCN 관련 
import torch_geometric_temporal
from torch_geometric_temporal.nn.recurrent import GConvGRU
from torch_geometric_temporal.signal import temporal_signal_split 

- `tqdm`: for문의 진행상태를 확인하기 위한 패키지
- `networkx`: 그래프 시그널 시각화를 위한 모듈
- `torch` : 파이토치 (STGCN은 파이토치 기반으로 만들어짐) 모듈
- `torch.nn.functional`: relu 등의 활성화함수를 불러오기 위한 모듈
- `Data` : 그래프 자료를 만들기 위한 클래스
- `GConvGRU` : STGCN layer를 만드는 클래스
- `temporal_signal_split` : STGCN dataset을 train/test 형태로 분리하는 기능이 있는 "함수"

`-` STGCN 학습을 위한 클래스 선언

In [3]:
class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features, filters):
        super(RecurrentGCN, self).__init__()
        self.recurrent = GConvGRU(node_features, filters, 2)
        self.linear = torch.nn.Linear(filters, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h

## 데이터셋 만들기

In [4]:
import pandas as pd
import numpy as np

### `node id` > edges > weight > feature

In [62]:
url = 'https://raw.githubusercontent.com/pinkocto/noteda/main/posts/SOLAR/data/df_new.csv'
df = pd.read_csv(url)

In [103]:
url2 = 'https://raw.githubusercontent.com/pinkocto/noteda/main/posts/SOLAR/data/weight.csv'
df2 = pd.read_csv(url2)

In [70]:
df = df.iloc[:,1:]

In [71]:
node_list =(df.columns).tolist()
node_ids = {node : i for i, node in enumerate(node_list)}
node_ids

{'북춘천': 0,
 '철원': 1,
 '대관령': 2,
 '춘천': 3,
 '백령도': 4,
 '북강릉': 5,
 '강릉': 6,
 '서울': 7,
 '인천': 8,
 '원주': 9,
 '울릉도': 10,
 '수원': 11,
 '서산': 12,
 '청주': 13,
 '대전': 14,
 '추풍령': 15,
 '안동': 16,
 '포항': 17,
 '대구': 18,
 '전주': 19,
 '창원': 20,
 '광주': 21,
 '부산': 22,
 '목포': 23,
 '여수': 24,
 '흑산도': 25,
 '고창': 26,
 '홍성': 27,
 '제주': 28,
 '고산': 29,
 '진주': 30,
 '고창군': 31,
 '영광군': 32,
 '김해시': 33,
 '순창군': 34,
 '북창원': 35,
 '양산시': 36,
 '보성군': 37,
 '강진군': 38,
 '의령군': 39,
 '함양군': 40,
 '광양시': 41,
 '청송군': 42,
 '경주시': 43}

In [117]:
edges = []    
 
for i in range(44):
    for j in range(44):
        if i != j:
            edges.append([i,j]) 
# print(edges)

In [119]:
np.array(edges).shape

(1892, 2)

In [97]:
# from itertools import permutations
# list(permutations(list(node_ids.values()), 2))

In [22]:
len(df['date']) # time

18250

In [115]:
df2

Unnamed: 0,북춘천,철원,대관령,춘천,백령도,북강릉,강릉,서울,인천,원주,...,순창군,북창원,양산시,보성군,강진군,의령군,함양군,광양시,청송군,경주시
0,1.0,0.962367,0.909826,0.985657,0.871681,0.900659,0.890927,0.943081,0.936899,0.944473,...,0.879136,0.853483,0.858169,0.863974,0.859685,0.866783,0.875598,0.865172,0.883101,0.848641
1,0.962367,1.0,0.890696,0.960057,0.886264,0.884161,0.874764,0.945527,0.941293,0.930826,...,0.868607,0.841086,0.846036,0.855756,0.855439,0.85131,0.863689,0.854172,0.866795,0.834444
2,0.909826,0.890696,1.0,0.905515,0.819283,0.953088,0.947341,0.884961,0.882525,0.921151,...,0.877524,0.872609,0.877105,0.868256,0.857598,0.879016,0.890071,0.870398,0.91697,0.888561
3,0.985657,0.960057,0.905515,1.0,0.874211,0.898639,0.888139,0.943508,0.938434,0.942972,...,0.88094,0.852272,0.857021,0.864633,0.860694,0.866294,0.875104,0.864782,0.88234,0.846321
4,0.871681,0.886264,0.819283,0.874211,1.0,0.820746,0.813769,0.875487,0.898814,0.841837,...,0.832971,0.80322,0.803683,0.82839,0.8302,0.816005,0.825798,0.824297,0.805963,0.786086
5,0.900659,0.884161,0.953088,0.898639,0.820746,1.0,0.977052,0.87656,0.878769,0.908658,...,0.874547,0.870073,0.873436,0.867076,0.856008,0.878188,0.880953,0.865367,0.907306,0.88659
6,0.890927,0.874764,0.947341,0.888139,0.813769,0.977052,1.0,0.867627,0.869775,0.90069,...,0.867068,0.863608,0.865448,0.86131,0.848392,0.871227,0.87462,0.858868,0.90092,0.877792
7,0.943081,0.945527,0.884961,0.943508,0.875487,0.87656,0.867627,1.0,0.959453,0.934546,...,0.876382,0.847805,0.847221,0.857383,0.858097,0.85502,0.865729,0.858174,0.867781,0.834465
8,0.936899,0.941293,0.882525,0.938434,0.898814,0.878769,0.869775,0.959453,1.0,0.923601,...,0.884033,0.854269,0.855463,0.874037,0.870921,0.867257,0.878578,0.870185,0.870756,0.843745
9,0.944473,0.930826,0.921151,0.942972,0.841837,0.908658,0.90069,0.934546,0.923601,1.0,...,0.898942,0.869611,0.872624,0.879347,0.874973,0.881663,0.8935,0.878524,0.905335,0.867475


In [120]:
weights = []    
 
for i in range(44):
    for j in range(44):
        if i != j:
            weights.append(df2.iloc[i,j]) 

In [114]:
np.array(weights).shape

(1892,)

In [157]:
FX = []    
for i in range(18250):
    FX.append(list(df.iloc[i,:])) 
#FX

In [158]:
np.array(FX).shape

(18250, 44)

`-` weights, edges, node_ids, FX

In [159]:
data_dict = {'edges':edges, 'node_ids':node_ids, 'weights':weights, 'FX':FX}

In [161]:
data_dict.keys()

dict_keys(['edges', 'node_ids', 'weights', 'FX'])

In [170]:
file_path = './data/solar.json'

In [171]:
with open(file_path, 'w') as f:
    json.dump(data_dict, f)

In [172]:
with open(file_path, 'r') as f:
    test = json.load(f, encoding='cp949')

In [169]:
# json_data = json.dumps(data_dict, ensure_ascii=False)
# json_data

- ${\bf f}=\begin{bmatrix} {\bf f}_1\\ {\bf f}_2\\ \dots \\ {\bf f}_{521} \end{bmatrix}=\begin{bmatrix} f(t=1,v=\tt{BACS}) & \dots & f(t=1,v=\tt{ZALA}) \\ f(t=2,v=\tt{BACS}) & \dots & f(t=2,v=\tt{ZALA}) \\ \dots & \dots & \dots \\ f(t=521,v=\tt{BACS}) & \dots & f(t=521,v=\tt{ZALA}) \end{bmatrix}$