In [6]:
from awpy import Demo

import torch
from torch_geometric_temporal.signal import DynamicHeteroGraphTemporalSignal

import pandas as pd
import polars as pl
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

from termcolor import colored
import time
import json
import pickle
import random
import sys
import os

pd.set_option('display.max_columns', 100)
pd.set_option('future.no_silent_downcasting', True)

pl.Config.set_tbl_rows(10)

sys.path.append(os.path.abspath('../../package'))

from CS2.graph import TabularGraphSnapshot, HeteroGraphSnapshot, TemporalHeteroGraphSnapshot
from CS2.token import Tokenizer
from CS2.preprocess import Dictionary, NormalizePosition, NormalizeTabularGraphSnapshot, ImputeTabularGraphSnapshot
from CS2.visualize import HeteroGraphVisualizer

DATA_PATH = '../../data/matches-processed/cs2/temporal-hetero-graph/'
DATA_SAVE_PATH = '../../data/matches-processed/cs2/temporal-hetero-graph-round/'
PROCESS_SAVE_PATH = './parses/temp-hetero-parse-2024.11.06_round/'

# Path constants
PATH_TEMP_GRAPH_DATA_20 = '../../data/matches-processed/cs2/temporal-hetero-graph/temporal-hetero-graph_20/'
PATH_TEMP_GRAPH_DATA_60 = '../../data/matches-processed/cs2/temporal-hetero-graph/temporal-hetero-graph-start_end_60/'
PATH_TEMP_GRAPH_DATA_R = '../../data/matches-processed/cs2/temporal-hetero-graph/temporal-hetero-graph-round/'


# Create Temporal dataset

In [9]:
hetero_graph_matches = [f for f in os.listdir(DATA_PATH) if os.path.isfile(os.path.join(DATA_PATH, f)) and f not in ['train.pt', 'val.pt']]
processed_matches = []

dataset_lengths = []
overall_length = 0

# Load processed matches
if os.path.exists(os.path.join(PROCESS_SAVE_PATH, 'process.txt')):
    with open(os.path.join(PROCESS_SAVE_PATH, 'process.txt'), 'r') as f:
        processed_matches = f.readlines()

# Remove whitespace characters like `\n` at the end of each line
for match_idx in range(len(processed_matches)):
    processed_matches[match_idx] = processed_matches[match_idx].strip()


for file in hetero_graph_matches:

    if file in processed_matches:
        print(colored(f'{file} already processed. Skipping...', 'yellow'))
        continue

    print(colored(f'Processing {file}...', 'light_blue'))

    match = torch.load(DATA_PATH + file, weights_only=False)
    thgs = TemporalHeteroGraphSnapshot()
    dyn_graphs = None
    dyn_graphs = thgs.process_match(match, interval=60, round_process_strategy='round')

    dataset_lengths.append(len(dyn_graphs))
    overall_length += len(dyn_graphs)

    print('DTDG Dataset Length:', len(dyn_graphs))
    print(colored(f'{file} processed.', 'green'))

    torch.save(dyn_graphs, DATA_SAVE_PATH + file)

    with open(os.path.join(PROCESS_SAVE_PATH, 'process.txt'), 'a') as f:
        f.write(f'{file}\n')

[94mProcessing 100000.pt...[0m
DTDG Dataset Length: 18
[32m100000.pt processed.[0m
[94mProcessing 100001.pt...[0m
DTDG Dataset Length: 21
[32m100001.pt processed.[0m
[94mProcessing 100002.pt...[0m
DTDG Dataset Length: 24
[32m100002.pt processed.[0m
[94mProcessing 100003.pt...[0m
DTDG Dataset Length: 20
[32m100003.pt processed.[0m
[94mProcessing 100004.pt...[0m
[1m[31mError:[0mError: There are missing ticks in the graph sequence. The error occured while parsing match 100004.0 at round                         0.1666666716337204. Skipping the round.
DTDG Dataset Length: 20
[32m100004.pt processed.[0m
[94mProcessing 100005.pt...[0m
DTDG Dataset Length: 21
[32m100005.pt processed.[0m
[94mProcessing 100006.pt...[0m
DTDG Dataset Length: 20
[32m100006.pt processed.[0m
[94mProcessing 100007.pt...[0m
DTDG Dataset Length: 18
[32m100007.pt processed.[0m
[94mProcessing 100008.pt...[0m
DTDG Dataset Length: 24
[32m100008.pt processed.[0m
[94mProcessing 100009.p

In [12]:
print(overall_length)

98098


# Create temporal dataset with varied length dyxnamic graphs

In [11]:
train_data = torch.load(PATH_TEMP_GRAPH_DATA_R + 'train.pt', weights_only=False)
random.shuffle(train_data)
print('Train data length:', len(train_data))

Train data length: 1490


In [12]:
short_train_data = train_data[:750]

In [13]:
torch.save(short_train_data, DATA_PATH + '/train_r.pt')

# Test

In [2]:
test_data = torch.load(DATA_PATH + '100000.pt', weights_only=False)

In [4]:
thgs = TemporalHeteroGraphSnapshot()
dyn_graphs = thgs.process_match(test_data, interval=20, shifted_intervals=True)