Concatenated Features (DNN feature extractors)

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_max_pool
import networkx as nx
import numpy as np
import random
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Set random seeds for reproducibility
np.random.seed(1)
random.seed(1)
torch.manual_seed(1)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

# Function to normalize features across datasets
def normalize_features(datasets):
    all_features = torch.cat([data.x for data in datasets], dim=0)
    scaler = StandardScaler()
    scaler.fit(all_features.numpy())
    for data in datasets:
        data.x = torch.tensor(scaler.transform(data.x.numpy()), dtype=torch.float)
    return datasets


def generate_image_graph(num_nodes, num_features, prob_edge=0.1):
    G = nx.erdos_renyi_graph(n=num_nodes, p=prob_edge)
    x = torch.normal(mean=0.5, std=0.1, size=(num_nodes, num_features)).clamp(0, 1)
    x = (x - x.min(dim=0)[0]) / (x.max(dim=0)[0] - x.min(dim=0)[0] + 1e-5)
    edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
    return Data(x=x, edge_index=edge_index)

def generate_video_graph(num_nodes, num_features, prob_edge=0.05):
    G = nx.barabasi_albert_graph(n=num_nodes, m=2)
    base_counts = torch.poisson(torch.ones(num_nodes, num_features) * 2)
    x = torch.exp(-base_counts.float() / 5)
    mask = torch.bernoulli(torch.full((num_nodes, num_features), 0.2))
    x = x * mask
    edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
    return Data(x=x, edge_index=edge_index)

def generate_text_graph(num_nodes, num_features, prob_edge=0.08):
    G = nx.watts_strogatz_graph(n=num_nodes, k=4, p=prob_edge)
    x = torch.randint(0, 2, (num_nodes, num_features)).float()
    for i in range(num_features):
        x[:, i] = torch.randint(0, 3, (num_nodes,)).float()
    edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
    return Data(x=x, edge_index=edge_index)

# Generate initial feature sets for each dataset
def generate_initial_feature_set(num_graphs, num_nodes, num_features, graph_type):
    dataset = []
    for _ in range(num_graphs):
        if graph_type == 'image':
            graph = generate_image_graph(num_nodes, num_features)
        elif graph_type == 'video':
            graph = generate_video_graph(num_nodes, num_features)
        elif graph_type == 'text':
            graph = generate_text_graph(num_nodes, num_features)
        dataset.append(graph)
    return dataset

# compute common labels based on all feature sets
def compute_common_labels(image_dataset, video_dataset, text_dataset):
    common_labels = []
    for i in range(len(image_dataset)):
        image_features = image_dataset[i].x
        video_features = video_dataset[i].x
        text_features = text_dataset[i].x

        combined_stat = (
            torch.sum(torch.sin(image_features[:, :image_features.shape[1] // 3] * 2)) +
            torch.prod(torch.tan(video_features[:, :video_features.shape[1] // 3] + 0.5)) -
            torch.sum(torch.abs(torch.cos(text_features[:, :text_features.shape[1] // 4] * 1.5)))
        )

        transformed_stat = (
            torch.log1p(torch.abs(combined_stat)) * torch.exp(-combined_stat) * 0.5 +
            torch.mean(torch.tanh(image_features)) +
            torch.sum(torch.sigmoid(video_features[:, -video_features.shape[1] // 4:])) -
            torch.sqrt(torch.sum(text_features[:, -text_features.shape[1] // 5]))
        )

        label = int((torch.cos(transformed_stat * 1.2) + torch.sin(transformed_stat * 1.2) >= 0.5))
        common_labels.append(label)
    return common_labels

# Assign common labels to each dataset
def assign_common_labels(dataset, common_labels):
    for i, graph in enumerate(dataset):
        graph.y = torch.tensor([common_labels[i]], dtype=torch.long)
    return dataset

# Define DNN feature extractor
class DNNFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(DNNFeatureExtractor, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, hidden_dim // 4)
        self.dropout = nn.Dropout(p=0.3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return x

# Train DNN feature extractors
def train_dnn_extractors(datasets, dnn_models, epochs=100, lr=0.002):
    optimizers = [torch.optim.Adam(dnn.parameters(), lr=lr) for dnn in dnn_models]
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        for i in range(len(datasets[0])):
            for j, dnn in enumerate(dnn_models):
                optimizer = optimizers[j]
                optimizer.zero_grad()
                x, y = datasets[j][i].x, datasets[j][i].y
                out = dnn(x).mean(dim=0, keepdim=True)
                loss = criterion(out, y)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
        print(f"Epoch {epoch + 1}, Loss: {total_loss / len(datasets[0]):.4f}")

# Extract features
def extract_features(dataset, dnn_model):
    features = []
    with torch.no_grad():
        for data in dataset:
            feature = dnn_model(data.x).mean(dim=0)
            features.append(feature)
    return torch.stack(features)

# Train final DNN classifier
class DNNClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(DNNClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, num_classes)
        self.dropout = nn.Dropout(p=0.4)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

def train_final_dnn(train_features, val_features, train_labels, val_labels, patience=10, weight_decay=1e-4):
    input_dim = train_features.shape[1]
    dnn = DNNClassifier(input_dim=input_dim, hidden_dim=128, num_classes=2)
    optimizer = torch.optim.Adam(dnn.parameters(), lr=0.002, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(200):
        dnn.train()
        optimizer.zero_grad()
        out = dnn(train_features)
        loss = criterion(out, train_labels)
        loss.backward()
        optimizer.step()

        dnn.eval()
        with torch.no_grad():
            val_out = dnn(val_features)
            val_loss = criterion(val_out, val_labels)
            val_accuracy = (val_out.argmax(dim=1) == val_labels).float().mean().item()

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

        print(f"Epoch {epoch + 1}, Training Loss: {loss.item():.4f}, Validation Loss: {val_loss.item():.4f}, Validation Accuracy: {val_accuracy * 100:.2f}%")

    return dnn


# Generate datasets
num_graphs, num_nodes, num_features = 2000, 50, 40
image_dataset = generate_initial_feature_set(num_graphs, num_nodes, num_features, 'image')
video_dataset = generate_initial_feature_set(num_graphs, num_nodes, num_features, 'video')
text_dataset = generate_initial_feature_set(num_graphs, num_nodes, num_features, 'text')

# Assign common labels
common_labels = compute_common_labels(image_dataset, video_dataset, text_dataset)
image_dataset = assign_common_labels(image_dataset, common_labels)
video_dataset = assign_common_labels(video_dataset, common_labels)
text_dataset = assign_common_labels(text_dataset, common_labels)

image_dataset = normalize_features(image_dataset)
video_dataset = normalize_features(video_dataset)
text_dataset = normalize_features(text_dataset)

train_image, test_image = train_test_split(image_dataset, test_size=0.2, random_state=1)
train_video, test_video = train_test_split(video_dataset, test_size=0.2, random_state=1)
train_text, test_text = train_test_split(text_dataset, test_size=0.2, random_state=1)

image_dnn = DNNFeatureExtractor(num_features, 64)
video_dnn = DNNFeatureExtractor(num_features, 64)
text_dnn = DNNFeatureExtractor(num_features, 64)

train_dnn_extractors([train_image, train_video, train_text], [image_dnn, video_dnn, text_dnn])

train_features_image = extract_features(train_image, image_dnn)
train_features_video = extract_features(train_video, video_dnn)
train_features_text = extract_features(train_text, text_dnn)

test_features_image = extract_features(test_image, image_dnn)
test_features_video = extract_features(test_video, video_dnn)
test_features_text = extract_features(test_text, text_dnn)

train_features = torch.cat([train_features_image, train_features_video, train_features_text], dim=1)
test_features = torch.cat([test_features_image, test_features_video, test_features_text], dim=1)

train_labels = torch.cat([data.y for data in train_image])
test_labels = torch.cat([data.y for data in test_image])

train_features, val_features, train_labels, val_labels = train_test_split(
    train_features, train_labels, test_size=0.2, random_state=1
)


final_model = train_final_dnn(train_features, val_features, train_labels, val_labels)

final_model.eval()
with torch.no_grad():
    test_out = final_model(test_features)
    test_accuracy = (test_out.argmax(dim=1) == test_labels).float().mean().item()

print(f"Final Test Accuracy: {test_accuracy}%")


Epoch 1, Loss: 2.0269
Epoch 2, Loss: 1.8656
Epoch 3, Loss: 1.8555
Epoch 4, Loss: 1.8481
Epoch 5, Loss: 1.8404
Epoch 6, Loss: 1.8324
Epoch 7, Loss: 1.8194
Epoch 8, Loss: 1.8026
Epoch 9, Loss: 1.7827
Epoch 10, Loss: 1.7577
Epoch 11, Loss: 1.7320
Epoch 12, Loss: 1.7049
Epoch 13, Loss: 1.6673
Epoch 14, Loss: 1.6138
Epoch 15, Loss: 1.5718
Epoch 16, Loss: 1.5220
Epoch 17, Loss: 1.4885
Epoch 18, Loss: 1.4401
Epoch 19, Loss: 1.3832
Epoch 20, Loss: 1.3293
Epoch 21, Loss: 1.3010
Epoch 22, Loss: 1.2693
Epoch 23, Loss: 1.2246
Epoch 24, Loss: 1.1735
Epoch 25, Loss: 1.1607
Epoch 26, Loss: 1.1283
Epoch 27, Loss: 1.0918
Epoch 28, Loss: 1.0790
Epoch 29, Loss: 1.0453
Epoch 30, Loss: 1.0554
Epoch 31, Loss: 0.9509
Epoch 32, Loss: 0.9735
Epoch 33, Loss: 1.0036
Epoch 34, Loss: 0.9797
Epoch 35, Loss: 0.9065
Epoch 36, Loss: 0.9285
Epoch 37, Loss: 0.8939
Epoch 38, Loss: 0.8329
Epoch 39, Loss: 0.8621
Epoch 40, Loss: 0.8687
Epoch 41, Loss: 0.8315
Epoch 42, Loss: 0.8192
Epoch 43, Loss: 0.8187
Epoch 44, Loss: 0.79