In [8]:
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.data import DataLoader
from torch_geometric.nn import GATConv, GCNConv, SAGEConv, global_max_pool
import pickle

# Custom dataset loading
def load_custom_dataset(pickle_path):
    """Load the dataset from a pickle file."""
    with open(pickle_path, 'rb') as f:
        graphs = pickle.load(f)
    return graphs

# Load your dataset
pickle_path = './charliehebdo-all-rnr-threads_graphs.pkl'  # Replace with the actual path
graphs = load_custom_dataset(pickle_path)

dataset = graphs

In [9]:
dataset[3]

HeteroData(
  graph_label=[1],
  tweet={
    x=[10, 2],
    text=[10],
  },
  user={ x=[8, 1] },
  (tweet, replies_to, tweet)={ edge_index=[2, 7] },
  (user, authors, tweet)={ edge_index=[2, 10] }
)

In [7]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn import HGTConv
import numpy as np
from tqdm import tqdm

parser = argparse.ArgumentParser(description='Training HGT for Tweet Classification')

'''
    Dataset arguments
'''
parser.add_argument('--model_dir', type=str, default='./model_save',
                    help='The address for storing the models and optimization results.')
parser.add_argument('--task_name', type=str, default='tweet_classification',
                    help='The name of the stored models and optimization results.')
parser.add_argument('--cuda', type=int, default=0,
                    help='Available GPU ID')

'''
   Model arguments 
'''
parser.add_argument('--n_hid', type=int, default=128,
                    help='Number of hidden dimension')
parser.add_argument('--n_heads', type=int, default=4,
                    help='Number of attention head')
parser.add_argument('--n_layers', type=int, default=2,
                    help='Number of GNN layers')
parser.add_argument('--dropout', type=float, default=0.2,
                    help='Dropout ratio')

'''
    Optimization arguments
'''
parser.add_argument('--optimizer', type=str, default='adamw',
                    choices=['adamw', 'adam', 'sgd', 'adagrad'],
                    help='optimizer to use.')
parser.add_argument('--n_epoch', type=int, default=100,
                    help='Number of epoch to run')
parser.add_argument('--batch_size', type=int, default=32,
                    help='Number of graphs in each batch')    
parser.add_argument('--lr', type=float, default=0.001,
                    help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=0.01,
                    help='Weight decay')
parser.add_argument('--clip', type=float, default=0.25,
                    help='Gradient Norm Clipping')

class TweetHGT(torch.nn.Module):
    def __init__(self, hidden_channels, num_heads, num_layers, metadata, dropout=0.2):
        super().__init__()
        
        # Save node types for later use
        self.node_types = metadata[0]
        self.edge_types = metadata[1]
        
        # Input feature dimensions for each node type
        self.lin_dict = torch.nn.ModuleDict()
        for node_type in self.node_types:
            if node_type == 'tweet':
                in_dim = 2
            else:  # user
                in_dim = 1
            self.lin_dict[node_type] = torch.nn.Linear(in_dim, hidden_channels)
        
        # HGT layers
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(
                in_channels=hidden_channels,
                out_channels=hidden_channels,
                metadata=metadata,
                heads=num_heads
            )
            self.convs.append(conv)
        
        # Output layer for tweet classification
        self.classifier = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels, 2)  # Binary classification: real/fake
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x_dict, edge_index_dict):
        # Initial feature transformation
        h_dict = {}
        for node_type, x in x_dict.items():
            h_dict[node_type] = self.dropout(F.relu(self.lin_dict[node_type](x)))
        
        # Apply HGT layers
        for conv in self.convs[:-1]:
            h_dict_new = conv(h_dict, edge_index_dict)
            # Apply dropout and non-linearity between HGT layers
            h_dict = {
                node_type: self.dropout(F.relu(h))
                for node_type, h in h_dict_new.items()
            }
        
        # Last HGT layer (no dropout after last conv)
        h_dict = self.convs[-1](h_dict, edge_index_dict)
        h_dict = {
            node_type: F.relu(h)
            for node_type, h in h_dict.items()
        }
        
        # Get tweet node representations and classify
        tweet_repr = h_dict['tweet']
        out = self.classifier(tweet_repr)
        
        return out

def train(model, train_loader, optimizer, device, criterion):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in tqdm(train_loader, desc='Training'):
        batch = batch.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        out = model(batch.x_dict, batch.edge_index_dict)
        
        # Get labels for tweet nodes
        labels = batch.graph_label.squeeze()
        
        # Calculate loss and accuracy
        loss = criterion(out, labels)
        pred = out.argmax(dim=1)
        correct += int((pred == labels).sum())
        total += len(labels)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(train_loader), correct / total

@torch.no_grad()
def evaluate(model, loader, device, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in loader:
        batch = batch.to(device)
        out = model(batch.x_dict, batch.edge_index_dict)
        labels = batch.graph_label.squeeze()
        
        loss = criterion(out, labels)
        pred = out.argmax(dim=1)
        
        total_loss += loss.item()
        correct += int((pred == labels).sum())
        total += len(labels)
    
    return total_loss / len(loader), correct / total

def main():
    args, unknown = parser.parse_known_args()
    
    device = torch.device(f'cuda:{args.cuda}' if args.cuda >= 0 and torch.cuda.is_available() else 'cpu')
    
    # Load your dataset - replace with your actual dataset loading code
    dataset = graphs
    
    # Split dataset (80/10/10)
    n_samples = len(dataset)
    train_idx = int(0.8 * n_samples)
    val_idx = int(0.9 * n_samples)
    
    train_dataset = dataset[:train_idx]
    val_dataset = dataset[train_idx:val_idx]
    test_dataset = dataset[val_idx:]
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size)
    
    # Define metadata properly for HGTConv
    node_types = ['tweet', 'user']
    edge_types = [
        ('tweet', 'replies_to', 'tweet'),
        ('user', 'authors', 'tweet')
    ]
    metadata = (node_types, edge_types)
    
    # Initialize model
    model = TweetHGT(
        hidden_channels=args.n_hid,
        num_heads=args.n_heads,
        num_layers=args.n_layers,
        metadata=metadata,
        dropout=args.dropout
    ).to(device)
    
    # Initialize optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay
    )
    
    # Loss function
    criterion = nn.CrossEntropyLoss()
    
    # Training loop
    best_val_acc = 0
    best_epoch = 0
    patience = 10
    no_improve = 0
    
    for epoch in range(args.n_epoch):
        # Train
        train_loss, train_acc = train(model, train_loader, optimizer, device, criterion)
        
        # Validate
        val_loss, val_acc = evaluate(model, val_loader, device, criterion)
        
        # Print metrics
        print(f'Epoch: {epoch:02d}, '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch
            torch.save(model.state_dict(), f'{args.model_dir}/{args.task_name}_best.pt')
            no_improve = 0
        else:
            no_improve += 1
        
        # Early stopping
        if no_improve >= patience:
            print(f'Early stopping after {epoch} epochs!')
            break
    
    print(f'Best validation accuracy: {best_val_acc:.4f} at epoch {best_epoch}')
    
    # Test best model
    model.load_state_dict(torch.load(f'{args.model_dir}/{args.task_name}_best.pt'))
    test_loss, test_acc = evaluate(model, test_loader, device, criterion)
    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}')

if __name__ == '__main__':
    main()

Training:   0%|          | 0/52 [00:00<?, ?it/s]


KeyError: 'user'