# Import

In [None]:
import os
import timm
import torch
import wandb
import pickle
import optuna
import random
import numpy as np
import pandas as pd
import torch.nn as nn
import networkx as nx
import torch.nn.functional as F
import torch_geometric.transforms as T
import torchvision.transforms as transforms

from torch_geometric import seed_everything
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import to_networkx
from torchvision.transforms.functional import to_pil_image

from glob import glob
from tqdm.notebook import tqdm
from sklearn.metrics import r2_score

import sys

sys.path.append("../../src")

from models import InductiveGATwithIMGS, InductiveGCNwithIMGS, CNNBlock
from utils import set_seed, EarlyStoppingR2, train_CFG
from training_utils import train_inductive, inference_inductive, cross_val_inductive

In [None]:
# global params

#IMPORTANT: USE THIS SEED
SEED = 111
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")

# use feature-propagation algo or not
use_features_propagation = False

# current region that you are working on
region = 812

regions_mapper = {
        777 : "Moscow",
        812 : "Saint-Petersburg",
        287 : "Kazan",
        473 : "Sochi"
    }


#paths params
# path_for_graph = f"../../../data/graph_preprocessing/{regions_mapper[region]}/graph_with_cv_full.pickle"
path_for_graph = f"/home/jupyter/datasphere/s3/s3-sirius/sirius_2024_participants/twwist/graph_with_cv_full_and_images/images_graph_{region}.pickle"
checkpoints_path = "../../chkps/inductive_gcn_pipeline" #from the notebook directory, must start with ../../chkps/
assert (os.path.exists(checkpoints_path)), "path for checkoints must exists"


#model params
hidden_dim = 64
n_layers = 4
n_head=2
cnn_out_channels=64

#training params setting
optimizer_name = "AdamW" #("Adam", "AdamW", "RMSProp")
use_scheduler = True

#early stopper params
use_stopper = True
stopper_patience = 100
stopper_delta = 0.001

verbose = 10

num_epochs = 1000

#number of epochs before starting using sheduler and stopper
started_patience = 1

#image features
add_image_features = True

image_features_path = f"../../data/image_embeddings/image_features_{region}.pickle"

# Loaders

In [None]:
graph = torch.load(
    path_for_graph
).to(device, "x", "edge_index")

if use_features_propagation:
    graph.x[graph.x == -1] = torch.nan
    graph = T.FeaturePropagation(missing_mask=torch.isnan(graph.x), num_iterations = 400)(graph)
    
if add_image_features:
    image_features = torch.load(image_features_path)
    graph.x = torch.cat([graph.x, image_features], dim=-1)

In [None]:
#setting up dataloaders for default training
loader_batch_size=256

train_loader = NeighborLoader(
    graph,
    input_nodes=graph.train_mask,
    num_neighbors=[-1, -1, -1, -1],
    batch_size=loader_batch_size,
    shuffle=True,
)

#val loader
val_loader = NeighborLoader(
    graph,
    input_nodes=graph.val_mask,
    num_neighbors=[-1, -1, -1, -1],
    batch_size=loader_batch_size,
    shuffle=False,
)

#special loader for efficient inferencing
inference_loader = NeighborLoader(
    graph,
    input_nodes=None,
    num_neighbors=[-1],
    batch_size=4056,
)

# Test model training

In [None]:
set_seed(SEED)

# model
# can be InductiveGCNwithIMGS or InductiveGATwithIMGS
# InductiveGATwithIMGS will require one more param - head (number of heads in each conv)

model = InductiveGCNwithIMGS(
    n_in=graph.num_features,
    n_out=1,
    hidden_dim=hidden_dim,
    n_layers=n_layers,
    cnn_in_channels=graph.imgs[0].shape[0], 
    cnn_out_channels=cnn_out_channels,
).to(device)

#optimizer
optimizer = getattr(torch.optim, optimizer_name)(model.parameters(), lr=0.001771619056705244)

#scheduler
if use_scheduler:
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer=optimizer, factor=0.7, patience=30, threshold=0.01, min_lr=1e-5 / 5
    )

#loss function 
loss_fn = torch.nn.MSELoss()

#EarlyStopper
if use_stopper:
    earlystopper = EarlyStoppingR2(
        patience=stopper_patience,
        verbose=False,
        delta=stopper_delta,
        path=checkpoints_path,
        trace_func=print,
        model_name="best_model_train.pt"
    )

In [None]:
# Loading pretrained model
pretrained_model =  timm.create_model('resnet18', pretrained=True, in_chans=12, num_classes=0).to(device)
use_pretrained = False

if base_train := True:

    #base train
    train_cfg = train_CFG()
    train_cfg("num_epochs", 100)
    train_cfg("verbose", 1)
    train_cfg("scheduler", True)
    
    train_inductive(
        train_loader=train_loader, 
        val_loader=val_loader, 
        model=model, 
        optimizer=optimizer, 
        loss_fn=loss_fn, 
        train_cfg=train_cfg, 
        scheduler=(scheduler if use_scheduler else None),
        started_patience=started_patience,
        earlystopper=(earlystopper if use_stopper else None),
        use_pretrained=use_pretrained,
        pretrained_model=pretrained_model
    )
    
    #evaluation
    metrics = inference(dataset, model, inference_loader)

print("Eval metrics: ")
print(f'Train R2: {metrics[0]}\n'
       f'Val R2: {metrics[1]}\n'
       f'Test R2: {metrics[2]}')

In [None]:
metrics = inference_inductive(graph, model, inference_loader, use_pretrained=use_pretrained, pretrained_model=pretrained_model)

In [None]:
print(f'Train R2: {metrics[0]}\n'
       f'Val R2: {metrics[1]}\n'
       f'Test R2: {metrics[2]}')

# Cross val

In [None]:
pretrained_model =  timm.create_model('resnet18', pretrained=True, in_chans=12, num_classes=0).to(device)
use_pretrained = False

if cv := True:

    cv_cfg = train_CFG()
    cv_cfg("num_epochs", num_epochs)
    cv_cfg("verbose", verbose)
    cv_cfg("scheduler", (True if use_scheduler else None))
    cv_cfg("stopper_patience", stopper_patience)
    cv_cfg("stopper_delta", stopper_delta)
    cv_cfg("started_patience", started_patience)
    
    model_params = dict(
        n_in=graph.num_features,
        n_out=1,
        hidden_dim=hidden_dim,
        n_layers=n_layers,
        head=n_head, #depending on model architecture you`d like to use heads 
        cnn_in_channels=graph.imgs[0].shape[0], 
        cnn_out_channels=cnn_out_channels,
    )


    val_score = cross_val_inductive(
        num_folds=5, 
        dataset=graph, 
        model_name="GAT", #model architecture name
        model_params=model_params,
        optimizer_params={"lr" : 0.001771619056705244}, 
        optimizer_name=optimizer_name,
        cv_cfg=cv_cfg, 
        checkpoints_path="../../chkps/inductive_gcn_cv_pipeline", # checkpoints path
        eval_test=False, # For now if set to "True" than all models evaluates using "test_mask" from graph
        device=device,
        use_pretrained=use_pretrained,
        pretrained_model=pretrained_model
    )

|                                   | **GAT** |          | **GCN** |          |
|-----------------------------------|---------|----------|---------|----------|
| **without images**                | 0.461   |          | 0.494   |          |
| **basic CNN**                     | 0.452   |          | 0.311   |          |
| **pretrained efficient net**      | 0.376   |          | 0.357   |          |
| **pretrained efficient net with image features** | 0.416   |          | 0.450   |          |


In [None]:
repo/chkps