<a href="https://colab.research.google.com/github/sahilsait/credit-risk-assessment-using-GNNs/blob/main/model_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
!pip install torch-geometric



In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GraphSAGE, TopKPooling
from torch_geometric.nn import global_mean_pool

In [25]:
class CreditRiskGNN(torch.nn.Module):
    def __init__(self, num_features=1, hidden_dim=32, num_classes=3):
        """
        Initialize the GNN model

        Args:
            num_features: Number of features per node (1 as per your graph)
            hidden_dim: Size of hidden layers (32 as per paper)
            num_classes: Number of risk categories (3: low, medium, high)
        """
        super(CreditRiskGNN, self).__init__()

        # Step 1: Three GraphSAGE layers
        self.sage1 = GraphSAGE(
            in_channels=num_features,  # Input features per node
            hidden_channels=hidden_dim, # Hidden dimension size
            num_layers=1               # Single GraphSAGE layer
        )

        self.sage2 = GraphSAGE(
            in_channels=hidden_dim,
            hidden_channels=hidden_dim,
            num_layers=1
        )

        self.sage3 = GraphSAGE(
            in_channels=hidden_dim,
            hidden_channels=hidden_dim,
            num_layers=1
        )

        # Step 2: TopK pooling layer
        self.pool = TopKPooling(hidden_dim)

        # Step 3: Final prediction layers (MLP)
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch):
        """
        Forward pass

        Args:
            x: Node features [num_nodes, num_features]
            edge_index: Graph connectivity [2, num_edges]
            batch: Batch assignments for nodes [num_nodes]
        """
        # Step 1: Apply GraphSAGE layers with ReLU
        x = F.relu(self.sage1(x, edge_index))
        x = F.relu(self.sage2(x, edge_index))
        x = F.relu(self.sage3(x, edge_index))

        # Step 2: Apply pooling
        x, edge_index, _, batch, _, _ = self.pool(x, edge_index, None, batch)

        # Step 3: Global mean pooling
        x = global_mean_pool(x, batch)

        # Step 4: MLP classifier
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        # Step 5: Log softmax for classification
        return F.log_softmax(x, dim=1)

# Function to verify model
def verify_model():
    """
    Create and verify the model with sample data
    """
    # Create model
    model = CreditRiskGNN(num_features=1)
    print("Model Architecture:")
    print(model)

    # Create sample data
    num_nodes = 20  # 20 financial indicators
    x = torch.randn(num_nodes, 1)  # Random features
    edge_index = torch.randint(0, num_nodes, (2, 38))  # Random edges
    batch = torch.zeros(num_nodes, dtype=torch.long)  # All nodes in same batch

    # Test forward pass
    try:
        out = model(x, edge_index, batch)
        print("\nForward pass successful!")
        print("Input shape:", x.shape)
        print("Output shape:", out.shape)
        print("Output values (log probabilities):", out)
    except Exception as e:
        print("Error in forward pass:", str(e))

# Create and test model
def main():
    print("Testing model architecture...")

    try:
        verify_model()

        # Additional information
        print("\nModel Information:")
        print("- Input: Financial indicators as nodes")
        print("- Hidden layers: 3 GraphSAGE layers")
        print("- Pooling: TopK pooling")
        print("- Output: 3 risk categories (low, medium, high)")

    except Exception as e:
        print("Error:", str(e))

if __name__ == "__main__":
    main()

Testing model architecture...
Model Architecture:
CreditRiskGNN(
  (sage1): GraphSAGE(1, 32, num_layers=1)
  (sage2): GraphSAGE(32, 32, num_layers=1)
  (sage3): GraphSAGE(32, 32, num_layers=1)
  (pool): TopKPooling(32, ratio=0.5, multiplier=1.0)
  (fc1): Linear(in_features=32, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=3, bias=True)
)

Forward pass successful!
Input shape: torch.Size([20, 1])
Output shape: torch.Size([1, 3])
Output values (log probabilities): tensor([[-1.1496, -1.0495, -1.0992]], grad_fn=<LogSoftmaxBackward0>)

Model Information:
- Input: Financial indicators as nodes
- Hidden layers: 3 GraphSAGE layers
- Pooling: TopK pooling
- Output: 3 risk categories (low, medium, high)


In [26]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [27]:
def calculate_zscore(df):
    """
    Calculate Altman Z-score
    """
    # Z-score components
    df['z1'] = (df['act'] - df['lct']) / df['at']  # Working Capital/Total Assets
    df['z2'] = df['ni'] / df['at']                 # Retained Earnings/Total Assets
    df['z3'] = df['oibdp'] / df['at']             # EBIT/Total Assets
    df['z4'] = df['ceq'] / df['lt']               # Equity/Total Liabilities
    df['z5'] = df['sale'] / df['at']              # Sales/Total Assets

    # Altman Z-score formula
    z_score = (
        1.2 * df['z1'] +
        1.4 * df['z2'] +
        3.3 * df['z3'] +
        0.6 * df['z4'] +
        1.0 * df['z5']
    )

    return z_score

In [28]:
def get_risk_category(z_score):
    """
    Convert Z-score to risk category
    """
    if z_score < 1.81:
        return 2  # High risk
    elif z_score < 2.99:
        return 1  # Medium risk
    else:
        return 0  # Low risk

In [29]:
def prepare_data_with_labels(graph_data_list, df):
    """
    Add risk category labels to graphs
    """
    # Calculate Z-scores
    df['z_score'] = calculate_zscore(df)

    # Convert to risk categories
    df['risk_category'] = df['z_score'].apply(get_risk_category)

    # Sort DataFrame to match graph order
    df_sorted = df.sort_values(['gvkey', 'fyear'])

    # Add labels to graphs
    for i, graph in enumerate(graph_data_list):
        graph.y = torch.tensor([df_sorted['risk_category'].iloc[i]], dtype=torch.long)

    # Print distribution
    print("\nRisk Category Distribution:")
    print(df_sorted['risk_category'].value_counts().sort_index())

    return graph_data_list

In [30]:
import pandas as pd
import torch

df = pd.read_csv('/content/drive/MyDrive/datasets/preprocessed_data.csv')  # Load your original data

# 3. Add labels to graphs
graph_data_list = torch.load('/content/drive/MyDrive/datasets/graphs.pt')
graph_data_list = prepare_data_with_labels(graph_data_list, df)

  graph_data_list = torch.load('/content/drive/MyDrive/datasets/graphs.pt')



Risk Category Distribution:
risk_category
0    44246
1      824
2     4822
Name: count, dtype: int64


In [32]:
import torch
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

def train_model():
    # 2. Split into train/val/test
    train_data, test_data = train_test_split(
        graph_data_list,
        test_size=0.2,
        random_state=42
    )
    train_data, val_data = train_test_split(
        train_data,
        test_size=0.2,
        random_state=42
    )

    print(f"\nSplit sizes:")
    print(f"Train: {len(train_data)}")
    print(f"Validation: {len(val_data)}")
    print(f"Test: {len(test_data)}")


    # 3. Create data loaders
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32)
    test_loader = DataLoader(test_data, batch_size=32)

    # 4. Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nUsing device: {device}")

    model = CreditRiskGNN(
        num_features=1,      # Each node has 1 feature
        hidden_dim=32,       # Hidden dimension size
        num_classes=3        # Risk categories
    ).to(device)

    # 5. Initialize optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # 6. Training loop
    best_val_acc = 0
    patience = 10
    patience_counter = 0
    num_epochs = 100

    print("\nStarting training...")

    for epoch in range(num_epochs):
        # Training
        model.train()
        total_loss = 0

        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            batch = batch.to(device)
            optimizer.zero_grad()

            # Forward pass
            out = model(batch.x, batch.edge_index, batch.batch)

            # Calculate loss
            loss = F.nll_loss(out, batch.y)

            # Backward pass
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)

        # Validation phase
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                out = model(batch.x, batch.edge_index, batch.batch)
                pred = out.argmax(dim=1)
                correct += pred.eq(batch.y).sum().item()
                total += batch.y.size(0)

        val_acc = correct / total

        # Print progress
        print(f'Epoch: {epoch+1:03d}, Loss: {avg_loss:.4f}, Val Acc: {val_acc:.4f}')

        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pt')
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping!")
                break

    # 6. Test phase
    model.load_state_dict(torch.load('best_model.pt'))
    model.eval()

    test_correct = 0
    test_total = 0
    predictions = []
    true_labels = []

    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch)
            pred = out.argmax(dim=1)
            test_correct += pred.eq(batch.y).sum().item()
            test_total += batch.y.size(0)

            predictions.extend(pred.cpu().numpy())
            true_labels.extend(batch.y.cpu().numpy())

    test_acc = test_correct / test_total

    # 7. Print results
    print("\nFinal Results:")
    print(f"Best Validation Accuracy: {best_val_acc:.4f}")
    print(f"Test Accuracy: {test_acc:.4f}")

    from sklearn.metrics import classification_report
    print("\nClassification Report:")
    print(classification_report(true_labels, predictions,
                              target_names=['Low Risk', 'Medium Risk', 'High Risk']))

    return model, test_acc

def main():
    try:
        print("Starting model training...")
        model, test_acc = train_model()
        print("\nTraining completed successfully!")
        print(f"Final Test Accuracy: {test_acc:.4f}")

    except Exception as e:
        print(f"Error during training: {str(e)}")
        raise

if __name__ == "__main__":
    main()

Starting model training...

Split sizes:
Train: 31930
Validation: 7983
Test: 9979

Using device: cpu

Starting training...


Epoch 1/100: 100%|██████████| 998/998 [00:19<00:00, 50.58it/s] 


Epoch: 001, Loss: 0.2685, Val Acc: 0.9088


Epoch 2/100: 100%|██████████| 998/998 [00:09<00:00, 105.62it/s]


Epoch: 002, Loss: 0.2001, Val Acc: 0.9168


Epoch 3/100: 100%|██████████| 998/998 [00:09<00:00, 107.89it/s]


Epoch: 003, Loss: 0.1891, Val Acc: 0.9220


Epoch 4/100: 100%|██████████| 998/998 [00:10<00:00, 96.63it/s] 


Epoch: 004, Loss: 0.1812, Val Acc: 0.9248


Epoch 5/100: 100%|██████████| 998/998 [00:10<00:00, 95.74it/s] 


Epoch: 005, Loss: 0.1751, Val Acc: 0.9226


Epoch 6/100: 100%|██████████| 998/998 [00:10<00:00, 99.20it/s]


Epoch: 006, Loss: 0.1712, Val Acc: 0.9277


Epoch 7/100: 100%|██████████| 998/998 [00:08<00:00, 113.60it/s]


Epoch: 007, Loss: 0.1671, Val Acc: 0.9314


Epoch 8/100: 100%|██████████| 998/998 [00:09<00:00, 101.79it/s]


Epoch: 008, Loss: 0.1655, Val Acc: 0.9306


Epoch 9/100: 100%|██████████| 998/998 [00:10<00:00, 96.37it/s] 


Epoch: 009, Loss: 0.1720, Val Acc: 0.9340


Epoch 10/100: 100%|██████████| 998/998 [00:10<00:00, 97.62it/s] 


Epoch: 010, Loss: 0.1654, Val Acc: 0.9359


Epoch 11/100: 100%|██████████| 998/998 [00:09<00:00, 104.93it/s]


Epoch: 011, Loss: 0.1611, Val Acc: 0.9291


Epoch 12/100: 100%|██████████| 998/998 [00:08<00:00, 112.77it/s]


Epoch: 012, Loss: 0.1624, Val Acc: 0.9330


Epoch 13/100: 100%|██████████| 998/998 [00:10<00:00, 97.54it/s] 


Epoch: 013, Loss: 0.1584, Val Acc: 0.9345


Epoch 14/100: 100%|██████████| 998/998 [00:10<00:00, 99.38it/s] 


Epoch: 014, Loss: 0.1585, Val Acc: 0.9336


Epoch 15/100: 100%|██████████| 998/998 [00:09<00:00, 102.67it/s]


Epoch: 015, Loss: 0.1575, Val Acc: 0.9315


Epoch 16/100: 100%|██████████| 998/998 [00:10<00:00, 96.28it/s]


Epoch: 016, Loss: 0.1560, Val Acc: 0.9327


Epoch 17/100: 100%|██████████| 998/998 [00:09<00:00, 105.15it/s]


Epoch: 017, Loss: 0.1570, Val Acc: 0.9292


Epoch 18/100: 100%|██████████| 998/998 [00:10<00:00, 97.39it/s] 


Epoch: 018, Loss: 0.1558, Val Acc: 0.9326


Epoch 19/100: 100%|██████████| 998/998 [00:10<00:00, 98.29it/s] 


Epoch: 019, Loss: 0.1575, Val Acc: 0.9355


Epoch 20/100: 100%|██████████| 998/998 [00:09<00:00, 104.07it/s]


Epoch: 020, Loss: 0.1531, Val Acc: 0.9364


Epoch 21/100: 100%|██████████| 998/998 [00:08<00:00, 113.59it/s]


Epoch: 021, Loss: 0.1526, Val Acc: 0.9221


Epoch 22/100: 100%|██████████| 998/998 [00:10<00:00, 97.90it/s] 


Epoch: 022, Loss: 0.1524, Val Acc: 0.9372


Epoch 23/100: 100%|██████████| 998/998 [00:10<00:00, 98.76it/s] 


Epoch: 023, Loss: 0.1546, Val Acc: 0.9331


Epoch 24/100: 100%|██████████| 998/998 [00:10<00:00, 98.01it/s]


Epoch: 024, Loss: 0.1588, Val Acc: 0.9346


Epoch 25/100: 100%|██████████| 998/998 [00:09<00:00, 106.09it/s]


Epoch: 025, Loss: 0.1493, Val Acc: 0.9329


Epoch 26/100: 100%|██████████| 998/998 [00:09<00:00, 103.95it/s]


Epoch: 026, Loss: 0.1490, Val Acc: 0.9326


Epoch 27/100: 100%|██████████| 998/998 [00:10<00:00, 96.18it/s] 


Epoch: 027, Loss: 0.1488, Val Acc: 0.9337


Epoch 28/100: 100%|██████████| 998/998 [00:10<00:00, 97.16it/s] 


Epoch: 028, Loss: 0.1505, Val Acc: 0.9321


Epoch 29/100: 100%|██████████| 998/998 [00:09<00:00, 108.31it/s]


Epoch: 029, Loss: 0.1520, Val Acc: 0.9364


Epoch 30/100: 100%|██████████| 998/998 [00:09<00:00, 110.24it/s]


Epoch: 030, Loss: 0.1509, Val Acc: 0.9317


Epoch 31/100: 100%|██████████| 998/998 [00:10<00:00, 97.11it/s] 


Epoch: 031, Loss: 0.1524, Val Acc: 0.9354


Epoch 32/100: 100%|██████████| 998/998 [00:10<00:00, 97.92it/s] 


Epoch: 032, Loss: 0.1499, Val Acc: 0.9342
Early stopping!


  model.load_state_dict(torch.load('best_model.pt'))



Final Results:
Best Validation Accuracy: 0.9372
Test Accuracy: 0.9382

Classification Report:
              precision    recall  f1-score   support

    Low Risk       0.97      0.98      0.97      8854
 Medium Risk       0.00      0.00      0.00       164
   High Risk       0.71      0.76      0.73       961

    accuracy                           0.94      9979
   macro avg       0.56      0.58      0.57      9979
weighted avg       0.92      0.94      0.93      9979


Training completed successfully!
Final Test Accuracy: 0.9382


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
