# Multimodal Network

Create GNN for each of the dimensions and later combine them into a single network.

1. Each dimension is a separate graph and gets its own GNN.
2. Dimensions are connected in a MLP layer.

In [68]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import seaborn as sns
import matplotlib.pyplot as plt
import networkx as nx
import os
import country_converter as coco
import functools

# Set up device (is available use GPU to speed up computations)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

os.chdir('/home/jovyan/dlss-project')
print("Current working directory:", os.getcwd())

cuda
Current working directory: /home/jovyan/dlss-project


## Data

In [69]:
path_geo_edges = "data_collection/geography/edges_yearly_dist_enc.parquet"
path_geo_nodes = "data_collection/geography/nodes_enc.parquet"
df_geo_edges = pd.read_parquet(path_geo_edges)
df_geo_nodes = pd.read_parquet(path_geo_nodes)

path_pol_edges = "data_collection/political/data/edge_features.parquet"
path_pol_nodes = "data_collection/political/data/node_features.parquet"
df_pol_edges = pd.read_parquet(path_pol_edges)
df_pol_nodes = pd.read_parquet(path_pol_nodes)

path_cult_edges = "data_collection/culture/culture_edges.parquet"
path_cult_nodes = "data_collection/culture/culture_nodes.parquet"
df_cult_edges = pd.read_parquet(path_cult_edges)
df_cult_nodes = pd.read_parquet(path_cult_nodes)

path_lang_edges = "data_collection/culture/language_religion_edges.parquet"
path_lang_nodes = "data_collection/culture/language_religion_nodes.parquet"
df_lang_edges = pd.read_parquet(path_lang_edges)
df_lang_nodes = pd.read_parquet(path_lang_nodes)

In [122]:
coco.CountryConverter().data[coco.CountryConverter().data['UNmember'].notna()]

Unnamed: 0,APEC,BASIC,BRIC,CC41,CIS,Cecilia2050,Continent_7,DACcode,EEA,EU,...,UNcode,UNmember,UNregion,WIOD,ccTLD,continent,name_official,name_short,obsolete,regex
0,,,,Rest of World,,RoW,Asia,625.0,,,...,4,1946,Southern Asia,RoW,af,Asia,Islamic Republic of Afghanistan,Afghanistan,,afghan
2,,,,Rest of World,,RoW,Europe,71.0,,,...,8,1955,Southern Europe,RoW,al,Europe,Republic of Albania,Albania,,albania
3,,,,Rest of World,,RoW,Africa,130.0,,,...,12,1962,Northern Africa,RoW,dz,Africa,People's Democratic Republic of Algeria,Algeria,,algeria
5,,,,Rest of World,,RoW,Europe,,,,...,20,1993,Southern Europe,RoW,ad,Europe,Principality of Andorra,Andorra,,andorra
6,,,,Rest of World,,RoW,Africa,225.0,,,...,24,1976,Middle Africa,RoW,ao,Africa,Republic of Angola,Angola,,angola
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
243,,,,Rest of World,,RoW,South America,463.0,,,...,862,1945,South America,RoW,ve,America,Bolivarian Republic of Venezuela,Venezuela,,venezuela
244,APEC,,,Rest of World,,RoW,Asia,769.0,,,...,704,1977,South-eastern Asia,RoW,vn,Asia,Socialist Republic of Vietnam,Vietnam,,^((?!n|s|.*republic)|(?=.*socialist)).*viet.?n...
247,,,,Rest of World,,RoW,Asia,580.0,,,...,887,1947,Western Asia,RoW,ye,Asia,Republic of Yemen,Yemen,,yemen
248,,,,Rest of World,,RoW,Africa,288.0,,,...,894,1964,Eastern Africa,RoW,zm,Africa,Republic of Zambia,Zambia,,zambia|northern.?rhodesia


In [125]:
coco.convert(['AUS'], to='UNnumeric', not_found=None)

36

In [144]:
# Pre-compute UN member countries once
@functools.lru_cache(maxsize=1)
def get_un_countries():
    """Cache UN member countries to avoid repeated lookups"""
    iso3 = coco.CountryConverter().data[coco.CountryConverter().data['UNmember'].notna()]['ISO3'].dropna().tolist()
    unnumeric = coco.CountryConverter().data[coco.CountryConverter().data['UNmember'].notna()]['UNcode'].dropna().tolist()
    return iso3, unnumeric


@functools.lru_cache(maxsize=1000)
def convert_country_code(country, target_format='UNnumeric'):
    """Cache country code conversions"""
    return coco.convert(names=country, to=target_format, not_found=None)

def create_data(edge_df, node_df, edge_country_a_col, edge_country_b_col, node_country_col, year_col="year"):
    # Get UN countries once
    uno_iso3_codes, uno_unnumeric_codes = get_un_countries()

    # Pre-filter dataframes
    edge_mask = edge_df[edge_country_a_col].isin(uno_iso3_codes) & edge_df[edge_country_b_col].isin(uno_iso3_codes)
    node_mask = node_df[node_country_col].isin(uno_iso3_codes)

    edge_df = edge_df[edge_mask].copy()
    node_df = node_df[node_mask].copy()

    # Ensure year is int before unique extraction
    edge_df[year_col] = edge_df[year_col].astype(int)
    node_df[year_col] = node_df[year_col].astype(int)

    print(f"Number of edges after filtering: {len(edge_df)}")
    print(f"Number of nodes after filtering: {len(node_df)}")

    # Vectorized country code conversion
    unique_countries_edges = pd.concat([edge_df[edge_country_a_col], edge_df[edge_country_b_col]]).unique()
    unique_countries_nodes = node_df[node_country_col].unique()
    all_unique_countries = np.unique(np.concatenate([unique_countries_edges, unique_countries_nodes]))

    country_to_id_map = dict(zip(
        all_unique_countries,
        coco.convert(all_unique_countries.tolist(), to='UNnumeric', not_found=None)
    ))

    # Apply mapping
    edge_df['country_id_a'] = edge_df[edge_country_a_col].map(country_to_id_map)
    edge_df['country_id_b'] = edge_df[edge_country_b_col].map(country_to_id_map)
    node_df['country_id'] = node_df[node_country_col].map(country_to_id_map)

    # Drop any rows where mapping failed (shouldn't if filters were correct)
    edge_df = edge_df.dropna(subset=['country_id_a', 'country_id_b'])
    node_df = node_df.dropna(subset=['country_id'])

    # Cast country_id to int
    edge_df['country_id_a'] = edge_df['country_id_a'].astype(int)
    edge_df['country_id_b'] = edge_df['country_id_b'].astype(int)
    node_df['country_id'] = node_df['country_id'].astype(int)

    data_dict = {}
    years = edge_df[year_col].unique()
    
    # limit years from 2000 to 2022
    years = [year for year in years if 2000 <= year <= 2022]

    for year in years:
        edge_df_year = edge_df[edge_df[year_col] == year]
        node_df_year = node_df[node_df[year_col] == year].copy()  # copy because we'll potentially modify
        
        print(len(node_df_year['country_id'].unique()), "countries in node_df_year")
        
        # --- Edge features ---
        edge_features_cols = [
            col for col in edge_df_year.columns
            if col not in ['country_id_a', 'country_id_b', edge_country_a_col, edge_country_b_col, year_col]
        ]
        edge_features_df = edge_df_year[edge_features_cols].copy()

        # Boolean to int
        bool_cols_edges = edge_features_df.select_dtypes(include='bool').columns
        if len(bool_cols_edges) > 0:
            edge_features_df[bool_cols_edges] = edge_features_df[bool_cols_edges].astype(int)

        # Coerce all edge feature columns to numeric, fill NaN with 0
        edge_features_df = edge_features_df.apply(pd.to_numeric, errors='coerce').fillna(0)

        # --- Node features ---
        node_features_cols = [
            col for col in node_df_year.columns
            if col not in ['country_id', node_country_col, year_col]
        ]

        # Add missing countries (so each year has all UN countries)
        existing_countries = set(node_df_year['country_id'])
        missing_countries = set(uno_unnumeric_codes) - existing_countries
        
        # Boolean to int for node features
        bool_cols_nodes = node_df_year.select_dtypes(include='bool').columns
        if len(bool_cols_nodes) > 0:
            node_df_year[bool_cols_nodes] = node_df_year[bool_cols_nodes].astype(int)
            
        # add column to state if country does exist in this year/dataset
        node_df_year['exists'] = 1
        print(len(missing_countries))
        if missing_countries:
            # Build missing rows once
            for country in missing_countries:
                missing_row = {col: 0 for col in node_features_cols}
                missing_row['country_id'] = country
                missing_row['year'] = year
                missing_row['exists'] = 0
                row_df = pd.DataFrame([missing_row])
                node_df_year = pd.concat([node_df_year, row_df], ignore_index=True)

        # Ensure node_df_year is sorted by country_id and align features accordingly
        node_df_year_sorted = node_df_year.sort_values(by='country_id').reset_index(drop=True)
        
        # select only feature columns
        node_df_year_features = node_df_year_sorted[node_features_cols].copy()
        print(len(node_df_year_features))
        # --- Tensor creation ---
        edge_features_array = edge_features_df.values.astype(np.float32)
        edge_attr = torch.from_numpy(edge_features_array)

        node_features_tensor = torch.tensor(node_df_year_features.values, dtype=torch.float32)
        edge_index = torch.tensor(
            edge_df_year[['country_id_a', 'country_id_b']].values.T,
            dtype=torch.long
        )
        data = Data(x=node_features_tensor, edge_index=edge_index, edge_attr=edge_attr)
        data_dict[year] = data

    return data_dict

In [145]:
# geo_data = create_data(df_geo_edges, df_geo_nodes, 
#                         edge_country_a_col='iso_o', edge_country_b_col='iso_d',
#                         node_country_col='code_3', year_col='year')

# pol_data = create_data(df_pol_edges, df_pol_nodes,
#                         edge_country_a_col='country_id_a', edge_country_b_col='country_id_b',
#                         node_country_col='country_id', year_col='year')

cul_data = create_data(df_cult_edges, df_cult_nodes,
                        edge_country_a_col='ISO3_a', edge_country_b_col='ISO3_b',
                        node_country_col='ISO3', year_col='year')

# lang_data = create_data(df_lang_edges, df_lang_nodes,
#                         edge_country_a_col='country_a', edge_country_b_col='country_b',
#                         node_country_col='ISO3', year_col='year')

Number of edges after filtering: 54716
Number of nodes after filtering: 2704
98 countries in node_df_year
95
199
98 countries in node_df_year
95
199
98 countries in node_df_year
95
199
98 countries in node_df_year
95
199
98 countries in node_df_year
95
199
98 countries in node_df_year
95
199
98 countries in node_df_year
95
199
98 countries in node_df_year
95
199
98 countries in node_df_year
95
199
98 countries in node_df_year
95
199
98 countries in node_df_year
95
199
98 countries in node_df_year
95
199
98 countries in node_df_year
95
199
98 countries in node_df_year
95
199
98 countries in node_df_year
95
199


In [139]:
lang_data

{np.int64(2000): Data(x=[193, 7], edge_index=[2, 37056], edge_attr=[37056, 2]),
 np.int64(2001): Data(x=[193, 7], edge_index=[2, 37056], edge_attr=[37056, 2]),
 np.int64(2002): Data(x=[193, 7], edge_index=[2, 37056], edge_attr=[37056, 2]),
 np.int64(2003): Data(x=[193, 7], edge_index=[2, 37056], edge_attr=[37056, 2]),
 np.int64(2004): Data(x=[193, 7], edge_index=[2, 37056], edge_attr=[37056, 2]),
 np.int64(2005): Data(x=[193, 7], edge_index=[2, 37056], edge_attr=[37056, 2]),
 np.int64(2006): Data(x=[193, 7], edge_index=[2, 37056], edge_attr=[37056, 2]),
 np.int64(2007): Data(x=[193, 7], edge_index=[2, 37056], edge_attr=[37056, 2]),
 np.int64(2008): Data(x=[193, 7], edge_index=[2, 37056], edge_attr=[37056, 2]),
 np.int64(2009): Data(x=[193, 7], edge_index=[2, 37056], edge_attr=[37056, 2]),
 np.int64(2010): Data(x=[193, 7], edge_index=[2, 37056], edge_attr=[37056, 2]),
 np.int64(2011): Data(x=[193, 7], edge_index=[2, 37056], edge_attr=[37056, 2]),
 np.int64(2012): Data(x=[193, 7], edge_i

## Model Setup

In [None]:
class ModalityGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x