In [1]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class GCNLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(GCNLayer, self).__init__()
        self.weight = nn.Parameter(torch.randn(input_dim, output_dim))

    def forward(self, node_features, edge_index):
        # Create sparse adjacency matrix
        num_nodes = node_features.size(0)
        adj = torch.sparse_coo_tensor(edge_index, torch.ones(edge_index.shape[1]), (num_nodes, num_nodes))
        
        # Normalize the adjacency matrix
        adj_sum = torch.sparse.sum(adj, dim=1).to_dense()
        D_inv_sqrt = torch.diag(adj_sum ** -0.5)
        norm_adj = torch.sparse.mm(D_inv_sqrt, torch.sparse.mm(adj, D_inv_sqrt))
        
        # Perform graph convolution
        out = torch.sparse.mm(norm_adj, node_features @ self.weight)
        return out

In [3]:
# Step 2: Update WGAN Generator and Discriminator
class WGANGenerator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(WGANGenerator, self).__init__()
        self.gcn1 = GCNLayer(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, z, edge_index, node_features):
        z = self.gcn1(node_features, edge_index)
        z = F.relu(z)
        return self.fc(z)

In [4]:
class WGANDiscriminator(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(WGANDiscriminator, self).__init__()
        self.gcn1 = GCNLayer(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x, edge_index):
        x = self.gcn1(x, edge_index)
        x = F.relu(x)
        return self.fc(x)

In [5]:
# Load graph dataset
node_features, edge_index, labels = torch.load('graph_dataset.pt')

In [6]:
# WGAN parameters
input_size = node_features.shape[1]
hidden_size = 32
output_size = input_size
latent_size = 16

In [7]:
# Initialize WGAN components
generator = WGANGenerator(input_size, hidden_size, output_size)
discriminator = WGANDiscriminator(input_size, hidden_size)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = generator.to(device)
discriminator = discriminator.to(device)
node_features = node_features.to(device)

In [8]:
# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001)

In [9]:
# Function to compute statistics: mean, variance, and standard deviation
def compute_statistics(features):
    mean = torch.mean(features, dim=0)
    var = torch.var(features, dim=0)
    std = torch.std(features, dim=0)
    return mean, var, std

In [10]:
# Print initial statistics
initial_mean, initial_var, initial_std = compute_statistics(node_features)
print(f"Initial Mean: {initial_mean}, Initial Variance: {initial_var}, Initial Std Dev: {initial_std}")

Initial Mean: tensor([2.3957e-03, 6.2278e+00, 4.4526e+02, 2.6677e+01], device='cuda:0'), Initial Variance: tensor([3.0670e-05, 1.5315e+01, 6.6874e+04, 2.0538e+02], device='cuda:0'), Initial Std Dev: tensor([5.5381e-03, 3.9134e+00, 2.5860e+02, 1.4331e+01], device='cuda:0')


In [11]:
# Training with early stopping
num_epochs = 5
target_minority_class = torch.sum(labels == 0)
real_data = node_features[labels == 1]

In [12]:
# Early stopping parameters
best_loss_d = float('inf')
patience = 1
trigger_times = 0

In [13]:
# Parameters
batch_size = 16  # Set a smaller batch size
latent_size = 100  # Size of the latent vector for the generator
num_epochs = 10  # Total number of epochs
target_minority_class = 10000  # Target number of minority class samples
patience = 10  # Early stopping patience
best_loss_d = float('inf')
trigger_times = 0

In [14]:
# Assuming you want to keep only the first 7506 labels
labels = labels[:real_data.size(0)]  # Truncate labels to match real_data size

# Create a dataset and DataLoader for real data
dataset = TensorDataset(real_data, labels)  # Wrap data in a TensorDataset
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [15]:
import psutil

# Function to print memory usage
def print_memory_usage():
    process = psutil.Process()
    memory_info = process.memory_info()
    print(f"Memory usage: {memory_info.rss / (1024 * 1024):.2f} MB")

In [None]:
# Efficient training loop with memory optimizations
for epoch in range(num_epochs):
    current_minority_count = torch.sum(labels == 1).item()
    print(f"Epoch: {epoch}, Current Minority Count: {current_minority_count}")
    
    if current_minority_count >= target_minority_class:
        print("Target minority class reached, ending training.")
        break
    
    for real_batch, label_batch in data_loader:
        z = torch.randn(real_batch.size(0), latent_size)
        
        # Memory-efficient forward pass (use detached tensors where possible)
        fake_data = generator(z, edge_index, node_features).detach()
        
        optimizer_d.zero_grad()
        d_real = discriminator(real_batch.detach(), edge_index)
        d_fake = discriminator(fake_data.detach(), edge_index)
        loss_d = -torch.mean(d_real) + torch.mean(d_fake)
        loss_d.backward()
        optimizer_d.step()

        for p in discriminator.parameters():
            p.data.clamp_(-0.01, 0.01)

        optimizer_g.zero_grad()
        fake_data = generator(torch.randn(real_batch.size(0), latent_size), edge_index, node_features.detach())
        loss_g = -torch.mean(discriminator(fake_data.detach(), edge_index))
        loss_g.backward()
        optimizer_g.step()

    print_memory_usage()  # After each epoch, check memory usage

    torch.cuda.empty_cache()  # Free up GPU memory (if using CUDA)

Epoch: 0, Current Minority Count: 40


In [None]:
# Generate enough samples to match the class distribution
num_generated_samples = target_minority_class - current_minority_count
generated_data = generator(torch.randn(num_generated_samples, latent_size), edge_index, node_features)
y_generated = torch.ones(num_generated_samples, dtype=torch.long)

# Combine generated data with the original data
x_augmented = torch.cat([node_features, generated_data], dim=0)
y_augmented = torch.cat([labels, y_generated], dim=0)

In [None]:
# Print final statistics
final_mean, final_var, final_std = compute_statistics(x_augmented)
print(f"Final Mean: {final_mean}, Final Variance: {final_var}, Final Std Dev: {final_std}")

In [None]:
# Graph visualization function
import networkx as nx
import matplotlib.pyplot as plt

In [None]:
def plot_graph(node_features, edge_index, labels):
    G = nx.DiGraph()
    for i in range(node_features.shape[0]):
        G.add_node(i, features=node_features[i].detach().numpy(), label=labels[i].item())
    for from_node, to_node in edge_index.T:
        G.add_edge(from_node.item(), to_node.item())
    color_map = ['red' if G.nodes[node]['label'] == 1 else 'blue' for node in G.nodes()]
    plt.figure(figsize=(12, 8))
    pos = nx.spring_layout(G)
    nx.draw(G, pos, with_labels=True, node_color=color_map, node_size=100, font_size=8, font_color='white', alpha=0.8)
    plt.title("Graph Dataset Visualization")
    plt.show()

In [None]:
# Call the function with your graph data
plot_graph(node_features, edge_index, labels)