In [1]:
#imports 
import pandas as pd
import numpy as np
import os
import pickle as pkl
import datetime as datetime
from sklearn.preprocessing import StandardScaler
import statsmodels.formula.api as sm
import dgl.function as fn
from tqdm import tqdm

#imports for graph creation
import torch
from sklearn.preprocessing import StandardScaler
from itertools import combinations
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt

#imports for graph learning
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn as nn
from tqdm import trange
import torch
import torch_geometric.datasets as datasets
import torch_geometric.data as data
import torch_geometric.transforms as transforms

In [2]:
crises = [1963, 1962, 1967, 1989, 2001, 1986, 1993, 1996, 1983, 1978]

In [40]:
import pandas as pd
import requests

def get_sitc_codes():
    # URL of the JSON file
    url = 'https://comtradeapi.un.org/files/v1/app/reference/S4.json'

    try:
        # Send a GET request to the URL and fetch the data
        response = requests.get(url)
        response.raise_for_status()  # Check that the request was successful
        
        # Load the JSON data
        data = response.json()

        # Since the JSON data might be nested, use json_normalize with appropriate arguments
        if isinstance(data, list):
            # If the top level is a list
            df = pd.json_normalize(data)
        else:
            # If the top level is a dictionary
            # Identify the key that holds the main data (adjust the path as necessary)
            main_data_key = 'results'  # Adjust this based on the actual structure
            df = pd.json_normalize(data[main_data_key])

    except requests.exceptions.RequestException as e:
        print(f"Error fetching data: {e}")
    except ValueError as e:
        print(f"Error parsing JSON: {e}")
    except KeyError as e:
        print(f"Error processing JSON structure: {e}")

    return df

In [49]:
class TradeNetwork:
    """
    We define a class which computes the MST trade network for a given year 
    """
    
    def __init__(self, year = 1962, data_dir = "data"):
        self.year = year
        self.data_dir = data_dir
        
    def prepare_features(self, filter_gdp = True):
        
        ###IMPORT GDP###
        #prepare GDP as a set of features 
        with open('data/all_wb_indicators.pickle', 'rb') as handle:
            features_dict = pkl.load(handle)

        self.gdp = features_dict['NY.GDP.MKTP.CD']
        scaler = StandardScaler()
        self.gdp[["prev_gdp"]] = scaler.fit_transform(np.log(self.gdp[['YR'+str(self.year-1)]]))
        self.gdp[["current_gdp"]] = scaler.fit_transform(np.log(self.gdp[['YR'+str(self.year)]]))
        #rename and keep relevant columns
        self.gdp["country_code"] = self.gdp["economy"]
        self.gdp = self.gdp[["country_code", "prev_gdp", "current_gdp"]].dropna()
        
        ###IMPORT GDP GROWTH###
        #prepare GDP growth
        self.gdp_growth = features_dict['NY.GDP.MKTP.KD.ZG']
        self.gdp_growth["prev_gdp_growth"] = self.gdp_growth['YR'+str(self.year-1)]
        self.gdp_growth["current_gdp_growth"] = self.gdp_growth['YR'+str(self.year)] 
        self.gdp_growth["future_gdp_growth"] = self.gdp_growth['YR'+str(self.year+1)]
        #rename and keep relevant columns
        self.gdp_growth["country_code"] = self.gdp_growth["economy"]
        self.gdp_growth = self.gdp_growth[["country_code", "prev_gdp_growth",
                                "current_gdp_growth", "future_gdp_growth"]].dropna()
        
        ###IMPORT GDP PER CAPITA###
        self.gdp_per_capita = features_dict['NY.GDP.PCAP.CD']
        self.gdp_per_capita["prev_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year-1)]
        self.gdp_per_capita["current_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year)]
        self.gdp_per_capita["future_gdp_per_cap"] = self.gdp_per_capita['YR'+str(self.year+1)]
        #rename and keep relevant columns
        self.gdp_per_capita["country_code"] = self.gdp_per_capita["economy"]
        self.gdp_per_capita = self.gdp_per_capita[["country_code", "prev_gdp_per_cap",
                                "current_gdp_per_cap", "future_gdp_per_cap"]].dropna()
        
        ###IMPORT GDP PER CAPITA GROWTH###
        self.gdp_per_capita_growth = features_dict['NY.GDP.PCAP.KD.ZG']
        self.gdp_per_capita_growth["prev_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year-1)]
        self.gdp_per_capita_growth["current_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year)]
        self.gdp_per_capita_growth["future_gdp_per_cap_growth"] = self.gdp_per_capita_growth['YR'+str(self.year+1)]
        
        #rename and keep relevant columns
        self.gdp_per_capita_growth["country_code"] = self.gdp_per_capita_growth["economy"]
        self.gdp_per_capita_growth = self.gdp_per_capita_growth[["country_code", "prev_gdp_per_cap_growth",
                                "current_gdp_per_cap_growth", "future_gdp_per_cap_growth"]].dropna()
        
        ###MERGE ALL DATA FEATURES###
        self.features = pd.merge(self.gdp_growth, self.gdp, on = "country_code").dropna()
        self.features = pd.merge(self.features, self.gdp_per_capita, on = "country_code").dropna()
        self.features = pd.merge(self.features, self.gdp_per_capita_growth, on = "country_code").dropna()

    def prepare_network(self):
        """
        We create an initial, import-centric trade link pandas dataframe for a given year
        """
        #get product codes
        data_dict = get_sitc_codes()
        data_cross = []
        i = 0
        for item_def in list(data_dict["text"]):
            if(i >= 2):
                data_cross.append(item_def.split(" - ", 1))
            i = i+1

        self.product_codes = pd.DataFrame(data_cross, columns = ['code', 'product'])
        self.product_codes["sitc_product_code"] = self.product_codes["code"]
        
        #get country codes
        self.country_codes = pd.read_excel("data/ISO3166.xlsx")
        self.country_codes["location_code"] = self.country_codes["Alpha-3 code"]
        self.country_codes["partner_code"] = self.country_codes["Alpha-3 code"]
        self.country_codes["country_i"] = self.country_codes["English short name"]
        self.country_codes["country_j"] = self.country_codes["English short name"]
        
        #get trade data for a given year
        trade_data = pd.read_stata(self.data_dir + "/country_partner_sitcproduct4digit_year_"+ str(self.year)+".dta") 
        #merge with product / country descriptions
        trade_data = pd.merge(trade_data, self.country_codes[["location_code", "country_i"]],on = ["location_code"])
        trade_data = pd.merge(trade_data, self.country_codes[["partner_code", "country_j"]],on = ["partner_code"])
        trade_data = pd.merge(trade_data, self.product_codes[["sitc_product_code", "product"]], 
                              on = ["sitc_product_code"])
        ###select level of product aggregation
        trade_data["product_category"] = trade_data["sitc_product_code"].apply(lambda x: x[0:1])
        
        #keep only nodes that we have features for
        #trade_data = trade_data[trade_data["location_code"].isin(self.features["country_code"])]
        #trade_data = trade_data[trade_data["partner_code"].isin(self.features["country_code"])]
        
        if (len(trade_data.groupby(["location_code", "partner_code", "sitc_product_code"])["import_value"].sum().reset_index()) != len(trade_data)):
            print("import, export, product combination not unique!")
        self.trade_data1 = trade_data
        #from import-export table, create only import table
        #extract imports
        imports1 = trade_data[['location_id', 'partner_id', 'product_id', 'year',
               'import_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        imports1 = imports1[imports1["import_value"] != 0]
        #transform records of exports into imports
        imports2 = trade_data[['location_id', 'partner_id', 'product_id', 'year',
               'export_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        imports2["temp1"] = imports2['partner_code']
        imports2["temp2"] = imports2['location_code']

        imports2['location_code'] = imports2["temp1"]
        imports2['partner_code'] = imports2["temp2"]
        imports2["import_value"] = imports2["export_value"]
        imports2 = imports2[imports2["import_value"] != 0]
        imports2 = imports2[['location_id', 'partner_id', 'product_id', 'year',
               'import_value', 'sitc_eci', 'sitc_coi', 'location_code', 'partner_code',
               'sitc_product_code', 'country_i', 'country_j', 'product', "product_category"]]
        
        imports_table = pd.concat([imports1, imports2]).drop_duplicates()
        
        #rename columns for better clarity
        imports_table["importer_code"] = imports_table["location_code"]
        imports_table["exporter_code"] = imports_table["partner_code"]
        imports_table["importer_name"] = imports_table["country_i"]
        imports_table["exporter_name"] = imports_table["country_j"]
        
        cols = ["importer_code", "exporter_code", "importer_name", "exporter_name",
               'product_id', 'year', 'import_value', 'sitc_eci', 'sitc_coi',
               'sitc_product_code', 'product', "product_category"]
        imports_table = imports_table[cols]
        
        exporter_total = imports_table.groupby(["exporter_code"])["import_value"].sum().reset_index()
        exporter_total = exporter_total.rename(columns = {"import_value": "export_total"})
        
        importer_total = imports_table.groupby(["importer_code"])["import_value"].sum().reset_index()
        importer_total = importer_total.rename(columns = {"import_value": "import_total"})
        
        #sum imports across all products between countries into single value 
        imports_table_grouped = imports_table.groupby(["importer_code", "exporter_code"])["import_value"].sum().reset_index()
        
        #sum exports in each category 
        self.export_types = imports_table.groupby(["exporter_code", "product_category"])["import_value"].sum().reset_index()
        self.export_types = pd.merge(self.export_types, exporter_total, on = "exporter_code")
        #multiply by 100 to allow weights to scale better in GNN
        self.export_types["category_fraction"] = self.export_types.import_value/self.export_types.export_total*10
        ss = StandardScaler()
        columns = list(set(self.export_types["product_category"]))
        self.export_types = self.export_types[["exporter_code", "product_category", "category_fraction"]]\
        .pivot(index = ["exporter_code"], columns = ["product_category"], values = "category_fraction")\
        .reset_index().fillna(0)
        #rename columns
        rename_columns = []
        for col in self.export_types.columns:
            if(col == "exporter_code"):
                rename_columns.append(col)
            else:
                rename_columns.append("resource_" + col)
        self.export_types.columns = rename_columns
        self.export_types = self.export_types.rename(columns = {"exporter_code": "country_code"})
        self.features = pd.merge(self.features, self.export_types, 
                                on = "country_code", how = "left")
        
        #look at fraction of goods traded with each counterparty
        imports_table_grouped = pd.merge(imports_table_grouped, exporter_total, how = "left")
        imports_table_grouped["export_percent"] = imports_table_grouped["import_value"]/imports_table_grouped["export_total"]
        scaler = StandardScaler()
        imports_table_grouped[["export_percent_feature"]] = scaler.fit_transform(np.log(imports_table_grouped[["export_percent"]]))
        imports_table_grouped["export_percent_feature"] = imports_table_grouped["export_percent_feature"] + abs(min(imports_table_grouped["export_percent_feature"]))
        
        imports_table_grouped = pd.merge(imports_table_grouped, importer_total, how = "left")
        imports_table_grouped["import_percent"] = imports_table_grouped["import_value"]/imports_table_grouped["import_total"]
        scaler = StandardScaler()
        imports_table_grouped[["import_percent_feature"]] = scaler.fit_transform(np.log(imports_table_grouped[["import_percent"]]))
        imports_table_grouped["import_percent_feature"] = imports_table_grouped["import_percent_feature"] + abs(min(imports_table_grouped["import_percent_feature"]))
        
        self.trade_data = imports_table_grouped

    def graph_create(self, exporter = True,
            node_features = ['prev_gdp_growth', 'current_gdp_growth','prev_gdp','current_gdp'],
            node_labels = 'future_gdp_growth'):
        
        if(exporter):
            center_node = "exporter_code"
            neighbors = "importer_code"
            edge_features = 'export_percent'
        
        #filter features and nodes to ones that are connected to others in trade data
        list_active_countries = list(set(list(self.trade_data ["importer_code"])+\
                        list(self.trade_data ["exporter_code"])))
        
        self.features = self.features[self.features["country_code"].isin(list_active_countries)].reset_index()
        self.features.fillna(0, inplace = True)
        self.features["node_numbers"] = self.features.index
        #create lookup dictionary making node number / node features combatible with ordering of nodes
        #in our edge table

        self.node_lookup1 = self.features.set_index('node_numbers').to_dict()['country_code']
        self.node_lookup2 = self.features.set_index('country_code').to_dict()['node_numbers']
        
        #get individual country's features
        self.regression_table = pd.merge(self.features, self.trade_data,
                        left_on = "country_code",
                        right_on = center_node, how = 'right')
        #get features for trade partners
        self.regression_table = pd.merge(self.features, self.regression_table,
                                        left_on = "country_code",
                                        right_on = neighbors, how = "right",
                                        suffixes = ("_neighbors", ""))
        
        self.trade_data = self.trade_data[self.trade_data[neighbors].isin(self.node_lookup2)]
        self.trade_data = self.trade_data[self.trade_data[center_node].isin(self.node_lookup2)]

        self.regression_table["source"] = self.trade_data[neighbors].apply(lambda x: self.node_lookup2[x])
        self.regression_table["target"] = self.trade_data[center_node].apply(lambda x: self.node_lookup2[x])    

        self.regression_table = self.regression_table.dropna()
        #filter only to relevant columns
        self.relevant_columns = ["source", "target"]
        self.relevant_columns.extend(node_features)
        self.relevant_columns.append(node_labels)
        self.graph_table = self.regression_table[self.relevant_columns]
        
        if(self.graph_table.isnull().values.any()): print("edges contain null / inf values")

        self.node_attributes = torch.tensor(np.array(self.features[node_features]))\
        .to(torch.float)
        self.source_nodes = list(self.graph_table["source"])
        self.target_nodes = list(self.graph_table["target"])
        self.edge_attributes = list(self.trade_data[edge_features])

        self.pyg_graph = data.Data(x = self.node_attributes,
                                   edge_index = torch.tensor([self.source_nodes, self.target_nodes]),
                                   edge_attr = torch.tensor(self.edge_attributes).to(torch.float),
                                   y = torch.tensor(list(self.features[node_labels])).to(torch.float))

        

# Graph Setup

In [50]:
years = range(1962,2021)

train_years = years[:30]
val_years = years[30:42]
test_years = years[42:]

train_graphs = []
val_graphs = []
test_graphs = []
i = 0

for year in tqdm(years):
    print(str(year), end='\r')
    
    trade = TradeNetwork(year = year)
    trade.prepare_features()
    trade.prepare_network()
    trade.graph_create(node_features = ['prev_gdp_per_cap_growth', 'current_gdp_per_cap_growth',
    'resource_0', 'resource_1', 'resource_2', 'resource_3', 'resource_4', 'resource_5', 'resource_6', 'resource_7',
       'resource_8', 'resource_9'],
        node_labels = 'future_gdp_per_cap_growth')
    
    if(year in val_years):
        val_graphs.append(trade.pyg_graph)
    elif(year in test_years):
        test_graphs.append(trade.pyg_graph)
    else: 
        train_graphs.append(trade.pyg_graph)
        
    trade.features["year"] = year
    
    if(i == 0):
        trade_df = trade.features
    else: 
        trade_df = pd.concat([trade_df, trade.features])
        
    i = i+1
    print(trade.node_attributes.size())

  0%|          | 0/59 [00:00<?, ?it/s]

1962

  2%|▏         | 1/59 [00:05<05:27,  5.64s/it]

torch.Size([75, 12])
1963

  3%|▎         | 2/59 [00:10<05:06,  5.38s/it]

torch.Size([82, 12])
1964

  5%|▌         | 3/59 [00:16<05:11,  5.56s/it]

torch.Size([82, 12])
1965

  7%|▋         | 4/59 [00:23<05:44,  6.27s/it]

torch.Size([85, 12])
1966

  8%|▊         | 5/59 [00:31<06:08,  6.83s/it]

torch.Size([87, 12])
1967

 10%|█         | 6/59 [00:38<05:51,  6.63s/it]

torch.Size([92, 12])
1968

 12%|█▏        | 7/59 [00:44<05:49,  6.72s/it]

torch.Size([96, 12])
1969

 14%|█▎        | 8/59 [00:51<05:41,  6.70s/it]

torch.Size([94, 12])
1970

 15%|█▌        | 9/59 [00:59<05:48,  6.98s/it]

torch.Size([99, 12])
1971

 17%|█▋        | 10/59 [01:06<05:49,  7.14s/it]

torch.Size([99, 12])
1972

 19%|█▊        | 11/59 [01:14<05:49,  7.27s/it]

torch.Size([108, 12])
1973

 20%|██        | 12/59 [01:22<05:58,  7.62s/it]

torch.Size([108, 12])
1974

 22%|██▏       | 13/59 [01:31<06:03,  7.91s/it]

torch.Size([108, 12])
1975

 24%|██▎       | 14/59 [01:40<06:10,  8.23s/it]

torch.Size([108, 12])
1976

 25%|██▌       | 15/59 [01:44<05:06,  6.97s/it]

torch.Size([110, 12])
1977

 27%|██▋       | 16/59 [01:48<04:28,  6.25s/it]

torch.Size([112, 12])
1978

 29%|██▉       | 17/59 [01:54<04:19,  6.18s/it]

torch.Size([114, 12])
1979

 31%|███       | 18/59 [02:01<04:13,  6.19s/it]

torch.Size([119, 12])
1980

 32%|███▏      | 19/59 [02:07<04:15,  6.38s/it]

torch.Size([119, 12])
1981

 34%|███▍      | 20/59 [02:14<04:15,  6.56s/it]

torch.Size([121, 12])
1982

 36%|███▌      | 21/59 [02:21<04:11,  6.61s/it]

torch.Size([130, 12])
1983

 37%|███▋      | 22/59 [02:28<04:05,  6.63s/it]

torch.Size([133, 12])
1984

 39%|███▉      | 23/59 [02:35<04:09,  6.92s/it]

torch.Size([136, 12])
1985

 41%|████      | 24/59 [02:42<04:01,  6.91s/it]

torch.Size([137, 12])
1986

 42%|████▏     | 25/59 [02:50<03:59,  7.05s/it]

torch.Size([139, 12])
1987

 44%|████▍     | 26/59 [02:58<04:03,  7.37s/it]

torch.Size([139, 12])
1988

 46%|████▌     | 27/59 [03:06<04:00,  7.50s/it]

torch.Size([141, 12])
1989

 47%|████▋     | 28/59 [03:14<04:03,  7.84s/it]

torch.Size([141, 12])
1990

 49%|████▉     | 29/59 [03:23<04:04,  8.17s/it]

torch.Size([141, 12])
1991

 51%|█████     | 30/59 [03:33<04:07,  8.53s/it]

torch.Size([142, 12])
1992

 53%|█████▎    | 31/59 [03:43<04:13,  9.06s/it]

torch.Size([159, 12])
1993

 54%|█████▍    | 32/59 [03:53<04:16,  9.50s/it]

torch.Size([162, 12])
1994

 56%|█████▌    | 33/59 [04:05<04:22, 10.08s/it]

torch.Size([166, 12])
1995

 58%|█████▊    | 34/59 [04:16<04:20, 10.42s/it]

torch.Size([167, 12])
1996

 59%|█████▉    | 35/59 [04:28<04:24, 11.03s/it]

torch.Size([168, 12])
1997

 61%|██████    | 36/59 [04:41<04:21, 11.37s/it]

torch.Size([177, 12])
1998

 63%|██████▎   | 37/59 [04:53<04:17, 11.72s/it]

torch.Size([177, 12])
1999

 64%|██████▍   | 38/59 [05:06<04:14, 12.12s/it]

torch.Size([179, 12])
2000

 66%|██████▌   | 39/59 [05:19<04:09, 12.46s/it]

torch.Size([184, 12])
2001

 68%|██████▊   | 40/59 [05:33<04:03, 12.79s/it]

torch.Size([185, 12])
2002

 69%|██████▉   | 41/59 [05:47<03:55, 13.09s/it]

torch.Size([189, 12])
2003

 71%|███████   | 42/59 [06:02<03:51, 13.61s/it]

torch.Size([190, 12])
2004

 73%|███████▎  | 43/59 [06:17<03:46, 14.14s/it]

torch.Size([194, 12])
2005

 75%|███████▍  | 44/59 [06:34<03:45, 15.05s/it]

torch.Size([195, 12])
2006

 76%|███████▋  | 45/59 [06:54<03:49, 16.39s/it]

torch.Size([196, 12])
2007

 78%|███████▊  | 46/59 [07:11<03:37, 16.74s/it]

torch.Size([196, 12])
2008

 80%|███████▉  | 47/59 [07:30<03:28, 17.39s/it]

torch.Size([197, 12])
2009

 81%|████████▏ | 48/59 [07:49<03:16, 17.82s/it]

torch.Size([197, 12])
2010

 83%|████████▎ | 49/59 [08:07<02:59, 17.93s/it]

torch.Size([197, 12])
2011

 85%|████████▍ | 50/59 [08:25<02:41, 17.95s/it]

torch.Size([197, 12])
2012

 86%|████████▋ | 51/59 [08:44<02:25, 18.15s/it]

torch.Size([200, 12])
2013

 88%|████████▊ | 52/59 [09:03<02:08, 18.41s/it]

torch.Size([201, 12])
2014

 90%|████████▉ | 53/59 [09:22<01:52, 18.69s/it]

torch.Size([200, 12])
2015

 92%|█████████▏| 54/59 [09:42<01:34, 18.96s/it]

torch.Size([201, 12])
2016

 93%|█████████▎| 55/59 [10:01<01:16, 19.02s/it]

torch.Size([201, 12])
2017

 95%|█████████▍| 56/59 [10:20<00:56, 18.95s/it]

torch.Size([201, 12])
2018

 97%|█████████▋| 57/59 [10:40<00:38, 19.21s/it]

torch.Size([199, 12])
2019

 98%|█████████▊| 58/59 [10:59<00:19, 19.38s/it]

torch.Size([197, 12])
2020

100%|██████████| 59/59 [11:18<00:00, 11.50s/it]

torch.Size([192, 12])





In [51]:
import pickle as pkl
with open("pygcn/train_graphs.pickle", "wb") as f:
    pkl.dump(train_graphs, f)

with open("pygcn/val_graphs.pickle", "wb") as f:
    pkl.dump(val_graphs, f)

with open("pygcn/test_graphs.pickle", "wb") as f:
    pkl.dump(test_graphs, f)

In [3]:
import pickle as pkl
with open("pygcn/train_graphs.pickle", "rb") as f:
    train_graphs = pkl.load(f)

with open("pygcn/val_graphs.pickle", "rb") as f:  
    val_graphs = pkl.load(f)

with open("pygcn/test_graphs.pickle", "rb") as f:         
    test_graphs = pkl.load(f)

## GNN

In [55]:
from torch_geometric.data import DataLoader
test_loader = DataLoader(test_graphs, batch_size=32)
train_loader = DataLoader(train_graphs, batch_size=32)
val_loader = DataLoader(val_graphs, batch_size=32)



In [31]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

class GCN(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, 64)
        self.conv2 = GCNConv(64, 32)
        self.conv3 = GCNConv(32, 16)
        self.conv4 = GCNConv(16, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index.long()
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.conv4(x, edge_index)
        return x

In [32]:
import math
model = GCN(in_channels=train_loader.dataset[0].num_node_features, out_channels=1)  # Output size is 1 for regression
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train(model, train_loader, val_loader, criterion, optimizer, epochs=80):
    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch)
            loss = criterion(out.view(-1), batch.y)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                out = model(batch)
                loss = criterion(out.view(-1), batch.y)
                total_val_loss += loss.item()

        train_rmse = math.sqrt(total_train_loss / len(train_loader))
        val_rmse = math.sqrt(total_val_loss / len(val_loader))

        if epoch % 20 == 0:
            print(f'Epoch {epoch}, Train RMSE: {train_rmse}, Val RMSE: {val_rmse}')

train(model, train_loader, val_loader, criterion, optimizer)

Epoch 0, Train RMSE: 6.167343755443079, Val RMSE: 6.603258553854508
Epoch 20, Train RMSE: 5.995345535565095, Val RMSE: 6.464505927257312
Epoch 40, Train RMSE: 5.95909992257279, Val RMSE: 6.494006695047539
Epoch 60, Train RMSE: 5.948425520616057, Val RMSE: 6.530044955555975


## Simple GNN

In [71]:
class SimpleGCN(torch.nn.Module):
    def __init__(self):
        super(SimpleGCN, self).__init__()
        self.conv1 = GCNConv(train_loader.dataset[0].num_node_features, 16)
        self.conv2 = GCNConv(16, 16)
        self.conv3 = GCNConv(16, 16)
        self.lns = nn.LayerNorm(16)
        self.post_mp = nn.Sequential(
            nn.Linear(16, 16),
            nn.Dropout(0.25),
            nn.Linear(16, 2)
        )

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index.to(torch.int64))
        x = F.relu(x)
        x = F.dropout(x, p=0.25, training=self.training)
        x = self.lns(x)

        x = self.conv2(x, edge_index.to(torch.int64))
        x = F.relu(x)
        x = F.dropout(x, p=0.25, training=self.training)
        x = self.lns(x)

        x = self.conv3(x, edge_index.to(torch.int64))
        x = F.relu(x)
        x = F.dropout(x, p=0.25, training=self.training)

        x = self.post_mp(x)
        return x

### Getting Embeddings for Each Graph

In [73]:
train_embeddings = []
val_embeddings = []
model = SimpleGCN()

for g in train_loader.dataset:
    train_embeddings.append(model(g))

for g in val_loader.dataset:
    val_embeddings.append(model(g))