In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
import random
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

# ---------------------------
# 1. Configuration Parameters
# ---------------------------
DATA_FOLDER = './saved_examples'   # Folder containing the data-model pairs
TRAIN_RATIO = 0.9         # Ratio of data to use for training
NUM_EPOCHS = 100          # Number of training epochs
LEARNING_RATE = 1e-4      # Learning rate for the optimizer
BATCH_SIZE = 512            # Number of subsets per batch

# Target Model Configuration
INPUT_SIZE = 784          # Input size (28x28 images flattened)
HIDDEN_SIZE = 8           # Hidden size in the target models
OUTPUT_SIZE = 10          # Number of classes
SUBSET_SIZE = 100         # Number of examples per subset

# Calculated Parameters
PER_EXAMPLE_INPUT_SIZE = INPUT_SIZE + OUTPUT_SIZE  # 784 + 10 = 794
TOTAL_INPUT_SIZE = SUBSET_SIZE * PER_EXAMPLE_INPUT_SIZE  # 100 * 794 = 79,400

# Target Model Parameter Size
# SimpleNN: (784 * 8 + 8) + (8 * 10 + 10) = 6,280 + 90 = 6,370
PARAM_SIZE = (INPUT_SIZE * HIDDEN_SIZE + HIDDEN_SIZE) + (HIDDEN_SIZE * OUTPUT_SIZE + OUTPUT_SIZE)  # 6370

# Hypernetwork Configuration
NUM_EXAMPLES = SUBSET_SIZE       # Number of examples per subset
EMBED_DIM = 256                   # Embedding dimension per example

# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ---------------------------
# 2. Model Definitions
# ---------------------------

class SimpleNN(nn.Module):
    """
    Target neural network whose parameters are to be predicted by the hypernetwork.
    """
    def __init__(self, input_size=784, hidden_size=8, num_layers=1, output_size=10):
        super(SimpleNN, self).__init__()
        layers = []
        current_size = input_size
        for _ in range(num_layers):
            layers.append(nn.Linear(current_size, hidden_size))
            layers.append(nn.ReLU())
            current_size = hidden_size
        layers.append(nn.Linear(current_size, output_size))
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)

class EmbeddingBranch(nn.Module):
    """
    Dedicated embedding branch for each input example.
    """
    def __init__(self, input_size=794, embed_dim=256):
        super(EmbeddingBranch, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.network(x)  # Output shape: (batch_size, embed_dim)

class FusionLayer(nn.Module):
    """
    Fusion layer to combine embeddings from all examples using attention.
    """
    def __init__(self, embed_dim=256, num_examples=100):
        super(FusionLayer, self).__init__()
        self.num_examples = num_examples
        self.attention = nn.MultiheadAttention(embed_dim, num_heads=8, dropout=0.1)
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.feedforward = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x):
        """
        Args:
            x: Tensor of shape (sequence_length=num_examples, batch_size, embed_dim)
        Returns:
            Tensor of shape (sequence_length=num_examples, batch_size, embed_dim)
        """
        attn_output, _ = self.attention(x, x, x)
        x = self.layer_norm(x + self.dropout(attn_output))
        ff_output = self.feedforward(x)
        x = self.layer_norm(x + self.dropout(ff_output))
        return x  # Shape: same as input

class ParameterGenerator(nn.Module):
    """
    Generates the parameter vector from the fused embeddings.
    """
    def __init__(self, embed_dim=256, param_size=6370):
        super(ParameterGenerator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(embed_dim, 4096),
            nn.ReLU(),
            nn.Linear(4096, param_size)
        )
    
    def forward(self, x):
        return self.network(x)  # Shape: (batch_size, param_size)

class MultiBranchHypernetwork(nn.Module):
    """
    Hypernetwork with dedicated embedding branches for each input example.
    """
    def __init__(self, num_examples=100, per_example_input_size=794, embed_dim=256, param_size=6370):
        super(MultiBranchHypernetwork, self).__init__()
        self.num_examples = num_examples
        self.per_example_input_size = per_example_input_size
        self.embed_dim = embed_dim
        
        # Create a list of embedding branches, one for each example
        self.embedding_branches = nn.ModuleList([
            EmbeddingBranch(input_size=per_example_input_size, embed_dim=embed_dim) for _ in range(num_examples)
        ])
        
        # Fusion layer to combine all embeddings
        self.fusion_layer = FusionLayer(embed_dim=embed_dim, num_examples=num_examples)
        
        # Parameter generator
        self.param_generator = ParameterGenerator(embed_dim=embed_dim, param_size=param_size)
    
    def forward(self, x):
        """
        Args:
            x: Tensor of shape (batch_size, num_examples * per_example_input_size)
        Returns:
            param_vector: Tensor of shape (batch_size, param_size)
        """
        batch_size = x.size(0)
        # Split the input into individual examples
        x = x.view(batch_size, self.num_examples, self.per_example_input_size)  # Shape: (batch_size, num_examples, per_example_input_size)
        
        # Pass each example through its dedicated embedding branch
        embeddings = []
        for i in range(self.num_examples):
            emb = self.embedding_branches[i](x[:, i, :])  # Shape: (batch_size, embed_dim)
            embeddings.append(emb)
        
        # Stack embeddings to form a sequence for the fusion layer
        embeddings = torch.stack(embeddings, dim=0)  # Shape: (num_examples, batch_size, embed_dim)
        
        # Apply fusion layer (e.g., attention-based)
        fused = self.fusion_layer(embeddings)  # Shape: (num_examples, batch_size, embed_dim)
        
        # Aggregate transformer outputs: mean pooling over the sequence
        aggregated = fused.mean(dim=0)  # Shape: (batch_size, embed_dim)
        
        # Generate parameter vector
        param_vector = self.param_generator(aggregated)  # Shape: (batch_size, param_size)
        
        return param_vector

# ---------------------------
# 3. Dataset and DataLoader
# ---------------------------

class HypernetworkDataset(Dataset):
    """
    Custom Dataset for Hypernetwork.
    """
    def __init__(self, data_list):
        """
        Args:
            data_list (list): List of tuples (input_vector, param_vector).
        """
        self.data_list = data_list
    
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        input_vector, param_vector = self.data_list[idx]
        return input_vector, param_vector

def load_data(data_folder):
    """
    Loads the data-model pairs from the specified folder.

    Args:
        data_folder (str): Directory containing the saved examples.

    Returns:
        data_list (list): List of tuples (input_vector, param_vector).
    """
    data_list = []
    for file_name in os.listdir(data_folder):
        if file_name.endswith('.pt'):
            file_path = os.path.join(data_folder, file_name)
            checkpoint = torch.load(file_path, map_location='cpu')  # Load on CPU to avoid CUDA issues in workers
            data = checkpoint['data']   # Tensor of shape (100, 784)
            targets = checkpoint['targets']   # Tensor of shape (100)
            model_state_dict = checkpoint['model_state_dict']  # Dictionary of model parameters

            # Initialize the model and load state_dict
            model = SimpleNN(input_size=784, hidden_size=8, num_layers=1, output_size=10).to('cpu')  # Initialize on CPU
            try:
                model.load_state_dict(model_state_dict)
            except RuntimeError as e:
                print(f"Error loading state_dict for file {file_name}: {e}")
                continue  # Skip this file if there's an error

            # Flatten model parameters into a single vector
            param_vector = flatten_model_parameters(model)  # Shape: (6370,)

            # Flatten data and labels into a single input vector
            labels_one_hot = torch.nn.functional.one_hot(targets, num_classes=10).float()  # Shape: (100, 10)
            input_vector = torch.cat([data, labels_one_hot], dim=1).view(-1)  # Shape: (100 * 794,) = (79,400,)

            data_list.append((input_vector, param_vector))
    return data_list

def split_data(data_list, train_ratio=TRAIN_RATIO):
    """
    Splits the data into training and test sets.

    Args:
        data_list (list): List of data-model pairs.
        train_ratio (float): Ratio of data to use for training.

    Returns:
        train_data (list): Training data.
        test_data (list): Test data.
    """
    random.shuffle(data_list)
    split_idx = int(len(data_list) * train_ratio)
    train_data = data_list[:split_idx]
    test_data = data_list[split_idx:]
    return train_data, test_data

def flatten_model_parameters(model):
    """
    Flattens all parameters of the model into a single vector.

    Args:
        model (nn.Module): The neural network model.

    Returns:
        flat_params (Tensor): Flattened parameter vector.
    """
    params = []
    for param in model.parameters():
        params.append(param.view(-1))
    return torch.cat(params)

# ---------------------------
# 4. Training and Evaluation Functions
# ---------------------------

def train_hypernetwork(train_loader, hypernetwork, optimizer, mse_criterion, ce_criterion, epoch):
    """
    Trains the hypernetwork for one epoch using a combined loss function.

    Args:
        train_loader (DataLoader): DataLoader for training data.
        hypernetwork (nn.Module): Hypernetwork to generate model parameters.
        optimizer (Optimizer): Optimizer for the hypernetwork.
        mse_criterion (Loss): Mean Squared Error loss for parameter regression.
        ce_criterion (Loss): Cross-Entropy loss for model performance.
        epoch (int): Current epoch number.

    Returns:
        avg_loss (float): Average combined loss over the epoch.
        avg_train_acc (float): Average training accuracy over the epoch.
    """
    hypernetwork.train()
    total_loss = 0.0
    total_train_acc = 0.0
    num_batches = len(train_loader)
    
    for batch_idx, (input_batch, param_batch) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")):
        input_batch = input_batch.to(device)       # Shape: (B, 79,400)
        param_batch = param_batch.to(device)       # Shape: (B, 6370)
        
        # Generate predicted parameters
        predicted_params = hypernetwork(input_batch)        # Shape: (B, 6370)
        
        # Define MSE loss between predicted_params and true params
        mse_loss = mse_criterion(predicted_params, param_batch)
        
        # Initialize a list to hold Cross-Entropy losses
        ce_losses = []
        correct_predictions = 0
        total_predictions = 0
        
        # Forward pass through each subset in the batch
        for i in range(predicted_params.size(0)):
            # Load predicted parameters into a new SimpleNN model
            generated_model = SimpleNN(input_size=784, hidden_size=8, num_layers=1, output_size=10).to(device)
            load_parameters_into_model(generated_model, predicted_params[i])  # Shape: (6370,)
            
            # Extract input data and labels for the subset
            subset_input = input_batch[i].view(SUBSET_SIZE, PER_EXAMPLE_INPUT_SIZE)[:, :INPUT_SIZE]  # Shape: (100, 784)
            subset_labels = input_batch[i].view(SUBSET_SIZE, PER_EXAMPLE_INPUT_SIZE)[:, INPUT_SIZE:].argmax(dim=1)  # Shape: (100,)
            
            # Forward pass through the generated model
            outputs = generated_model(subset_input)  # Shape: (100, 10)
            
            # Compute Cross-Entropy loss
            ce_loss = ce_criterion(outputs, subset_labels)
            ce_losses.append(ce_loss)
            
            # Compute accuracy
            _, predicted_labels = torch.max(outputs, 1)
            correct_predictions += (predicted_labels == subset_labels).sum().item()
            total_predictions += subset_labels.size(0)
        
        # Aggregate CE losses
        ce_loss = torch.stack(ce_losses).mean()
        
        # Combined loss
        combined_loss = mse_loss + ce_loss
        
        # Backpropagation
        optimizer.zero_grad()
        combined_loss.backward()
        optimizer.step()
        
        # Accumulate loss and accuracy
        total_loss += combined_loss.item()
        total_train_acc += (correct_predictions / total_predictions) * 100  # Percentage
    
    avg_loss = total_loss / num_batches
    avg_train_acc = total_train_acc / num_batches
    return avg_loss, avg_train_acc

def evaluate_hypernetwork(test_loader, hypernetwork, mse_criterion, ce_criterion):
    """
    Evaluates the hypernetwork on the test data.

    Args:
        test_loader (DataLoader): DataLoader for test data.
        hypernetwork (nn.Module): Hypernetwork to generate model parameters.
        mse_criterion (Loss): Mean Squared Error loss for parameter regression.
        ce_criterion (Loss): Cross-Entropy loss for model performance.

    Returns:
        avg_test_loss (float): Average combined loss on the test data.
        avg_test_acc (float): Average test accuracy.
    """
    hypernetwork.eval()
    total_loss = 0.0
    total_test_acc = 0.0
    num_batches = len(test_loader)
    
    with torch.no_grad():
        for batch_idx, (input_batch, param_batch) in enumerate(tqdm(test_loader, desc="Evaluating on Test Data")):
            input_batch = input_batch.to(device)       # Shape: (B, 79,400)
            param_batch = param_batch.to(device)       # Shape: (B, 6370)
            
            # Generate predicted parameters
            predicted_params = hypernetwork(input_batch)        # Shape: (B, 6370)
            
            # Define MSE loss between predicted_params and true params
            mse_loss = mse_criterion(predicted_params, param_batch)
            
            # Initialize a list to hold Cross-Entropy losses
            ce_losses = []
            correct_predictions = 0
            total_predictions = 0
            
            # Forward pass through each subset in the batch
            for i in range(predicted_params.size(0)):
                # Load predicted parameters into a new SimpleNN model
                generated_model = SimpleNN(input_size=784, hidden_size=8, num_layers=1, output_size=10).to(device)
                load_parameters_into_model(generated_model, predicted_params[i])  # Shape: (6370,)
                
                # Extract input data and labels for the subset
                subset_input = input_batch[i].view(SUBSET_SIZE, PER_EXAMPLE_INPUT_SIZE)[:, :INPUT_SIZE]  # Shape: (100, 784)
                subset_labels = input_batch[i].view(SUBSET_SIZE, PER_EXAMPLE_INPUT_SIZE)[:, INPUT_SIZE:].argmax(dim=1)  # Shape: (100,)
                
                # Forward pass through the generated model
                outputs = generated_model(subset_input)  # Shape: (100, 10)
                
                # Compute Cross-Entropy loss
                ce_loss = ce_criterion(outputs, subset_labels)
                ce_losses.append(ce_loss)
                
                # Compute accuracy
                _, predicted_labels = torch.max(outputs, 1)
                correct_predictions += (predicted_labels == subset_labels).sum().item()
                total_predictions += subset_labels.size(0)
            
            # Aggregate CE losses
            ce_loss = torch.stack(ce_losses).mean()
            
            # Combined loss
            combined_loss = mse_loss + ce_loss
            
            # Accumulate loss and accuracy
            total_loss += combined_loss.item()
            total_test_acc += (correct_predictions / total_predictions) * 100  # Percentage
    
    avg_test_loss = total_loss / num_batches
    avg_test_acc = total_test_acc / num_batches
    return avg_test_loss, avg_test_acc

# ---------------------------
# 5. Utility Functions
# ---------------------------
def load_parameters_into_model(model, flat_params):
    """
    Loads a flat parameter vector into the model's parameters.

    Args:
        model (nn.Module): The neural network model.
        flat_params (Tensor): Flat parameter vector.
    """
    current_index = 0
    for param in model.parameters():
        param_size = param.numel()
        param.data.copy_(flat_params[current_index:current_index+param_size].view(param.size()))
        current_index += param_size

# ---------------------------
# 6. Main Execution
# ---------------------------
def main():
    # Load data
    print("Loading data...")
    data_list = load_data(DATA_FOLDER)
    print(f"Total data-model pairs loaded: {len(data_list)}\n")
    
    if len(data_list) == 0:
        print("No data found in the specified DATA_FOLDER. Please ensure that the folder contains .pt files.")
        return
    
    # Split data
    train_data, test_data = split_data(data_list)
    print(f"Training data size: {len(train_data)}, Test data size: {len(test_data)}\n")
    
    # Create Datasets and DataLoaders
    train_dataset = HypernetworkDataset(train_data)
    test_dataset = HypernetworkDataset(test_data)
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=0,  # Set to 0 to avoid CUDA initialization issues
        pin_memory=True if torch.cuda.is_available() else False
    )
    test_loader = DataLoader(
        test_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=0,  # Set to 0 to avoid CUDA initialization issues
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    # Initialize Hypernetwork
    hypernetwork = MultiBranchHypernetwork(
        num_examples=NUM_EXAMPLES,
        per_example_input_size=PER_EXAMPLE_INPUT_SIZE,
        embed_dim=EMBED_DIM,
        param_size=PARAM_SIZE
    ).to(device)
    
    # Define loss functions and optimizer
    mse_criterion = nn.MSELoss()
    ce_criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(hypernetwork.parameters(), lr=LEARNING_RATE)
    
    # Training loop
    for epoch in range(NUM_EPOCHS):
        avg_loss, avg_train_acc = train_hypernetwork(
            train_loader, hypernetwork, optimizer, mse_criterion, ce_criterion, epoch
        )
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] - Combined Loss: {avg_loss:.6f}, "
              f"Training Accuracy: {avg_train_acc:.2f}%\n")
    
    # Final Evaluation on Test Data
    final_test_loss, final_test_acc = evaluate_hypernetwork(test_loader, hypernetwork, mse_criterion, ce_criterion)
    print(f"Final Evaluation on Test Data - Combined Loss: {final_test_loss:.6f}, "
          f"Test Accuracy: {final_test_acc:.2f}%")
    
    # Save the trained hypernetwork
    save_path = 'trained_multi_branch_hypernetwork.pt'
    torch.save({
        'hypernetwork_state_dict': hypernetwork.state_dict(),
    }, save_path)
    print(f"Hypernetwork training completed and saved to '{save_path}'.")

if __name__ == "__main__":
    main()


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
import random
from torch.utils.data import Dataset, DataLoader

# ---------------------------
# 1. Configuration Parameters
# ---------------------------
DATA_FOLDER = './saved_examples'   # Folder containing the data-model pairs
TRAIN_RATIO = 0.99         # Ratio of data to use for training
NUM_EPOCHS = 100          # Number of training epochs
LEARNING_RATE = 1e-4      # Learning rate for the optimizer
BATCH_SIZE = 512            # Number of subsets per batch

# Target Model Configuration
INPUT_SIZE = 784          # Input size (28x28 images flattened)
HIDDEN_SIZE = 8           # Hidden size in the target models
OUTPUT_SIZE = 10          # Number of classes
SUBSET_SIZE = 100         # Number of examples per subset

# Calculated Parameters
PER_EXAMPLE_INPUT_SIZE = INPUT_SIZE + OUTPUT_SIZE  # 784 + 10 = 794
TOTAL_INPUT_SIZE = SUBSET_SIZE * PER_EXAMPLE_INPUT_SIZE  # 100 * 794 = 79,400

# Target Model Parameter Size
# SimpleNN: (784 * 8 + 8) + (8 * 10 + 10) = 6,280 + 90 = 6,370
PARAM_SIZE = (INPUT_SIZE * HIDDEN_SIZE + HIDDEN_SIZE) + (HIDDEN_SIZE * OUTPUT_SIZE + OUTPUT_SIZE)  # 6370

# Hypernetwork Configuration
NUM_EXAMPLES = SUBSET_SIZE       # Number of examples per subset
EMBED_DIM = 256                   # Embedding dimension per example

# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ---------------------------
# 2. Model Definitions
# ---------------------------

class SimpleNN(nn.Module):
    """
    Target neural network whose parameters are to be predicted by the hypernetwork.
    """
    def __init__(self, input_size=784, hidden_size=8, num_layers=1, output_size=10):
        super(SimpleNN, self).__init__()
        layers = []
        current_size = input_size
        for _ in range(num_layers):
            layers.append(nn.Linear(current_size, hidden_size))
            layers.append(nn.ReLU())
            current_size = hidden_size
        layers.append(nn.Linear(current_size, output_size))
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)

class EmbeddingBranch(nn.Module):
    """
    Dedicated embedding branch for each input example.
    """
    def __init__(self, input_size=794, embed_dim=256):
        super(EmbeddingBranch, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.network(x)  # Output shape: (batch_size, embed_dim)

class FusionLayer(nn.Module):
    """
    Fusion layer to combine embeddings from all examples using attention.
    """
    def __init__(self, embed_dim=256, num_examples=100):
        super(FusionLayer, self).__init__()
        self.num_examples = num_examples
        self.attention = nn.MultiheadAttention(embed_dim, num_heads=8, dropout=0.1)
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.feedforward = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x):
        """
        Args:
            x: Tensor of shape (sequence_length=num_examples, batch_size, embed_dim)
        Returns:
            Tensor of shape (sequence_length=num_examples, batch_size, embed_dim)
        """
        attn_output, _ = self.attention(x, x, x)
        x = self.layer_norm(x + self.dropout(attn_output))
        ff_output = self.feedforward(x)
        x = self.layer_norm(x + self.dropout(ff_output))
        return x  # Shape: same as input

class ParameterGenerator(nn.Module):
    """
    Generates the parameter vector from the fused embeddings.
    """
    def __init__(self, embed_dim=256, param_size=6370):
        super(ParameterGenerator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(embed_dim, 4096),
            nn.ReLU(),
            nn.Linear(4096, param_size)
        )
    
    def forward(self, x):
        return self.network(x)  # Shape: (batch_size, param_size)

class MultiBranchHypernetwork(nn.Module):
    """
    Hypernetwork with dedicated embedding branches for each input example.
    """
    def __init__(self, num_examples=100, per_example_input_size=794, embed_dim=256, param_size=6370):
        super(MultiBranchHypernetwork, self).__init__()
        self.num_examples = num_examples
        self.per_example_input_size = per_example_input_size
        self.embed_dim = embed_dim
        
        # Create a list of embedding branches, one for each example
        self.embedding_branches = nn.ModuleList([
            EmbeddingBranch(input_size=per_example_input_size, embed_dim=embed_dim) for _ in range(num_examples)
        ])
        
        # Fusion layer to combine all embeddings
        self.fusion_layer = FusionLayer(embed_dim=embed_dim, num_examples=num_examples)
        
        # Parameter generator
        self.param_generator = ParameterGenerator(embed_dim=embed_dim, param_size=param_size)
    
    def forward(self, x):
        """
        Args:
            x: Tensor of shape (batch_size, num_examples * per_example_input_size)
        Returns:
            param_vector: Tensor of shape (batch_size, param_size)
        """
        batch_size = x.size(0)
        # Split the input into individual examples
        x = x.view(batch_size, self.num_examples, self.per_example_input_size)  # Shape: (batch_size, num_examples, per_example_input_size)
        
        # Pass each example through its dedicated embedding branch
        embeddings = []
        for i in range(self.num_examples):
            emb = self.embedding_branches[i](x[:, i, :])  # Shape: (batch_size, embed_dim)
            embeddings.append(emb)
        
        # Stack embeddings to form a sequence for the fusion layer
        embeddings = torch.stack(embeddings, dim=0)  # Shape: (num_examples, batch_size, embed_dim)
        
        # Apply fusion layer (e.g., attention-based)
        fused = self.fusion_layer(embeddings)  # Shape: (num_examples, batch_size, embed_dim)
        
        # Aggregate transformer outputs: mean pooling over the sequence
        aggregated = fused.mean(dim=0)  # Shape: (batch_size, embed_dim)
        
        # Generate parameter vector
        param_vector = self.param_generator(aggregated)  # Shape: (batch_size, param_size)
        
        return param_vector

# ---------------------------
# 3. Dataset and DataLoader
# ---------------------------

class HypernetworkDataset(Dataset):
    """
    Custom Dataset for Hypernetwork.
    """
    def __init__(self, data_list):
        """
        Args:
            data_list (list): List of tuples (input_vector, param_vector).
        """
        self.data_list = data_list
    
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        input_vector, param_vector = self.data_list[idx]
        return input_vector, param_vector

def load_data(data_folder):
    """
    Loads the data-model pairs from the specified folder.

    Args:
        data_folder (str): Directory containing the saved examples.

    Returns:
        data_list (list): List of tuples (input_vector, param_vector).
    """
    data_list = []
    for file_name in os.listdir(data_folder):
        if file_name.endswith('.pt'):
            file_path = os.path.join(data_folder, file_name)
            checkpoint = torch.load(file_path, map_location='cpu')  # Load on CPU to avoid CUDA issues in workers
            data = checkpoint['data']   # Tensor of shape (100, 784)
            targets = checkpoint['targets']   # Tensor of shape (100)
            model_state_dict = checkpoint['model_state_dict']  # Dictionary of model parameters

            # Initialize the model and load state_dict
            model = SimpleNN(input_size=784, hidden_size=8, num_layers=1, output_size=10).to('cpu')  # Initialize on CPU
            try:
                model.load_state_dict(model_state_dict)
            except RuntimeError as e:
                print(f"Error loading state_dict for file {file_name}: {e}")
                continue  # Skip this file if there's an error

            # Flatten model parameters into a single vector
            param_vector = flatten_model_parameters(model)  # Shape: (6370,)

            # Flatten data and labels into a single input vector
            labels_one_hot = torch.nn.functional.one_hot(targets, num_classes=10).float()  # Shape: (100, 10)
            input_vector = torch.cat([data, labels_one_hot], dim=1).view(-1)  # Shape: (100 * 794,) = (79,400,)

            data_list.append((input_vector, param_vector))
    return data_list

def split_data(data_list, train_ratio=TRAIN_RATIO):
    """
    Splits the data into training and test sets.

    Args:
        data_list (list): List of data-model pairs.
        train_ratio (float): Ratio of data to use for training.

    Returns:
        train_data (list): Training data.
        test_data (list): Test data.
    """
    random.shuffle(data_list)
    split_idx = int(len(data_list) * train_ratio)
    train_data = data_list[:split_idx]
    test_data = data_list[split_idx:]
    return train_data, test_data

def flatten_model_parameters(model):
    """
    Flattens all parameters of the model into a single vector.

    Args:
        model (nn.Module): The neural network model.

    Returns:
        flat_params (Tensor): Flattened parameter vector.
    """
    params = []
    for param in model.parameters():
        params.append(param.view(-1))
    return torch.cat(params)

# ---------------------------
# 4. Training and Evaluation Functions
# ---------------------------

def train_hypernetwork(train_loader, hypernetwork, optimizer, mse_criterion, ce_criterion):
    """
    Trains the hypernetwork for one epoch using a combined loss function.

    Args:
        train_loader (DataLoader): DataLoader for training data.
        hypernetwork (nn.Module): Hypernetwork to generate model parameters.
        optimizer (Optimizer): Optimizer for the hypernetwork.
        mse_criterion (Loss): Mean Squared Error loss for parameter regression.
        ce_criterion (Loss): Cross-Entropy loss for model performance.

    Returns:
        avg_loss (float): Average combined loss over the epoch.
        avg_train_acc (float): Average training accuracy over the epoch.
    """
    hypernetwork.train()
    total_loss = 0.0
    total_train_acc = 0.0
    num_batches = len(train_loader)
    
    for batch_idx, (input_batch, param_batch) in enumerate(train_loader):
        input_batch = input_batch.to(device)       # Shape: (B, 79,400)
        param_batch = param_batch.to(device)       # Shape: (B, 6370)
        
        # Generate predicted parameters
        predicted_params = hypernetwork(input_batch)        # Shape: (B, 6370)
        
        # Define MSE loss between predicted_params and true params
        mse_loss = mse_criterion(predicted_params, param_batch)
        
        # Initialize a list to hold Cross-Entropy losses
        ce_losses = []
        correct_predictions = 0
        total_predictions = 0
        
        # Forward pass through each subset in the batch
        for i in range(predicted_params.size(0)):
            # Load predicted parameters into a new SimpleNN model
            generated_model = SimpleNN(input_size=784, hidden_size=8, num_layers=1, output_size=10).to(device)
            load_parameters_into_model(generated_model, predicted_params[i])  # Shape: (6370,)
            
            # Extract input data and labels for the subset
            subset_input = input_batch[i].view(SUBSET_SIZE, PER_EXAMPLE_INPUT_SIZE)[:, :INPUT_SIZE]  # Shape: (100, 784)
            subset_labels = input_batch[i].view(SUBSET_SIZE, PER_EXAMPLE_INPUT_SIZE)[:, INPUT_SIZE:].argmax(dim=1)  # Shape: (100,)
            
            # Forward pass through the generated model
            outputs = generated_model(subset_input)  # Shape: (100, 10)
            
            # Compute Cross-Entropy loss
            ce_loss = ce_criterion(outputs, subset_labels)
            ce_losses.append(ce_loss)
            
            # Compute accuracy
            _, predicted_labels = torch.max(outputs, 1)
            correct_predictions += (predicted_labels == subset_labels).sum().item()
            total_predictions += subset_labels.size(0)
        
        # Aggregate CE losses
        ce_loss = torch.stack(ce_losses).mean()
        
        # Combined loss
        combined_loss = mse_loss + ce_loss
        
        # Backpropagation
        optimizer.zero_grad()
        combined_loss.backward()
        optimizer.step()
        
        # Accumulate loss and accuracy
        total_loss += combined_loss.item()
        total_train_acc += (correct_predictions / total_predictions) * 100  # Percentage
    
    avg_loss = total_loss / num_batches
    avg_train_acc = total_train_acc / num_batches
    return avg_loss, avg_train_acc

def evaluate_hypernetwork(test_loader, hypernetwork, mse_criterion, ce_criterion):
    """
    Evaluates the hypernetwork on the test data.

    Args:
        test_loader (DataLoader): DataLoader for test data.
        hypernetwork (nn.Module): Hypernetwork to generate model parameters.
        mse_criterion (Loss): Mean Squared Error loss for parameter regression.
        ce_criterion (Loss): Cross-Entropy loss for model performance.

    Returns:
        avg_test_loss (float): Average combined loss on the test data.
        avg_test_acc (float): Average test accuracy.
    """
    hypernetwork.eval()
    total_loss = 0.0
    total_test_acc = 0.0
    num_batches = len(test_loader)
    
    with torch.no_grad():
        for batch_idx, (input_batch, param_batch) in enumerate(test_loader):
            input_batch = input_batch.to(device)       # Shape: (B, 79,400)
            param_batch = param_batch.to(device)       # Shape: (B, 6370)
            
            # Generate predicted parameters
            predicted_params = hypernetwork(input_batch)        # Shape: (B, 6370)
            
            # Define MSE loss between predicted_params and true params
            mse_loss = mse_criterion(predicted_params, param_batch)
            
            # Initialize a list to hold Cross-Entropy losses
            ce_losses = []
            correct_predictions = 0
            total_predictions = 0
            
            # Forward pass through each subset in the batch
            for i in range(predicted_params.size(0)):
                # Load predicted parameters into a new SimpleNN model
                generated_model = SimpleNN(input_size=784, hidden_size=8, num_layers=1, output_size=10).to(device)
                load_parameters_into_model(generated_model, predicted_params[i])  # Shape: (6370,)
                
                # Extract input data and labels for the subset
                subset_input = input_batch[i].view(SUBSET_SIZE, PER_EXAMPLE_INPUT_SIZE)[:, :INPUT_SIZE]  # Shape: (100, 784)
                subset_labels = input_batch[i].view(SUBSET_SIZE, PER_EXAMPLE_INPUT_SIZE)[:, INPUT_SIZE:].argmax(dim=1)  # Shape: (100,)
                
                # Forward pass through the generated model
                outputs = generated_model(subset_input)  # Shape: (100, 10)
                
                # Compute Cross-Entropy loss
                ce_loss = ce_criterion(outputs, subset_labels)
                ce_losses.append(ce_loss)
                
                # Compute accuracy
                _, predicted_labels = torch.max(outputs, 1)
                correct_predictions += (predicted_labels == subset_labels).sum().item()
                total_predictions += subset_labels.size(0)
            
            # Aggregate CE losses
            ce_loss = torch.stack(ce_losses).mean()
            
            # Combined loss
            combined_loss = mse_loss + ce_loss
            
            # Accumulate loss and accuracy
            total_loss += combined_loss.item()
            total_test_acc += (correct_predictions / total_predictions) * 100  # Percentage
    
    avg_test_loss = total_loss / num_batches
    avg_test_acc = total_test_acc / num_batches
    return avg_test_loss, avg_test_acc

# ---------------------------
# 5. Utility Functions
# ---------------------------
def load_parameters_into_model(model, flat_params):
    """
    Loads a flat parameter vector into the model's parameters.

    Args:
        model (nn.Module): The neural network model.
        flat_params (Tensor): Flat parameter vector.
    """
    current_index = 0
    for param in model.parameters():
        param_size = param.numel()
        param.data.copy_(flat_params[current_index:current_index+param_size].view(param.size()))
        current_index += param_size

# ---------------------------
# 6. Main Execution
# ---------------------------
def main():
    # Load data
    print("Loading data...")
    data_list = load_data(DATA_FOLDER)
    print(f"Total data-model pairs loaded: {len(data_list)}\n")
    
    if len(data_list) == 0:
        print("No data found in the specified DATA_FOLDER. Please ensure that the folder contains .pt files.")
        return
    
    # Split data
    train_data, test_data = split_data(data_list)
    print(f"Training data size: {len(train_data)}, Test data size: {len(test_data)}\n")
    
    # Create Datasets and DataLoaders
    train_dataset = HypernetworkDataset(train_data)
    test_dataset = HypernetworkDataset(test_data)
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=0,  # Set to 0 to avoid CUDA initialization issues
        pin_memory=True if torch.cuda.is_available() else False
    )
    test_loader = DataLoader(
        test_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=0,  # Set to 0 to avoid CUDA initialization issues
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    # Initialize Hypernetwork
    hypernetwork = MultiBranchHypernetwork(
        num_examples=NUM_EXAMPLES,
        per_example_input_size=PER_EXAMPLE_INPUT_SIZE,
        embed_dim=EMBED_DIM,
        param_size=PARAM_SIZE
    ).to(device)
    
    # Define loss functions and optimizer
    mse_criterion = nn.MSELoss()
    ce_criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(hypernetwork.parameters(), lr=LEARNING_RATE)
    
    # Training loop
    for epoch in range(NUM_EPOCHS):
        avg_loss, avg_train_acc = train_hypernetwork(
            train_loader, hypernetwork, optimizer, mse_criterion, ce_criterion
        )
        avg_test_loss, avg_test_acc = evaluate_hypernetwork(
            test_loader, hypernetwork, mse_criterion, ce_criterion
        )
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] - Loss: {avg_loss:.6f}, "
              f"Train Acc: {avg_train_acc:.2f}%, Test Acc: {avg_test_acc:.2f}%")
    
    # Save the trained hypernetwork
    save_path = 'trained_multi_branch_hypernetwork.pt'
    torch.save({
        'hypernetwork_state_dict': hypernetwork.state_dict(),
    }, save_path)
    print(f"\nHypernetwork training completed and saved to '{save_path}'.")

if __name__ == "__main__":
    main()

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
import random
from tqdm import tqdm

# ---------------------------
# 1. Configuration Parameters
# ---------------------------
DATA_FOLDER = './large'   # Folder containing the data-model pairs
TRAIN_RATIO = 0.9                  # Ratio of data to use for training
NUM_EPOCHS = 100                    # Number of training epochs
LEARNING_RATE = 1e-4               # Learning rate for the optimizer

# Target Model Configuration
INPUT_SIZE = 784                   # Input size (28x28 images flattened)
HIDDEN_SIZE = 8                    # Hidden size in the target models
OUTPUT_SIZE = 10                   # Number of classes
SUBSET_SIZE = 100                  # Number of examples per subset

# Calculated Parameters
PER_EXAMPLE_INPUT_SIZE = INPUT_SIZE + OUTPUT_SIZE  # 784 + 10 = 794
TOTAL_INPUT_SIZE = SUBSET_SIZE * PER_EXAMPLE_INPUT_SIZE  # 100 * 794 = 79,400

# Target Model Parameter Size
# SimpleNN: (784 * 8 + 8) + (8 * 10 + 10) = 6,280 + 90 = 6,370
PARAM_SIZE = (INPUT_SIZE * HIDDEN_SIZE + HIDDEN_SIZE) + (HIDDEN_SIZE * OUTPUT_SIZE + OUTPUT_SIZE)  # 6370

# Hypernetwork Configuration
NUM_EXAMPLES = SUBSET_SIZE       # Number of examples per subset
EMBED_DIM = 256                   # Embedding dimension per example

# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ---------------------------
# 2. Model Definitions
# ---------------------------

class SimpleNN(nn.Module):
    """
    Target neural network whose parameters are to be predicted by the hypernetwork.
    """
    def __init__(self, input_size=784, hidden_size=8, num_layers=1, output_size=10):
        super(SimpleNN, self).__init__()
        layers = []
        current_size = input_size
        for _ in range(num_layers):
            layers.append(nn.Linear(current_size, hidden_size))
            layers.append(nn.ReLU())
            current_size = hidden_size
        layers.append(nn.Linear(current_size, output_size))
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)

class EmbeddingBranch(nn.Module):
    """
    Dedicated embedding branch for each input example.
    """
    def __init__(self, input_size=794, embed_dim=256):
        super(EmbeddingBranch, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.network(x)  # Output shape: (batch_size, embed_dim)

class FusionLayer(nn.Module):
    """
    Fusion layer to combine embeddings from all examples using attention.
    """
    def __init__(self, embed_dim=256, num_examples=100):
        super(FusionLayer, self).__init__()
        self.num_examples = num_examples
        self.attention = nn.MultiheadAttention(embed_dim, num_heads=8, dropout=0.1)
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.feedforward = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x):
        """
        Args:
            x: Tensor of shape (sequence_length=num_examples, batch_size, embed_dim)
        Returns:
            Tensor of shape (sequence_length=num_examples, batch_size, embed_dim)
        """
        attn_output, _ = self.attention(x, x, x)
        x = self.layer_norm(x + self.dropout(attn_output))
        ff_output = self.feedforward(x)
        x = self.layer_norm(x + self.dropout(ff_output))
        return x  # Shape: same as input

class ParameterGenerator(nn.Module):
    """
    Generates the parameter vector from the fused embeddings.
    """
    def __init__(self, embed_dim=256, param_size=6370):
        super(ParameterGenerator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(embed_dim, 4096),
            nn.ReLU(),
            nn.Linear(4096, param_size)
        )
    
    def forward(self, x):
        return self.network(x)  # Shape: (batch_size, param_size)

class MultiBranchHypernetwork(nn.Module):
    """
    Hypernetwork with dedicated embedding branches for each input example.
    """
    def __init__(self, num_examples=100, per_example_input_size=794, embed_dim=256, param_size=6370):
        super(MultiBranchHypernetwork, self).__init__()
        self.num_examples = num_examples
        self.per_example_input_size = per_example_input_size
        self.embed_dim = embed_dim
        
        # Create a list of embedding branches, one for each example
        self.embedding_branches = nn.ModuleList([
            EmbeddingBranch(input_size=per_example_input_size, embed_dim=embed_dim) for _ in range(num_examples)
        ])
        
        # Fusion layer to combine all embeddings
        self.fusion_layer = FusionLayer(embed_dim=embed_dim, num_examples=num_examples)
        
        # Parameter generator
        self.param_generator = ParameterGenerator(embed_dim=embed_dim, param_size=param_size)
    
    def forward(self, x):
        """
        Args:
            x: Tensor of shape (batch_size, num_examples * per_example_input_size)
        Returns:
            param_vector: Tensor of shape (batch_size, param_size)
        """
        batch_size = x.size(0)
        # Split the input into individual examples
        x = x.view(batch_size, self.num_examples, self.per_example_input_size)  # Shape: (batch_size, num_examples, per_example_input_size)
        
        # Pass each example through its dedicated embedding branch
        embeddings = []
        for i in range(self.num_examples):
            emb = self.embedding_branches[i](x[:, i, :])  # Shape: (batch_size, embed_dim)
            embeddings.append(emb)
        
        # Stack embeddings to form a sequence for the fusion layer
        embeddings = torch.stack(embeddings, dim=0)  # Shape: (num_examples, batch_size, embed_dim)
        
        # Apply fusion layer (e.g., attention-based)
        fused = self.fusion_layer(embeddings)  # Shape: (num_examples, batch_size, embed_dim)
        
        # Aggregate transformer outputs: mean pooling over the sequence
        aggregated = fused.mean(dim=0)  # Shape: (batch_size, embed_dim)
        
        # Generate parameter vector
        param_vector = self.param_generator(aggregated)  # Shape: (batch_size, param_size)
        
        return param_vector

# ---------------------------
# 3. Data Loading Functions
# ---------------------------
def load_data(data_folder):
    """
    Loads the data-model pairs from the specified folder.

    Args:
        data_folder (str): Directory containing the saved examples.

    Returns:
        data_list (list): List of tuples (input_vector, param_vector).
    """
    data_list = []
    for file_name in os.listdir(data_folder):
        if file_name.endswith('.pt'):
            file_path = os.path.join(data_folder, file_name)
            checkpoint = torch.load(file_path, map_location=device)
            data = checkpoint['data']   # Tensor of shape (100, 784)
            targets = checkpoint['targets']   # Tensor of shape (100)
            model_state_dict = checkpoint['model_state_dict']  # Dictionary of model parameters

            # Initialize the model and load state_dict
            model = SimpleNN(input_size=784, hidden_size=8, num_layers=1, output_size=10).to(device)
            try:
                model.load_state_dict(model_state_dict)
            except RuntimeError as e:
                print(f"Error loading state_dict for file {file_name}: {e}")
                continue  # Skip this file if there's an error

            # Flatten model parameters into a single vector
            param_vector = flatten_model_parameters(model)  # Shape: (6370,)

            # Flatten data and labels into a single input vector
            labels_one_hot = torch.nn.functional.one_hot(targets, num_classes=10).float()  # Shape: (100, 10)
            input_vector = torch.cat([data, labels_one_hot], dim=1).view(-1)  # Shape: (100 * 794,) = (79,400,)

            data_list.append((input_vector, param_vector))
    return data_list

def split_data(data_list, train_ratio=TRAIN_RATIO):
    """
    Splits the data into training and test sets.

    Args:
        data_list (list): List of data-model pairs.
        train_ratio (float): Ratio of data to use for training.

    Returns:
        train_data (list): Training data.
        test_data (list): Test data.
    """
    random.shuffle(data_list)
    split_idx = int(len(data_list) * train_ratio)
    train_data = data_list[:split_idx]
    test_data = data_list[split_idx:]
    return train_data, test_data

def flatten_model_parameters(model):
    """
    Flattens all parameters of the model into a single vector.

    Args:
        model (nn.Module): The neural network model.

    Returns:
        flat_params (Tensor): Flattened parameter vector.
    """
    params = []
    for param in model.parameters():
        params.append(param.view(-1))
    return torch.cat(params)

# ---------------------------
# 4. Training and Evaluation Functions
# ---------------------------
def train_hypernetwork(train_data, hypernetwork, optimizer, mse_criterion, ce_criterion, epoch):
    """
    Trains the hypernetwork for one epoch using a combined loss function.

    Args:
        train_data (list): List of training data-model pairs.
        hypernetwork (nn.Module): Hypernetwork to generate model parameters.
        optimizer (Optimizer): Optimizer for the hypernetwork.
        mse_criterion (Loss): Mean Squared Error loss for parameter regression.
        ce_criterion (Loss): Cross-Entropy loss for model performance.
        epoch (int): Current epoch number.

    Returns:
        avg_loss (float): Average combined loss over the epoch.
        avg_train_acc (float): Average training accuracy over the epoch.
    """
    hypernetwork.train()
    total_loss = 0.0
    total_train_acc = 0.0
    num_batches = len(train_data)
    
    for idx in tqdm(range(len(train_data)), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        input_vector, param_vector = train_data[idx]
        input_vector = input_vector.to(device).unsqueeze(0)  # Shape: (1, 79,400)
        param_vector = param_vector.to(device)               # Shape: (6370,)
        
        # Generate predicted parameters
        predicted_params = hypernetwork(input_vector)        # Shape: (1, 6370)
        predicted_params = predicted_params.squeeze(0)       # Shape: (6370,)
        
        # Initialize a new SimpleNN model and load predicted parameters
        generated_model = SimpleNN(input_size=784, hidden_size=8, num_layers=1, output_size=10).to(device)
        load_parameters_into_model(generated_model, predicted_params)
        
        # Define MSE loss between predicted_params and true params
        mse_loss = mse_criterion(predicted_params, param_vector)
        
        # Extract input data and labels from the input_vector
        input_data = input_vector.view(SUBSET_SIZE, PER_EXAMPLE_INPUT_SIZE)[:, :INPUT_SIZE]  # Shape: (100, 784)
        labels = input_vector.view(SUBSET_SIZE, PER_EXAMPLE_INPUT_SIZE)[:, INPUT_SIZE:].argmax(dim=1)  # Shape: (100,)
        
        # Forward pass through the generated model
        outputs = generated_model(input_data)  # Shape: (100, 10)
        
        # Compute Cross-Entropy loss
        ce_loss = ce_criterion(outputs, labels)
        
        # Combined loss
        combined_loss = mse_loss + ce_loss
        
        # Backpropagation
        optimizer.zero_grad()
        combined_loss.backward()
        optimizer.step()
        
        # Accumulate loss
        total_loss += combined_loss.item()
        
        # Compute training accuracy
        _, predicted_labels = torch.max(outputs, 1)
        train_acc = (predicted_labels == labels).sum().item() / labels.size(0)
        total_train_acc += train_acc
    
    avg_loss = total_loss / num_batches
    avg_train_acc = (total_train_acc / num_batches) * 100  # Percentage
    return avg_loss, avg_train_acc

def evaluate_hypernetwork(test_data, hypernetwork, mse_criterion, ce_criterion):
    """
    Evaluates the hypernetwork on the test data.

    Args:
        test_data (list): List of test data-model pairs.
        hypernetwork (nn.Module): Hypernetwork to generate model parameters.
        mse_criterion (Loss): Mean Squared Error loss for parameter regression.
        ce_criterion (Loss): Cross-Entropy loss for model performance.

    Returns:
        avg_test_loss (float): Average combined loss on the test data.
        avg_test_acc (float): Average test accuracy.
    """
    hypernetwork.eval()
    total_loss = 0.0
    total_test_acc = 0.0
    num_batches = len(test_data)
    
    with torch.no_grad():
        for idx in tqdm(range(len(test_data)), desc="Evaluating on Test Data"):
            input_vector, param_vector = test_data[idx]
            input_vector = input_vector.to(device).unsqueeze(0)  # Shape: (1, 79,400)
            param_vector = param_vector.to(device)               # Shape: (6370,)
            
            # Generate predicted parameters
            predicted_params = hypernetwork(input_vector)        # Shape: (1, 6370)
            predicted_params = predicted_params.squeeze(0)       # Shape: (6370,)
            
            # Initialize a new SimpleNN model and load predicted parameters
            generated_model = SimpleNN(input_size=784, hidden_size=8, num_layers=1, output_size=10).to(device)
            load_parameters_into_model(generated_model, predicted_params)
            
            # Define MSE loss between predicted_params and true params
            mse_loss = mse_criterion(predicted_params, param_vector)
            
            # Extract input data and labels from the input_vector
            input_data = input_vector.view(SUBSET_SIZE, PER_EXAMPLE_INPUT_SIZE)[:, :INPUT_SIZE]  # Shape: (100, 784)
            labels = input_vector.view(SUBSET_SIZE, PER_EXAMPLE_INPUT_SIZE)[:, INPUT_SIZE:].argmax(dim=1)  # Shape: (100,)
            
            # Forward pass through the generated model
            outputs = generated_model(input_data)  # Shape: (100, 10)
            
            # Compute Cross-Entropy loss
            ce_loss = ce_criterion(outputs, labels)
            
            # Combined loss
            combined_loss = mse_loss + ce_loss
            
            # Accumulate loss
            total_loss += combined_loss.item()
            
            # Compute test accuracy
            _, predicted_labels = torch.max(outputs, 1)
            test_acc = (predicted_labels == labels).sum().item() / labels.size(0)
            total_test_acc += test_acc
    
    avg_test_loss = total_loss / num_batches
    avg_test_acc = (total_test_acc / num_batches) * 100  # Percentage
    
    return avg_test_loss, avg_test_acc

# ---------------------------
# 5. Utility Functions
# ---------------------------
def load_parameters_into_model(model, flat_params):
    """
    Loads a flat parameter vector into the model's parameters.

    Args:
        model (nn.Module): The neural network model.
        flat_params (Tensor): Flat parameter vector.
    """
    current_index = 0
    for param in model.parameters():
        param_size = param.numel()
        param.data.copy_(flat_params[current_index:current_index+param_size].view(param.size()))
        current_index += param_size

# ---------------------------
# 6. Main Execution
# ---------------------------
def main():
    # Load data
    print("Loading data...")
    data_list = load_data(DATA_FOLDER)
    print(f"Total data-model pairs loaded: {len(data_list)}\n")
    
    if len(data_list) == 0:
        print("No data found in the specified DATA_FOLDER. Please ensure that the folder contains .pt files.")
        return
    
    # Split data
    train_data, test_data = split_data(data_list)
    print(f"Training data size: {len(train_data)}, Test data size: {len(test_data)}\n")
    
    # Initialize Hypernetwork
    hypernetwork = MultiBranchHypernetwork(
        num_examples=NUM_EXAMPLES,
        per_example_input_size=PER_EXAMPLE_INPUT_SIZE,
        embed_dim=EMBED_DIM,
        param_size=PARAM_SIZE
    ).to(device)
    
    # Define loss functions and optimizer
    mse_criterion = nn.MSELoss()
    ce_criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(hypernetwork.parameters(), lr=LEARNING_RATE)
    
    # Training loop
    for epoch in range(NUM_EPOCHS):
        avg_loss, avg_train_acc = train_hypernetwork(
            train_data, hypernetwork, optimizer, mse_criterion, ce_criterion, epoch
        )
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] - Combined Loss: {avg_loss:.6f}, "
              f"Training Accuracy: {avg_train_acc:.2f}%\n")
    
    # Final Evaluation on Test Data
    final_test_loss, final_test_acc = evaluate_hypernetwork(test_data, hypernetwork, mse_criterion, ce_criterion)
    print(f"Final Evaluation on Test Data - Combined Loss: {final_test_loss:.6f}, "
          f"Test Accuracy: {final_test_acc:.2f}%")
    
    # Save the trained hypernetwork
    save_path = 'trained_multi_branch_hypernetwork.pt'
    torch.save({
        'hypernetwork_state_dict': hypernetwork.state_dict(),
    }, save_path)
    print(f"Hypernetwork training completed and saved to '{save_path}'.")

if __name__ == "__main__":
    main()


Using device: cuda
Loading data...
Total data-model pairs loaded: 547

Training data size: 492, Test data size: 55



Epoch 1/100:  51%|█████     | 252/492 [00:16<00:15, 15.65it/s]


KeyboardInterrupt: 