In [1]:
import pandas as pd
import numpy as np
import pickle as pkl
import datetime as datetime
from tqdm import tqdm
import networkx as nx
import random

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from src.utils.CreateFeatures import CreateFeatures
from torch.utils.data import Dataset
from torch_geometric.loader import DataLoader
from sklearn.metrics import f1_score, fbeta_score
from src.utils.functions import dist_labels_to_changepoint_labels, dist_labels_to_changepoint_labels_adjusted
from sklearn.model_selection import train_test_split
from src.pygcn.batched_model import Model

import torch
import torch.nn as nn
import torch
import torch_geometric.data as data
from src.synthetic_experiments.sample import sample_pairs, sample_pairs_in_window
import itertools

In [2]:
years = range(1962,2019)

In [3]:
all_nodes = ['ABW', 'AFG', 'AGO', 'ALB', 'AND', 'ARE', 'ARG', 'ARM', 'ASM',
       'ATG', 'AUS', 'AUT', 'AZE', 'BDI', 'BEL', 'BEN', 'BFA', 'BGD',
       'BGR', 'BHR', 'BHS', 'BIH', 'BLR', 'BLZ', 'BMU', 'BOL', 'BRA',
       'BRB', 'BRN', 'BTN', 'BWA', 'CAF', 'CAN', 'CHE', 'CHL', 'CHN',
       'CIV', 'CMR', 'COD', 'COG', 'COL', 'COM', 'CPV', 'CRI', 'CUB',
       'CUW', 'CYM', 'CYP', 'CZE', 'DEU', 'DMA', 'DNK', 'DOM', 'DZA',
       'ECU', 'EGY', 'ESP', 'EST', 'ETH', 'FIN', 'FJI', 'FRA', 'FSM',
       'GAB', 'GBR', 'GEO', 'GHA', 'GIN', 'GMB', 'GNB', 'GNQ', 'GRC',
       'GRD', 'GRL', 'GTM', 'GUM', 'GUY', 'HKG', 'HND', 'HRV', 'HTI',
       'HUN', 'IDN', 'IND', 'IRL', 'IRN', 'IRQ', 'ISL', 'ISR', 'ITA',
       'JAM', 'JOR', 'JPN', 'KAZ', 'KEN', 'KGZ', 'KHM', 'KNA', 'KOR',
       'KWT', 'LAO', 'LBN', 'LBR', 'LBY', 'LCA', 'LKA', 'LSO', 'LTU',
       'LUX', 'LVA', 'MAC', 'MAR', 'MDA', 'MDG', 'MDV', 'MEX', 'MHL',
       'MKD', 'MLI', 'MLT', 'MMR', 'MNE', 'MNG', 'MNP', 'MOZ', 'MRT',
       'MUS', 'MWI', 'MYS', 'NAM', 'NER', 'NGA', 'NIC', 'NLD', 'NOR',
       'NPL', 'NRU', 'NZL', 'OMN', 'PAK', 'PAN', 'PER', 'PHL', 'PLW',
       'PNG', 'POL', 'PRT', 'PRY', 'PSE', 'PYF', 'QAT', 'ROU', 'RUS',
       'RWA', 'SAU', 'SDN', 'SEN', 'SGP', 'SLB', 'SLE', 'SLV', 'SMR',
       'SRB', 'SSD', 'STP', 'SUR', 'SVK', 'SVN', 'SWE', 'SWZ', 'SXM',
       'SYC', 'SYR', 'TCD', 'TGO', 'THA', 'TJK', 'TKM', 'TLS', 'TON',
       'TTO', 'TUN', 'TUR', 'TUV', 'TZA', 'UGA', 'UKR', 'URY', 'USA',
       'UZB', 'VCT', 'VEN', 'VNM', 'VUT', 'WSM', 'YEM', 'ZAF', 'ZMB',
       'ZWE']

In [4]:
with open("src/pygcn/all_graphs.pkl", "rb") as f:         
    all_graphs = pkl.load(f)

In [5]:
crisis_years = [1962, 1967, 1973, 1978, 1981, 1989, 1993, 1996, 2002, 2007, 2012, 2014, 2016]
phases = []
p = -1
for i in range(1962,2019):
    if i in crisis_years:
        p += 1
    phases.append(p)

In [6]:
def train_epoch(model, loss_fn, optimiser, training_dataloader):
    total, correct = 0, 0
    losses = []
    for batch in training_dataloader:
        g1, g2, labels = batch
        labels = labels.float().unsqueeze(1)
        predictions = model(g1, g2)
        labels = labels.squeeze(1)
        loss = loss_fn(predictions, labels)
        loss.backward()
        optimiser.step()
        optimiser.zero_grad()
        losses.append(loss.detach())
        correct += torch.sum((predictions>0.5).long() == labels).item()
        total += len(labels)
    accuracy = correct / total
    return torch.mean(torch.tensor(losses)).item(), accuracy

def validate(model, loss_fn, validation_dataloader):
    total, correct = 0, 0
    losses = []
    with torch.no_grad():
        for batch in validation_dataloader:
            g1, g2, labels = batch
            labels = labels.float().unsqueeze(1)
            labels = labels.squeeze(1)
            predictions = model(g1, g2)
            loss = loss_fn(predictions, labels)
            losses.append(loss.detach())
            correct += torch.sum((predictions>0.5).long() == labels).item()
            total += len(labels)
    accuracy = correct / total
    return torch.mean(torch.tensor(losses)).item(), accuracy

In [7]:
import torch_geometric as pyg
class MergeDataset(Dataset):

    def __init__(self, pairs) -> None:
        super().__init__()

        prog_bar = tqdm(desc="Building dataset.")

        self.dataset = []
        for p in pairs:

            p[0].edge_index = (p[0].edge_index).int()
            p[1].edge_index = (p[1].edge_index).int()
            self.dataset.append((
                p[0],
                p[1],
                p[2]
            ))
            prog_bar.update(1)

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, index: int) -> tuple[pyg.data.Data, pyg.data.Data, torch.Tensor]:
        return self.dataset[index]

In [8]:
with open("feature_dicts/mis_norm.pkl", "rb") as f:
    feat_dict = pkl.load(f)

In [9]:
def add_features(years, graphs, feat_dict, dim):

    zeros = torch.zeros(dim)

    for i in range(len(years)):
        new_x = torch.empty(0, dim)
        year = years[i]

        feat_dict_year = feat_dict[year].combined_features

        for j, country in enumerate(all_nodes):
            if j == 0:
                new_x = torch.stack([zeros])

            elif country in feat_dict_year["country_code"].values:
                tensor_before = graphs[i].x[j]
                country_row = feat_dict_year[feat_dict_year["country_code"] == country]
                country_row = country_row.drop(columns = ["country_code", "current_gdp_growth"])
                row_values = country_row.values.tolist()
                row_tensor = torch.tensor(row_values)[0]
                combined_values = torch.cat((tensor_before, row_tensor))

                new_x = torch.cat((new_x, combined_values.unsqueeze(0)), dim=0)

            else:
                new_x = torch.cat((new_x, zeros.unsqueeze(0)), dim=0)

        graphs[i].x = new_x

    return graphs

In [10]:
with open("src/pygcn/all_graphs.pkl", "rb") as f:         
    all_graphs = pkl.load(f)

In [11]:
all_graphs = add_features(years, all_graphs, feat_dict, 27)

In [12]:
#For window sampling
train_graphs = all_graphs[:34]
val_graphs = all_graphs[34:45]
test_graphs = all_graphs[45:]

labels = dist_labels_to_changepoint_labels(phases)
graph_pairs_train = sample_pairs(train_graphs,labels[:34])
graph_pairs_val = sample_pairs(val_graphs,labels[34:45])
graph_pairs_test = sample_pairs(test_graphs,labels[45:])

280 positive and 168 negative examples
27 positive and 18 negative examples
33 positive and 32 negative examples


In [13]:
import torch_geometric as pyg
def collate_fn(batch):
    return (
        pyg.data.Batch.from_data_list([triple[0] for triple in batch]),
        pyg.data.Batch.from_data_list([triple[1] for triple in batch]),
        torch.stack([triple[2] for triple in batch])
    )

training_data = MergeDataset(graph_pairs_train)
validation_data = MergeDataset(graph_pairs_val)
training_dataloader = DataLoader(training_data, batch_size=32, collate_fn=collate_fn, shuffle=True)
validation_dataloader =  DataLoader(validation_data, batch_size=16, collate_fn=collate_fn)

Building dataset.: 561it [00:00, 66125.35it/s]
Building dataset.: 55it [00:00, 17909.07it/s]


In [18]:
model = Model()
loss_fn = nn.BCEWithLogitsLoss()
optimiser = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(12):
    train_loss, train_accuracy = train_epoch(model, loss_fn, optimiser, training_dataloader)
    valid_loss, valid_accuracy = validate(model, loss_fn, validation_dataloader)
    print(
        epoch, 
        round(train_loss, 4), 
        round(train_accuracy, 2),
        round(valid_loss, 4), 
        round(valid_accuracy, 2),
        sep='\t'
    )

0	0.6207	0.66	0.6441	0.65
1	0.606	0.69	0.6428	0.65
2	0.6066	0.68	0.6531	0.65
3	0.6006	0.68	0.6581	0.58
4	0.6008	0.68	0.6416	0.6
5	0.5951	0.64	0.6505	0.56
6	0.5872	0.64	0.6456	0.56
7	0.5804	0.67	0.65	0.56
8	0.582	0.68	0.6596	0.56
9	0.5721	0.67	0.6355	0.58
10	0.564	0.7	0.6203	0.64
11	0.5321	0.77	0.6679	0.51
