In [9]:
from awpy import Demo

import torch
from torch_geometric_temporal.signal import DynamicHeteroGraphTemporalSignal

import pandas as pd
import polars as pl
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

from termcolor import colored
import time
import json
import pickle
import sys
import os

pd.set_option('display.max_columns', 100)
pd.set_option('future.no_silent_downcasting', True)

pl.Config.set_tbl_rows(10)

sys.path.append(os.path.abspath('../../package'))

from CS2.graph import TabularGraphSnapshot, HeteroGraphSnapshot, TemporalHeteroGraphSnapshot
from CS2.token import Tokenizer
from CS2.preprocess import Dictionary, NormalizePosition, NormalizeTabularGraphSnapshot, ImputeTabularGraphSnapshot
from CS2.visualize import HeteroGraphVisualizer

DATA_PATH = '../../data/matches-processed/cs2/temporal-hetero-graph/'
DATA_SAVE_PATH = '../../data/matches-processed/cs2/temporal-hetero-graph/concat/'
PROCESS_SAVE_PATH = './parses/temp-hetero-parse-2024.10.13/'

### Concat single matches

In [18]:
files = [
    '100070.pt',
    '100071.pt',
    '100072.pt',
    '100073.pt',
    '100074.pt',
    '100075.pt',
    '100076.pt',
    '100077.pt',
    '100078.pt',
    '100079.pt',
]

lens = 0
concat_list = []

for match in files:
    match = torch.load(os.path.join(DATA_PATH, match), weights_only=False)
    lens += len(match)
    concat_list.extend(match)

print(lens)
print(len(concat_list))

torch.save(concat_list, DATA_SAVE_PATH + '100070-100079.pt')

8266
8266


In [19]:
files = [
    '100080.pt',
    '100081.pt',
    '100082.pt',
    '100083.pt',
    '100084.pt',
    '100085.pt',
    '100086.pt',
    '100087.pt',
    '100088.pt',
    '100089.pt',
]

lens = 0
concat_list = []

for match in files:
    match = torch.load(os.path.join(DATA_PATH, match), weights_only=False)
    lens += len(match)
    concat_list.extend(match)

print(lens)
print(len(concat_list))

torch.save(concat_list, DATA_SAVE_PATH + '100080-100089.pt')

8956
8956


In [20]:
files = [
    '100090.pt',
    '100091.pt',
    '100092.pt',
    '100093.pt',
    '100094.pt',
    '100095.pt',
    '100096.pt',
    '100097.pt',
    '100098.pt',
    '100099.pt',
]

lens = 0
concat_list = []

for match in files:
    match = torch.load(os.path.join(DATA_PATH, match), weights_only=False)
    lens += len(match)
    concat_list.extend(match)

print(lens)
print(len(concat_list))

torch.save(concat_list, DATA_SAVE_PATH + '100090-100099.pt')

9468
9468


In [21]:
files = [
    '100100.pt',
    '100101.pt',
    '100102.pt',
    '100103.pt',
    '100104.pt',
    '100105.pt',
    '100106.pt',
    '100107.pt',
    '100108.pt',
    '100109.pt',
]

lens = 0
concat_list = []

for match in files:
    match = torch.load(os.path.join(DATA_PATH, match), weights_only=False)
    lens += len(match)
    concat_list.extend(match)

print(lens)
print(len(concat_list))

torch.save(concat_list, DATA_SAVE_PATH + '100100-100109.pt')

9395
9395


In [22]:
files = [
    '100110.pt',
    '100111.pt',
    '100112.pt',
    '100113.pt',
]

lens = 0
concat_list = []

for match in files:
    match = torch.load(os.path.join(DATA_PATH, match), weights_only=False)
    lens += len(match)
    concat_list.extend(match)

print(lens)
print(len(concat_list))

torch.save(concat_list, DATA_SAVE_PATH + '100110-100113.pt')

3010
3010


### Create train dataset

In [24]:
file_names = [
    '100000-100009.pt',
    '100010-100019.pt',
    '100020-100029.pt',
    '100030-100039.pt',
    '100040-100049.pt',
    '100050-100059.pt',
    '100060-100069.pt',
    '100070-100079.pt',
]

train_data = []

for file in file_names:

    match_graphs = torch.load(DATA_SAVE_PATH + file, weights_only=False)
    train_data.extend(match_graphs)
    print(f'Loaded {file}')

torch.save(train_data, DATA_PATH + '/train.pt')
    

Loaded 100000-100009.pt
Loaded 100010-100019.pt
Loaded 100020-100029.pt
Loaded 100030-100039.pt
Loaded 100040-100049.pt
Loaded 100050-100059.pt
Loaded 100060-100069.pt
Loaded 100070-100079.pt


In [25]:
file_names = [
    '100080-100089.pt',
    '100090-100099.pt',
    '100110-100113.pt',
]

val_data = []

for file in file_names:

    match_graphs = torch.load(DATA_SAVE_PATH + file, weights_only=False)
    val_data.extend(match_graphs)
    print(f'Loaded {file}')

torch.save(val_data, DATA_PATH + '/val.pt')
    

Loaded 100080-100089.pt
Loaded 100090-100099.pt
Loaded 100110-100113.pt


In [29]:
len(train_data)

67269

# Test

In [2]:
test_data = torch.load(DATA_PATH + '100000.pt', weights_only=False)

In [4]:
thgs = TemporalHeteroGraphSnapshot()
dyn_graphs = thgs.process_match(test_data, interval=20, shifted_intervals=True)