# Multimodal Network

Create GNN for each of the dimensions and later combine them into a single network.

1. Each dimension is a separate graph and gets its own GNN.
2. Dimensions are connected in a MLP layer.

#### Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Data
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import seaborn as sns
import matplotlib.pyplot as plt
import networkx as nx
import os
import country_converter as coco
import functools

# Set up device (is available use GPU to speed up computations)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

os.chdir('/home/jovyan/dlss-project')
print("Current working directory:", os.getcwd())

cuda
Current working directory: /home/jovyan/dlss-project


## Data

In [4]:
path_geo_edges = "data_collection/geography/edges_yearly_dist_enc.parquet"
path_geo_nodes = "data_collection/geography/nodes_enc.parquet"
df_geo_edges = pd.read_parquet(path_geo_edges)
df_geo_nodes = pd.read_parquet(path_geo_nodes)

path_pol_edges = "data_collection/political/data/edge_features.parquet"
path_pol_nodes = "data_collection/political/data/node_features.parquet"
df_pol_edges = pd.read_parquet(path_pol_edges)
df_pol_nodes = pd.read_parquet(path_pol_nodes)

path_cult_edges = "data_collection/culture/culture_edges.parquet"
path_cult_nodes = "data_collection/culture/culture_nodes.parquet"
df_cult_edges = pd.read_parquet(path_cult_edges)
df_cult_nodes = pd.read_parquet(path_cult_nodes)

path_lang_edges = "data_collection/culture/language_religion_edges.parquet"
path_lang_nodes = "data_collection/culture/language_religion_nodes.parquet"
df_lang_edges = pd.read_parquet(path_lang_edges)
df_lang_nodes = pd.read_parquet(path_lang_nodes)

path_eco_edges = "data_collection/economics/edges_economics.parquet"
path_eco_nodes = "data_collection/economics/nodes_economics.parquet"
df_eco_edges = pd.read_parquet(path_eco_edges)
df_eco_nodes = pd.read_parquet(path_eco_nodes)

In [5]:
# Pre-compute UN member countries once
@functools.lru_cache(maxsize=1)
def get_un_countries():
    """Cache UN member countries to avoid repeated lookups"""
    iso3 = coco.CountryConverter().data[coco.CountryConverter().data['UNmember'].notna()]['ISO3'].dropna().tolist()
    unnumeric = coco.CountryConverter().data[coco.CountryConverter().data['UNmember'].notna()]['UNcode'].dropna().tolist()
    return iso3, unnumeric


@functools.lru_cache(maxsize=1000)
def convert_country_code(country, target_format='UNnumeric'):
    """Cache country code conversions"""
    return coco.convert(names=country, to=target_format, not_found=None)

def create_data(edge_df, node_df, edge_country_a_col, edge_country_b_col, node_country_col, year_col="year"):
    # Get UN countries once
    uno_iso3_codes, uno_unnumeric_codes = get_un_countries()

    # Pre-filter dataframes
    edge_mask = edge_df[edge_country_a_col].isin(uno_iso3_codes) & edge_df[edge_country_b_col].isin(uno_iso3_codes)
    node_mask = node_df[node_country_col].isin(uno_iso3_codes)

    edge_df = edge_df[edge_mask].copy()
    node_df = node_df[node_mask].copy()

    # Ensure year is int before unique extraction
    edge_df[year_col] = edge_df[year_col].astype(int)
    node_df[year_col] = node_df[year_col].astype(int)

    print(f"Number of edges after filtering: {len(edge_df)}")
    print(f"Number of nodes after filtering: {len(node_df)}")

    # Vectorized country code conversion
    unique_countries_edges = pd.concat([edge_df[edge_country_a_col], edge_df[edge_country_b_col]]).unique()
    unique_countries_nodes = node_df[node_country_col].unique()
    all_unique_countries = np.unique(np.concatenate([unique_countries_edges, unique_countries_nodes]))

    country_to_id_map = dict(zip(
        all_unique_countries,
        coco.convert(all_unique_countries.tolist(), to='UNnumeric', not_found=None)
    ))

    # Apply mapping
    edge_df['country_id_a'] = edge_df[edge_country_a_col].map(country_to_id_map)
    edge_df['country_id_b'] = edge_df[edge_country_b_col].map(country_to_id_map)
    node_df['country_id'] = node_df[node_country_col].map(country_to_id_map)

    # Drop any rows where mapping failed (shouldn't if filters were correct)
    edge_df = edge_df.dropna(subset=['country_id_a', 'country_id_b'])
    node_df = node_df.dropna(subset=['country_id'])

    # Cast country_id to int
    edge_df['country_id_a'] = edge_df['country_id_a'].astype(int)
    edge_df['country_id_b'] = edge_df['country_id_b'].astype(int)
    node_df['country_id'] = node_df['country_id'].astype(int)

    data_dict = {}
    years = edge_df[year_col].unique()
    
    # limit years from 2000 to 2022
    years = [year for year in years if 2000 <= year <= 2022]

    for year in years:
        edge_df_year = edge_df[edge_df[year_col] == year]
        node_df_year = node_df[node_df[year_col] == year].copy()  # copy because we'll potentially modify
                
        # --- Edge features ---
        edge_features_cols = [
            col for col in edge_df_year.columns
            if col not in ['country_id_a', 'country_id_b', edge_country_a_col, edge_country_b_col, year_col]
        ]
        edge_features_df = edge_df_year[edge_features_cols].copy()

        # Boolean to int
        bool_cols_edges = edge_features_df.select_dtypes(include='bool').columns
        if len(bool_cols_edges) > 0:
            edge_features_df[bool_cols_edges] = edge_features_df[bool_cols_edges].astype(int)

        # Coerce all edge feature columns to numeric, fill NaN with 0
        edge_features_df = edge_features_df.apply(pd.to_numeric, errors='coerce').fillna(0)

        # --- Node features ---
        node_features_cols = [
            col for col in node_df_year.columns
            if col not in ['country_id', node_country_col, year_col]
        ]

        # Add missing countries (so each year has all UN countries)
        existing_countries = set(node_df_year['country_id'])
        missing_countries = set(uno_unnumeric_codes) - existing_countries
        
        # Boolean to int for node features
        bool_cols_nodes = node_df_year.select_dtypes(include='bool').columns
        if len(bool_cols_nodes) > 0:
            node_df_year[bool_cols_nodes] = node_df_year[bool_cols_nodes].astype(int)
            
        # add column to state if country does exist in this year/dataset
        node_df_year['exists'] = 1

        if missing_countries:
            # Build missing rows once
            for country in missing_countries:
                missing_row = {col: 0 for col in node_features_cols}
                missing_row['country_id'] = country
                missing_row['year'] = year
                missing_row['exists'] = 0
                row_df = pd.DataFrame([missing_row])
                node_df_year = pd.concat([node_df_year, row_df], ignore_index=True)

        # Ensure node_df_year is sorted by country_id and align features accordingly
        node_df_year_sorted = node_df_year.sort_values(by='country_id').reset_index(drop=True)
        
        # select only feature columns
        node_df_year_features = node_df_year_sorted[node_features_cols].copy()
        
        country_id_tensor = torch.tensor(node_df_year_sorted['country_id'].values, dtype=torch.long)

        # --- Tensor creation ---
        edge_features_array = edge_features_df.values.astype(np.float32)
        edge_attr = torch.from_numpy(edge_features_array)

        node_features_tensor = torch.tensor(node_df_year_features.values, dtype=torch.float32)
        edge_index = torch.tensor(
            edge_df_year[['country_id_a', 'country_id_b']].values.T,
            dtype=torch.long
        )
        data = Data(x=node_features_tensor, edge_index=edge_index, edge_attr=edge_attr, country_id = country_id_tensor)
        data_dict[year] = data

    return data_dict

In [10]:
df_eco_nodes

Unnamed: 0,ISO3,year,gdp,gdp_per_capita,gov_spending
0,ABW,2000,-0.170822,0.852447,-1.123562
1,AFG,2000,-0.170034,-0.621073,0.189950
2,AGO,2000,-0.167353,-0.593135,0.269855
3,ALB,2000,-0.170004,-0.550258,0.402861
4,AND,2000,-0.171032,0.933591,-1.123562
...,...,...,...,...,...
4779,WSM,2022,-0.207001,-0.547226,0.838002
4780,YEM,2022,-0.203019,-0.653621,-1.183771
4781,ZAF,2022,0.000000,0.000000,0.000000
4782,ZMB,2022,-0.201570,-0.626434,0.311007


In [9]:
geo_data = create_data(df_geo_edges, df_geo_nodes, 
                        edge_country_a_col='iso_o', edge_country_b_col='iso_d',
                        node_country_col='code_3', year_col='year')

# pol_data = create_data(df_pol_edges, df_pol_nodes,
#                         edge_country_a_col='country_id_a', edge_country_b_col='country_id_b',
#                         node_country_col='country_id', year_col='year')

cul_data = create_data(df_cult_edges, df_cult_nodes,
                        edge_country_a_col='ISO3_a', edge_country_b_col='ISO3_b',
                        node_country_col='ISO3', year_col='year')

lang_data = create_data(df_lang_edges, df_lang_nodes,
                        edge_country_a_col='country_a', edge_country_b_col='country_b',
                        node_country_col='ISO3', year_col='year')

eco_data = create_data(df_eco_edges, df_eco_nodes,
                        edge_country_a_col='src_ISO3', edge_country_b_col='tgt_ISO3',
                        node_country_col='ISO3', year_col='year')

Number of edges after filtering: 782920
Number of nodes after filtering: 4402
Number of edges after filtering: 54716
Number of nodes after filtering: 2548
Number of edges after filtering: 852288
Number of nodes after filtering: 4186
Number of edges after filtering: 589247
Number of nodes after filtering: 4393


In [16]:
df_cult_nodes

Unnamed: 0,ISO3,year,pdi,idv,mas,uai,ltowvs,ivr
0,ALB,2000,68.0,25.0,57.0,29.0,61.0,15.0
1,ALB,2001,68.0,25.0,57.0,29.0,61.0,15.0
2,ALB,2002,68.0,25.0,57.0,29.0,61.0,15.0
3,ALB,2003,68.0,25.0,57.0,29.0,61.0,15.0
4,ALB,2004,68.0,25.0,57.0,29.0,61.0,15.0
...,...,...,...,...,...,...,...,...
2621,ZWE,2021,70.0,46.0,53.0,68.0,15.0,28.0
2622,ZWE,2022,70.0,46.0,53.0,68.0,15.0,28.0
2623,ZWE,2023,70.0,46.0,53.0,68.0,15.0,28.0
2624,ZWE,2024,70.0,46.0,53.0,68.0,15.0,28.0


In [11]:
# TODO extract GINI and add to complete data

data_years = {}

for year in geo_data.keys():
    network_data = {
        'geography': geo_data[year],
        'political': pol_data[year] if 'pol_data' in locals() else None,
        'culture': cul_data[year],
        'language': lang_data[year] if 'lang_data' in locals() else None,
        'economy': eco_data[year],  # Placeholder for future economy data
        'GINI': None  # Placeholder for GINI data
    }
    data_years[year] = network_data
    
# get the dimesions of each network: number of node_features, number of edge-features
network_dimesions = {
    'geography': {
        'node_features': geo_data[2000].x.shape[1],
        'edge_features': geo_data[2000].edge_attr.shape[1]
    },
    'political': {
        'node_features': pol_data[2000].x.shape[1],
        'edge_features': pol_data[2000].edge_attr.shape[1]
    } if 'pol_data' in locals() else None,
    'culture': {
        'node_features': cul_data[2000].x.shape[1],
        'edge_features': cul_data[2000].edge_attr.shape[1]
    },
    'language': {
        'node_features': lang_data[2000].x.shape[1],
        'edge_features': lang_data[2000].edge_attr.shape[1]
    } if 'lang_data' in locals() else None,
    'economy': None,  # Placeholder for future economy data
}

KeyError: np.int64(2015)

## Model Setup

In [None]:
# GNN model for each of the different networks

class DimensionModel(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, hidden_size=32, target_size=1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_node_features = num_node_features
        self.num_edge_features = num_edge_features
        self.target_size = target_size
        self.convs = [GATv2Conv(self.num_node_features, self.hidden_size, edge_dim = num_edge_features),
                      GATv2Conv(self.hidden_size, self.hidden_size, edge_dim = num_edge_features)]

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        for conv in self.convs[:-1]:
            x = conv(x, edge_index, edge_attr=edge_attr) # adding edge features here!
            x = F.relu(x)
            x = F.dropout(x, training=self.training)
        x = self.convs[-1](x, edge_index, edge_attr=edge_attr)

        return x
    
    
class MultiModalGNNModel(nn.Module):
    def __init__(self, 
                 input_dims,  # list of (num_node_features, num_edge_features) per modality
                 hidden_size=32, 
                 fusion_hidden_size=64, 
                 final_output_size=1):
        super().__init__()
        
        self.geo_model = DimensionModel(input_dims['geography']['n_node_features'], 
                                        input_dims['geography']['n_edge_features'], hidden_size)
        self.pol_model = DimensionModel(input_dims['political']['n_node_features'], 
                                        input_dims['political']['n_edge_features'], hidden_size)
        self.cult_model = DimensionModel(input_dims['culture']['n_node_features'], 
                                        input_dims['culture']['n_edge_features'], hidden_size)
        self.lang_model = DimensionModel(input_dims['language']['n_node_features'], 
                                        input_dims['language']['n_edge_features'], hidden_size)

        self.fusion_input_size = hidden_size * len(input_dims)

        self.mlp = nn.Sequential(
            nn.Linear(self.fusion_input_size, fusion_hidden_size),
            nn.ReLU(),
            nn.Linear(fusion_hidden_size, final_output_size)
        )

    def forward(self, data_dict):
        geo_emb = self.geo_model(data_dict['geography'])
        pol_emb = self.pol_model(data_dict['political'])
        cult_emb = self.cult_model(data_dict['culture'])
        lang_emb = self.lang_model(data_dict['language'])
        
        embeddings = [geo_emb, pol_emb, cult_emb, lang_emb]

        fused = torch.cat(embeddings, dim=-1)  # shape: [num_nodes, hidden_size * num_modalities]
        output = self.mlp(fused)               # shape: [num_nodes, final_output_size]
        return output

## Training Setup