# Heterogeneous GNN

In [63]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch_geometric.data import HeteroData, DataLoader, Dataset
from torch_geometric.nn import RGCNConv
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, to_hetero, GCNConv
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, GraphConv, Linear
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import trim_to_layer
from torch_geometric.loader import DataLoader
from torch.nn.functional import normalize
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

class InfernoDataset(Dataset):
    def __init__(self, data_list):
        super(InfernoDataset, self).__init__()
        self.data_list = data_list

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        return self.data_list[idx]

In [91]:
dataset_raw = torch.load('data/inferno_graph_dataset.pt')
dataset_raw.len()

# dataset2 = torch.load('data/inferno_graph_dataset_2.pt')
# dataset2.len()

# dataset = torch.load('data/inferno_graph_dataset_3.pt')
# dataset.len()

# Lengths of the datasets
# 64686+165182+61479

64686

In [92]:
def format_dataset(dataset):
  for data in dataset:
    data['player'].x = data['player'].x[:,:-42]
    data['player'].x = F.normalize(data['player'].x, p=2, dim=0)
    data['map'].x = F.normalize(data['map'].x, p=2, dim=0)
    if data['player'].x.shape == torch.Size([10,43]):
      data['player'].x = torch.cat((data['player'].x[:, : -3], torch.zeros((10,1)), data['player'].x[:, -3:]), dim=1)
  return dataset

dataset = format_dataset(dataset_raw)

data = dataset[0]
data

HeteroData(
  y={
    roundNum=1.0,
    sec=0.0,
    team1AliveNum=5.0,
    team2AliveNum=5.0,
    CTwinsRound=1
  },
  [1mplayer[0m={ x=[10, 44] },
  [1mmap[0m={ x=[181, 3] },
  [1m(map, connected_to, map)[0m={ edge_index=[2, 204] },
  [1m(player, closest_to, map)[0m={ edge_index=[2, 10] }
)

In [67]:
y = []
for data in dataset:
    y.append(data.y['CTwinsRound'].item())
pd.DataFrame(y).value_counts()

1    34022
0    30664
Name: count, dtype: int64

In [114]:
class HeterogeneousGNN(torch.nn.Module):

    def __init__(self, hidden_channels, num_layers, edge_types):
        super().__init__()
        
        torch.manual_seed(42)
        
        self.convs = torch.nn.ModuleList()
        for layernum in range(num_layers):
            conv = HeteroConv({
                    edge_type: SAGEConv((-1, -1), hidden_channels)
                    for edge_type in edge_types
                }, aggr='sum')
            self.convs.append(conv)
        self.lin1 = Linear(-1, 256)
        self.lin2 = Linear(256, 128)
        self.lin3 = Linear(128, 32)
        self.lin4 = Linear(32, 1)

    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            temp = conv(x_dict, edge_index_dict)
            x_dict['map'] = temp['map']
            x_dict = {key: x.relu() for key, x in x_dict.items()}
            
        x = torch.cat([torch.flatten(x_dict['player']), torch.flatten(x_dict['map'])])
        x = self.lin1(x).relu()
        #print(torch.sum(torch.isnan(x)))
        x = self.lin2(x).relu()
        x = self.lin3(x).relu()
        x = self.lin4(x).sigmoid()
        return x

model = HeterogeneousGNN(hidden_channels=20, num_layers=4, edge_types=data.edge_types)
print(model);

HeterogeneousGNN(
  (convs): ModuleList(
    (0-3): 4 x HeteroConv(num_relations=2)
  )
  (lin1): Linear(-1, 256, bias=True)
  (lin2): Linear(256, 128, bias=True)
  (lin3): Linear(128, 32, bias=True)
  (lin4): Linear(32, 1, bias=True)
)




In [115]:
train_loader = DataLoader(dataset[:64], batch_size=1, shuffle=True)
val_loader = DataLoader(dataset[64:128], batch_size=1, shuffle=True)
data = dataset[0].to('cuda')

model = HeterogeneousGNN(hidden_channels=20, num_layers=10, edge_types=data.edge_types).to('cuda')
optimizer = torch.optim.AdamW(model.parameters(), lr=0.1)
loss_function = torch.nn.BCELoss()

with torch.no_grad():  # Initialize lazy modules.
     out = model(data.x_dict, data.edge_index_dict)
     print(out)

tensor([0.6016], device='cuda:0')


In [110]:
def train():
    model.train()
    for data in train_loader:  # Iterate in batches over the training dataset.
        data.to('cuda')
        out = model(data.x_dict, data.edge_index_dict).to(torch.float32)  # Perform a single forward pass.
        target = data.y['CTwinsRound'].to(torch.float32)
        loss = loss_function(out, target)  # Compute the loss.
        optimizer.zero_grad()
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.


def validate(val_loader):
    model.eval()  # Átkapcsoljuk a modellt értékelési üzemmódba.
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        for data in val_loader:  # Iterálunk a validációs adatokon.
            data.to('cuda')
            out = model(data.x_dict, data.edge_index_dict).to(torch.float32)
            target = data.y['CTwinsRound'].to(torch.float32)
            loss = loss_function(out, target)
            total_loss += loss.item()
            total_samples += len(target)

            # Ellenőrizzük a helyes előrejelzéseket (például egy bináris probléma esetében).
            predictions = (out > 0.5).float()
            correct_predictions += (predictions == target).sum().item()

    # Kiszámítjuk az átlagos veszteséget és a pontosságot.
    avg_loss = total_loss / len(val_loader)
    accuracy = correct_predictions / total_samples

    return avg_loss, accuracy


In [None]:
data.to('cpu')
model.to('cpu')

In [116]:
for epoch in range(1, 10):
    train()
    train_acc = validate(val_loader)
    print('Epoch ', epoch, ': (avg_loss, accuracy) ', train_acc)

Epoch  1 : (avg_loss, accuracy)  (0.0, 1.0)
Epoch  2 : (avg_loss, accuracy)  (0.0, 1.0)
Epoch  3 : (avg_loss, accuracy)  (0.0, 1.0)
Epoch  4 : (avg_loss, accuracy)  (0.0, 1.0)
Epoch  5 : (avg_loss, accuracy)  (0.0, 1.0)
Epoch  6 : (avg_loss, accuracy)  (0.0, 1.0)
Epoch  7 : (avg_loss, accuracy)  (0.0, 1.0)
Epoch  8 : (avg_loss, accuracy)  (0.0, 1.0)
Epoch  9 : (avg_loss, accuracy)  (0.0, 1.0)
