# Multimodal Network

Create GNN for each of the dimensions and later combine them into a single network.

1. Each dimension is a separate graph and gets its own GNN.
2. Dimensions are connected in a MLP layer.

In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import seaborn as sns
import matplotlib.pyplot as plt
import networkx as nx
import os
import country_converter as coco
import functools

# Set up device (is available use GPU to speed up computations)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

os.chdir('/home/jovyan/dlss-project')
print("Current working directory:", os.getcwd())

cuda
Current working directory: /home/jovyan/dlss-project


## Data

In [45]:
path_geo_edges = "data_collection/geography/edges_yearly_dist_enc.parquet"
path_geo_nodes = "data_collection/geography/nodes_enc.parquet"
df_geo_edges = pd.read_parquet(path_geo_edges)
df_geo_nodes = pd.read_parquet(path_geo_nodes)

path_pol_edges = "data_collection/political/data/edge_features.parquet"
path_pol_nodes = "data_collection/political/data/node_features.parquet"
df_pol_edges = pd.read_parquet(path_pol_edges)
df_pol_nodes = pd.read_parquet(path_pol_nodes)

path_cult_edges = "data_collection/culture/culture_edges.parquet"
path_cult_nodes = "data_collection/culture/culture_nodes.parquet"
df_cult_edges = pd.read_parquet(path_cult_edges)
df_cult_nodes = pd.read_parquet(path_cult_nodes)

path_lang_edges = "data_collection/culture/language_religion_edges.parquet"
path_lang_nodes = "data_collection/culture/language_religion_nodes.parquet"
df_lang_edges = pd.read_parquet(path_lang_edges)
df_lang_nodes = pd.read_parquet(path_lang_nodes)

In [38]:
# Pre-compute UN member countries once
@functools.lru_cache(maxsize=1)
def get_un_countries():
    """Cache UN member countries to avoid repeated lookups"""
    return coco.CountryConverter().data[coco.CountryConverter().data['UNmember'].notna()]['ISO3'].dropna().tolist()

@functools.lru_cache(maxsize=1000)
def convert_country_code(country, target_format='UNnumeric'):
    """Cache country code conversions"""
    return coco.convert(names=country, to=target_format, not_found=None)

def create_data(edge_df, node_df, edge_country_a_col, edge_country_b_col, node_country_col, year_col="year"):
    
    # Get UN countries once
    uno_iso3_codes = get_un_countries()
    
    # Pre-filter dataframes more efficiently
    edge_mask = edge_df[edge_country_a_col].isin(uno_iso3_codes) & edge_df[edge_country_b_col].isin(uno_iso3_codes)
    node_mask = node_df[node_country_col].isin(uno_iso3_codes)
    
    edge_df = edge_df[edge_mask].copy()
    node_df = node_df[node_mask].copy()
    
    print(f"Number of edges after filtering: {len(edge_df)}")
    print(f"Number of nodes after filtering: {len(node_df)}")
    
    # Vectorized country code conversion (more efficient than individual conversions)
    unique_countries_edges = pd.concat([edge_df[edge_country_a_col], edge_df[edge_country_b_col]]).unique()
    unique_countries_nodes = node_df[node_country_col].unique()
    all_unique_countries = np.unique(np.concatenate([unique_countries_edges, unique_countries_nodes]))
    
    # Single batch conversion
    country_to_id_map = dict(zip(all_unique_countries, coco.convert(all_unique_countries.tolist(), to='UNnumeric', not_found=None)))
    
    # Apply mapping
    edge_df['country_id_a'] = edge_df[edge_country_a_col].map(country_to_id_map)
    edge_df['country_id_b'] = edge_df[edge_country_b_col].map(country_to_id_map)
    node_df['country_id'] = node_df[node_country_col].map(country_to_id_map)
    
    node_df = node_df.sort_values(by='country_id')
    
    data_dict = {}
    years = edge_df[year_col].unique()
    
    for year in years:        
        # More efficient filtering
        edge_year_mask = edge_df[year_col] == year
        node_year_mask = node_df[year_col] == str(year)
        
        edge_df_year = edge_df[edge_year_mask]
        node_df_year = node_df[node_year_mask]
        
        # Optimize feature processing
        edge_features_cols = [col for col in edge_df_year.columns 
                             if col not in ['country_id_a', 'country_id_b', edge_country_a_col, edge_country_b_col, year_col]]
        
        edge_features_df = edge_df_year[edge_features_cols].fillna(0)
        
        # Convert boolean columns more efficiently
        bool_cols = edge_features_df.select_dtypes(include='bool').columns
        edge_features_df[bool_cols] = edge_features_df[bool_cols].astype(int)
        
        # Optimize node features processing
        node_features_cols = [col for col in node_df_year.columns 
                             if col not in ['country_id', node_country_col, year_col]]
        node_df_year_features = node_df_year[node_features_cols].copy()
        
        # Fix boolean column conversion to avoid FutureWarning
        bool_cols_nodes = node_df_year_features.select_dtypes(include='bool').columns
        if len(bool_cols_nodes) > 0:
            # First convert the boolean columns to int dtype
            for col in bool_cols_nodes:
                node_df_year_features[col] = node_df_year_features[col].astype(int)

        # More efficient tensor creation
        edge_features_array = edge_features_df.values.astype(np.float32)
        edge_attr = torch.from_numpy(edge_features_array)
                
        node_features = torch.tensor(node_df_year_features.values, dtype=torch.float)
        edge_index = torch.tensor(edge_df_year[['country_id_a', 'country_id_b']].values.T, dtype=torch.long)
        
        data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr)
        data_dict[year] = data
        
    return data_dict

In [44]:
df_pol_edges.head()

Unnamed: 0,state1_convert,state2_convert,dyad_st_year,dyad_end_year,left_censor,right_censor,defense,neutrality,nonaggression,entente,asymmetric
0,Czechia,Russia,1945,1989,0,0,1,1,0,1.0,0
1,United States,Cuba,1945,1947,0,0,1,0,0,1.0,0
2,United States,Haiti,1945,1947,0,0,1,0,0,1.0,0
3,United States,Dominican Republic,1945,1947,0,0,1,0,0,1.0,0
4,United States,Mexico,1945,1947,0,0,1,0,0,1.0,0


In [None]:
geo_data = create_data(df_geo_edges, df_geo_nodes, 
                        edge_country_a_col='iso_o', edge_country_b_col='iso_d',
                        node_country_col='code_3', year_col='year')

# pol_data = create_data(df_pol_edges, df_pol_nodes,
#                         edge_country_a_col='country_id_a', edge_country_b_col='country_id_b',
#                         node_country_col='country_id', year_col='year')

Number of edges after filtering: 782920
Number of nodes after filtering: 4402


In [40]:
geo_data

{np.int64(2000): Data(x=[189, 28], edge_index=[2, 34040], edge_attr=[34040, 16]),
 np.int64(2001): Data(x=[189, 28], edge_index=[2, 34040], edge_attr=[34040, 16]),
 np.int64(2002): Data(x=[189, 28], edge_index=[2, 34040], edge_attr=[34040, 16]),
 np.int64(2003): Data(x=[189, 28], edge_index=[2, 34040], edge_attr=[34040, 16]),
 np.int64(2004): Data(x=[189, 28], edge_index=[2, 34040], edge_attr=[34040, 16]),
 np.int64(2005): Data(x=[189, 28], edge_index=[2, 34040], edge_attr=[34040, 16]),
 np.int64(2006): Data(x=[191, 28], edge_index=[2, 34040], edge_attr=[34040, 16]),
 np.int64(2007): Data(x=[191, 28], edge_index=[2, 34040], edge_attr=[34040, 16]),
 np.int64(2008): Data(x=[191, 28], edge_index=[2, 34040], edge_attr=[34040, 16]),
 np.int64(2009): Data(x=[191, 28], edge_index=[2, 34040], edge_attr=[34040, 16]),
 np.int64(2010): Data(x=[191, 28], edge_index=[2, 34040], edge_attr=[34040, 16]),
 np.int64(2011): Data(x=[191, 28], edge_index=[2, 34040], edge_attr=[34040, 16]),
 np.int64(2012):