# Graph Convolutional Network for Root Prediction
This notebook demonstrates a non-linear approach using a GCN to predict the root node in parsed dependency trees for one language.

In [1]:
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [2]:
# !pip install torch-geometric

In [68]:
import pandas as pd
import ast
import torch
import torch.nn.functional as F
import random
import itertools

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv

In [69]:
EXP_PATH     = '../../data/expanded_with_features_non-linear.csv'
TRAIN_PATH   = '../../data/train.csv'
TEST_FEATS   = '../../data/expanded_test_with_features_non-linear.csv'
RAW_TEST     = '../../data/test.csv'
CENT_COLS    = [
    'degree','closeness','harmonic','betweeness','load','pagerank',
    'eigenvector','katz','information','current_flow_betweeness',
    'percolation','second_order','laplacian','pos_norm','max_branch_size','subtree_entropy'
]
BATCH_SIZE   = 32
PATIENCE     = 10
MAX_EPOCHS   = 100
DEVICE       = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [70]:
def build_graphs(df_sub):
    graphs = []
    for (_, sent), grp in df_sub.groupby(['language','sentence'], sort=False):
        x = torch.tensor(grp[CENT_COLS].values, dtype=torch.float)
        edges = grp.edgelist.iloc[0]
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() - 1
        y = torch.zeros(len(grp), dtype=torch.long)
        y[grp.root.iloc[0] - 1] = 1
        graphs.append(Data(x=x, edge_index=edge_index, y=y))
    return graphs

In [71]:
def train_and_validate(train_graphs, val_graphs, hidden, dropout, lr, wd):
    """Returns (val_acc, model_trained_on_train_graphs)"""
    train_loader = DataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_graphs,   batch_size=BATCH_SIZE, shuffle=False)

    class GCN(torch.nn.Module):
        def __init__(self, in_feats):
            super().__init__()
            self.conv1 = GCNConv(in_feats, hidden)
            self.conv2 = GCNConv(hidden, 2)
            self.drop  = torch.nn.Dropout(dropout)
        def forward(self, data):
            x, edge = data.x.to(DEVICE), data.edge_index.to(DEVICE)
            x = F.relu(self.conv1(x, edge))
            x = self.drop(x)
            x = self.conv2(x, edge)
            return F.log_softmax(x, dim=1)

    def root_acc(pred, data):
        probs = pred.exp()[:,1]
        correct, total = 0, 0
        batch = data.batch if hasattr(data, 'batch') else torch.zeros(len(data.y), dtype=torch.long)
        for i in range(data.num_graphs):
            mask = (batch == i)
            idx  = probs[mask].argmax()
            correct += int(data.y.to(DEVICE)[mask][idx] == 1)
            total   += 1
        return correct / total

    model = GCN(len(CENT_COLS)).to(DEVICE)
    opt   = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'max', patience=3, factor=0.5)

    best_val, cnt = 0.0, 0
    for ep in range(1, MAX_EPOCHS+1):
        model.train()
        for batch in train_loader:
            batch = batch.to(DEVICE)
            opt.zero_grad()
            loss = F.nll_loss(model(batch), batch.y.to(DEVICE))
            loss.backward(); opt.step()

        model.eval()
        val_acc = sum(root_acc(model(b), b) for b in val_loader) / len(val_loader)
        sched.step(val_acc)

        if val_acc > best_val + 1e-4:
            best_val, cnt = val_acc, 0
        else:
            cnt += 1
            if cnt >= PATIENCE:
                break

    return best_val, model

In [72]:
param_grid = {
    'hidden':  [32, 64],
    'dropout': [0.3, 0.5],
    'lr':      [1e-2, 5e-3],
    'wd':      [1e-3, 1e-4],
}

In [73]:
def run_group(df_group, name):
    print(f"\n>>> Group: {name}")
    graphs = build_graphs(df_group)
    random.shuffle(graphs)

    # 300 / 100 / 100
    train_graphs = graphs[:300]
    val_graphs   = graphs[300:400]
    test_graphs  = graphs[400:500]

    # hyper‐search
    best_cfg, best_val = None, 0.0
    for combo in itertools.product(*param_grid.values()):
        cfg = dict(zip(param_grid.keys(), combo))
        val_acc, _ = train_and_validate(train_graphs, val_graphs, **cfg)
        if val_acc > best_val:
            best_val, best_cfg = val_acc, cfg

    print(f"  best_val_acc = {best_val:.4f} with {best_cfg}")

    # retrain on train+val
    combined = train_graphs + val_graphs
    _, model = train_and_validate(combined, val_graphs, **best_cfg)

    # evaluate on test
    test_loader = DataLoader(test_graphs, batch_size=BATCH_SIZE, shuffle=False)
    def root_acc_batch(pred, data):
        probs = pred.exp()[:,1]
        correct, total = 0, 0
        batch = data.batch if hasattr(data, 'batch') else torch.zeros(len(data.y), dtype=torch.long)
        for i in range(data.num_graphs):
            mask = (batch == i)
            idx  = probs[mask].argmax()
            correct += int(data.y.to(DEVICE)[mask][idx] == 1)
            total   += 1
        return correct / total

    test_acc = sum(root_acc_batch(model(b.to(DEVICE)), b.to(DEVICE))
                   for b in test_loader) / len(test_loader)
    print(f"  test_acc     = {test_acc:.4f}")

    return model, best_val, test_acc

In [74]:
# Load & merge
exp   = pd.read_csv(EXP_PATH)
train = pd.read_csv(TRAIN_PATH)
train['edgelist'] = train['edgelist'].apply(ast.literal_eval)
df    = exp.merge(train[['language','sentence','edgelist','root']],
                  on=['language','sentence'])

# 6) Run on both groups
model_nonjp, val_nonjp, test_nonjp = run_group(df[df.language != 'Japanese'], "NON_JAPANESE")
model_jp,    val_jp,    test_jp    = run_group(df[df.language == 'Japanese'],  "JAPANESE_ONLY")

print("\nSummary:")
print(f" NON_JP: val={val_nonjp:.4f}  test={test_nonjp:.4f}")
print(f" JAPANESE_ONLY: val={val_jp:.4f}  test={test_jp:.4f}")
print(f" AVG_TEST_ACC: {(test_nonjp+test_jp)/2:.4f}")


>>> Group: NON_JAPANESE




  best_val_acc = 0.3828 with {'hidden': 64, 'dropout': 0.3, 'lr': 0.01, 'wd': 0.0001}
  test_acc     = 0.4141

>>> Group: JAPANESE_ONLY




  best_val_acc = 0.3047 with {'hidden': 32, 'dropout': 0.5, 'lr': 0.005, 'wd': 0.001}
  test_acc     = 0.1641

Summary:
 NON_JP: val=0.3828  test=0.4141
 JAPANESE_ONLY: val=0.3047  test=0.1641
 AVG_TEST_ACC: 0.2891


## Test data
Now let's use the best estimators found to predict the test data:

In [75]:
import pandas as pd, ast, torch
from torch_geometric.data import Data

In [76]:
test_feats = pd.read_csv(TEST_FEATS)
raw_test   = pd.read_csv(RAW_TEST)
raw_test['edgelist'] = raw_test['edgelist'].apply(ast.literal_eval)

df_test = (
    test_feats
    .merge(raw_test[['id','language','sentence','edgelist']],
           on=['id','language','sentence'])
)

In [77]:
# build graphs with a “vertex” field
test_graphs = []
for tid, grp in df_test.groupby('id', sort=False):
    x          = torch.tensor(grp[CENT_COLS].values, dtype=torch.float)
    edge_index = torch.tensor(grp.edgelist.iloc[0], dtype=torch.long).t().contiguous() - 1
    verts      = torch.tensor(grp['vertex'].values, dtype=torch.long)
    lang       = grp.language.iloc[0]
    data = Data(x=x, edge_index=edge_index, vertex=verts,
                id=torch.tensor(tid), language=lang)
    test_graphs.append(data)

In [78]:
# predict
results = []
for data in test_graphs:
    data = data.to(DEVICE)
    model = model_jp if data.language=='Japanese' else model_nonjp
    model.eval()
    with torch.no_grad():
        out   = model(data)       # [num_nodes, 2]
        probs = out.exp()[:,1]    # P(root)
    best_idx    = probs.argmax().item()
    best_vertex = data.vertex[best_idx].item()
    results.append({'id':int(data.id.item()), 'root':best_vertex})

submission_gcn = pd.DataFrame(results)
submission_gcn.to_csv('../../data/submission_GCN.csv', index=False)
print(f"\nWrote {len(submission_gcn)} rows to ../../data/submission_GCN.csv")


Wrote 10395 rows to ../../data/submission_GCN.csv


In [79]:
# compare with linear model
sub_lin = pd.read_csv('../../data/labeled_test.csv')
sub_gcn = submission_gcn

cmp = sub_lin.merge(sub_gcn, on='id', suffixes=('_lin','_gcn'))
cmp['match'] = cmp['root_lin'] == cmp['root_gcn']
print(f"\nAgreement rate: {cmp['match'].mean():.3%} ({cmp['match'].sum()}/{len(cmp)})")
print("\nExamples of mismatches:")
print(cmp.loc[~cmp['match']].head())


Agreement rate: 2.030% (211/10395)

Examples of mismatches:
   id  root_lin  root_gcn  match
0   1         4        33  False
1   2        17        34  False
2   3         5         6  False
3   4        15         9  False
4   5         9        10  False
