In [1]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors
import torch

# ============================================================================
# STEP 1: LOAD DATA AND DEFINE TARGET
# ============================================================================

df = pd.read_csv("global.csv")

# Define target BEFORE dropping any columns
# Use multiple indicators for robustness
df["has_REE"] = df["REE_Mins"].notna().astype(int)

print(f"Target distribution:\n{df['has_REE'].value_counts()}\n")

# ============================================================================
# STEP 2: IDENTIFY AND REMOVE REE-SPECIFIC COLUMNS (PREVENT DATA LEAKAGE)
# ============================================================================

# These columns contain information that would only be known AFTER 
# determining REE presence - they cannot be used as features
ree_specific_cols = [
    'REE_Mins',      # Specific REE minerals present
    'REE',           # REE grade/content information
    'HREE_Note',     # Heavy REE notes
    'LREE_Note',     # Light REE notes
    'REE_Ratio',     # HREE/LREE ratio
    'Sig_Mins',      # Often contains REE mineral names
]

# Remove these columns from the dataset
existing_ree_cols = [col for col in ree_specific_cols if col in df.columns]
df = df.drop(columns=existing_ree_cols)

print(f"Removed REE-specific columns to prevent leakage: {existing_ree_cols}\n")

df['Rec_Type_clean'] = df['Rec_Type'].str.replace(r'\(\?\)', '', regex=True).str.strip()


# ============================================================================
# STEP 3: HANDLE MISSING VALUES FOR KEY COLUMNS
# ============================================================================

# Fill nulls for columns we need for feature engineering
key_columns_fillna = {
    'Dep_Type': 'unknown',
    'Dep_Form': 'unknown',
    'Dep_Note': '',
    'Rec_Note': '',
    'Stat_Note': '',
    'Commods': '',
    'Components': '',
    'Part_of': '',
}

for col, fill_value in key_columns_fillna.items():
    if col in df.columns:
        df[col] = df[col].fillna(fill_value)

print("Filled missing values for key columns")

# ============================================================================
# STEP 4: ANALYZE NULL RATIOS
# ============================================================================

null_counts = df.isna().sum()
null_ratios = null_counts / len(df)

print(f"\nColumns with >50% missing values:")
high_null_cols = null_ratios[null_ratios > 0.5]
print(high_null_cols)



Target distribution:
has_REE
1    2026
0    1088
Name: count, dtype: int64

Removed REE-specific columns to prevent leakage: ['REE_Mins', 'REE', 'HREE_Note', 'LREE_Note', 'REE_Ratio', 'Sig_Mins']

Filled missing values for key columns

Columns with >50% missing values:
Oth_Mins      0.598908
Age_Mzn       0.748555
Age_Ma        0.958253
Host_Age      0.736994
HAge_Ma       0.941875
Host_Unit     0.885678
Assoc_Rock    0.883109
Alteration    0.925819
Company       0.818240
Comments      0.644188
Discov_Yr     0.947977
Expl_Note     0.990366
Mine_Meth     0.969171
PStat_Note    0.947013
P_Years       0.984265
P_refs        0.983622
P_Note        0.956969
RR_Ore_Mt     0.897559
RR_TREO_Mt    0.937701
RR_TREOgrd    0.927425
RR_REE_grd    0.985870
RR_Cutoff     0.988439
RR_HM_Mt      0.987797
RR_HM_pct     0.982659
RR_min_Mt     0.990687
RR_min_pct    0.999037
RR_mon_Mt     0.955363
RR_mon_pct    0.963070
RR_oth_grd    0.979769
RR_Yr_Est     0.927425
RR_Refs       0.826911
RR_RegCode    0.8

In [2]:
df['Region'].value_counts()

Region
South and Central Asia    532
Oceania                   504
North America             423
Europe                    383
East Asia                 358
South America             253
Africa                    235
China                     214
Russian Federation        147
Middle East                63
Antarctica                  2
Name: count, dtype: int64

In [3]:
def check_data_leakage(df, target_col='has_REE'):
    """
    Comprehensive check for data leakage in features.
    
    This function checks:
    1. Text fields that mention REE-related terms
    2. Correlation between text mentions and target
    3. Feature importance for potential leakage
    """
    
    print("\n" + "="*70)
    print("DATA LEAKAGE ANALYSIS")
    print("="*70)
    
    # REE-related keywords to search for
    ree_keywords = [
        r'\bREE\b', r'\bREEs\b', r'\bREO\b',
        r'rare earth', r'rare-earth',
        r'lanthanide', r'lanthanoid',
        r'cerium', r'lanthanum', r'neodymium', r'praseodymium',
        r'samarium', r'europium', r'gadolinium', r'terbium',
        r'dysprosium', r'holmium', r'erbium', r'thulium',
        r'ytterbium', r'lutetium', r'yttrium', r'scandium',
        r'bastnasite', r'bastnäsite', r'monazite', r'xenotime',
        r'loparite', r'eudialyte', r'allanite', r'apatite.*REE'
    ]
    
    ree_pattern = '|'.join(ree_keywords)
    
    # Text columns to check
    text_cols = ['Dep_Type', 'Dep_Note', 'Rec_Note', 'Commods', 
                 'Status', 'Stat_Note', 'P_Status', 'Dep_Form']
    
    leakage_report = []
    
    print("\n1. CHECKING TEXT FIELDS FOR REE MENTIONS:")
    print("-" * 70)
    
    for col in text_cols:
        if col not in df.columns:
            continue
            
        # Check for REE mentions
        ree_mentions = df[col].fillna('').str.contains(
            ree_pattern, 
            case=False, 
            regex=True
        )
        
        n_mentions = ree_mentions.sum()
        
        if n_mentions > 0:
            # Calculate correlation with target
            has_ree_when_mentioned = df.loc[ree_mentions, target_col].mean()
            has_ree_when_not_mentioned = df.loc[~ree_mentions, target_col].mean()
            
            # Calculate lift (how much more likely to have REE when mentioned)
            baseline = df[target_col].mean()
            lift = has_ree_when_mentioned / baseline if baseline > 0 else 0
            
            leakage_report.append({
                'column': col,
                'mentions': n_mentions,
                'percent': n_mentions / len(df) * 100,
                'has_ree_when_mentioned': has_ree_when_mentioned,
                'has_ree_when_not': has_ree_when_not_mentioned,
                'lift': lift
            })
            
            print(f"\n{col}:")
            print(f"  Mentions REE: {n_mentions} rows ({n_mentions/len(df)*100:.1f}%)")
            print(f"  has_REE when mentioned: {has_ree_when_mentioned:.1%}")
            print(f"  has_REE when NOT mentioned: {has_ree_when_not_mentioned:.1%}")
            print(f"  Lift: {lift:.2f}x")
            
            if lift > 1.5:
                print(f"  ⚠️  WARNING: Strong leakage signal (lift > 1.5x)")
    
    # Create summary dataframe
    if leakage_report:
        leakage_df = pd.DataFrame(leakage_report)
        leakage_df = leakage_df.sort_values('lift', ascending=False)
        
        print("\n2. LEAKAGE SUMMARY (sorted by lift):")
        print("-" * 70)
        print(leakage_df.to_string(index=False))
        
        print("\n3. RECOMMENDATIONS:")
        print("-" * 70)
        high_leakage = leakage_df[leakage_df['lift'] > 1.5]
        if len(high_leakage) > 0:
            print("⚠️  HIGH RISK columns (lift > 1.5x):")
            for col in high_leakage['column']:
                print(f"  - {col}: Should NOT be used for classification")
        
        moderate_leakage = leakage_df[(leakage_df['lift'] > 1.2) & (leakage_df['lift'] <= 1.5)]
        if len(moderate_leakage) > 0:
            print("\n⚠️  MODERATE RISK columns (lift 1.2-1.5x):")
            for col in moderate_leakage['column']:
                print(f"  - {col}: Use with caution, consider removing")
        
        low_leakage = leakage_df[leakage_df['lift'] <= 1.2]
        if len(low_leakage) > 0:
            print("\n✓ LOW RISK columns (lift <= 1.2x):")
            for col in low_leakage['column']:
                print(f"  - {col}: Probably safe to use")
    else:
        print("\n✓ No obvious REE mentions found in text fields")
    
    return leakage_df if leakage_report else None

In [4]:
def check_train_test_connectivity(edge_index, train_idx, test_idx):
    """
    Check how many edges connect train and test nodes.
    High connectivity can cause information leakage through the graph.
    """
    print(f"\n{'='*70}")
    print("TRAIN-TEST GRAPH CONNECTIVITY CHECK")
    print(f"{'='*70}")
    
    train_set = set(train_idx)
    test_set = set(test_idx)
    
    cross_edges = 0
    train_edges = 0
    test_edges = 0
    
    for i in range(edge_index.shape[1]):
        src, dst = edge_index[0, i], edge_index[1, i]
        
        if src in train_set and dst in train_set:
            train_edges += 1
        elif src in test_set and dst in test_set:
            test_edges += 1
        elif (src in train_set and dst in test_set) or (src in test_set and dst in train_set):
            cross_edges += 1
    
    total_edges = edge_index.shape[1]
    
    print(f"\nEdge distribution:")
    print(f"  Train-Train edges: {train_edges} ({train_edges/total_edges*100:.1f}%)")
    print(f"  Test-Test edges: {test_edges} ({test_edges/total_edges*100:.1f}%)")
    print(f"  Train-Test edges: {cross_edges} ({cross_edges/total_edges*100:.1f}%)")
    
    if cross_edges / total_edges > 0.2:
        print(f"\n⚠️  WARNING: {cross_edges/total_edges*100:.1f}% of edges cross train-test boundary")
        print("  This can cause information leakage through the graph structure.")
        print("  Consider: reducing k, using different test region, or using inductive GNN")
    else:
        print(f"\n✓ Reasonable train-test separation ({cross_edges/total_edges*100:.1f}% cross-edges)")
    
    return {
        'train_edges': train_edges,
        'test_edges': test_edges,
        'cross_edges': cross_edges,
        'cross_edge_ratio': cross_edges / total_edges
    }


In [5]:
import re

# ============================================================================
# STEP 5: CLEAN FEATURE ENGINEERING (NO LEAKAGE)
# ============================================================================

print("\n" + "="*70)
print("FEATURE ENGINEERING (LEAKAGE-FREE)")
print("="*70)

# Run leakage check first
leakage_report = check_data_leakage(df)

# Based on the leakage check, we'll exclude risky features
print("\n" + "="*70)
print("CREATING SAFE FEATURES")
print("="*70)

# ---------------------------------------------------------------------------
# 5.1 Spatial Features (Standardized) - SAFE
# ---------------------------------------------------------------------------
scaler = StandardScaler()
df[['lat_z', 'lon_z']] = scaler.fit_transform(df[['Latitude', 'Longitude']])
print("✓ Created standardized spatial features (lat_z, lon_z)")

# ---------------------------------------------------------------------------
# 5.2 Regional Features (One-Hot Encoding) - SAFE
# ---------------------------------------------------------------------------
region_dummies = pd.get_dummies(df['Region'], prefix='region', drop_first=False)
df = pd.concat([df, region_dummies], axis=1)
print(f"✓ Created regional features ({len(region_dummies.columns)} regions)")

# ---------------------------------------------------------------------------
# 5.3 Mineral System Hierarchy - SAFE
# ---------------------------------------------------------------------------
df['is_composite_system'] = df['Components'].notna().astype(int)
df['is_part_of_complex'] = df['Part_of'].notna().astype(int)
print("✓ Created mineral system hierarchy features")

# ---------------------------------------------------------------------------
# 5.4 Record Type Features - SAFE
# ---------------------------------------------------------------------------
# rec_type_dummies = pd.get_dummies(df['Rec_Type_clean'], prefix='rec_type', drop_first=False)
rec_type_dummies = pd.get_dummies(df['Rec_Type'], prefix='rec_type', drop_first=False)
df = pd.concat([df, rec_type_dummies], axis=1)
print(f"✓ Created record type features ({len(rec_type_dummies.columns)} types)")

# ---------------------------------------------------------------------------
# 5.5 Deposit Classification - RISKY (based on text)
# ---------------------------------------------------------------------------
# Option A: Use with cleaned text (remove REE mentions)
# Option B: Skip entirely and rely on other features
# Let's use Option A with cleaned text

def classify_system(row):
    """
    Classify deposit type WITHOUT using REE-specific information.
    This removes any mention of REE-related terms before classification.
    """
    # Get text fields
    text = ' '.join([
        str(row.get('Dep_Type', '')),
        str(row.get('Dep_Note', '')),
        str(row.get('Rec_Note', ''))
    ])
    
    # Remove REE-related keywords to prevent leakage
    ree_pattern = r'(REE|rare earth|lanthanide|monazite|bastn[aä]site|xenotime|loparite|eudialyte)'
    text_clean = re.sub(ree_pattern, '', text, flags=re.IGNORECASE).lower()
    
    # Now classify based on geological characteristics only
    if 'carbonatite' in text_clean:
        return 'carbonatite'
    if 'alkaline' in text_clean or 'alkali' in text_clean:
        return 'alkaline_intrusive'
    if 'placer' in text_clean:
        return 'placer'
    if any(k in text_clean for k in ['clay', 'laterite', 'ion adsorption', 'ion-adsorption']):
        return 'clay_laterite'
    if 'pegmatite' in text_clean:
        return 'pegmatite'
    if 'vein' in text_clean:
        return 'vein'
    if any(k in text_clean for k in ['skarn', 'contact']):
        return 'skarn'
    if 'hydrothermal' in text_clean:
        return 'hydrothermal'
    if str(row.get('Dep_Type', '')).strip().lower() not in ['', 'nan', 'unknown']:
        return 'other'
    return 'unknown'

def classify_system_prediscovery(row):
    """
    Classify based ONLY on host rock type, not mineralization.
    This should be knowable before any REE analysis.
    """
    dep_type = str(row.get('Dep_Type', '')).lower()
    
    # Use only host rock terms, not mineral names
    if 'carbonatite' in dep_type:
        return 'carbonatite_host'
    if any(x in dep_type for x in ['alkaline', 'alkali', 'syenite', 'nepheline']):
        return 'alkaline_host'
    if any(x in dep_type for x in ['granite', 'pegmatite']):
        return 'felsic_host'
    if any(x in dep_type for x in ['sediment', 'sandstone', 'conglomerate']):
        return 'sedimentary_host'
    if 'metamorphic' in dep_type:
        return 'metamorphic_host'
    
    return 'unknown_host'

df['system_class'] = df.apply(classify_system, axis=1)
system_class_dummies = pd.get_dummies(df['system_class'], prefix='system', drop_first=False)
df = pd.concat([df, system_class_dummies], axis=1)
print(f"✓ Created deposit classification features (REE-sanitized) ({len(system_class_dummies.columns)} classes)")

# ---------------------------------------------------------------------------
# 5.6 Deposit Form Features - POTENTIALLY RISKY
# ---------------------------------------------------------------------------
# Skip this for now as it might contain REE-specific information
print("⚠️  Skipping Dep_Form features (potential leakage risk)")

# ---------------------------------------------------------------------------
# 5.7 Commodity Association Features - SAFE (excluding REE)
# ---------------------------------------------------------------------------
def has_element_safe(series, element):
    """Check if element is present, but exclude rows that mention REE."""
    # First check for the element
    has_elem = series.str.contains(element, case=False, na=False)
    
    # Exclude if REE is also mentioned (to avoid correlation through REE deposits)
    mentions_ree = series.str.contains(r'\bREE\b|rare earth', case=False, na=False)
    
    # Only count as having element if REE is NOT mentioned
    return (has_elem & ~mentions_ree).astype(int)

# Use safe version that excludes REE-mentioning commodities
df['has_Nb'] = has_element_safe(df['Commods'], r'\bNb\b')
df['has_Ta'] = has_element_safe(df['Commods'], r'\bTa\b')
df['has_Th'] = has_element_safe(df['Commods'], r'\bTh\b')
df['has_U'] = has_element_safe(df['Commods'], r'\bU\b')
df['has_P'] = has_element_safe(df['Commods'], r'\bP\b')
df['has_F'] = has_element_safe(df['Commods'], r'\bF\b')
df['has_Zr'] = has_element_safe(df['Commods'], r'\bZr\b')

# Count of associated commodities
df['commodity_count'] = (
    df[['has_Nb', 'has_Ta', 'has_Th', 'has_U', 'has_P', 'has_F', 'has_Zr']].sum(axis=1)
)
print("✓ Created commodity association features (REE-sanitized)")

# ---------------------------------------------------------------------------
# 5.8 Development Status Features - RISKY (known after discovery)
# ---------------------------------------------------------------------------
# These are known AFTER a deposit is characterized, so they can leak information
# SKIP these features
print("⚠️  Skipping development status features (temporal leakage risk)")

# ============================================================================
# STEP 6: SELECT FEATURES FOR MODEL (CONSERVATIVE)
# ============================================================================

# Base features (safe)
base_features = [
    'lat_z', 'lon_z',
    'is_composite_system', 'is_part_of_complex',
    'has_Nb', 'has_Ta', 'has_Th', 'has_U', 'has_P', 'has_F', 'has_Zr',
    'commodity_count'
]

# Categorical features (safe)
categorical_features = (
    list(region_dummies.columns) +
    list(rec_type_dummies.columns) 
)

# +
#     list(system_class_dummies.columns)

# Combine all features
gnn_features = base_features + categorical_features
# gnn_features = base_features

# Create feature matrix
X = df[gnn_features].astype(float)
y = df['has_REE'].astype(int)

print(f"\n{'='*70}")
print(f"FEATURE MATRIX CREATED (LEAKAGE-FREE)")
print(f"{'='*70}")
print(f"Total features: {X.shape[1]}")
print(f"Total samples: {X.shape[0]}")
print(f"Missing values: {X.isna().sum().sum()}")
print(f"\nFeature breakdown:")
print(f"  - Base features: {len(base_features)}")
print(f"  - Regional features: {len(region_dummies.columns)}")
print(f"  - Record type features: {len(rec_type_dummies.columns)}")
print(f"  - System class features: {len(system_class_dummies.columns)}")


FEATURE ENGINEERING (LEAKAGE-FREE)

DATA LEAKAGE ANALYSIS

1. CHECKING TEXT FIELDS FOR REE MENTIONS:
----------------------------------------------------------------------

Dep_Note:
  Mentions REE: 104 rows (3.3%)
  has_REE when mentioned: 51.0%
  has_REE when NOT mentioned: 65.5%
  Lift: 0.78x

Rec_Note:
  Mentions REE: 28 rows (0.9%)
  has_REE when mentioned: 71.4%
  has_REE when NOT mentioned: 65.0%
  Lift: 1.10x

Commods:
  Mentions REE: 3109 rows (99.8%)
  has_REE when mentioned: 65.0%
  has_REE when NOT mentioned: 80.0%
  Lift: 1.00x

2. LEAKAGE SUMMARY (sorted by lift):
----------------------------------------------------------------------
  column  mentions   percent  has_ree_when_mentioned  has_ree_when_not     lift
Rec_Note        28  0.899165                0.714286          0.650032 1.097871
 Commods      3109 99.839435                0.650370          0.800000 0.999631
Dep_Note       104  3.339756                0.509615          0.655482 0.783288

3. RECOMMENDATIONS:
--

In [6]:
region_dummies.columns

Index(['region_Africa', 'region_Antarctica', 'region_China',
       'region_East Asia', 'region_Europe', 'region_Middle East',
       'region_North America', 'region_Oceania', 'region_Russian Federation',
       'region_South America', 'region_South and Central Asia'],
      dtype='str')

In [7]:
rec_type_dummies.columns

Index(['rec_type_district or area', 'rec_type_intrusion or complex',
       'rec_type_site'],
      dtype='str')

In [8]:
system_class_dummies.columns

Index(['system_alkaline_intrusive', 'system_carbonatite',
       'system_clay_laterite', 'system_hydrothermal', 'system_other',
       'system_pegmatite', 'system_placer', 'system_skarn', 'system_vein'],
      dtype='str')

In [9]:
X.shape

(3114, 26)

In [10]:
# ============================================================================
# STEP 7: BUILD GRAPH STRUCTURE (k-NN SPATIAL EDGES)
# ============================================================================

print(f"\n{'='*70}")
print("BUILDING GRAPH STRUCTURE")
print(f"{'='*70}")

coords = df[['Latitude', 'Longitude']].values

# Build k-nearest neighbors graph
k = 10  # Number of nearest neighbors
nbrs = NearestNeighbors(n_neighbors=k+1, metric='haversine')
nbrs.fit(np.radians(coords))

distances, indices = nbrs.kneighbors(np.radians(coords))

# Create edge list
edge_list = []
for i in range(indices.shape[0]):
    for j in indices[i][1:]:  # Skip self-loop
        edge_list.append((i, j))

# Convert to edge index (remove duplicates and make bidirectional)
edge_index = np.array(list(set(edge_list))).T
edge_index = np.concatenate([edge_index, edge_index[::-1]], axis=1)

print(f"✓ Created k-NN graph (k={k})")
print(f"  - Total edges: {edge_index.shape[1]}")
print(f"  - Average degree: {edge_index.shape[1] / len(df):.1f}")




BUILDING GRAPH STRUCTURE
✓ Created k-NN graph (k=10)
  - Total edges: 62280
  - Average degree: 20.0


In [11]:
# ============================================================================
# STEP 8: TRAIN/TEST SPLIT FUNCTION
# ============================================================================

def create_spatial_split(df, X, y, edge_index, test_region, verbose=True):
    """
    Create train/test split based on spatial holdout (by region).
    
    Parameters:
    -----------
    df : pd.DataFrame
        Original dataframe with 'Region' column
    X : pd.DataFrame
        Feature matrix
    y : pd.Series
        Target labels
    edge_index : np.ndarray
        Graph edge indices
    test_region : str
        Region to hold out for testing
    verbose : bool
        Whether to print statistics
        
    Returns:
    --------
    dict containing:
        - X_train, X_test: Feature matrices for train/test
        - y_train, y_test: Labels for train/test
        - train_idx, test_idx: Node indices for train/test
        - train_mask, test_mask: PyTorch boolean masks
    """
    
    if verbose:
        print(f"\n{'='*70}")
        print("TRAIN/TEST SPLIT (SPATIAL HOLDOUT)")
        print(f"{'='*70}")
        
        regions = df['Region'].unique()
        print(f"\nAvailable regions: {list(regions)}")
        
        if test_region not in regions:
            raise ValueError(f"Test region '{test_region}' not found in data. Available: {list(regions)}")
    
    # Create boolean masks
    test_idx_bool = df['Region'] == test_region
    train_idx_bool = ~test_idx_bool
    
    # Get integer indices
    train_idx = np.where(train_idx_bool)[0]
    test_idx = np.where(test_idx_bool)[0]
    
    # Create train/test splits
    X_train = X.iloc[train_idx] if hasattr(X, 'iloc') else X[train_idx]
    X_test = X.iloc[test_idx] if hasattr(X, 'iloc') else X[test_idx]
    y_train = y.iloc[train_idx] if hasattr(y, 'iloc') else y[train_idx]
    y_test = y.iloc[test_idx] if hasattr(y, 'iloc') else y[test_idx]
    
    # Create PyTorch masks for GNN
    num_nodes = len(df)
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    train_mask[train_idx] = True
    test_mask[test_idx] = True
    
    if verbose:
        print(f"\nTest region: {test_region}")
        print(f"Train nodes: {len(train_idx)} ({len(train_idx)/len(df):.1%})")
        print(f"Test nodes: {len(test_idx)} ({len(test_idx)/len(df):.1%})")
        print(f"\nClass distribution:")
        print(f"  Train - has_REE: {y_train.mean():.2%}")
        print(f"  Test  - has_REE: {y_test.mean():.2%}")
    
    split_dict = {
        'X_train': X_train,
        'X_test': X_test,
        'y_train': y_train,
        'y_test': y_test,
        'train_idx': train_idx,
        'test_idx': test_idx,
        'train_mask': train_mask,
        'test_mask': test_mask,
    }
    
    # Check train-test connectivity
    if verbose:
        check_train_test_connectivity(edge_index, train_idx, test_idx)
    
    return split_dict

In [15]:
from gnn_model import train_gnn_model
test_region = 'China'  # Change this to test different regions
split_data = create_spatial_split(df, X, y, edge_index, test_region)

model, history, metrics = train_gnn_model(
        X=X,
        y=y,
        edge_index=edge_index,
        split_data=split_data,
        model_type='GCN',
        hidden_channels=64,
        num_layers=3,
        dropout=0.5,
        lr=0.01,
        epochs=200,
        patience=20
    )



TRAIN/TEST SPLIT (SPATIAL HOLDOUT)

Available regions: ['Africa', 'Antarctica', 'China', 'East Asia', 'Europe', 'Middle East', 'North America', 'Oceania', 'Russian Federation', 'South America', 'South and Central Asia']

Test region: China
Train nodes: 2900 (93.1%)
Test nodes: 214 (6.9%)

Class distribution:
  Train - has_REE: 66.38%
  Test  - has_REE: 47.20%

TRAIN-TEST GRAPH CONNECTIVITY CHECK

Edge distribution:
  Train-Train edges: 57792 (92.8%)
  Test-Test edges: 3708 (6.0%)
  Train-Test edges: 780 (1.3%)

✓ Reasonable train-test separation (1.3% cross-edges)

Using device: cuda

MODEL CONFIGURATION
Architecture: GCN
Hidden channels: 64
Number of layers: 3
Dropout: 0.5
Learning rate: 0.01
Weight decay: 0.0005
Max epochs: 200
Early stopping patience: 20
Preparing data for GNN...
✓ Created PyG Data object
  - Nodes: 3114
  - Edges: 62280
  - Features: 26
  - Train nodes: 2900
  - Test nodes: 214

✓ Initialized GCN model
  Total parameters: 6274

TRAINING GNN MODEL



Training:  10%|██████▌                                                          | 20/200 [00:00<00:01, 143.23it/s, Loss=0.5264, Train Acc=0.7566, Test F1=0.0917, Test AUC=0.5870]


Early stopping at epoch 21

FINAL EVALUATION RESULTS

TRAIN SET:
  Accuracy:  0.7393
  Precision: 0.8562
  Recall:    0.7299
  F1 Score:  0.7880
  AUC-ROC:   0.8194

TEST SET:
  Accuracy:  0.5374
  Precision: 0.6250
  Recall:    0.0495
  F1 Score:  0.0917
  AUC-ROC:   0.5895

CONFUSION MATRIX (Test Set):
[[110   3]
 [ 96   5]]

TN: 110, FP: 3
FN: 96, TP: 5

CLASSIFICATION REPORT (Test Set):
              precision    recall  f1-score   support

      No REE       0.53      0.97      0.69       113
     Has REE       0.62      0.05      0.09       101

    accuracy                           0.54       214
   macro avg       0.58      0.51      0.39       214
weighted avg       0.58      0.54      0.41       214







✓ Saved training history plot to training_history.png


In [16]:
model, history, metrics = train_gnn_model(
        X=X,
        y=y,
        edge_index=edge_index,
        split_data=split_data,
        model_type='GAT',
        hidden_channels=64,
        num_layers=3,
        dropout=0.5,
        lr=0.01,
        epochs=200,
        patience=20
    )


Using device: cuda

MODEL CONFIGURATION
Architecture: GAT
Hidden channels: 64
Number of layers: 3
Dropout: 0.5
Learning rate: 0.01
Weight decay: 0.0005
Max epochs: 200
Early stopping patience: 20
Preparing data for GNN...
✓ Created PyG Data object
  - Nodes: 3114
  - Edges: 62280
  - Features: 26
  - Train nodes: 2900
  - Test nodes: 214

✓ Initialized GAT model
  Total parameters: 74246

TRAINING GNN MODEL



Training:  12%|███████▌                                                          | 23/200 [00:00<00:03, 55.69it/s, Loss=0.6021, Train Acc=0.7341, Test F1=0.4183, Test AUC=0.5906]



Early stopping at epoch 24

FINAL EVALUATION RESULTS

TRAIN SET:
  Accuracy:  0.6710
  Precision: 0.8332
  Recall:    0.6306
  F1 Score:  0.7179
  AUC-ROC:   0.7702

TEST SET:
  Accuracy:  0.5374
  Precision: 1.0000
  Recall:    0.0198
  F1 Score:  0.0388
  AUC-ROC:   0.5959

CONFUSION MATRIX (Test Set):
[[113   0]
 [ 99   2]]

TN: 113, FP: 0
FN: 99, TP: 2

CLASSIFICATION REPORT (Test Set):
              precision    recall  f1-score   support

      No REE       0.53      1.00      0.70       113
     Has REE       1.00      0.02      0.04       101

    accuracy                           0.54       214
   macro avg       0.77      0.51      0.37       214
weighted avg       0.75      0.54      0.39       214


✓ Saved training history plot to training_history.png


In [14]:
# Check REE prevalence by system class
print("REE prevalence by system_class:")
system_ree = df.groupby('system_class')['has_REE'].agg(['mean', 'count']).sort_values('mean', ascending=False)
print(system_ree)
print(f"\nOverall baseline: {df['has_REE'].mean():.1%}")

REE prevalence by system_class:
                        mean  count
system_class                       
placer              0.903970   1083
pegmatite           0.712727    275
hydrothermal        0.693878     49
carbonatite         0.632075    424
alkaline_intrusive  0.604905    367
vein                0.567901     81
skarn               0.474576     59
clay_laterite       0.366197     71
other               0.321986    705

Overall baseline: 65.1%
