# Setup

In [None]:
!pip install torch===2.3.1
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.3.1+{CUDA}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.3.1+{CUDA}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-2.3.1+{CUDA}.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.3.1+{CUDA}.html
!pip install torch-geometric

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Graph Dataset Defintion and Feature Engineering

In [None]:
import torch
import math
import csv
import os
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from typing import List, Tuple
from torch.utils.data import Dataset
from tqdm import tqdm
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import laplacian

class MyDataset(Dataset):
    def __init__(self, root: str, desti: str, market: str, comlist: List[str], start: str, end: str, window: int, dataset_type: str, k: float, threshold: float):
        super().__init__()

        self.comlist = comlist
        self.market = market
        self.root = root
        self.desti = desti
        self.start = start
        self.end = end
        self.window = window
        self.dates, self.next_day = self.find_dates(start, end, root, comlist, market)
        self.dataset_type = dataset_type
        self.k = k
        self.threshold = threshold
        graph_files_exist = all(os.path.exists(os.path.join(desti, f'{market}_{dataset_type}_{start}_{end}_{window}_{k}_{threshold}/graph_{i}.pt')) for i in range(len(self.dates) - window + 1))
        if not graph_files_exist:
            self._create_graphs(self.dates, desti, comlist, market, root, window)

    def __len__(self):
        return len(self.dates) - self.window + 1

    def __getitem__(self, idx: int):
        directory_path = os.path.join(self.desti, f'{self.market}_{self.dataset_type}_{self.start}_{self.end}_{self.window}_{self.k}_{self.threshold}')
        data_path = os.path.join(directory_path, f'graph_{idx}.pt')
        if os.path.exists(data_path):
            return torch.load(data_path)
        else:
            raise FileNotFoundError(f"No graph data found for index {idx}")

    def check_years(self, date_str: str, start_str: str, end_str: str) -> bool:
        date_format = "%Y-%m-%d"
        # strptime() can turn a str to a datetime object, which allows the comparsions among datetime objects
        date = datetime.strptime(date_str, date_format)
        start = datetime.strptime(start_str, date_format)
        end = datetime.strptime(end_str, date_format)
        return start <= date <= end

    def find_dates(self, start: str, end: str, path: str, comlist: List[str], market: str) -> Tuple[List[str], str]:
        # Find the common date
        date_sets = []
        after_end_date_sets = []

        for h in comlist:
            dates = set()
            after_end_dates = set()
            d_path = os.path.join(path,market, f'{h}.csv')

            with open(d_path) as f:
              file = csv.reader(f)
              next(file, None)  # Skip the header row
              for line in file:
                  date_str = line[0][:10]

                  if self.check_years(date_str, start, end):
                      dates.add(date_str)
                  elif self.check_years(date_str, end, '2024-07-11'):
                      after_end_dates.add(date_str)

            date_sets.append(dates)
            after_end_date_sets.append(after_end_dates)

        all_dates = list(set.intersection(*date_sets))
        all_after_end_dates = list(set.intersection(*after_end_date_sets))
        next_common_day = min(all_after_end_dates) if all_after_end_dates else None

        return sorted(all_dates), next_common_day

    def signal_energy(self, x_tuple: Tuple[float]) -> float:
        x = np.array(x_tuple)
        return np.sum(np.square(x))

    def adjacency(self, X: torch.Tensor) -> torch.sparse_coo_tensor:
        energies = torch.tensor([self.signal_energy(tuple(x)) for x in X])
        A = torch.zeros((X.shape[0], X.shape[0]))
        k = self.k

        for i in range(X.shape[0]):
            energy_diffs = torch.exp(-torch.abs(energies[i] - energies)/(k * self.window))

            energy_sum = energy_diffs.sum()

            # Probability of connection i -> j
            A[i] = energy_diffs / energy_sum

        A[A < self.threshold] = 0

        row_indices, col_indices = A.nonzero(as_tuple=True)
        values = A[row_indices, col_indices]

        A_coo = torch.sparse_coo_tensor(torch.stack([row_indices, col_indices]), values, A.size())

        return A_coo

    def random(self, X: torch.Tensor) -> torch.sparse_coo_tensor:
        num_nodes = X.shape[0]
        A = torch.rand((num_nodes, num_nodes))

        A[A < 0.5] = 0

        row_indices, col_indices = A.nonzero(as_tuple=True)
        values = A[row_indices, col_indices]
        A_coo = torch.sparse_coo_tensor(torch.stack([row_indices, col_indices]), values, A.size())

        return A_coo
        
    def node_feature_matrix(self, dates: List[str], comlist: List[str], market: str, path: str) -> torch.Tensor:
        # [7, len(comlist), len(dates)]
        # 7 features: 4 log price features + 3 derived features (oc, gap, hl_range)
        dates_dt = pd.to_datetime([d for d in dates if d is not None])
        X = torch.zeros((7, len(comlist), len(dates_dt)), dtype=torch.float32)

        eps = 1e-8 
        for idx, h in enumerate(comlist):
            d_path = os.path.join(path, market, f'{h}.csv')
            df = pd.read_csv(d_path, parse_dates=['Date'], index_col='Date').reindex(dates_dt)

            # log price
            logC = np.log(df['Adj Close'].astype(float).clip(lower=eps))
            logH = np.log(df['High'].astype(float).clip(lower=eps))
            logL = np.log(df['Low'].astype(float).clip(lower=eps))
            logO = np.log(df['Open'].astype(float).clip(lower=eps))

            r_close = logC.diff().fillna(0.0)  
            r_high = logH.diff().fillna(0.0)  
            r_low = logL.diff().fillna(0.0)   
            r_open = logO.diff().fillna(0.0)

            # Open-to-Close Return
            # Overnight Gap
            # High-Low Range
            oc = (logC - logO)                        
            gap = (logO - logC.shift(1)).fillna(0.0)  
            hl_range = (logH - logL).clip(lower=0.0)      

            X[:, idx, :] = torch.from_numpy(
                np.vstack([
                    r_close.values, r_high.values, r_low.values, r_open.values,
                    oc.values, gap.values, hl_range.values
                ])
            ).to(torch.float32)

        return X

    def z_score(self, x: torch.Tensor) -> torch.Tensor:
        return (x - x.mean()) / x.std()

    def _create_graphs(self, dates: List[str], desti: str, comlist: List[str], market: str, root: str, window: int):
        dates.append(self.next_day)

        for i in tqdm(range(min(len(dates) - window, 120))):
            directory_path = os.path.join(desti, f'{market}_{self.dataset_type}_{self.start}_{self.end}_{window}_{self.k}_{self.threshold}')
            filename = os.path.join(directory_path, f'graph_{i}.pt')

            if os.path.exists(filename):
                print(f"Graph {i}/{len(dates) - window + 1} already exists, skipping...")
                continue

            print(f'Generating graph {i}/{len(dates) - window + 1}...')

            box = dates[i:i + window + 1]
            X = self.node_feature_matrix(box, comlist, market, root)
            C = (X[0, :, -1] > 0).to(torch.long)

            X = X[:,:,:-1]
            X = X.permute(1, 2, 0)
            X = X.reshape(X.shape[0], -1)
            X = self.z_score(X)
            A = self.adjacency(X)

            os.makedirs(directory_path, exist_ok=True)

            torch.save({'X': X, 'A': A, 'Y': C}, filename)

# Create Graph Dataset

In [None]:
root = 'PATH TO YOUR DATA DIRECTORY'
desti = 'PATH TO YOUR OUTPUT DIRECTORY'

markets = ['LSE'] # Market options: 'FTSE', 'LSE', 'NASDAQ', 'NYSE', 'SP'
start_train = '2021-07-12'
end_train = '2023-07-11'
start_test = '2023-07-12'
end_test = '2024-07-11'
dataset_type = 'train'
dataset_trains = []
dataset_tests = []

# Hyperparameters
k_value = 0.08
window_size = 14
s = 0.55

for market in markets:
  comlist = []
  folder_path = f'{root}/{market}'
  for filename in os.listdir(folder_path):
    if filename.endswith('.csv'):
      company_name = filename[:-4]
      comlist.append(company_name)
  comlist = list(dict.fromkeys(comlist))
  dataset_train = MyDataset(root, desti, market, comlist, start_train, end_train, window_size, dataset_type, k_value, s)
  dataset_trains.append(dataset_train)

  dataset_test = MyDataset(root, desti, market, comlist, start_test, end_test, window_size, 'test', k_value, s)
  dataset_tests.append(dataset_test)

# Model Definition

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Data
import torch_geometric.transforms as T
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

class GATABlock(nn.Module):
    def __init__(self, in_dim, out_dim, class_dim, num_heads, time_stamp, window, is_first_block=False, is_concat=True):
        super().__init__()
        self.conv = GATv2Conv(in_dim, out_dim // num_heads, heads=num_heads, concat=True)
        self.time_stamp = time_stamp
        self.transfer_layer = nn.Linear(out_dim, time_stamp)
        self.is_first_block = is_first_block
        self.is_concat = is_concat
        self.window = window
        self.attention = nn.MultiheadAttention(out_dim // window, num_heads, batch_first=True)

        if not is_first_block:
            self.skip_layer = nn.Linear(in_dim, out_dim)
        else:
            self.activation = nn.PReLU()
            self.activation_att = nn.PReLU()

        self.temp_linear = nn.Linear(out_dim + time_stamp, out_dim) if is_concat else nn.Identity()

    def forward(self, x, edge_index, h_prime):
        h = self.conv(x, edge_index)

        if not self.is_first_block:
            h += self.skip_layer(x)
        else:
            h = self.activation(h)

        h_shape = h.shape[1]
        if self.is_concat and h_prime is not None:
            h_att = torch.cat((h_prime, h), dim=1)
            h_att = self.temp_linear(h_att)
        else:
            h_att = h

        h_att = h_att.view(x.size(0), self.window, -1)
        h_att, w = self.attention(h_att, h_att, h_att)
        if self.is_first_block:
            h_att = self.activation_att(h_att)
        h_att = h_att.reshape(x.size(0), -1)
        h_att = self.transfer_layer(h_att)

        return h, h_att, w

class GATA(nn.Module):
    def __init__(self, num_blocks, dim_list, class_dim, num_heads, window, time_stamp, is_concat=True):
        super().__init__()
        self.blocks = nn.ModuleList()
        self.num_blocks = num_blocks
        self.time_stamp = time_stamp
        self.window = window
        self.class_dim = class_dim
        self.final_layer = nn.Linear(time_stamp, class_dim)
        self.is_concat = is_concat

        for i in range(num_blocks):
            is_first_block = (i == 0)
            self.blocks.append(GATABlock(dim_list[i], dim_list[i+1], class_dim, num_heads, time_stamp, window, is_first_block, is_concat))

    def forward(self, data):
        x, edge_index = data['X'], data['A']
        h_list = []
        w_list = []
        h_prime = None

        h = x
        for block in self.blocks:
            h, h_att, w = block(h, edge_index, h_prime)
            h_list.append(h_att)
            w_list.append(w)
            h_prime = h_att

        h_final = torch.sum(torch.stack([h for h in h_list]), dim=0)
        h_final = self.final_layer(h_final)

        return h_final, w_list

# Training

In [None]:
from sklearn.metrics import accuracy_score, matthews_corrcoef, f1_score
dim_head = [1, 18, 12, 6]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attention_maps = {}

for j in range(1):
  dataset_train = dataset_trains[j]
  dataset_test = dataset_tests[j]
  time_stamp = dataset_train[0]['X'].shape[1]
  window = time_stamp // 7
  dim_list = [time_stamp*rate for rate in dim_head]
  # GATA(num_blocks, dim_list, class_dim, num_heads, window, time_stamp)
  model = GATA(3, dim_list, 2, 14, window, time_stamp, True).to(device)
  for i in range(1):
      dir = 'PATH TO YOUR LOG DIRECTORY'
      os.makedirs(dir, exist_ok=True)
      log_path = os.path.join(dir, f'training_attention_map.txt')
      log_file = open(log_path, 'w')
      optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=6e-4)
      scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500, eta_min=1e-5)
      min_loss = 1
      max_acc = 0
      max_test = 0
      all_true_labels = []
      all_pred_labels = []
      for epoch in range(500):

          model.train()
          objective = 0
          optimizer.zero_grad()

          for idx, i in enumerate(dataset_train):
            if idx > 99:
              break
            else:
                data = {k: v.to(device) for k, v in i.items()}
                data['Y'] = data['Y'].to(torch.long)
                out, _ = model(data)
                loss = F.cross_entropy(out, data['Y'])
                objective += loss

          objective = objective / 100
          objective.backward()
          optimizer.step()
          scheduler.step()

          model.eval()
          acc_t = 0.0
          test_a = 0.0
          true_labels = []
          pred_labels = []
          with torch.no_grad():
            for idx, i in enumerate(dataset_train):
              if idx > 99:
                break
              else:
                data = {k: v.to(device) for k, v in i.items()}
                data['Y'] = data['Y'].to(torch.long)
                pred, _ = model(data)
                pred = pred.argmax(dim=1)
                correct = (pred == data['Y']).sum().item()
                acc_t += int(correct) / int(data['Y'].shape[0])

          model.eval()
          with torch.no_grad():
            for idx, i in enumerate(dataset_test):
              if idx > 19:
                break
              else:
                data_test = {k: v.to(device) for k, v in i.items()}
                data_test['Y'] = data_test['Y'].to(torch.long)
                pred_test, w_list = model(data_test)
                pred_test = pred_test.argmax(dim=1)
                correct_pred = (pred_test == data_test['Y']).sum().item()
                test_a += correct_pred / data_test['Y'].shape[0]
                true_labels.extend(data_test['Y'].cpu().numpy())
                pred_labels.extend(pred_test.cpu().numpy())


          avg_test = test_a / 20
          avg_acc = acc_t / 100
          min_loss = objective if objective < min_loss else min_loss
          max_acc = avg_acc if avg_acc > max_acc else max_acc
          max_test = avg_test if avg_test > max_test else max_test
          all_true_labels.extend(true_labels)
          all_pred_labels.extend(pred_labels)

          if epoch == 0 or epoch == 399:
              attention_maps[epoch] = w_list

          log_file.write(f'Epoch {epoch+1}: Accuracy={avg_acc:.4f}, Loss={objective.item():.4f}, Test={avg_test:.4f}\n')
          print(f'Epoch {epoch+1}: Accuracy={avg_acc:.4f}, Loss={objective.item():.4f}, Test={avg_test:.4f}, Max={max_test:.4f}')

      final_acc_test = accuracy_score(all_true_labels, all_pred_labels) * 100
      final_mcc_test = matthews_corrcoef(all_true_labels, all_pred_labels)
      final_f1_test = f1_score(all_true_labels, all_pred_labels, average='macro')

      log_file.write(f'Final Test Results: ACC={final_acc_test:.2f}%, MCC={final_mcc_test:.4f}, F1-Score={final_f1_test:.4f}\n')
      print(f'Final Test Results: ACC={final_acc_test:.2f}%, MCC={final_mcc_test:.4f}, F1-Score={final_f1_test:.4f}')

      print(f'Max Accuracy={max_acc:.4f}, Min Loss={min_loss:.4f}, Max Test={max_test:.4f}')
      log_file.write(f'Max Accuracy={max_acc:.4f}, Min Loss={min_loss:.4f}, Max Test={max_test:.4f}\n')
      log_file.close()