In [None]:
#@title PyG Installation { form-width: "25%" }
# enter these commands in CLI to install Pytorch-Geometric
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git

In [None]:
#@title Module Imports { form-width: "20%" }
import pandas as pd
import torch
from torch.nn import Module,\
                     ModuleList,\
                     Embedding,\
                     BatchNorm1d,\
                     LogSoftmax,\
                     Softmax,\
                     Linear,\
                     NLLLoss,\
                     CrossEntropyLoss
from torch.optim import Adam
import torch.nn.functional as F
import torch_geometric as PyG
from torch_geometric.data import Data, HeteroData
from torch_geometric.nn.conv import RGCNConv, GINConv, GATConv, HeteroConv, GCNConv
from torch_geometric.utils import to_networkx
from collections import OrderedDict as od
import logging
import json
from typing import NoReturn
import typing

In [None]:
#@title Global Variables
# Global Values
WON = 0
LOST_TO = 1
TIED_WITH = 2
PLAYED_IN = 3
USED = 4
BEFORE = 5
AFTER = 6

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device is {DEVICE}')

Device is cpu


In [None]:
#@title GNN Model { form-width: "10%" }
class GNN(Module):
  def __init__(self, embedding_dims: tuple, conv_dims: list, fully_connected_dims: list, dropout: dict)-> NoReturn:
    super(GNN, self).__init__()

    self.mode = None # 'train' or 'test' or 'dev' later 
    self.output_dim = 3 #home_result: win, lose, tie
    self.num_relations = 7 #win/lose/tie/play/use/after/before
    self.dropout = dropout

    #one-hot to latent
    self.embed = Embedding(embedding_dims[0], embedding_dims[1])

    conv_list = [
                  RGCNConv(embedding_dims[1], conv_dims[0], self.num_relations)
                ] + \
                [
                  RGCNConv(conv_dims[i], conv_dims[i+1], self.num_relations)
                  for i in range(len(conv_dims[:-1]))
                ]
  
    batch_norm_list = [
                         BatchNorm1d(conv_dims[i])
                         for i in range(len(conv_dims[:-1]))
                      ]

    fully_connected_list =   [
                                Linear(2*conv_dims[-1], fully_connected_dims[0])
                             ] + \
                             [
                                Linear(fully_connected_dims[i], fully_connected_dims[i+1])
                                for i in range(len(fully_connected_dims[:-1]))
                             ] + \
                             [
                                Linear(fully_connected_dims[-1], self.output_dim)
                             ]
    #graph conv layers
    self.conv_layers = ModuleList(conv_list)
    #batch normalization layers
    self.batch_norm_layers = ModuleList(batch_norm_list)
    #fully connected dense layers
    self.fully_connected_layers = ModuleList(fully_connected_list)

    self.classifier = LogSoftmax()

    
  def reset_parameters(self):
        for conv in self.conv_layers:
            conv.reset_parameters()
        for bn in self.batch_norm_layers:
            bn.reset_parameters()
        for fc in self.fully_connected_layers:
            fc.reset_parameters()
          

  def forward(self, x:torch.Tensor, edge_index:torch.Tensor, edge_type:torch.Tensor, home_list:list, away_list:list) -> torch.Tensor:
    x = self.embed(x)
    if self.training:
      x = F.dropout(x, p=self.dropout["emb"])

    for conv, bn in zip(self.conv_layers[:-1], self.batch_norm_layers):
      x = conv(x, edge_index=edge_index, edge_type=edge_type)
      x = bn(x)
      x = F.relu(x)
      if self.training:
        x = F.dropout(x, p=self.dropout["conv"])


    x = self.conv_layers[-1](x, edge_index, edge_type)
    if self.training:
      x = F.dropout(x, p=self.dropout["conv"])

    ##################################### End of Encoder 

    pred = list()
    for home_team, away_team in zip(home_list, away_list):
      h = torch.cat((x[home_team], x[away_team]))

      for fc in self.fully_connected_layers[:-1]:
        h = fc(h)
        h = F.relu(h)
        if self.training:
          h = F.dropout(h, p=self.dropout["fc"])

      h = self.fully_connected_layers[-1](h)
      if self.training:
        h = F.dropout(h, p=self.dropout["fc"])
      pred.append(self.classifier(h))

    return torch.stack(pred)

In [None]:
#@title GNN Model { form-width: "10%" }
class HeteroGNN(Module):
  def __init__(self, embedding_dims: tuple, conv_dims: list, fully_connected_dims: list, dropout: dict)-> NoReturn:
    super(HeteroGNN, self).__init__()

    self.mode = None # 'train' or 'test' or 'dev' later 
    self.output_dim = 3 #home_result: win, lose, tie
    self.num_relations = 7 #win/lose/tie/play/use/after/before
    self.dropout = dropout

    #one-hot to latent
    self.embed = Embedding(embedding_dims[0], embedding_dims[1])
    
    conv_list = [
                  # GINConv(embedding_dims[1], conv_dims[0])
                  HeteroConv(
                      {
                          ('team', 'won', 'team'): GCNConv(embedding_dims[-1], conv_dims[0]),
                          ('team', 'lost_to', 'team'): GCNConv(embedding_dims[-1], conv_dims[0]),
                          ('team', 'tied_with', 'team'): GCNConv(embedding_dims[-1], conv_dims[0]),
                          ('player', 'played_for', 'team'): GATConv(embedding_dims[-1], conv_dims[0], heads=1),
                          ('team', 'used', 'player'): GATConv(embedding_dims[-1], conv_dims[0], heads=1),
                          ('player', 'is_before', 'player'): GCNConv(embedding_dims[-1], conv_dims[0]),
                          ('player', 'is_after', 'player'): GCNConv(embedding_dims[-1], conv_dims[0]),
                          ('team', 'is_before', 'team'): GCNConv(embedding_dims[-1], conv_dims[0]),
                          ('team', 'is_after', 'team'): GCNConv(embedding_dims[-1], conv_dims[0])
                      }, aggr='sum'
                  )
                ] + \
                [
                  HeteroConv(
                      {
                          ('team', 'won', 'team'): GCNConv(conv_dims[i], conv_dims[i+1]),
                          ('team', 'lost_to', 'team'): GCNConv(conv_dims[i], conv_dims[i+1]),
                          ('team', 'tied_with', 'team'): GCNConv(conv_dims[i], conv_dims[i+1]),
                          ('player', 'played_for', 'team'): GATConv(conv_dims[i], conv_dims[i+1], heads=1),
                          ('team', 'used', 'player'): GATConv(conv_dims[i], conv_dims[i+1], heads=1),
                          ('player', 'is_before', 'player'): GCNConv(conv_dims[i], conv_dims[i+1]),
                          ('player', 'is_after', 'player'): GCNConv(conv_dims[i], conv_dims[i+1]),
                          ('team', 'is_before', 'team'): GCNConv(conv_dims[i], conv_dims[i+1]),
                          ('team', 'is_after', 'team'): GCNConv(conv_dims[i], conv_dims[i+1])
                      }, aggr='sum'
                  )
                  for i in range(len(conv_dims[:-1]))
                ]


                # self.convs = torch.nn.ModuleList()
                #   for _ in range(num_layers):
                #       conv = HeteroConv({
                #           ('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels)
                #           ('author', 'writes', 'paper'): GATConv((-1, -1), hidden_channels)
                #           ('author', 'affiliated_with', 'institution'): SAGEConv((-1, -1), hidden_channels)
                #       }, aggr='sum')
                #       self.convs.append(conv)

  
    batch_norm_list = [
                         BatchNorm1d(conv_dims[i])
                         for i in range(len(conv_dims[:-1]))
                      ]

    fully_connected_list =   [
                                Linear(2*conv_dims[-1], fully_connected_dims[0])
                             ] + \
                             [
                                Linear(fully_connected_dims[i], fully_connected_dims[i+1])
                                for i in range(len(fully_connected_dims[:-1]))
                             ] + \
                             [
                                Linear(fully_connected_dims[-1], self.output_dim)
                             ]
    #graph conv layers
    self.conv_layers = ModuleList(conv_list)
    #batch normalization layers
    self.batch_norm_layers = ModuleList(batch_norm_list)
    #fully connected dense layers
    self.fully_connected_layers = ModuleList(fully_connected_list)

    self.classifier = LogSoftmax()

    
  # def reset_parameters(self):
  #       for conv in self.conv_layers:
  #           conv.reset_parameters()
  #       for bn in self.batch_norm_layers:
  #           bn.reset_parameters()
  #       for fc in self.fully_connected_layers:
  #           fc.reset_parameters()
          

  def forward(self, data: HeteroData) -> torch.Tensor:
    # x:torch.Tensor, edge_index:torch.Tensor, home_list:list, away_list:list
    x_dict = data.x_dict
    home_list = data.home_list
    away_list = data.away_list
    edge_index_dict = data.edge_index_dict

    # x_dict = conv(x_dict, edge_index_dict)
    # x_dict = {key: x.relu() for key, x in x_dict.items()}

    x_dict = {key: self.embed(x) for key, x in x_dict.items()}

    # edge_index = data.edge_index
    
    if self.training:
      x_dict = {key: F.dropout(x, p=self.dropout["emb"]) for key, x in x_dict.items()}
      # x = F.dropout(x, p=self.dropout["emb"])

    for conv, bn in zip(self.conv_layers[:-1], self.batch_norm_layers):
      x_dict = conv(x_dict, edge_index_dict=edge_index_dict)
      # x = conv(x, edge_index=edge_index)
      # x = bn(x)
      x_dict = {key: bn(x) for key, x in x_dict.items()}
      # x = F.relu(x)
      x_dict = {key: F.relu(x) for key, x in x_dict.items()}
      if self.training:
        # x = F.dropout(x, p=self.dropout["conv"])
        x_dict = {key: F.dropout(x, p=self.dropout["conv"]) for key, x in x_dict.items()}


    # x = self.conv_layers[-1](x, edge_index)
    x_dict = self.conv_layers[-1](x_dict, edge_index_dict=edge_index_dict)
    if self.training:
      # x = F.dropout(x, p=self.dropout["conv"])
      x_dict = {key: F.dropout(x, p=self.dropout["conv"]) for key, x in x_dict.items()}

    ##################################### End of Encoder 

    pred = list()
    # print(x_dict)
    for home_team, away_team in zip(home_list, away_list):
      h = torch.cat((x_dict['team'][home_team], x_dict['team'][away_team]))

      for fc in self.fully_connected_layers[:-1]:
        h = fc(h)
        h = F.relu(h)
        if self.training:
          h = F.dropout(h, p=self.dropout["fc"])

      h = self.fully_connected_layers[-1](h)
      if self.training:
        h = F.dropout(h, p=self.dropout["fc"])
      pred.append(self.classifier(h))

    return torch.stack(pred)

In [None]:
#@title home_result(row)
def home_result(row: str) -> int:
  if row == 'home':
    return WON
  elif row == 'tie':
    return TIED_WITH
  elif row == 'away':
    return LOST_TO

In [None]:
#@title remove_redundancy(players) { form-width: "15%" }
def remove_redundancy(players: list) -> list:
  new_players = list()

  for player in players:
    if 'Own' in player:
      player = player.replace('Own', '')
    if 'Pen. Scored' in player:
      player = player.replace('Pen. Scored', '')
    if 'Pen. Score' in player:
      player = player.replace('Pen. Score', '')
    if 'Own' in player or 'Scored' in player or 'Score' in player:
      print(player)
      #SHOULD NOT PRINT IF CODE IS CORRECT
    else:
      new_players.append(player.strip())
  return new_players

In [None]:
#@title extract_players(home_lineup, away_lineup) { form-width: "15%" }
def extract_players(home_lineup: str, away_lineup: str) -> list:
  home_players = home_lineup[:-2].split(' - ')
  away_players = away_lineup[:-2].split(' - ')
  
  return remove_redundancy(home_players), remove_redundancy(away_players)

In [None]:
#@title stats(df, show_players, show_teams, show_results) { form-width: "10%" }
def stats(df: pd.DataFrame, show_players: bool=False, show_teams: bool=False, show_results: bool=False) -> NoReturn:
  players_set = set()
  players_list = list()
  teams_set = set()

  teams_list = list()
  results = dict()
  for index, (h_team, a_team, result, h_lineup, a_lineup) in df.iterrows():
    home_players, away_players = extract_players(h_lineup, a_lineup)
    players_set.update(home_players + away_players)
    players_list.extend(home_players + away_players)
    if result == 'home':
      results.update({f'{h_team} #Wins': results.get(f'{h_team} #Wins', 0)+1})
      results.update({f'{a_team} #Losses': results.get(f'{a_team} #Losses', 0)+1})
    elif result == 'tie':
      results.update({f'{h_team} #Ties': results.get(f'{h_team} #Ties', 0)+1})
      results.update({f'{a_team} #Ties': results.get(f'{a_team} #Ties', 0)+1})
    else:
      results.update({f'{a_team} #Wins': results.get(f'{a_team} #Wins', 0)+1})
      results.update({f'{h_team} #Losses': results.get(f'{h_team} #Losses', 0)+1})

    teams_list.extend([h_team, a_team])
    teams_set.update([h_team, a_team])
    
  if show_players:
    for player in players_set:
      print(f'{player} played in {players_list.count(player)} matches.')
  if show_teams:
    for team in teams_set:
      print(f'{team} played {teams_list.count(team)} matches.')
  if show_results:
    results = od(sorted(results.items()))
    for key, val in results.items():
      print(f'{key}: {val}')

In [None]:
#@title extract_entities(df) { form-width: "15%" }
def extract_entities(df: pd.DataFrame) -> typing.Tuple[set, set]:
  players_set = set()
  players_list = list()
  teams_set = set()

  teams_list = list()
  # results = dict()
  for index, (h_team, a_team, result, h_lineup, a_lineup) in df.iterrows():
    home_players, away_players = extract_players(h_lineup, a_lineup)

    players_set.update(home_players + away_players)
    teams_set.update([h_team, a_team])
    
  
  return teams_set, players_set

In [None]:
#@title gen_entites(df) { form-width: "15%" }
def gen_entities(df: pd.DataFrame) -> dict:
  teams, players = extract_entities(df)
  entities = {entity: index for index, entity in enumerate(list(players) + list(teams))}
  return entities

In [None]:
#@title nodes_gen(df) OK_HETERO { form-width: "15%" }

def nodes_gen(df: pd.DataFrame) -> typing.Tuple[dict, dict]:
  player_nodes = dict()
  team_nodes = dict()
  player_node_counter = 0
  team_node_counter = 0

  for index, (h_team, a_team, result, h_lineup, a_lineup) in df.iterrows():
      home_players, away_players = extract_players(h_lineup, a_lineup)

      for player_index, player in enumerate(home_players):
        player_nodes[f'{player}@{index}'] = player_node_counter
        player_node_counter += 1
      for player_index, player in enumerate(away_players):
        player_nodes[f'{player}@{index}'] = player_node_counter
        player_node_counter += 1

      team_nodes[f'{h_team}*{index}'] = team_node_counter
      team_node_counter += 1

      team_nodes[f'{a_team}*{index}'] = team_node_counter
      team_node_counter += 1

  return player_nodes, team_nodes


In [None]:
#@title show_edges(df, edge, edge_type) USELESS { form-width: "15%" }
def show_edges(df: pd.DataFrame, edge: torch.Tensor, edge_type: torch.Tensor, tt:str) -> NoReturn:
  types = {
      0: 'Won',
      1: 'Lost To',
      2: 'Tied With',
      3: 'Played For',
      4: 'Used As Player',
      5: 'Is Before',
      6: 'Is After'
  }
  t = {'p': 0, 't':1}
  nodes = nodes_gen(df)[t[tt]]
  r = {k:v for v, k in nodes.items()}
  for i in range(edge_type.shape[0]):
    head = int(edge[0][i].item())
    tail = int(edge[1][i].item())
    relation = int(edge_type[i].item())
    arrow = f'=== {types[relation]} ===>'
    print(f'{r[head]:<32}   {arrow}   {r[tail]:>32}')

In [None]:
#@title home_won_gen(df) OK_HETERO { form-width: "15%" }
def home_won_gen(df: pd.DataFrame) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  home_winning_matches = df.loc[df['result'] == 'home']
  home_winners = home_winning_matches['home_team']
  away_losers = home_winning_matches['away_team']

  winning_hashes = list()
  losing_hashes = list()

  for home, away, match in zip(home_winners, away_losers, home_winners.index):
    winning_hashes.append(f'{home}*{match}')
    losing_hashes.append(f'{away}*{match}')

  winning_nodes = list()
  losing_nodes = list()

  _, team_nodes = nodes_gen(df)

  for winner, loser in zip(winning_hashes, losing_hashes):
    winning_nodes.append(team_nodes[winner]) 
    losing_nodes.append(team_nodes[loser])

  won_edges = torch.tensor(
      [
      winning_nodes,
      losing_nodes
      ], 
      dtype=torch.long,
      device=DEVICE
  )

  lost_edges = torch.tensor(
      [
      losing_nodes,
      winning_nodes
      ],
      dtype=torch.long,
      device=DEVICE
  )

  # won_edge_types = torch.ones(won_edges.shape[1], dtype=torch.long, device=DEVICE) * WON
  # lost_edge_types = torch.ones(lost_edges.shape[1], dtype=torch.long, device=DEVICE) * LOST_TO 

  # return won_edges, won_edge_types, lost_edges, lost_edge_types
  return won_edges, lost_edges

In [None]:
#@title away_won_gen(df) OK_HETERO { form-width: "15%" }
def away_won_gen(df: pd.DataFrame) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  away_winning_matches = df.loc[df['result'] == 'away']
  away_winners = away_winning_matches['away_team']
  home_losers = away_winning_matches['home_team']

  winning_hashes = list()
  losing_hashes = list()

  for home, away, match in zip(home_losers, away_winners, away_winners.index):
    winning_hashes.append(f'{away}*{match}')
    losing_hashes.append(f'{home}*{match}')

  winning_nodes = list()
  losing_nodes = list()

  _, team_nodes = nodes_gen(df)

  for winner, loser in zip(winning_hashes, losing_hashes):
    winning_nodes.append(team_nodes[winner]) 
    losing_nodes.append(team_nodes[loser])

  won_edges = torch.tensor(
      [
      winning_nodes,
      losing_nodes
      ],
      dtype=torch.long,
      device=DEVICE
  )

  lost_edges = torch.tensor(
      [
      losing_nodes,
      winning_nodes
      ],
      dtype=torch.long,
      device=DEVICE
  )
  
  # won_edge_types = torch.ones(won_edges.shape[1], dtype=torch.long, device=DEVICE) * WON
  # lost_edge_types = torch.ones(lost_edges.shape[1], dtype=torch.long, device=DEVICE) * LOST_TO 
  
  # return won_edges, won_edge_types, lost_edges, lost_edge_types
  return won_edges, lost_edges

In [None]:
#@title tied_gen(df) OK_HETERO { form-width: "15%" }
def tied_gen(df: pd.DataFrame) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  tied_matches = df.loc[df['result'] == 'tie']
  home_teams = tied_matches['home_team']
  away_teams = tied_matches['away_team']

  home_hashes = list()
  away_hashes = list()

  for home, away, match in zip(home_teams, away_teams, away_teams.index):
    away_hashes.append(f'{away}*{match}')
    home_hashes.append(f'{home}*{match}')

  home_nodes = list()
  away_nodes = list()

  _, team_nodes = nodes_gen(df)

  for home, away in zip(home_hashes, away_hashes):
    home_nodes.append(team_nodes[home]) 
    away_nodes.append(team_nodes[away])

  home_tied_edges = torch.tensor(
      [
      home_nodes,
      away_nodes
      ],
      dtype=torch.long,
      device=DEVICE
  )

  away_tied_edges = torch.tensor(
      [
      away_nodes,
      home_nodes
      ], 
      dtype=torch.long,
      device=DEVICE
  )

  # home_tied_edge_types = torch.ones(home_tied_edges.shape[1], dtype=torch.long, device=DEVICE) * TIED_WITH
  # away_tied_edge_types = torch.ones(away_tied_edges.shape[1], dtype=torch.long, device=DEVICE) * TIED_WITH

  # return home_tied_edges, home_tied_edge_types, away_tied_edges, away_tied_edge_types
  return home_tied_edges, away_tied_edges

In [None]:
#@title played_used_gen(df) OK_HETERO { form-width: "15%" }
def played_used_gen(df: pd.DataFrame) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  team_nodes = list()
  player_nodes = list()

  p_nodes, t_nodes = nodes_gen(df)

  for index, (h_team, a_team, result, h_lineup, a_lineup) in df.iterrows():
    home_players, away_players = extract_players(h_lineup, a_lineup)

    for home_player, away_player in zip(home_players, away_players):
      player_nodes.append(p_nodes[f'{home_player}@{index}'])
      team_nodes.append(t_nodes[f'{h_team}*{index}'])
      player_nodes.append(p_nodes[f'{away_player}@{index}'])
      team_nodes.append(t_nodes[f'{a_team}*{index}'])

  played_in_edges = torch.tensor(
      [
       player_nodes,
       team_nodes
      ],
      dtype=torch.long,
      device=DEVICE
  )

  # played_in_edge_types = torch.ones(played_in_edges.shape[1], dtype=torch.long, device=DEVICE) * PLAYED_IN

  used_edges = torch.tensor(
      [
       team_nodes,
       player_nodes
      ],
      dtype=torch.long,
      device=DEVICE
  ) 

  # used_edge_types = torch.ones(used_edges.shape[1], dtype=torch.long, device=DEVICE) * USED

  # return played_in_edges, played_in_edge_types, used_edges, used_edge_types
  return played_in_edges, used_edges

In [None]:
#@title players_before_after_gen(df) OK_HETERO { form-width: "15%" }
#TODO
def players_before_after_gen(df: pd.DataFrame) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  player_match_hashes = list()

  for index, (h_team, a_team, result, h_lineup, a_lineup) in df.iterrows():
      home_players, away_players = extract_players(h_lineup, a_lineup)

      for player in home_players + away_players:
        player_match_hashes.append(f'{player}@{index}')



  sorted_hashes = sorted(
      player_match_hashes,
      key=lambda w: (w.split('@')[0], int(w.split('@')[1]))
  )

  before_nodes = list()
  after_nodes = list()

  player_nodes, _ = nodes_gen(df)

  for index, hash in enumerate(sorted_hashes):
    player, match = hash.split('@')
    before_node = player_nodes[hash]
    try:
      after_node = player_nodes[sorted_hashes[index+1]]
      before_name = player_match_hashes[before_node].split('@')[0]
      after_name = player_match_hashes[after_node].split('@')[0]
      if before_name == after_name:
        before_nodes.append(before_node)
        after_nodes.append(after_node)
    except:
      pass
  before_edges = torch.tensor(
      [
      before_nodes,
      after_nodes
      ], dtype=torch.long,
      device=DEVICE
  )

  # before_edge_types = torch.ones(before_edges.shape[1], dtype=torch.long, device=DEVICE) * BEFORE

  after_edges = torch.tensor(
      [
      after_nodes,
      before_nodes
      ], dtype=torch.long,
      device=DEVICE
  )

  # after_edge_types = torch.ones(after_edges.shape[1], dtype= torch.long, device=DEVICE) * AFTER

  # return before_edges, before_edge_types, after_edges, after_edge_types
  return before_edges, after_edges

In [None]:
#@title teams_before_after_gen(df) OK_HETERO { form-width: "15%" }
def teams_before_after_gen(df: pd.DataFrame) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  team_match_hashes = list()

  for index, (h_team, a_team, result, h_lineup, a_lineup) in df.iterrows():
      team_match_hashes.append(f'{h_team}*{index}')
      team_match_hashes.append(f'{a_team}*{index}')

  sorted_hashes = sorted(
      team_match_hashes,
      key= lambda w: (w.split('*')[0], int(w.split('*')[1]))
  )

  before_nodes = list()
  after_nodes = list()

  _, team_nodes = nodes_gen(df)

  for index, hash in enumerate(sorted_hashes):
    team, match = hash.split('*')
    before_node = team_nodes[hash]
    try:
      after_node = team_nodes[sorted_hashes[index+1]]
      before_name = team_match_hashes[before_node].split('*')[0]
      after_name = team_match_hashes[after_node].split('*')[0]
      if before_name == after_name:
        before_nodes.append(before_node)
        after_nodes.append(after_node)
    except:
      pass
  before_edges = torch.tensor(
      [
      before_nodes,
      after_nodes
      ], dtype=torch.long,
      device=DEVICE
  )

  # before_edge_types = torch.ones(before_edges.shape[1], dtype=torch.long, device=DEVICE) * BEFORE

  after_edges = torch.tensor(
      [
      after_nodes,
      before_nodes
      ], dtype=torch.long,
      device=DEVICE
  )

  # after_edge_types = torch.ones(after_edges.shape[1], dtype=torch.long, device=DEVICE) * AFTER

  # return before_edges, before_edge_types, after_edges, after_edge_types
  return before_edges, after_edges

In [None]:
#@title complete_graph_gen(df, for_players, for_teams) OK_HETERO { form-width: "10%" }
def complete_graph_edge_gen(df: pd.DataFrame, for_players: bool=True, for_teams: bool=True) -> dict:
  home_won, away_lost = home_won_gen(df)
  away_won, home_lost = away_won_gen(df)
  home_tied, away_tied = tied_gen(df)
  player_played, team_used = played_used_gen(df)

  # edge_index = torch.cat(
  #       (home_win, away_lost, away_won, home_lost, home_tied, away_tied, player_played, team_used),
  #       dim=1
  #   )
  
  # edge_type = torch.cat(
  #       (won1, lost1, won2, lost2, tied1, tied2, played1, used1)
  #   )

  if for_players:
    player_before, player_after = players_before_after_gen(df)
    # edge_index = torch.cat((edge_index, player_before, player_after), dim=1)
    # edge_type = torch.cat((edge_type, before1, after1))
  if for_teams:
    team_before, team_after = teams_before_after_gen(df)
    # edge_index = torch.cat((edge_index, team_before, team_after), dim=1)
    # edge_type = torch.cat((edge_type, before2, after2))

  won_edge_index = torch.cat(
      (home_won, away_won),
      dim=1
  )
  lost_edge_index = torch.cat(
      (away_lost, home_lost),
      dim=1
  )
  tied_edge_index = torch.cat(
      (home_tied, away_tied),
      dim=1
  )
  edge_index = {
      'won': won_edge_index,
      'lost': lost_edge_index,
      'tied': tied_edge_index,
      'played': player_played,
      'used': team_used,
      'p_after':player_after,
      'p_before': player_before,
      't_after': team_after,
      't_before': team_after
  }
  # return won_edge_index, lost_edge_index, tied_edge_index, player_played, team_used, player_before, player_after, team_before, team_after
  # return edge_index, edge_type    
  return edge_index

In [None]:
#@title supervision_graph_gen(df, for_players, for_teams, log_supervision_matches) OK_HETERO { form-width: "10%" }
#TODO idea1: messaging=[1, 2, 3, ..., 10], supervision=[11, 12, ..., 15]
def supervision_graph_gen(df : pd.DataFrame, messaging: list, supervision: list, for_players: bool=True, for_teams: bool=True, log_supervision_matches: bool=False) -> typing.Tuple[torch.Tensor, torch.Tensor]:
  ######################################################### TODO
  # if df.shape[0] > 10:
  #   first_match = df.index[0]
  #   last_match = df.index[-11]
  # else:
  #   first_match = df.index[1]
  #   last_match = df.index[df.shape[0] * -1]
  ######################################################### TODO
  if log_supervision_matches:
    if model.mode == 'train':
      mode = 'training'
    elif model.mode == 'dev':
      mode = 'validating'
    elif model.mode == 'test':
      mode = 'testing'
    logging.info(
        f'Messaging on matches ({messaging[0] + 1} -> {messaging[-1] + 1:>5}),\ Model is {mode} on matches ({last_match+2} -> {last_match + 11})'
    )
  ########################################### TODO
  # arg_df = df[messaging] [1, 2, 3, 10, 11, 12 ] 
  ########################################### TODO
  home_won, away_lost = home_won_gen(df.loc[messaging])
  away_won, home_lost = away_won_gen(df.loc[messaging])
  home_tied, away_tied = tied_gen(df.loc[messaging])
  player_played, team_used = played_used_gen(df)

  if for_players:
    player_before, player_after = players_before_after_gen(df)
    # edge_index = torch.cat((edge_index, player_before, player_after), dim=1)
    # edge_type = torch.cat((edge_type, before1, after1))
  if for_teams:
    team_before, team_after = teams_before_after_gen(df)
    # edge_index = torch.cat((edge_index, team_before, team_after), dim=1)
    # edge_type = torch.cat((edge_type, before2, after2))

  won_edge_index = torch.cat(
      (home_won, away_won),
      dim=1
  )
  lost_edge_index = torch.cat(
      (away_lost, home_lost),
      dim=1
  )
  tied_edge_index = torch.cat(
      (home_tied, away_tied),
      dim=1
  )
  edge_index = {
      'won': won_edge_index,
      'lost': lost_edge_index,
      'tied': tied_edge_index,
      'played': player_played,
      'used': team_used,
      'p_after':player_after,
      'p_before': player_before,
      't_after': team_after,
      't_before': team_after
  }
  # return won_edge_index, lost_edge_index, tied_edge_index, player_played, team_used, player_before, player_after, team_before, team_after
  # return edge_index, edge_type    
  return edge_index

In [None]:
#@title data_gen(df, remove_supervision_links, for_players, for_teams, print_edges, log_supervision_matches) OK_HETERO { form-width: "10%" }
def data_gen(df: pd.DataFrame, messaging: list, supervision: list, remove_supervision_links: bool=True, for_players: bool=True, for_teams: bool=True, print_edges: bool=False, log_supervision_matches: bool=False) -> HeteroData:
  if print_edges:
    show_edges(df, edge_index, edge_type)
  if remove_supervision_links:
    edge_index = supervision_graph_gen(
        df,
        messaging=messaging,
        supervision=supervision,
        for_players=for_players,
        for_teams=for_teams,
        log_supervision_matches=log_supervision_matches
    )
    ##################################################################
    # if df.shape[0] > 10:
    #   first_supervision_match = df.index[-10]
    #   last_supervision_match = df.index[-1]
    # else:
    #   first_supervision_match = df.index[0]
    #   last_supervision_match = df.index[-1]
    y = torch.tensor(
        df.loc[supervision]['result'].map(home_result).values,
        device=DEVICE
    )
    ###################################################################

  else:
    edge_index = complete_graph_edge_gen(df, for_players, for_teams)
    y = torch.tensor(
        df.loc[supervision]['result'].map(home_result).values,
        device=DEVICE
    )

  ############################## OK
  data = HeteroData()
  data['player'].x = torch.unique(edge_index['played'][0]).to(DEVICE).type(torch.int64)
  data['team'].x = torch.unique(edge_index['used'][0]).to(DEVICE).type(torch.int64)
  
  data['team', 'won', 'team'].edge_index = edge_index['won']
  data['team', 'lost_to', 'team'].edge_index = edge_index['lost']
  data['team', 'tied_with', 'team'].edge_index = edge_index['tied']
  data['player', 'played_for', 'team'].edge_index = edge_index['played']
  data['team', 'used', 'player'].edge_index = edge_index['used']
  data['player', 'is_before', 'player'].edge_index = edge_index['p_before']
  data['player', 'is_after', 'player'].edge_index = edge_index['p_after']
  data['team', 'is_before', 'team'].edge_index = edge_index['t_before']
  data['team', 'is_after', 'team'].edge_index = edge_index['t_after']
  data.y = y

  return data

In [None]:
#@title visualzie_graph(df, width, height, title, remove_supervision_links) { form-width: "10%" }
def visualize_graph(df:pd.DataFrame, width: int=20, height: int=20, title: str=None, remove_supervision_links: bool=False) -> NoReturn:
  import networkx as nx
  import matplotlib.pyplot as plt
  nodes = nodes_gen(df)
  r = {k:v for v, k in nodes.items()}
  d = data_gen(df, remove_supervision_links=remove_supervision_links)
  G = to_networkx(d)
  types = {
        0: 'Won',
        1: 'Lost To',
        2: 'Tied With',
        3: 'Played For',
        4: 'Used As Player',
        5: 'Is Before',
        6: 'Is After'
  }

  type_color = {
      0: '#00ff00', #won
      1: '#ff0000', #lost to
      2: '#e6d70e', #tied with
      3: '#1338f0', #played for
      4: '#f01373', #used as player
      5: '#0f072e', #is before
      6: '#d909cb' #is after
  }

  double_edge_types = {
      0: '(Won[green] - Lost to[red])',
      1: '(Lost to[red] - Won[green])',
      2: '(Tied with[yellow])',
      3: '(Played for[blue] - Used as Player[pink])',
      4: '(Used as Player[pink] - Played for[blue])',
      5: '(Is Before[dark blue] - Is After[purple])',
      6: '(Is After[purple] - Is Before[dark blue])'
  }

  link_colors = dict(zip(
        types.values(),
        type_color.values()
      )
  )

  node_colors = {
      'player-color': '#8f0ba1',
      'team-color': '#02fae1'   
  }

  all_colors = link_colors.copy()
  all_colors.update(node_colors)

  

  for color_use in all_colors.keys():
      plt.scatter([],[], c=[all_colors[color_use]], label=f'{color_use}')

  edge_colors = list()
  edge_labels = dict()

  ######################################################## NOT OPTIMIZED
  for edge in G.edges():
    e = torch.tensor(edge, device=DEVICE)
    for index, node_node in enumerate(d.edge_index.t()):
      if torch.equal(e, node_node):
        edge_colors.append(type_color[d.edge_type[index].item()])
        label = double_edge_types[d.edge_type[index].item()]
        edge_labels.update({edge:label})
  colors = list()
  node_labels = dict()
  for node in G.nodes():
    if '@' in r[node]:
      colors.append(all_colors['player-color'])
      node_labels.update({node: r[node].split('@')[0]})
    elif '*' in r[node]:
      colors.append(all_colors['team-color'])
      node_labels.update({node:r[node].split('*')[0]})
  ######################################################## NOT OPTIMIZED

  fig = plt.gcf()
  fig.set_size_inches(width, height)
  pos = nx.spring_layout(G)
  nx.draw_networkx_nodes(G, pos, node_color=colors)
  nx.draw_networkx_labels(G, pos, labels=node_labels)
  nx.draw_networkx_edges(G, pos, edge_color=edge_colors, connectionstyle='arc3,rad=0.05')
  nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
  plt.legend()
  plt.title(title)
  fig.show()
  plt.show()

In [None]:
#@title batch_gen(df, entities, log_supervision_matches) OK_HETERO { form-width: "10%" }
def batch_gen(df: pd.DataFrame, messaging: list, supervision: list, entities: dict, remove_supervision_links: bool=True, log_supervision_matches: bool=False) -> HeteroData:
  graph = data_gen(
      df,
      messaging=messaging,
      supervision=supervision, 
      remove_supervision_links=remove_supervision_links,
      log_supervision_matches=log_supervision_matches
  )
  
  home_teams = list()
  away_teams = list()

  p_nodes, t_nodes = nodes_gen(df)
  nodes = {**p_nodes, **t_nodes}
  
  indices = dict()
  for hash, index in nodes.items():
    if '@' in hash:
      player = hash.split('@')[0]
      player_id = entities[player]
      indices.update({index:player_id})
    elif '*' in hash:
      team = hash.split('*')[0]
      team_id = entities[team]
      indices.update({index: team_id})
  for index, (h_team, a_team, result, h_lineup, a_lineup) in df.loc[df.index[-1 * graph.y.shape[0]]:, :].iterrows():
      home_teams.append(nodes[f'{h_team}*{index}'])
      away_teams.append(nodes[f'{a_team}*{index}'])

  
  features_player = torch.tensor(
      [indices[i.item()] for i in graph['player'].x],
      device=DEVICE
  )
  features_team = torch.tensor(
      [indices[i.item()] for i in graph['team'].x],
      device=DEVICE
  )

  graph['player'].x = features_player
  graph['team'].x = features_team
  graph.home_list = home_teams
  graph.away_list = away_teams
  

  # graph_data = {
  #     "data": graph,
  #     "home_teams": home_teams,
  #     "away_teams": away_teams,
  #     "y": graph.y
  # }

  return graph

In [None]:
hd = batch_gen(
    dataset.loc[0:1, :],
    entities=gen_entities(dataset.loc[0:1, :]),
    messaging=[0],
    supervision=[1],
    remove_supervision_links=True
)
hd

HeteroData(
  y=[1],
  home_list=[1],
  away_list=[1],
  [1mplayer[0m={ x=[44] },
  [1mteam[0m={ x=[4] },
  [1m(team, won, team)[0m={ edge_index=[2, 1] },
  [1m(team, lost_to, team)[0m={ edge_index=[2, 1] },
  [1m(team, tied_with, team)[0m={ edge_index=[2, 0] },
  [1m(player, played_for, team)[0m={ edge_index=[2, 44] },
  [1m(team, used, player)[0m={ edge_index=[2, 44] },
  [1m(player, is_before, player)[0m={ edge_index=[2, 0] },
  [1m(player, is_after, player)[0m={ edge_index=[2, 0] },
  [1m(team, is_before, team)[0m={ edge_index=[2, 0] },
  [1m(team, is_after, team)[0m={ edge_index=[2, 0] }
)

In [None]:
#@title train(model, dataset, optimizer, loss_fn) { form-width: "15%" }
def train(model: GNN, graph_data: dict, optimizer: torch.optim, loss_fn: torch.nn.modules.loss) -> typing.Tuple[float, int, int]:
  batch_loss = 0

  model.train()

  out = model(
      x=graph_data["x"],
      edge_index=graph_data["edge_index"],
      edge_type=graph_data["edge_type"],
      home_list=graph_data["home_teams"],
      away_list=graph_data["away_teams"]
  )

  optimizer.zero_grad()
  loss = loss_fn(out, graph_data["y"])
  batch_loss = loss.item()
  loss.backward()
  optimizer.step()

  prediction = out.argmax(dim=-1)
  correct = torch.tensor(
      (prediction == graph_data["y"]),
      dtype=torch.int, device=DEVICE).sum().item()
  all = graph_data["y"].shape[0]

  return batch_loss, correct, all

In [None]:
#@title evaluate(model, dataset) { form-width: "25px" }
@torch.no_grad()
def evaluate(model: GNN, graph_data: dict) -> typing.Tuple[int, int]:
  all = 0
  correct = 0

  model.eval()
  out = model(
      x=graph_data["x"],
      edge_index=graph_data["edge_index"],
      edge_type=graph_data["edge_type"],
      home_list=graph_data["home_teams"],
      away_list=graph_data["away_teams"]
  )

  prediction = out.argmax(dim=-1)
  correct = torch.tensor(
      (prediction == graph_data["y"]),
      dtype=torch.int, device=DEVICE).sum().item()
  all = graph_data["y"].shape[0]
  model.train()

  return correct, all

In [None]:
#@title Dataset Download { form-width: "15%" }
import requests
from os import getcwd

url_epl = "https://raw.githubusercontent.com/jokecamp/FootballData/master/EPL%202011-2019/PL_scraped_ord.csv"
url_fk = "https://raw.githubusercontent.com/masoudmousavi/Sports-Analysis-with-GNNs/main/FakeData_EPL.csv?token=ARGPVT5GGWQ4RGABHAF2TE3BKNB54"
# current_directory = getcwd()
filename_rl = 'dataset.csv'
filename_fk = 'fake.csv'
req_rl = requests.get(url_epl)
req_fk = requests.get(url_fk)

if req_rl.status_code == 200:
  with open(filename_rl, 'wb') as fp:
    fp.write(req_rl.content)
else:
  print(f'Error downloading file at {url_epl}')
if req_fk.status_code == 200:
  with open(filename_fk, 'wb') as fp:
    fp.write(req_fk.content)
else:
  print(f'Error downloading file at {url_fk}')

Error downloading file at https://raw.githubusercontent.com/masoudmousavi/Sports-Analysis-with-GNNs/main/FakeData_EPL.csv?token=ARGPVT5GGWQ4RGABHAF2TE3BKNB54


In [None]:
#@title Dataset Loading and Cleaning { form-width: "15px" }
dataset = pd.read_csv(
    filename_rl,
    encoding='latin-1',
    usecols=['home_team', 'away_team', 'result', 'home_lineup', 'away_lineup']
)
corrupted = dataset.loc[pd.isna(dataset['away_lineup']) | pd.isna(dataset['home_lineup'])]
dataset = dataset.drop(corrupted.index, axis=0)
dataset = dataset.reset_index(drop=True)


In [None]:
#@title Log { form-width: "15%" }
logging.basicConfig(
    filename='model-logs.log',
    filemode='w',
    level=logging.INFO
)


In [None]:
#@title Hyperparameters File
hp_file = open('hyperparameters.json', 'w')
hyperparameters = {
    "learning_rate": 1e-3,
    "num_epochs": 200,
    "fc_dropout":0.01,
    "conv_dropout": 0.01,
    "emb_dropout": 0.01,
    "train_messaging_graph_size": 440,
    "val_messaging_graph_size": 440,
    "test_messaging_graph_size": 440,
    "iter_size": 10,
    "val_week_denom": 50,
    "test_week_denom": 60,
    "embedding_dim": 32,
    "conv_dims":[
          32,
          32, 
          32,
          32
    ],
    "fully_connected_dims":[
              32,
              32
    ]
}

json.dump(hyperparameters, hp_file)
hp_file.close()

In [None]:
#@title Model and Model Hyperparameters { form-width: "15%" }
log_supervision_matches = True
with open('hyperparameters.json', 'r') as hp_file:
  hyperparameters = json.load(hp_file)
learning_rate = hyperparameters["learning_rate"]
num_epochs = hyperparameters["num_epochs"]
fc_dropout = hyperparameters["fc_dropout"]
conv_dropout = hyperparameters["conv_dropout"]
emb_dropout = hyperparameters["emb_dropout"]

remove_supervision_links = False

entities = gen_entities(dataset)

######################################## Scheme 4
train_messaging_graph_size = hyperparameters["train_messaging_graph_size"]
val_messaging_graph_size = hyperparameters["val_messaging_graph_size"]
test_messaging_graph_size = hyperparameters["test_messaging_graph_size"]
iter_size = hyperparameters["iter_size"]
val_week_denom = hyperparameters["val_week_denom"]
test_week_denom = hyperparameters["test_week_denom"]
######################################## Parameters

model = HeteroGNN(
    embedding_dims=(
        max(entities.values()) + 1,
        hyperparameters["embedding_dim"]
    ),
    conv_dims=hyperparameters["conv_dims"],
    fully_connected_dims=hyperparameters["fully_connected_dims"],
    dropout={
        "emb": emb_dropout,
        "conv": conv_dropout,
        "fc": fc_dropout
    }
).to(DEVICE)
# model = PyG.nn.to_hetero(model, hd.metadata(), aggr='sum')
# model.reset_parameters()

print(model)

# print(model(hd))

optimizer = Adam(
    model.parameters(),
    lr=learning_rate
)
criterion = NLLLoss()

HeteroGNN(
  (embed): Embedding(1565, 32)
  (conv_layers): ModuleList(
    (0): HeteroConv(num_relations=9)
    (1): HeteroConv(num_relations=9)
    (2): HeteroConv(num_relations=9)
    (3): HeteroConv(num_relations=9)
  )
  (batch_norm_layers): ModuleList(
    (0): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (fully_connected_layers): ModuleList(
    (0): Linear(in_features=64, out_features=32, bias=True)
    (1): Linear(in_features=32, out_features=32, bias=True)
    (2): Linear(in_features=32, out_features=3, bias=True)
  )
  (classifier): LogSoftmax(dim=None)
)


In [None]:
#@title Data Batch Maker { form-width: "15%" }
train_batches = list()
val_batches = list()
test_batches = list()

for i in range(train_messaging_graph_size, dataset.shape[0], iter_size):
      if i % val_week_denom == 0:
        ######################## Validation ########################
        from_match = i - val_messaging_graph_size
        to_match = i - 1
        model.mode = 'dev'

        validation_df = dataset.loc[from_match: to_match, :]
        val_graph_data = batch_gen(
              validation_df,
              entities=entities,
              remove_supervision_links=remove_supervision_links,
              log_supervision_matches=log_supervision_matches
          )
        val_batches.append(val_graph_data)

      elif i % test_week_denom == 0:
        ######################## Test ########################
        model.eval()
        model.mode = 'test'
        
        from_match = i - test_messaging_graph_size
        to_match = i - 1

        test_df = dataset.loc[from_match: to_match, :]
        test_graph_data = batch_gen(
            test_df,
            entities=entities,
            remove_supervision_links=remove_supervision_links,
            log_supervision_matches=log_supervision_matches
        )
        
        test_batches.append(test_graph_data)

      else:
        ######################## Train ########################

        from_match = i - train_messaging_graph_size
        to_match = i - 1
        model.mode = 'train'

        train_df = dataset.loc[from_match: to_match, :]
        train_graph_data = batch_gen(
            train_df,
            entities=entities,
            remove_supervision_links=remove_supervision_links,
            log_supervision_matches=log_supervision_matches
        )

        train_batches.append(train_graph_data)

In [None]:
#@title Data Batch Maker MODIFIED { form-width: "15%" }
train_batches_modified = list()
val_batches_modified = list()
test_batches_modified = list()

for i in range(train_messaging_graph_size, dataset.shape[0], iter_size):
  if i % val_week_denom == 0:
    ######################## Validation ########################
    from_match = i - val_messaging_graph_size
    to_match = i - 1
    model.mode = 'dev'

    validation_df = dataset.loc[from_match: to_match, :]
    val_graph_data = batch_gen(
          validation_df,
          entities=entities,
          remove_supervision_links=remove_supervision_links,
          log_supervision_matches=log_supervision_matches
      )
    val_batches.append(val_graph_data)

  elif i % test_week_denom == 0:
    ######################## Test ########################
    model.eval()
    model.mode = 'test'
    
    from_match = i - test_messaging_graph_size
    to_match = i - 1

    test_df = dataset.loc[from_match: to_match, :]
    test_graph_data = batch_gen(
        test_df,
        entities=entities,
        remove_supervision_links=remove_supervision_links,
        log_supervision_matches=log_supervision_matches
    )
    
    test_batches.append(test_graph_data)

  else:
    ######################## Train ########################

    from_match = i - train_messaging_graph_size
    to_match = i - 1
    model.mode = 'train'

    train_df = dataset.loc[from_match: to_match, :]
    train_graph_data = batch_gen(
        train_df,
        entities=entities,
        remove_supervision_links=remove_supervision_links,
        log_supervision_matches=log_supervision_matches
    )

    train_batches.append(train_graph_data)

In [None]:
#@title Model Fitting Scheme 4 { form-width: "15%" }
try:
  train_losses = list()
  train_accuracies = list()
  val_accuracies = list()
  
  for epoch in range(num_epochs):
    epoch_loss = 0
    val_correct = 0
    val_all = 0
    train_all = 0
    train_correct = 0

    for index, train_graph_data in enumerate(train_batches):
       ######################## Train ########################
        model.train()
        model.mode = 'train'

        train_batch_loss, train_batch_correct, train_batch_all = train(
              model=model,
              graph_data=train_graph_data,
              optimizer=optimizer,
              loss_fn=criterion
          )

        epoch_loss += train_batch_loss
        train_correct += train_batch_correct
        train_all += train_batch_all

        ######################## Validation ########################
        model.eval()
        model.mode = 'dev'

        val_batch_correct, val_batch_all = evaluate(
            model=model,
            graph_data=val_batches[index%len(val_batches)]
        )

        val_correct += val_batch_correct
        val_all += val_batch_all
      
    ########## end of epoch ###########
    print(f'{"="*32} Epoch {epoch + 1} {"="*32}')
    print(f'Train Loss:          {epoch_loss:.4f}')
    print(f'Train Cost:          {epoch_loss / train_all:.4f}')
    print(f'Train Accuracy:      {train_correct * 100 / train_all:.3f}%')
    print(f'Validation Accuracy: {val_correct * 100 / val_all:.3f}%')
    logging.info(f'{"="*32} Epoch {epoch + 1} {"="*32}')
    logging.info(f'Train Loss:          {epoch_loss:.4f}')
    logging.info(f'Train Cost:          {epoch_loss / train_all:.4f}')
    logging.info(f'Train Accuracy:      {train_correct * 100 / train_all:.3f}%')
    logging.info(f'Validation Accuracy: {val_correct * 100 / val_all:.3f}%')

    train_losses.append(epoch_loss)
    train_accuracies.append(train_correct * 100 / train_all)
    val_accuracies.append(val_correct * 100 / val_all)

except KeyboardInterrupt:
  pass

In [None]:
#@title Results
import matplotlib.pyplot as plt
import numpy as np
t = [i for i in list(range(len(train_losses)))]
t = np.array(t)
y1 = np.array(train_losses)
y2 = np.array(train_accuracies)
y3 = np.array(val_accuracies)
fig = plt.gcf()
plt.plot(t, y1)
plt.title("Train Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
#plt.legend(['Train', 'Validation'])
fig.set_size_inches(20, 10)

In [None]:
#@title Model Test { form-width: "25%" }

test_correct = 0
test_all = 0

for test_graph_data in test_batches:
  model.eval()
  model.mode = 'test'

  test_batch_correct, test_batch_all = evaluate(
      model=model,
      graph_data=test_graph_data
  )
  test_correct += test_batch_correct
  test_all += test_batch_all

print(f'Test Accuracy: {test_correct * 100 / test_all:.3f}%')
logging.info('=' * 70)
logging.info(f'Test Accuracy: {test_correct * 100 / test_all:.3f}%')



In [None]:
# @title Model Save
torch.save(model.state_dict(), 'model-2.pth')