In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
import random
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np



# Define transforms (resize all images to 224x224)
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),  # Resize all images to 224x224
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
# ])
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
# Path to your ImageNet data
data_dir = '/home/pratibha/nas_vision/vit_nas_imgnet/imagenet200'

# Load ImageNet dataset and filter only the first 200 classes
filtered_dataset = datasets.ImageFolder(root=data_dir, transform=transform)
# Use only the first 200 classes

In [2]:
train_size = int(0.8 * len(filtered_dataset))
test_size = len(filtered_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(filtered_dataset, [train_size, test_size])

# Create DataLoader for training and testing
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Check the number of samples in each set
print(f"Training set size: {len(train_dataset)}")
print(f"Test set size: {len(test_dataset)}")

Training set size: 99548
Test set size: 24888


## dynamic vit model

In [None]:

# import torch
# import torch.nn as nn

# class DynamicPatchEmbed(nn.Module):
#     def __init__(self, img_size=224, patch_size=16, embed_dim=768):
#         super().__init__()
#         self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
#         self.num_patches = (img_size // patch_size) ** 2

#     def forward(self, x):
#         x = self.proj(x)
#         return x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)




# class DynamicMultiHeadAttention(nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super().__init__()
#         self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
#         self.proj = nn.Linear(embed_dim, embed_dim)
#         self.scale = (embed_dim // num_heads) ** -0.5
#         self.num_heads = num_heads  # Store num_heads as a class attribute

#         # Ensure that the number of heads divides the embedding dimension
#         assert embed_dim % num_heads == 0, f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"

#     def forward(self, x):
#         B, N, C = x.shape
#         qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

#         q, k, v = qkv[0], qkv[1], qkv[2]
#         attn = (q @ k.transpose(-2, -1)) * self.scale
#         attn = attn.softmax(dim=-1)
#         x = (attn @ v).transpose(1, 2).reshape(B, N, C)
#         return self.proj(x)

# class MLPBlock(nn.Module):  
#     def __init__(self, embed_dim, mlp_ratio):
#         super().__init__()
#         hidden_dim = int(embed_dim * mlp_ratio)
#         self.fc1 = nn.Linear(embed_dim, hidden_dim)  # Matches `mlp.fc1`
#         self.act = nn.GELU()
#         self.fc2 = nn.Linear(hidden_dim, embed_dim)  # Matches `mlp.fc2`

#     def forward(self, x):
#         return self.fc2(self.act(self.fc1(x)))

# class DynamicTransformerBlock(nn.Module):
#     def __init__(self, embed_dim, num_heads, mlp_ratio=4.0):
#         super().__init__()
#         self.norm1 = nn.LayerNorm(embed_dim)
#         self.attn = DynamicMultiHeadAttention(embed_dim, num_heads)
#         self.norm2 = nn.LayerNorm(embed_dim)
        
#         #  Fix: Wrap MLP inside a separate module to match ViT
#         self.mlp = MLPBlock(embed_dim, mlp_ratio)  

#     def forward(self, x):
#         x = x + self.attn(self.norm1(x))
#         x = x + self.mlp(self.norm2(x))
#         return x

# class DynamicViT(nn.Module):
#     def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, num_classes=10):
#         super().__init__()
#         self.patch_embed = DynamicPatchEmbed(img_size, patch_size, embed_dim)
        
#         #  Fix: Correct positional embedding key
#         self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, embed_dim))
        
#         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
#         self.blocks = nn.ModuleList([DynamicTransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth)])
#         self.norm = nn.LayerNorm(embed_dim)
#         self.head = nn.Linear(embed_dim, num_classes)

#     def forward(self, x):
#         x = self.patch_embed(x)
#         B = x.shape[0]

#         # Add class token
#         cls_tokens = self.cls_token.expand(B, -1, -1)
#         x = torch.cat((cls_tokens, x), dim=1)
        
#         x = x + self.pos_embed

#         for block in self.blocks:
#             x = block(x)

#         x = self.norm(x[:, 0])
#         return self.head(x)

In [None]:
# from timm import create_model

# # Load pretrained ViT-Base
# pretrained_vit = create_model("vit_base_patch16_224", pretrained=True)
# pretrained_state_dict = pretrained_vit.state_dict()

# # Initialize our super network
# super_vit = DynamicViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, num_classes=1000)

# ## check this whether it is 1000 or 200 and finetune 

# # Filter matching weights
# model_state_dict = super_vit.state_dict()
# filtered_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}

# # Load pretrained weights
# super_vit.load_state_dict(filtered_dict, strict=False)

In [None]:
# just defining model again here for easily avaliability
import torch
import torch.nn as nn

class DynamicPatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size // patch_size) ** 2

    def forward(self, x):
        x = self.proj(x)
        return x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)


class DynamicMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.scale = (embed_dim // num_heads) ** -0.5
        self.num_heads = num_heads  # Store num_heads as a class attribute

        # Ensure that the number of heads divides the embedding dimension
        assert embed_dim % num_heads == 0, f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)

class MLPBlock(nn.Module):  
    def __init__(self, embed_dim, mlp_ratio):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)  # Matches `mlp.fc1`
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)  # Matches `mlp.fc2`

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

class DynamicTransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = DynamicMultiHeadAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        #  Fix: Wrap MLP inside a separate module to match ViT
        self.mlp = MLPBlock(embed_dim, mlp_ratio)  

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# class DynamicViT(nn.Module):
#     def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, num_classes=10):
#         super().__init__()
#         self.patch_embed = DynamicPatchEmbed(img_size, patch_size, embed_dim)
        
#         #  Fix: Correct positional embedding key
#         self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, embed_dim))
        
#         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
#         self.blocks = nn.ModuleList([DynamicTransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth)])
#         self.norm = nn.LayerNorm(embed_dim)
#         self.head = nn.Linear(embed_dim, num_classes)

#     def forward(self, x):
#         x = self.patch_embed(x)
#         B = x.shape[0]

#         # Add class token
#         cls_tokens = self.cls_token.expand(B, -1, -1)
#         x = torch.cat((cls_tokens, x), dim=1)
        
#         x = x + self.pos_embed

#         for block in self.blocks:
#             x = block(x)

#         x = self.norm(x[:, 0])
#         return self.head(x)


class DynamicViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, num_classes=200):
        super().__init__()
        self.depth = depth  # Store depth as an instance variable
        self.num_heads = num_heads  # Store num_heads as an instance variable
        self.mlp_ratio = mlp_ratio  # Store mlp_ratio as an instance variable
        
        self.patch_embed = DynamicPatchEmbed(img_size, patch_size, embed_dim)
        
        # Fix: Correct positional embedding key
        self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, embed_dim))
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.blocks = nn.ModuleList([DynamicTransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        B = x.shape[0]

        # Add class token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        x = x + self.pos_embed

        for block in self.blocks:
            x = block(x)

        x = self.norm(x[:, 0])
        return self.head(x)


In [None]:
# import random
# import os
# import torch
# import torch.nn as nn
# from torch.optim import Adam
# from timm import create_model
# import time

# # Path to save the models after fine-tuning
# SAVE_PATH = '/home/pratibha/nas_vision/weights-cifar5'
# # SAVE_PATH = '/kaggle/working/'

# # Set the device (GPU if available, else CPU)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # First-time loading pretrained weights for initialization
# def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
#     pretrained_vit = create_model(pretrained_model_name, pretrained=True)
#     pretrained_state_dict = pretrained_vit.state_dict()
    
#     # Match keys between pretrained and current model
#     model_state_dict = model.state_dict()
#     filtered_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}

#     # Load pretrained weights
#     model.load_state_dict(filtered_dict, strict=False)
#     print(f"Pretrained weights loaded into {model.__class__.__name__} successfully.")

# # Check if pretrained weights are loaded correctly
# def check_pretrained_weights(model, generation=0, model_type="subnetwork"):
#     pretrained_vit = create_model("vit_base_patch16_224", pretrained=True)
#     pretrained_state_dict = pretrained_vit.state_dict()
    
#     model_state_dict = model.state_dict()
#     matching_keys = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
    
#     if len(matching_keys) > 0:
#         print(f"Generation {generation + 1}: {model_type} model has loaded {len(matching_keys)} layers from pretrained weights.")
#     else:
#         print(f"Generation {generation + 1}: {model_type} model has NOT loaded any pretrained weights.")

# # Sample Subnetwork - Randomly sample hyperparameters (depth, num_heads, etc.)
# def sample_subnetwork(seen_architectures):
#     while True:
#         depth = random.choice([4, 6, 8, 10, 12])
#         num_heads = random.choice([4, 8, 12, 16])
#         mlp_ratio = random.choice([2.0, 4.0, 6.0])
#         embed_dim = 768  # Fixed embedding dimension
        
#         architecture = (depth, num_heads, mlp_ratio, embed_dim)
        
#         # Skip if architecture has already been sampled
#         if architecture not in seen_architectures:
#             seen_architectures.add(architecture)
#             print(f"Sampled architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
            
#             # Create the model to calculate its number of parameters
#             sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000)
#             num_params = count_parameters(sampled_model)
#             print(f"Number of parameters in the sampled model: {num_params:,}")
            
#             return architecture
#         else:
#             print(f"Repeated architecture found, resampling...")

# # Count number of trainable parameters
# def count_parameters(model):
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # Evaluate architecture: accuracy, latency, and memory usage
# def evaluate_architecture(model, test_loader):
#     model.eval()
#     correct = 0
#     total = 0
#     running_loss = 0.0
#     y_true = []
#     y_pred = []
    
#     criterion = nn.CrossEntropyLoss()

#     # Start measuring inference latency
#     start_time = time.time()

#     with torch.no_grad():
#         for images, labels in test_loader:
#             images, labels = images.to(device), labels.to(device)  # Move to the same device
#             outputs = model(images)
#             loss = criterion(outputs, labels)
#             running_loss += loss.item()

#             _, predicted = torch.max(outputs, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#             y_true.extend(labels.cpu().numpy())
#             y_pred.extend(predicted.cpu().numpy())

#     # Measure total time for inference (latency)
#     latency = (time.time() - start_time) / len(test_loader.dataset)

#     # Compute accuracy
#     accuracy = 100 * correct / total

#     # Compute memory usage (rough estimation)
#     memory_usage = estimate_memory_usage(model)

#     # Compute average loss
#     test_loss = running_loss / len(test_loader)

#     print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.2f}%, Latency: {latency:.6f} seconds/image, Memory Usage: {memory_usage:.2f} MB")

#     return accuracy, test_loss, latency, memory_usage

# # Estimate memory usage of a model during inference (rough estimation)
# def estimate_memory_usage(model):
#     # Create dummy input matching the expected shape of the input tensor
#     dummy_input = torch.randn(1, 3, 224, 224).to(device)  # Example for ViT (3-channel image of size 224x224)
    
#     # Use torch.utils.benchmark to measure memory usage during inference
#     start_mem = torch.cuda.memory_allocated()
    
#     # Run the model once with the dummy input
#     with torch.no_grad():
#         model(dummy_input)
    
#     end_mem = torch.cuda.memory_allocated()
#     memory_usage = (end_mem - start_mem) / (1024 ** 2)  # Convert bytes to MB
#     return memory_usage


# def calculate_crowding_distance(population, test_loader):
#     crowding_distances = [0] * len(population)
#     num_objectives = 3  # Accuracy, Latency, Memory

#     # Evaluate each architecture once, then reuse the results
#     evaluated_results = []
#     for arch in population:
#         model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
#                            depth=arch[0], num_heads=arch[1],
#                            mlp_ratio=arch[2], num_classes=10).to(device)
#         accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
#         memory = count_parameters(model) * 4  # memory in bytes
        
#         evaluated_results.append((accuracy, latency, memory))
#         del model
#         torch.cuda.empty_cache()

#     for objective_index in range(num_objectives):
#         sorted_indices = sorted(range(len(population)),
#                                 key=lambda idx: evaluated_results[idx][objective_index])
        
#         crowding_distances[sorted_indices[0]] = crowding_distances[sorted_indices[-1]] = float('inf')

#         for i in range(1, len(sorted_indices) - 1):
#             prev_value = evaluated_results[sorted_indices[i - 1]][objective_index]
#             next_value = evaluated_results[sorted_indices[i + 1]][objective_index]
#             distance = next_value - prev_value
#             crowding_distances[sorted_indices[i]] += distance

#     return crowding_distances


# def dominates(model1, model2, test_loader):
#     # Evaluate both models on the test set
#     accuracy1, latency1, _, _ = evaluate_architecture(model1, test_loader)
#     accuracy2, latency2, _, _ = evaluate_architecture(model2, test_loader)
    
#     # Calculate memory usage as the number of parameters * 4 bytes (FP32)
#     memory1 = count_parameters(model1) * 4  # Memory in bytes
#     memory2 = count_parameters(model2) * 4  # Memory in bytes
    
#     # Compare performance metrics
#     dominates_in_accuracy = accuracy1 >= accuracy2
#     dominates_in_latency = latency1 <= latency2
#     dominates_in_memory = memory1 <= memory2

#     # Return True if model1 dominates model2 in all aspects
#     return dominates_in_accuracy and dominates_in_latency and dominates_in_memory


# # Mutation: Randomly mutate architecture's hyperparameters
# def mutate(architecture):
#     depth, num_heads, mlp_ratio, embed_dim = architecture
#     if random.random() < 0.5: depth = random.choice([4, 6, 8, 10, 12])
#     if random.random() < 0.5: num_heads = random.choice([4, 8, 12, 16])
#     if random.random() < 0.5: mlp_ratio = random.choice([2.0, 4.0, 6.0])
#     print(f"Mutated architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
#     return depth, num_heads, mlp_ratio, embed_dim

# # One-Point Crossover: Combine two parent architectures to create new architectures
# def one_point_crossover(parent1, parent2):
#     crossover_point = random.choice([0, 1, 2, 3])  # Crossover at depth, num_heads, etc.
#     child1 = parent1[:crossover_point] + parent2[crossover_point:]
#     child2 = parent2[:crossover_point] + parent1[crossover_point:]
#     print(f"Crossover result: Child1={child1}, Child2={child2}")
#     return child1, child2




# # Optimized Pareto selection based on stored performance metrics
# def pareto_selection(arch_performance):
#     def dominates(perf1, perf2):
#         acc1, lat1, mem1 = perf1
#         acc2, lat2, mem2 = perf2
#         return (acc1 >= acc2 and lat1 <= lat2 and mem1 <= mem2) and (acc1 > acc2 or lat1 < lat2 or mem1 < mem2)

#     ranks = {}
#     for arch1, perf1 in arch_performance.items():
#         dominated_count = 0
#         for arch2, perf2 in arch_performance.items():
#             if arch1 != arch2 and dominates(perf2, perf1):
#                 dominated_count += 1
#         ranks[arch1] = dominated_count

#     # Sort architectures by rank (lower dominated_count = better)
#     sorted_population = sorted(ranks.keys(), key=lambda arch: ranks[arch])
#     return sorted_population

# # Fine-tune model on dataset (train for a few epochs)
# def fine_tune_model(sampled_model, train_loader, test_loader, epochs=3, architecture_folder=None):
#     print(f"Fine-tuning model with architecture: Depth={sampled_model.depth}, Num Heads={sampled_model.num_heads}, MLP Ratio={sampled_model.mlp_ratio}")
#     sampled_model.to(device)  # Ensure the model is on the correct device
#     criterion = nn.CrossEntropyLoss()
#     optimizer = Adam(sampled_model.parameters(), lr=1e-4)
    
#     for epoch in range(epochs):
#         sampled_model.train()
#         running_loss = 0.0
#         for images, labels in train_loader:
#             images, labels = images.to(device), labels.to(device)  # Ensure inputs are on the same device
#             optimizer.zero_grad()
#             outputs = sampled_model(images)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#             running_loss += loss.item()
        
#         test_accuracy, test_loss, test_latency, memory_usage = evaluate_architecture(sampled_model, test_loader)
#         print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%, Test Latency: {test_latency:.6f} seconds/image")

#     # Save the model after fine-tuning
#     if architecture_folder:
#         os.makedirs(architecture_folder, exist_ok=True)
#         torch.save(sampled_model.state_dict(), os.path.join(architecture_folder, 'checkpoint.pth'))
#     return sampled_model





# def save_top_ranked_models(population, arch_performance, generation):
#     top_n = min(5, len(population))
#     for idx, arch in enumerate(population[:top_n]):
#         depth, num_heads, mlp_ratio, embed_dim = arch
#         model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
#                            num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000).to(device)

#         architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
#         checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
#         model.load_state_dict(torch.load(checkpoint_path))

#         top_model_path = os.path.join(SAVE_PATH, f'top_ranked_model_gen{generation+1}_rank_{idx+1}.pth')
#         torch.save(model.state_dict(), top_model_path)
        
#         acc, lat, mem = arch_performance[arch]

#         with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
#             f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
#             f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")

#         print(f"Saved top-ranked model: Generation {generation+1}, Rank {idx+1} (Acc={acc:.2f}%, Lat={lat:.6f}, Mem={mem/(1024**2):.2f}MB)")
        




# def evolutionary_algorithm(population_size=5, generations=2, mutation_rate=0.1, crossover_rate=0.7, train_loader=None, test_loader=None):
#     seen_architectures = set()
#     population = [sample_subnetwork(seen_architectures) for _ in range(population_size)]

#     # Evaluate and store metrics only once per architecture per generation
#     arch_performance = {}

#     for generation in range(generations):
#         print(f"\n--- Generation {generation + 1}/{generations} ---")
        
#         for arch in population:
#             depth, num_heads, mlp_ratio, embed_dim = arch
#             architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")

#             model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
#                                depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
#                                num_classes=1000).to(device)

#             # if generation == 0:
#             #     load_pretrained_weights(model)
#             # else:
#             #     checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
#             #     model.load_state_dict(torch.load(checkpoint_path))
#             #     print(f"Loaded weights from previous generation for {arch}")
#             if generation == 0:
#                 load_pretrained_weights(model)
#             else:
#                 checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
#                 if os.path.exists(checkpoint_path):
#                     model.load_state_dict(torch.load(checkpoint_path))
#                     print(f"Generation {generation + 1}: Loaded fine-tuned weights from previous generation for {arch}.")
#                 else:
#                     print(f"Generation {generation + 1}: Fine-tuned weights not found for {arch}. Loading pretrained ViT weights.")
#                     load_pretrained_weights(model)


#             check_pretrained_weights(model, generation=generation, model_type="subnetwork")

#             fine_tune_model(
#                 model, train_loader, test_loader, epochs=3, architecture_folder=architecture_folder
#             )

#             # Evaluate once per architecture
#             accuracy, _, latency, memory = evaluate_architecture(model, test_loader)
#             arch_performance[arch] = (accuracy, latency, memory)

#             del model
#             torch.cuda.empty_cache()

#         # Pareto selection
#         population = pareto_selection(arch_performance)

#         # Save top models clearly ranked (1 = best)
#         save_top_ranked_models(population, arch_performance, generation)

#         # Generate offspring using crossover and mutation
#         offspring = []
#         for i in range(0, len(population)-1, 2):
#             if random.random() < crossover_rate:
#                 child1, child2 = one_point_crossover(population[i], population[i + 1])
#                 offspring.extend([child1, child2])
#             else:
#                 offspring.extend([population[i], population[i + 1]])

#         offspring = [mutate(child) if random.random() < mutation_rate else child for child in offspring]

#         # Next-generation combines top parents and offspring
#         population = population[:len(population)//2] + offspring

#     return population

# # Run the evolutionary algorithm
# evolutionary_algorithm(population_size=5, generations=2, train_loader=train_loader, test_loader=test_loader)


In [None]:
# import random
# import os
# import torch
# import torch.nn as nn
# from torch.optim import Adam
# from timm import create_model
# import time

# # Path to save models
# SAVE_PATH = '/home/pratibha/nas_vision/weights-img-evol1'

# # Device setup
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Load pretrained ViT weights initially
# def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
#     pretrained_vit = create_model(pretrained_model_name, pretrained=True)
#     pretrained_state_dict = pretrained_vit.state_dict()
#     model_state_dict = model.state_dict()
#     filtered_dict = {k: v for k, v in pretrained_state_dict.items()
#                      if k in model_state_dict and v.shape == model_state_dict[k].shape}
#     model.load_state_dict(filtered_dict, strict=False)
#     print(f"Loaded pretrained weights into {model.__class__.__name__}")

# # Count parameters
# def count_parameters(model):
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # Evaluate model clearly with corrected memory usage
# def evaluate_architecture(model, test_loader):
#     model.eval()
#     correct, total, running_loss = 0, 0, 0.0
#     criterion = nn.CrossEntropyLoss()
#     start_time = time.time()

#     with torch.no_grad():
#         for images, labels in test_loader:
#             images, labels = images.to(device), labels.to(device)
#             outputs = model(images)
#             loss = criterion(outputs, labels)
#             running_loss += loss.item()
#             _, predicted = torch.max(outputs, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#     latency = (time.time() - start_time) / len(test_loader.dataset)
#     accuracy = 100 * correct / total
#     num_params = count_parameters(model)
#     memory_usage = (num_params * 4) / (1024 ** 2)  # MB for FP32
#     test_loss = running_loss / len(test_loader)

#     print(f"Loss: {test_loss:.4f}, Acc: {accuracy:.2f}%, Lat: {latency:.6f}s/img, Mem: {memory_usage:.2f}MB")

#     return accuracy, test_loss, latency, memory_usage

# # Mutate architecture
# def mutate(arch):
#     depth, num_heads, mlp_ratio, embed_dim = arch
#     if random.random() < 0.5: depth = random.choice([4, 6, 8, 10, 12])
#     if random.random() < 0.5: num_heads = random.choice([4, 8, 12, 16])
#     if random.random() < 0.5: mlp_ratio = random.choice([2.0, 4.0, 6.0])
#     return depth, num_heads, mlp_ratio, embed_dim

# # Crossover architectures
# def one_point_crossover(p1, p2):
#     cp = random.randint(1, 3)
#     return p1[:cp] + p2[cp:], p2[:cp] + p1[cp:]

# # Pareto ranking
# def pareto_selection(perf):
#     def dominates(a, b):
#         return all(x >= y for x, y in zip(a, b)) and any(x > y for x, y in zip(a, b))

#     ranks = {}
#     for arch1, p1 in perf.items():
#         dominated = sum(dominates(p2, p1) for arch2, p2 in perf.items() if arch2 != arch1)
#         ranks[arch1] = dominated
#     return sorted(ranks, key=ranks.get)

# # Fine-tune model
# def fine_tune_model(model, train_loader, test_loader, epochs, folder):
#     criterion, optimizer = nn.CrossEntropyLoss(), Adam(model.parameters(), lr=1e-4)
#     model.to(device)

#     for epoch in range(epochs):
#         model.train()
#         running_loss = 0
#         for images, labels in train_loader:
#             images, labels = images.to(device), labels.to(device)
#             optimizer.zero_grad()
#             loss = criterion(model(images), labels)
#             loss.backward()
#             optimizer.step()
#             running_loss += loss.item()

#         acc, loss, lat, mem = evaluate_architecture(model, test_loader)
#         print(f"Epoch {epoch+1}/{epochs} done.")

#     os.makedirs(folder, exist_ok=True)
#     torch.save(model.state_dict(), os.path.join(folder, 'checkpoint.pth'))

# # Evolutionary NAS main loop
# def evolutionary_algorithm(population_size, generations, train_loader, test_loader):
#     seen, perf = set(), {}
#     population = [(random.choice([4,6,8,10,12]), random.choice([4,8,12,16]), random.choice([2.0,4.0,6.0]), 768)
#                   for _ in range(population_size)]

#     for gen in range(generations):
#         print(f"\nGeneration {gen+1}/{generations}")

#         for arch in population:
#             folder = os.path.join(SAVE_PATH, f"arch_{arch[0]}_{arch[1]}_{arch[2]}_{arch[3]}")
#             model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3], depth=arch[0],
#                                num_heads=arch[1], mlp_ratio=arch[2], num_classes=200).to(device)

#             ckpt = os.path.join(folder, 'checkpoint.pth')
#             if os.path.exists(ckpt):
#                 model.load_state_dict(torch.load(ckpt))
#                 print(f"Loaded previous weights for {arch}")
#             else:
#                 print(f"No previous weights, loading pretrained for {arch}")
#                 load_pretrained_weights(model)

#             fine_tune_model(model, train_loader, test_loader, 3, folder)
#             perf[arch] = evaluate_architecture(model, test_loader)
#             del model; torch.cuda.empty_cache()

#         population = pareto_selection(perf)

#         for i, arch in enumerate(population[:5]):
#             print(f"Rank {i+1}: {arch}, Acc: {perf[arch][0]:.2f}%")

#         offspring = []
#         for i in range(0, len(population)-1, 2):
#             child1, child2 = one_point_crossover(population[i], population[i+1])
#             offspring.extend([mutate(child1), mutate(child2)])

#         population = population[:len(population)//2] + offspring

#     return population

# # Example run:
# # evolutionary_algorithm(5, 2, train_loader, test_loader)
# evolutionary_algorithm(population_size=5, generations=2, train_loader=train_loader, test_loader=test_loader)


  from .autonotebook import tqdm as notebook_tqdm



Generation 1/2
No previous weights, loading pretrained for (6, 8, 2.0, 768)
Loaded pretrained weights into DynamicViT


KeyboardInterrupt: 

## 24 march  here ranking after first generation finetuning is not good see results

In [6]:
import random
import os
import torch
import torch.nn as nn
from torch.optim import Adam
from timm import create_model
import time

# Path to save the models after fine-tuning
SAVE_PATH = '/home/pratibha/nas_vision/weights-img-evol1'
# SAVE_PATH = '/kaggle/working/'

# Set the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# First-time loading pretrained weights for initialization
def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
    pretrained_vit = create_model(pretrained_model_name, pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    # Match keys between pretrained and current model
    model_state_dict = model.state_dict()
    filtered_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}

    # Load pretrained weights
    model.load_state_dict(filtered_dict, strict=False)
    print(f"Pretrained weights loaded into {model.__class__.__name__} successfully.")

# Check if pretrained weights are loaded correctly
def check_pretrained_weights(model, generation=0, model_type="subnetwork"):
    pretrained_vit = create_model("vit_base_patch16_224", pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    model_state_dict = model.state_dict()
    matching_keys = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
    
    if len(matching_keys) > 0:
        print(f"Generation {generation + 1}: {model_type} model has loaded {len(matching_keys)} layers from pretrained weights.")
    else:
        print(f"Generation {generation + 1}: {model_type} model has NOT loaded any pretrained weights.")

# Sample Subnetwork - Randomly sample hyperparameters (depth, num_heads, etc.)
def sample_subnetwork(seen_architectures):
    while True:
        depth = random.choice([4, 6, 8, 10, 12])
        num_heads = random.choice([4, 8, 12, 16])
        mlp_ratio = random.choice([2.0, 4.0, 6.0])
        embed_dim = 768  # Fixed embedding dimension
        
        architecture = (depth, num_heads, mlp_ratio, embed_dim)
        
        # Skip if architecture has already been sampled
        if architecture not in seen_architectures:
            seen_architectures.add(architecture)
            print(f"Sampled architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
            
            # Create the model to calculate its number of parameters
            # sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000)
            sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
                                        depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, 
                                        num_classes=200  
                                    )
            num_params = count_parameters(sampled_model)
            print(f"Number of parameters in the sampled model: {num_params:,}")
            
            return architecture
        else:
            print(f"Repeated architecture found, resampling...")

# Count number of trainable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Evaluate architecture: accuracy, latency, and memory usage
def evaluate_architecture(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    y_true = []
    y_pred = []
    
    criterion = nn.CrossEntropyLoss()

    # Start measuring inference latency
    start_time = time.time()

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)  # Move to the same device
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    # Measure total time for inference (latency)
    latency = (time.time() - start_time) / len(test_loader.dataset)

    # Compute accuracy
    accuracy = 100 * correct / total

    # Compute memory usage (rough estimation)
    num_params = count_parameters(model)
    memory_usage = (num_params * 4) / (1024 ** 2)  # Convert bytes to MB (FP32)

    # Compute average loss
    test_loss = running_loss / len(test_loader)

    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.2f}%, Latency: {latency:.6f} seconds/image, Memory Usage: {memory_usage:.2f} MB")

    return accuracy, test_loss, latency, memory_usage

# def evaluate_architecture(model, test_loader):
#     model.eval()
#     correct = 0
#     total = 0
#     running_loss = 0.0
    
#     criterion = nn.CrossEntropyLoss()
#     start_time = time.time()

#     with torch.no_grad():
#         for images, labels in test_loader:
#             images, labels = images.to(device), labels.to(device)
#             outputs = model(images)
#             loss = criterion(outputs, labels)
#             running_loss += loss.item()

#             _, predicted = torch.max(outputs, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#     latency = (time.time() - start_time) / len(test_loader.dataset)
#     accuracy = 100 * correct / total

#     num_params = count_parameters(model)
#     memory_usage = (num_params * 4) / (1024 ** 2)  # Convert bytes to MB (FP32)

#     test_loss = running_loss / len(test_loader)

#     print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.2f}%, Latency: {latency:.6f} seconds/image, Memory Usage: {memory_usage:.2f} MB")

#     return accuracy, test_loss, latency, memory_usage


# Estimate memory usage of a model during inference (rough estimation)
def estimate_memory_usage(model):
    # Create dummy input matching the expected shape of the input tensor
    dummy_input = torch.randn(1, 3, 224, 224).to(device)  # Example for ViT (3-channel image of size 224x224)
    
    # Use torch.utils.benchmark to measure memory usage during inference
    start_mem = torch.cuda.memory_allocated()
    
    # Run the model once with the dummy input
    with torch.no_grad():
        model(dummy_input)
    
    end_mem = torch.cuda.memory_allocated()
    memory_usage = (end_mem - start_mem) / (1024 ** 2)  # Convert bytes to MB
    return memory_usage


def calculate_crowding_distance(population, test_loader):
    crowding_distances = [0] * len(population)
    num_objectives = 3  # Accuracy, Latency, Memory

    # Evaluate each architecture once, then reuse the results
    evaluated_results = []
    for arch in population:
        # # model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
        #                    depth=arch[0], num_heads=arch[1],
        #                    mlp_ratio=arch[2], num_classes=10).to(device)
        model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
                            depth=arch[0], num_heads=arch[1], mlp_ratio=arch[2], 
                            num_classes=200).to(device)

        accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
        memory = count_parameters(model) * 4  # memory in bytes
        
        evaluated_results.append((accuracy, latency, memory))
        del model
        torch.cuda.empty_cache()

    for objective_index in range(num_objectives):
        sorted_indices = sorted(range(len(population)),
                                key=lambda idx: evaluated_results[idx][objective_index])
        
        crowding_distances[sorted_indices[0]] = crowding_distances[sorted_indices[-1]] = float('inf')

        for i in range(1, len(sorted_indices) - 1):
            prev_value = evaluated_results[sorted_indices[i - 1]][objective_index]
            next_value = evaluated_results[sorted_indices[i + 1]][objective_index]
            distance = next_value - prev_value
            crowding_distances[sorted_indices[i]] += distance

    return crowding_distances


def dominates(model1, model2, test_loader):
    # Evaluate both models on the test set
    accuracy1, latency1, _, _ = evaluate_architecture(model1, test_loader)
    accuracy2, latency2, _, _ = evaluate_architecture(model2, test_loader)
    
    # Calculate memory usage as the number of parameters * 4 bytes (FP32)
    memory1 = count_parameters(model1) * 4  # Memory in bytes
    memory2 = count_parameters(model2) * 4  # Memory in bytes
    
    # Compare performance metrics
    dominates_in_accuracy = accuracy1 >= accuracy2
    dominates_in_latency = latency1 <= latency2
    dominates_in_memory = memory1 <= memory2

    # Return True if model1 dominates model2 in all aspects
    return dominates_in_accuracy and dominates_in_latency and dominates_in_memory


# Mutation: Randomly mutate architecture's hyperparameters
def mutate(architecture):
    depth, num_heads, mlp_ratio, embed_dim = architecture
    if random.random() < 0.5: depth = random.choice([4, 6, 8, 10, 12])
    if random.random() < 0.5: num_heads = random.choice([4, 8, 12, 16])
    if random.random() < 0.5: mlp_ratio = random.choice([2.0, 4.0, 6.0])
    print(f"Mutated architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
    return depth, num_heads, mlp_ratio, embed_dim

# One-Point Crossover: Combine two parent architectures to create new architectures
def one_point_crossover(parent1, parent2):
    crossover_point = random.choice([0, 1, 2, 3])  # Crossover at depth, num_heads, etc.
    child1 = parent1[:crossover_point] + parent2[crossover_point:]
    child2 = parent2[:crossover_point] + parent1[crossover_point:]
    print(f"Crossover result: Child1={child1}, Child2={child2}")
    return child1, child2




# Optimized Pareto selection based on stored performance metrics
def pareto_selection(arch_performance):
    def dominates(perf1, perf2):
        acc1, lat1, mem1 = perf1
        acc2, lat2, mem2 = perf2
        return (acc1 >= acc2 and lat1 <= lat2 and mem1 <= mem2) and (acc1 > acc2 or lat1 < lat2 or mem1 < mem2)

    ranks = {}
    for arch1, perf1 in arch_performance.items():
        dominated_count = 0
        for arch2, perf2 in arch_performance.items():
            if arch1 != arch2 and dominates(perf2, perf1):
                dominated_count += 1
        ranks[arch1] = dominated_count

    # Sort architectures by rank (lower dominated_count = better)
    sorted_population = sorted(ranks.keys(), key=lambda arch: ranks[arch])
    return sorted_population

# Fine-tune model on dataset (train for a few epochs)
def fine_tune_model(sampled_model, train_loader, test_loader, epochs=3, architecture_folder=None):
    print(f"Fine-tuning model with architecture: Depth={sampled_model.depth}, Num Heads={sampled_model.num_heads}, MLP Ratio={sampled_model.mlp_ratio}")
    sampled_model.to(device)  # Ensure the model is on the correct device
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(sampled_model.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        sampled_model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)  # Ensure inputs are on the same device
            optimizer.zero_grad()
            outputs = sampled_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        test_accuracy, test_loss, test_latency, memory_usage = evaluate_architecture(sampled_model, test_loader)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%, Test Latency: {test_latency:.6f} seconds/image")

    # Save the model after fine-tuning
    if architecture_folder:
        os.makedirs(architecture_folder, exist_ok=True)
        torch.save(sampled_model.state_dict(), os.path.join(architecture_folder, 'checkpoint.pth'))
    return sampled_model





def save_top_ranked_models(population, arch_performance, generation):
    top_n = min(5, len(population))
    for idx, arch in enumerate(population[:top_n]):
        depth, num_heads, mlp_ratio, embed_dim = arch
        # model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
        #                    num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000).to(device)
        model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
                            num_heads=num_heads, mlp_ratio=mlp_ratio, 
                            num_classes=200).to(device)


        architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
        checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
        model.load_state_dict(torch.load(checkpoint_path))

        top_model_path = os.path.join(SAVE_PATH, f'top_ranked_model_gen{generation+1}_rank_{idx+1}.pth')
        torch.save(model.state_dict(), top_model_path)
        
        acc, lat, mem = arch_performance[arch]

        # with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
        #     f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
        #     f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")
        with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
            f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
            f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")


        print(f"Saved top-ranked model: Generation {generation+1}, Rank {idx+1} (Acc={acc:.2f}%, Lat={lat:.6f}, Mem={mem/(1024**2):.2f}MB)")
        

def evolutionary_algorithm(population_size=5, generations=2, mutation_rate=0.1, crossover_rate=0.7, train_loader=None, test_loader=None):
    seen_architectures = set()
    population = [sample_subnetwork(seen_architectures) for _ in range(population_size)]

    # Evaluate and store metrics only once per architecture per generation
    arch_performance = {}

    for generation in range(generations):
        print(f"\n--- Generation {generation + 1}/{generations} ---")
        
        for arch in population:
            depth, num_heads, mlp_ratio, embed_dim = arch
            architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")

            # model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
            #                    depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
            #                    num_classes=1000).to(device)
            model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
                                depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
                                num_classes=200).to(device)


            # if generation == 0:
            #     load_pretrained_weights(model)
            # else:
            #     checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
            #     model.load_state_dict(torch.load(checkpoint_path))
            #     print(f"Loaded weights from previous generation for {arch}")
            
            if generation == 0:
                load_pretrained_weights(model)
            else:
                checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
                if os.path.exists(checkpoint_path):
                    model.load_state_dict(torch.load(checkpoint_path))
                    print(f"Generation {generation + 1}: Loaded fine-tuned weights from previous generation for {arch}.")
                else:
                    print(f"Generation {generation + 1}: Fine-tuned weights not found for {arch}. Loading pretrained ViT weights.")
                    load_pretrained_weights(model)

            check_pretrained_weights(model, generation=generation, model_type="subnetwork")

            fine_tune_model(
                model, train_loader, test_loader, epochs=3, architecture_folder=architecture_folder
            )

            # Evaluate once per architecture
            accuracy, _, latency, memory = evaluate_architecture(model, test_loader)
            arch_performance[arch] = (accuracy, latency, memory)

            del model
            torch.cuda.empty_cache()

        # Pareto selection
        population = pareto_selection(arch_performance)

        # Save top models clearly ranked (1 = best)
        save_top_ranked_models(population, arch_performance, generation)

        # Generate offspring using crossover and mutation
        offspring = []
        for i in range(0, len(population)-1, 2):
            if random.random() < crossover_rate:
                child1, child2 = one_point_crossover(population[i], population[i + 1])
                offspring.extend([child1, child2])
            else:
                offspring.extend([population[i], population[i + 1]])

        offspring = [mutate(child) if random.random() < mutation_rate else child for child in offspring]

        # Next-generation combines top parents and offspring
        population = population[:len(population)//2] + offspring

    return population

# Run the evolutionary algorithm
evolutionary_algorithm(population_size=5, generations=2, train_loader=train_loader, test_loader=test_loader)


Sampled architecture: Depth=10, Num Heads=4, MLP Ratio=2.0, Embed Dim=768
Number of parameters in the sampled model: 48,168,392
Sampled architecture: Depth=8, Num Heads=4, MLP Ratio=2.0, Embed Dim=768
Number of parameters in the sampled model: 38,714,312
Sampled architecture: Depth=8, Num Heads=8, MLP Ratio=4.0, Embed Dim=768
Number of parameters in the sampled model: 57,600,968
Sampled architecture: Depth=8, Num Heads=16, MLP Ratio=6.0, Embed Dim=768
Number of parameters in the sampled model: 76,487,624
Sampled architecture: Depth=12, Num Heads=12, MLP Ratio=6.0, Embed Dim=768
Number of parameters in the sampled model: 114,282,440

--- Generation 1/2 ---
Pretrained weights loaded into DynamicViT successfully.
Generation 1: subnetwork model has loaded 96 layers from pretrained weights.
Fine-tuning model with architecture: Depth=10, Num Heads=4, MLP Ratio=2.0
Test Loss: 3.3444, Test Accuracy: 22.95%, Latency: 0.004883 seconds/image, Memory Usage: 183.75 MB
Epoch 1/3, Loss: 6370.4857, Te

[(10, 4, 2.0, 768),
 (8, 4, 2.0, 768),
 (12, 4, 4.0, 768),
 (10, 4, 2.0, 768),
 (8, 16, 6.0, 768),
 (8, 8, 4.0, 768)]

In [None]:
import random
import os
import torch
import torch.nn as nn
from torch.optim import Adam
from timm import create_model
import time

# Path to save the models after fine-tuning
SAVE_PATH = '/home/pratibha/nas_vision/weights-img-evol1'
# SAVE_PATH = '/kaggle/working/'

# Set the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# First-time loading pretrained weights for initialization
def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
    pretrained_vit = create_model(pretrained_model_name, pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    # Match keys between pretrained and current model
    model_state_dict = model.state_dict()
    filtered_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}

    # Load pretrained weights
    model.load_state_dict(filtered_dict, strict=False)
    print(f"Pretrained weights loaded into {model.__class__.__name__} successfully.")

# Check if pretrained weights are loaded correctly
def check_pretrained_weights(model, generation=0, model_type="subnetwork"):
    pretrained_vit = create_model("vit_base_patch16_224", pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    model_state_dict = model.state_dict()
    matching_keys = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
    
    if len(matching_keys) > 0:
        print(f"Generation {generation + 1}: {model_type} model has loaded {len(matching_keys)} layers from pretrained weights.")
    else:
        print(f"Generation {generation + 1}: {model_type} model has NOT loaded any pretrained weights.")

# Sample Subnetwork - Randomly sample hyperparameters (depth, num_heads, etc.)
def sample_subnetwork(seen_architectures):
    while True:
        depth = random.choice([4, 6, 8, 10, 12])
        num_heads = random.choice([4, 8, 12, 16])
        mlp_ratio = random.choice([2.0, 4.0, 6.0])
        embed_dim = 768  # Fixed embedding dimension
        
        architecture = (depth, num_heads, mlp_ratio, embed_dim)
        
        # Skip if architecture has already been sampled
        if architecture not in seen_architectures:
            seen_architectures.add(architecture)
            print(f"Sampled architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
            
            # Create the model to calculate its number of parameters
            # sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000)
            sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
                                        depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, 
                                        num_classes=200  
                                    )
            num_params = count_parameters(sampled_model)
            print(f"Number of parameters in the sampled model: {num_params:,}")
            
            return architecture
        else:
            print(f"Repeated architecture found, resampling...")

# Count number of trainable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Evaluate architecture: accuracy, latency, and memory usage
# def evaluate_architecture(model, test_loader):
#     model.eval()
#     correct = 0
#     total = 0
#     running_loss = 0.0
#     y_true = []
#     y_pred = []
    
#     criterion = nn.CrossEntropyLoss()

#     # Start measuring inference latency
#     start_time = time.time()

#     with torch.no_grad():
#         for images, labels in test_loader:
#             images, labels = images.to(device), labels.to(device)  # Move to the same device
#             outputs = model(images)
#             loss = criterion(outputs, labels)
#             running_loss += loss.item()

#             _, predicted = torch.max(outputs, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#             y_true.extend(labels.cpu().numpy())
#             y_pred.extend(predicted.cpu().numpy())

#     # Measure total time for inference (latency)
#     latency = (time.time() - start_time) / len(test_loader.dataset)

#     # Compute accuracy
#     accuracy = 100 * correct / total

#     # Compute memory usage (rough estimation)
#     num_params = count_parameters(model)
#     memory_usage = (num_params * 4) / (1024 ** 2)  # Convert bytes to MB (FP32)

#     # Compute average loss
#     test_loss = running_loss / len(test_loader)

#     print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.2f}%, Latency: {latency:.6f} seconds/image, Memory Usage: {memory_usage:.2f} MB")

#     return accuracy, test_loss, latency, memory_usage

def evaluate_architecture(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    
    criterion = nn.CrossEntropyLoss()
    
    # Start measuring inference latency
    start_time = time.time()
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    # Measure total time for inference (latency)
    latency = (time.time() - start_time) / len(test_loader.dataset)
    
    # Compute accuracy
    accuracy = 100 * correct / total
    
    # Compute memory usage (improved)
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    memory_usage = (num_params * 4) / (1024 ** 2)  # Convert bytes to MB (FP32)
    if torch.cuda.is_available():
        memory_usage += torch.cuda.max_memory_allocated() / (1024 ** 2)  # Add GPU memory usage
    
    # Compute average loss
    test_loss = running_loss / len(test_loader)
    
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.2f}%, Latency: {latency:.6f} seconds/image, Memory Usage: {memory_usage:.2f} MB")
    
    return accuracy, test_loss, latency, memory_usage


# Estimate memory usage of a model during inference (rough estimation)
# def estimate_memory_usage(model):
#     # Create dummy input matching the expected shape of the input tensor
#     dummy_input = torch.randn(1, 3, 224, 224).to(device)  # Example for ViT (3-channel image of size 224x224)
    
#     # Use torch.utils.benchmark to measure memory usage during inference
#     start_mem = torch.cuda.memory_allocated()
    
#     # Run the model once with the dummy input
#     with torch.no_grad():
#         model(dummy_input)
    
#     end_mem = torch.cuda.memory_allocated()
#     memory_usage = (end_mem - start_mem) / (1024 ** 2)  # Convert bytes to MB
#     return memory_usage


def calculate_crowding_distance(population, test_loader):
    crowding_distances = [0] * len(population)
    num_objectives = 3  # Accuracy, Latency, Memory

    # Evaluate each architecture once, then reuse the results
    evaluated_results = []
    for arch in population:
        # # model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
        #                    depth=arch[0], num_heads=arch[1],
        #                    mlp_ratio=arch[2], num_classes=10).to(device)
        model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
                            depth=arch[0], num_heads=arch[1], mlp_ratio=arch[2], 
                            num_classes=200).to(device)

        accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
        memory = count_parameters(model) * 4  # memory in bytes
        
        evaluated_results.append((accuracy, latency, memory))
        del model
        torch.cuda.empty_cache()

    for objective_index in range(num_objectives):
        sorted_indices = sorted(range(len(population)),
                                key=lambda idx: evaluated_results[idx][objective_index])
        
        crowding_distances[sorted_indices[0]] = crowding_distances[sorted_indices[-1]] = float('inf')

        for i in range(1, len(sorted_indices) - 1):
            prev_value = evaluated_results[sorted_indices[i - 1]][objective_index]
            next_value = evaluated_results[sorted_indices[i + 1]][objective_index]
            distance = next_value - prev_value
            crowding_distances[sorted_indices[i]] += distance

    return crowding_distances


def dominates(model1, model2, test_loader):
    # Evaluate both models on the test set
    accuracy1, latency1, _, _ = evaluate_architecture(model1, test_loader)
    accuracy2, latency2, _, _ = evaluate_architecture(model2, test_loader)
    
    # Calculate memory usage as the number of parameters * 4 bytes (FP32)
    memory1 = count_parameters(model1) * 4  # Memory in bytes
    memory2 = count_parameters(model2) * 4  # Memory in bytes
    
    # Compare performance metrics
    dominates_in_accuracy = accuracy1 >= accuracy2
    dominates_in_latency = latency1 <= latency2
    dominates_in_memory = memory1 <= memory2

    # Return True if model1 dominates model2 in all aspects
    return dominates_in_accuracy and dominates_in_latency and dominates_in_memory


# Mutation: Randomly mutate architecture's hyperparameters
def mutate(architecture):
    depth, num_heads, mlp_ratio, embed_dim = architecture
    if random.random() < 0.5: depth = random.choice([4, 6, 8, 10, 12])
    if random.random() < 0.5: num_heads = random.choice([4, 8, 12, 16])
    if random.random() < 0.5: mlp_ratio = random.choice([2.0, 4.0, 6.0])
    print(f"Mutated architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
    return depth, num_heads, mlp_ratio, embed_dim

# One-Point Crossover: Combine two parent architectures to create new architectures
def one_point_crossover(parent1, parent2):
    crossover_point = random.choice([0, 1, 2, 3])  # Crossover at depth, num_heads, etc.
    child1 = parent1[:crossover_point] + parent2[crossover_point:]
    child2 = parent2[:crossover_point] + parent1[crossover_point:]
    print(f"Crossover result: Child1={child1}, Child2={child2}")
    return child1, child2


# Optimized Pareto selection based on stored performance metrics
# def pareto_selection(arch_performance):
#     def dominates(perf1, perf2):
#         acc1, lat1, mem1 = perf1
#         acc2, lat2, mem2 = perf2
#         return (acc1 >= acc2 and lat1 <= lat2 and mem1 <= mem2) and (acc1 > acc2 or lat1 < lat2 or mem1 < mem2)

#     ranks = {}
#     for arch1, perf1 in arch_performance.items():
#         dominated_count = 0
#         for arch2, perf2 in arch_performance.items():
#             if arch1 != arch2 and dominates(perf2, perf1):
#                 dominated_count += 1
#         ranks[arch1] = dominated_count

#     # Sort architectures by rank (lower dominated_count = better)
#     sorted_population = sorted(ranks.keys(), key=lambda arch: ranks[arch])
#     return sorted_population

def pareto_selection(arch_performance, population_size):
    def dominates(perf1, perf2):
        acc1, lat1, mem1 = perf1
        acc2, lat2, mem2 = perf2
        # For accuracy, higher is better; for latency and memory, lower is better
        return (acc1 >= acc2 and lat1 <= lat2 and mem1 <= mem2) and (acc1 > acc2 or lat1 < lat2 or mem1 < mem2)

    # Initialize fronts
    fronts = [[]]
    dominated = {arch: set() for arch in arch_performance}
    domination_count = {arch: 0 for arch in arch_performance}

    # Calculate dominance relationships
    for arch1, perf1 in arch_performance.items():
        for arch2, perf2 in arch_performance.items():
            if arch1 != arch2:
                if dominates(perf1, perf2):
                    dominated[arch1].add(arch2)
                elif dominates(perf2, perf1):
                    domination_count[arch1] += 1
        
        # If not dominated by any other architecture, add to first front
        if domination_count[arch1] == 0:
            fronts[0].append(arch1)

    # Calculate remaining fronts
    i = 0
    while len(fronts[i]) > 0:
        next_front = []
        for arch in fronts[i]:
            for dominated_arch in dominated[arch]:
                domination_count[dominated_arch] -= 1
                if domination_count[dominated_arch] == 0:
                    next_front.append(dominated_arch)
        i += 1
        fronts.append(next_front)

    # Select architectures based on fronts and crowding distance
    selected = []
    for front in fronts:
        if len(selected) + len(front) <= population_size:
            selected.extend(front)
        else:
            # Calculate crowding distance for the current front
            crowding_distance = {}
            for arch in front:
                crowding_distance[arch] = 0
            
            # For each objective (accuracy, latency, memory)
            for obj_idx in range(3):
                # Sort front by the current objective
                if obj_idx == 0:  # Accuracy (higher is better)
                    sorted_front = sorted(front, key=lambda x: arch_performance[x][obj_idx], reverse=True)
                else:  # Latency and memory (lower is better)
                    sorted_front = sorted(front, key=lambda x: arch_performance[x][obj_idx])
                
                # Set boundary points to infinity
                crowding_distance[sorted_front[0]] = float('inf')
                crowding_distance[sorted_front[-1]] = float('inf')
                
                # Calculate crowding distance for intermediate points
                for i in range(1, len(sorted_front) - 1):
                    if crowding_distance[sorted_front[i]] != float('inf'):
                        # Normalize by the range of the objective
                        obj_range = arch_performance[sorted_front[0]][obj_idx] - arch_performance[sorted_front[-1]][obj_idx]
                        if obj_range != 0:
                            prev_val = arch_performance[sorted_front[i-1]][obj_idx]
                            next_val = arch_performance[sorted_front[i+1]][obj_idx]
                            crowding_distance[sorted_front[i]] += abs(next_val - prev_val) / obj_range
            
            # Sort by crowding distance (higher is better)
            sorted_front = sorted(front, key=lambda x: crowding_distance[x], reverse=True)
            selected.extend(sorted_front[:population_size - len(selected)])
            break

    return selected

import time
import matplotlib.pyplot as plt
import os
import random
import torch
import torch.nn as nn
from torch.optim import Adam

def plot_training_progress(train_losses, test_losses, test_accuracies, depth, num_heads, mlp_ratio):
    epochs = range(1, len(train_losses) + 1)
    
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(epochs, train_losses, 'b-', label='Train Loss')
    plt.title('Train Loss vs Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 3, 2)
    plt.plot(epochs, test_losses, 'r-', label='Test Loss')
    plt.title('Test Loss vs Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 3, 3)
    plt.plot(epochs, test_accuracies, 'g-', label='Test Accuracy')
    plt.title('Test Accuracy vs Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    
    plt.suptitle(f'Training Progress: Depth={depth}, Heads={num_heads}, MLP Ratio={mlp_ratio}')
    plt.tight_layout()
    
    # Save the plot
    architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{768}")
    os.makedirs(architecture_folder, exist_ok=True)
    plt.savefig(os.path.join(architecture_folder, 'training_progress.png'))
    plt.close()
    
    
    
# Fine-tune model on dataset (train for a few epochs)
# def fine_tune_model(sampled_model, train_loader, test_loader, epochs=3, architecture_folder=None):
#     print(f"Fine-tuning model with architecture: Depth={sampled_model.depth}, Num Heads={sampled_model.num_heads}, MLP Ratio={sampled_model.mlp_ratio}")
#     sampled_model.to(device)  # Ensure the model is on the correct device
#     criterion = nn.CrossEntropyLoss()
#     optimizer = Adam(sampled_model.parameters(), lr=1e-4)
    
#     for epoch in range(epochs):
#         sampled_model.train()
#         running_loss = 0.0
#         for images, labels in train_loader:
#             images, labels = images.to(device), labels.to(device)  # Ensure inputs are on the same device
#             optimizer.zero_grad()
#             outputs = sampled_model(images)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#             running_loss += loss.item()
        
#         test_accuracy, test_loss, test_latency, memory_usage = evaluate_architecture(sampled_model, test_loader)
#         print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%, Test Latency: {test_latency:.6f} seconds/image")

#     # Save the model after fine-tuning
#     if architecture_folder:
#         os.makedirs(architecture_folder, exist_ok=True)
#         torch.save(sampled_model.state_dict(), os.path.join(architecture_folder, 'checkpoint.pth'))
#     return sampled_model

def fine_tune_model(sampled_model, train_loader, test_loader, epochs=5, architecture_folder=None):
    print(f"Fine-tuning model with architecture: Depth={sampled_model.depth}, Num Heads={sampled_model.num_heads}, MLP Ratio={sampled_model.mlp_ratio}")
    sampled_model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(sampled_model.parameters(), lr=1e-4)
    
    train_losses = []
    test_losses = []
    test_accuracies = []
    
    for epoch in range(epochs):
        epoch_start_time = time.time()
        
        # Training phase
        sampled_model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = sampled_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)
        
        # Evaluation phase
        test_accuracy, test_loss, test_latency, memory_usage = evaluate_architecture(sampled_model, test_loader)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)
        
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%, Execution Time: {epoch_time:.2f} seconds")
    
    # Plot training progress
    plot_training_progress(train_losses, test_losses, test_accuracies, sampled_model.depth, sampled_model.num_heads, sampled_model.mlp_ratio)
    
    # Save the model
    if architecture_folder:
        os.makedirs(architecture_folder, exist_ok=True)
        torch.save(sampled_model.state_dict(), os.path.join(architecture_folder, 'checkpoint.pth'))
    
    return sampled_model

def save_top_ranked_models(population, arch_performance, generation):
    top_n = min(5, len(population))
    for idx, arch in enumerate(population[:top_n]):
        depth, num_heads, mlp_ratio, embed_dim = arch
        # model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
        #                    num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000).to(device)
        model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
                            num_heads=num_heads, mlp_ratio=mlp_ratio, 
                            num_classes=200).to(device)


        architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
        checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
        model.load_state_dict(torch.load(checkpoint_path))

        top_model_path = os.path.join(SAVE_PATH, f'top_ranked_model_gen{generation+1}_rank_{idx+1}.pth')
        torch.save(model.state_dict(), top_model_path)
        
        acc, lat, mem = arch_performance[arch]

        
        with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
            f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
            f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")


        print(f"Saved top-ranked model: Generation {generation+1}, Rank {idx+1} (Acc={acc:.2f}%, Lat={lat:.6f}, Mem={mem/(1024**2):.2f}MB)")
        

# def evolutionary_algorithm(population_size=5, generations=2, mutation_rate=0.1, crossover_rate=0.7, train_loader=None, test_loader=None):
#     seen_architectures = set()
#     population = [sample_subnetwork(seen_architectures) for _ in range(population_size)]

#     # Evaluate and store metrics only once per architecture per generation
#     arch_performance = {}

#     for generation in range(generations):
#         print(f"\n--- Generation {generation + 1}/{generations} ---")
        
#         for arch in population:
#             depth, num_heads, mlp_ratio, embed_dim = arch
#             architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")

            
#             model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
#                                 depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
#                                 num_classes=200).to(device)


            
#             if generation == 0:
#                 load_pretrained_weights(model)
#             else:
#                 checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
#                 if os.path.exists(checkpoint_path):
#                     model.load_state_dict(torch.load(checkpoint_path))
#                     print(f"Generation {generation + 1}: Loaded fine-tuned weights from previous generation for {arch}.")
#                 else:
#                     print(f"Generation {generation + 1}: Fine-tuned weights not found for {arch}. Loading pretrained ViT weights.")
#                     load_pretrained_weights(model)

#             check_pretrained_weights(model, generation=generation, model_type="subnetwork")

#             fine_tune_model(
#                 model, train_loader, test_loader, epochs=3, architecture_folder=architecture_folder
#             )

#             # Evaluate once per architecture
#             accuracy, _, latency, memory = evaluate_architecture(model, test_loader)
#             arch_performance[arch] = (accuracy, latency, memory)

#             del model
#             torch.cuda.empty_cache()

#         # Pareto selection
#         population = pareto_selection(arch_performance)

#         # Save top models clearly ranked (1 = best)
#         save_top_ranked_models(population, arch_performance, generation)

#         # Generate offspring using crossover and mutation
#         offspring = []
#         for i in range(0, len(population)-1, 2):
#             if random.random() < crossover_rate:
#                 child1, child2 = one_point_crossover(population[i], population[i + 1])
#                 offspring.extend([child1, child2])
#             else:
#                 offspring.extend([population[i], population[i + 1]])

#         offspring = [mutate(child) if random.random() < mutation_rate else child for child in offspring]

#         # Next-generation combines top parents and offspring
#         population = population[:len(population)//2] + offspring

#     return population

# # Run the evolutionary algorithm
# evolutionary_algorithm(population_size=5, generations=2, train_loader=train_loader, test_loader=test_loader)

def evolutionary_algorithm(population_size=5, generations=2, mutation_rate=0.1, crossover_rate=0.7, train_loader=None, test_loader=None):
    seen_architectures = set()
    population = [sample_subnetwork(seen_architectures) for _ in range(population_size)]
    
    # Evaluate and store metrics only once per architecture per generation
    arch_performance = {}
    
    for generation in range(generations):
        print(f"\n--- Generation {generation + 1}/{generations} ---")
        generation_start_time = time.time()
        
        for arch in population:
            depth, num_heads, mlp_ratio, embed_dim = arch
            architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
            
            model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
                              depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
                              num_classes=200).to(device)
            
            if generation == 0:
                load_pretrained_weights(model)
            else:
                checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
                if os.path.exists(checkpoint_path):
                    model.load_state_dict(torch.load(checkpoint_path))
                    print(f"Generation {generation + 1}: Loaded fine-tuned weights from previous generation for {arch}.")
                else:
                    print(f"Generation {generation + 1}: Fine-tuned weights not found for {arch}. Loading pretrained ViT weights.")
                    load_pretrained_weights(model)
            
            check_pretrained_weights(model, generation=generation, model_type="subnetwork")
            
            # Time the fine-tuning process
            arch_start_time = time.time()
            fine_tune_model(
                model, train_loader, test_loader, epochs=5, architecture_folder=architecture_folder
            )
            arch_time = time.time() - arch_start_time
            
            # Evaluate once per architecture
            accuracy, _, latency, memory = evaluate_architecture(model, test_loader)
            arch_performance[arch] = (accuracy, latency, memory)
            
            print(f"Architecture {arch} total execution time: {arch_time:.2f} seconds")
            
            del model
            torch.cuda.empty_cache()
        
        # Pareto selection with population size
        population = pareto_selection(arch_performance, population_size)
        
        # Save top models clearly ranked (1 = best)
        save_top_ranked_models(population, arch_performance, generation)
        
        # Generate offspring using crossover and mutation
        offspring = []
        for i in range(0, len(population)-1, 2):
            if random.random() < crossover_rate:
                child1, child2 = one_point_crossover(population[i], population[i + 1])
                offspring.extend([child1, child2])
            else:
                offspring.extend([population[i], population[i + 1]])
        
        offspring = [mutate(child) if random.random() < mutation_rate else child for child in offspring]
        
        # Next-generation combines top parents and offspring
        population = population[:len(population)//2] + offspring
        
        generation_time = time.time() - generation_start_time
        print(f"Generation {generation + 1} completed in {generation_time:.2f} seconds")
    
    return population


# # Run the evolutionary algorithm
evolutionary_algorithm(population_size=5, generations=2, train_loader=train_loader, test_loader=test_loader)

In [None]:
# def pareto_selection(arch_performance, population_size):
#     def dominates(perf1, perf2):
#         acc1, lat1, mem1 = perf1
#         acc2, lat2, mem2 = perf2
#         # For accuracy, higher is better; for latency and memory, lower is better
#         return (acc1 >= acc2 and lat1 <= lat2 and mem1 <= mem2) and (acc1 > acc2 or lat1 < lat2 or mem1 < mem2)

#     # Initialize fronts
#     fronts = [[]]
#     dominated = {arch: set() for arch in arch_performance}
#     domination_count = {arch: 0 for arch in arch_performance}

#     # Calculate dominance relationships
#     for arch1, perf1 in arch_performance.items():
#         for arch2, perf2 in arch_performance.items():
#             if arch1 != arch2:
#                 if dominates(perf1, perf2):
#                     dominated[arch1].add(arch2)
#                 elif dominates(perf2, perf1):
#                     domination_count[arch1] += 1
        
#         # If not dominated by any other architecture, add to first front
#         if domination_count[arch1] == 0:
#             fronts[0].append(arch1)

#     # Calculate remaining fronts
#     i = 0
#     while len(fronts[i]) > 0:
#         next_front = []
#         for arch in fronts[i]:
#             for dominated_arch in dominated[arch]:
#                 domination_count[dominated_arch] -= 1
#                 if domination_count[dominated_arch] == 0:
#                     next_front.append(dominated_arch)
#         i += 1
#         fronts.append(next_front)

#     # Select architectures based on fronts and crowding distance
#     selected = []
#     for front in fronts:
#         if len(selected) + len(front) <= population_size:
#             selected.extend(front)
#         else:
#             # Calculate crowding distance for the current front
#             crowding_distance = {}
#             for arch in front:
#                 crowding_distance[arch] = 0
            
#             # For each objective (accuracy, latency, memory)
#             for obj_idx in range(3):
#                 # Sort front by the current objective
#                 if obj_idx == 0:  # Accuracy (higher is better)
#                     sorted_front = sorted(front, key=lambda x: arch_performance[x][obj_idx], reverse=True)
#                 else:  # Latency and memory (lower is better)
#                     sorted_front = sorted(front, key=lambda x: arch_performance[x][obj_idx])
                
#                 # Set boundary points to infinity
#                 crowding_distance[sorted_front[0]] = float('inf')
#                 crowding_distance[sorted_front[-1]] = float('inf')
                
#                 # Calculate crowding distance for intermediate points
#                 for i in range(1, len(sorted_front) - 1):
#                     if crowding_distance[sorted_front[i]] != float('inf'):
#                         # Normalize by the range of the objective
#                         obj_range = arch_performance[sorted_front[0]][obj_idx] - arch_performance[sorted_front[-1]][obj_idx]
#                         if obj_range != 0:
#                             prev_val = arch_performance[sorted_front[i-1]][obj_idx]
#                             next_val = arch_performance[sorted_front[i+1]][obj_idx]
#                             crowding_distance[sorted_front[i]] += abs(next_val - prev_val) / obj_range
            
#             # Sort by crowding distance (higher is better)
#             sorted_front = sorted(front, key=lambda x: crowding_distance[x], reverse=True)
#             selected.extend(sorted_front[:population_size - len(selected)])
#             break

#     return selected


## use below functins

In [None]:
# import matplotlib.pyplot as plt

# def plot_pareto_front(arch_performance):
#     accuracies = [v[0] for v in arch_performance.values()]
#     latencies = [v[1] for v in arch_performance.values()]
#     memories = [v[2] / (1024**2) for v in arch_performance.values()]  # convert to MB

#     # Accuracy vs Latency
#     plt.figure(figsize=(8,6))
#     plt.scatter(latencies, accuracies, c='blue')
#     plt.xlabel('Latency (s/image)')
#     plt.ylabel('Accuracy (%)')
#     plt.title('Pareto Front (Accuracy vs Latency)')
#     plt.grid()
#     plt.show()

#     # Accuracy vs Memory
#     plt.figure(figsize=(8,6))
#     plt.scatter(memories, accuracies, c='green')
#     plt.xlabel('Memory (MB)')
#     plt.ylabel('Accuracy (%)')
#     plt.title('Pareto Front (Accuracy vs Memory)')
#     plt.grid()
#     plt.show()

# # Call after your last generation completes:
# plot_pareto_front(arch_performance)


In [None]:
# def evolutionary_algorithm(population_size=10, generations=5, mutation_rate=0.1, crossover_rate=0.7, train_loader=None, test_loader=None):
#     seen_architectures = set()
#     population = [sample_subnetwork(seen_architectures) for _ in range(population_size)]
#     arch_performance = {}

#     prev_best_accuracy = 0
#     no_improvement_count = 0

#     for generation in range(generations):
#         print(f"\n--- Generation {generation + 1}/{generations} ---")

#         # Fine-tuning and evaluation
#         for arch in population:
#             depth, num_heads, mlp_ratio, embed_dim = arch
#             architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")

#             model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
#                                depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
#                                num_classes=200).to(device)

#             checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')

#             # Load weights only once clearly
#             if generation == 0 or not os.path.exists(checkpoint_path):
#                 load_pretrained_weights(model)
#             else:
#                 model.load_state_dict(torch.load(checkpoint_path))
#                 print(f"Loaded weights from previous generation for architecture {arch}")

#             fine_tune_model(model, train_loader, test_loader, epochs=3, architecture_folder=architecture_folder)

#             accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
#             memory = count_parameters(model) * 4 / (1024 ** 2)  # MB
#             arch_performance[arch] = (accuracy, latency, memory)

#             del model
#             torch.cuda.empty_cache()

#         # Pareto selection
#         population = pareto_selection(arch_performance)

#         # Saving clearly ranked top models
#         save_top_ranked_models(population, arch_performance, generation)

#         # Check Pareto front convergence (stopping criteria)
#         current_best_accuracy = arch_performance[population[0]][0]
#         if current_best_accuracy - prev_best_accuracy < 1.0:
#             no_improvement_count += 1
#             print(f"Minimal improvement detected: {current_best_accuracy - prev_best_accuracy:.2f}%")
#             if no_improvement_count >= 2:
#                 print("Pareto front has converged. Stopping early.")
#                 break
#         else:
#             no_improvement_count = 0
#         prev_best_accuracy = current_best_accuracy

#         # Generate offspring
#         next_population = population[:len(population)//2]  # Only top half
#         offspring = []

#         for i in range(0, len(next_population)-1, 2):
#             parent1, parent2 = next_population[i], next_population[i+1]

#             if random.random() < crossover_rate:
#                 child1, child2 = one_point_crossover(parent1, parent2)
#                 print(f"Crossover parents: {parent1} & {parent2}")
#                 offspring.extend([child1, child2])
#             else:
#                 offspring.extend([parent1, parent2])

#         # Mutation with clear logging
#         mutated_offspring = []
#         for child in offspring:
#             if random.random() < mutation_rate:
#                 original_child = child
#                 child = mutate(child)
#                 print(f"Mutated from {original_child} to {child}")
#             mutated_offspring.append(child)

#         population = next_population + mutated_offspring

#     # Plot Pareto Front at end
#     plot_pareto_front(arch_performance)

#     return population

# # Call the algorithm
# evolutionary_algorithm(population_size=10, generations=5, train_loader=train_loader, test_loader=test_loader)


In [None]:
# import random
# import os
# import torch
# import torch.nn as nn
# from torch.optim import Adam
# from timm import create_model
# import time

# # Path to save models
# SAVE_PATH = '/home/pratibha/nas_vision/weights-img-evol1'

# # Device setup
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Load pretrained ViT weights initially
# def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
#     pretrained_vit = create_model(pretrained_model_name, pretrained=True)
#     pretrained_state_dict = pretrained_vit.state_dict()
#     model_state_dict = model.state_dict()
#     filtered_dict = {k: v for k, v in pretrained_state_dict.items()
#                      if k in model_state_dict and v.shape == model_state_dict[k].shape}
#     model.load_state_dict(filtered_dict, strict=False)
#     print(f"Loaded pretrained weights into {model.__class__.__name__}")

# # Count parameters
# def count_parameters(model):
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # Evaluate model clearly with corrected memory usage
# def evaluate_architecture(model, test_loader):
#     model.eval()
#     correct, total, running_loss = 0, 0, 0.0
#     criterion = nn.CrossEntropyLoss()
#     start_time = time.time()

#     with torch.no_grad():
#         for images, labels in test_loader:
#             images, labels = images.to(device), labels.to(device)
#             outputs = model(images)
#             loss = criterion(outputs, labels)
#             running_loss += loss.item()
#             _, predicted = torch.max(outputs, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#     latency = (time.time() - start_time) / len(test_loader.dataset)
#     accuracy = 100 * correct / total
#     num_params = count_parameters(model)
#     memory_usage = (num_params * 4) / (1024 ** 2)  # MB for FP32
#     test_loss = running_loss / len(test_loader)

#     print(f"Loss: {test_loss:.4f}, Acc: {accuracy:.2f}%, Lat: {latency:.6f}s/img, Mem: {memory_usage:.2f}MB")

#     return accuracy, test_loss, latency, memory_usage

# # Mutate architecture
# def mutate(arch):
#     depth, num_heads, mlp_ratio, embed_dim = arch
#     if random.random() < 0.5: depth = random.choice([4, 6, 8, 10, 12])
#     if random.random() < 0.5: num_heads = random.choice([4, 8, 12, 16])
#     if random.random() < 0.5: mlp_ratio = random.choice([2.0, 4.0, 6.0])
#     return depth, num_heads, mlp_ratio, embed_dim

# # Crossover architectures
# def one_point_crossover(p1, p2):
#     cp = random.randint(1, 3)
#     return p1[:cp] + p2[cp:], p2[:cp] + p1[cp:]

# # Pareto ranking
# def pareto_selection(perf):
#     def dominates(a, b):
#         return all(x >= y for x, y in zip(a, b)) and any(x > y for x, y in zip(a, b))

#     ranks = {}
#     for arch1, p1 in perf.items():
#         dominated = sum(dominates(p2, p1) for arch2, p2 in perf.items() if arch2 != arch1)
#         ranks[arch1] = dominated
#     return sorted(ranks, key=ranks.get)

# # Fine-tune model
# def fine_tune_model(model, train_loader, test_loader, epochs, folder):
#     criterion, optimizer = nn.CrossEntropyLoss(), Adam(model.parameters(), lr=1e-4)
#     model.to(device)

#     for epoch in range(epochs):
#         model.train()
#         running_loss = 0
#         for images, labels in train_loader:
#             images, labels = images.to(device), labels.to(device)
#             optimizer.zero_grad()
#             loss = criterion(model(images), labels)
#             loss.backward()
#             optimizer.step()
#             running_loss += loss.item()

#         acc, loss, lat, mem = evaluate_architecture(model, test_loader)
#         print(f"Epoch {epoch+1}/{epochs} done.")

#     os.makedirs(folder, exist_ok=True)
#     torch.save(model.state_dict(), os.path.join(folder, 'checkpoint.pth'))

# # Evolutionary NAS main loop
# def evolutionary_algorithm(population_size, generations, train_loader, test_loader):
#     seen, perf = set(), {}
#     population = [(random.choice([4,6,8,10,12]), random.choice([4,8,12,16]), random.choice([2.0,4.0,6.0]), 768)
#                   for _ in range(population_size)]

#     for gen in range(generations):
#         print(f"\nGeneration {gen+1}/{generations}")

#         for arch in population:
#             folder = os.path.join(SAVE_PATH, f"arch_{arch[0]}_{arch[1]}_{arch[2]}_{arch[3]}")
#             model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3], depth=arch[0],
#                                num_heads=arch[1], mlp_ratio=arch[2], num_classes=200).to(device)

#             ckpt = os.path.join(folder, 'checkpoint.pth')
#             if os.path.exists(ckpt):
#                 model.load_state_dict(torch.load(ckpt))
#                 print(f"Loaded previous weights for {arch}")
#             else:
#                 print(f"No previous weights, loading pretrained for {arch}")
#                 load_pretrained_weights(model)

#             fine_tune_model(model, train_loader, test_loader, 3, folder)
#             perf[arch] = evaluate_architecture(model, test_loader)
#             del model; torch.cuda.empty_cache()

#         population = pareto_selection(perf)

#         for i, arch in enumerate(population[:5]):
#             print(f"Rank {i+1}: {arch}, Acc: {perf[arch][0]:.2f}%")

#         offspring = []
#         for i in range(0, len(population)-1, 2):
#             child1, child2 = one_point_crossover(population[i], population[i+1])
#             offspring.extend([mutate(child1), mutate(child2)])

#         population = population[:len(population)//2] + offspring

#     return population

# # Example run:
# # evolutionary_algorithm(5, 2, train_loader, test_loader)
# evolutionary_algorithm(population_size=5, generations=2, train_loader=train_loader, test_loader=test_loader)


In [8]:
!pip install matplotlib

Collecting matplotlib
  Downloading matplotlib-3.10.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.4 kB)
Collecting cycler>=0.10 (from matplotlib)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.56.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (101 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Downloading kiwisolver-1.4.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.2 kB)
Collecting pyparsing>=2.3.1 (from matplotlib)
  Downloading pyparsing-3.2.3-py3-none-any.whl.metadata (5.0 kB)
Downloading matplotlib-3.10.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚î

In [9]:
import matplotlib.pyplot as plt

In [None]:
 # Call after your last generation completes:
# plot_pareto_front(arch_performance)


# def evolutionary_algorithm(population_size=5, generations=2, mutation_rate=0.1, crossover_rate=0.7, train_loader=None, test_loader=None):
#     seen_architectures = set()
#     population = [sample_subnetwork(seen_architectures) for _ in range(population_size)]

#     # Evaluate and store metrics only once per architecture per generation
#     arch_performance = {}

#     for generation in range(generations):
#         print(f"\n--- Generation {generation + 1}/{generations} ---")
        
#         for arch in population:
#             depth, num_heads, mlp_ratio, embed_dim = arch
#             architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")

#             # model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
#             #                    depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
#             #                    num_classes=1000).to(device)
#             model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
#                                 depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
#                                 num_classes=200).to(device)


#             # if generation == 0:
#             #     load_pretrained_weights(model)
#             # else:
#             #     checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
#             #     model.load_state_dict(torch.load(checkpoint_path))
#             #     print(f"Loaded weights from previous generation for {arch}")
            
#             if generation == 0:
#                 load_pretrained_weights(model)
#             else:
#                 checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
#                 if os.path.exists(checkpoint_path):
#                     model.load_state_dict(torch.load(checkpoint_path))
#                     print(f"Generation {generation + 1}: Loaded fine-tuned weights from previous generation for {arch}.")
#                 else:
#                     print(f"Generation {generation + 1}: Fine-tuned weights not found for {arch}. Loading pretrained ViT weights.")
#                     load_pretrained_weights(model)

#             check_pretrained_weights(model, generation=generation, model_type="subnetwork")

#             fine_tune_model(
#                 model, train_loader, test_loader, epochs=3, architecture_folder=architecture_folder
#             )

#             # Evaluate once per architecture
#             accuracy, _, latency, memory = evaluate_architecture(model, test_loader)
#             arch_performance[arch] = (accuracy, latency, memory)

#             del model
#             torch.cuda.empty_cache()

#         # Pareto selection
#         population = pareto_selection(arch_performance)

#         # Save top models clearly ranked (1 = best)
#         save_top_ranked_models(population, arch_performance, generation)

#         # Generate offspring using crossover and mutation
#         offspring = []
#         for i in range(0, len(population)-1, 2):
#             if random.random() < crossover_rate:
#                 child1, child2 = one_point_crossover(population[i], population[i + 1])
#                 offspring.extend([child1, child2])
#             else:
#                 offspring.extend([population[i], population[i + 1]])

#         offspring = [mutate(child) if random.random() < mutation_rate else child for child in offspring]

#         # Next-generation combines top parents and offspring
#         population = population[:len(population)//2] + offspring

#     return population

# def evolutionary_algorithm(population_size=10, generations=5, mutation_rate=0.1, crossover_rate=0.7, train_loader=None, test_loader=None):
#     seen_architectures = set()
#     population = [sample_subnetwork(seen_architectures) for _ in range(population_size)]
#     arch_performance = {}

#     prev_best_accuracy = 0
#     no_improvement_count = 0

#     for generation in range(generations):
#         print(f"\n--- Generation {generation + 1}/{generations} ---")

#         # Fine-tuning and evaluation
#         for arch in population:
#             depth, num_heads, mlp_ratio, embed_dim = arch
#             architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")

#             model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
#                                depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
#                                num_classes=200).to(device)

#             checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')

#             # Load weights only once clearly
#             if generation == 0 or not os.path.exists(checkpoint_path):
#                 load_pretrained_weights(model)
#             else:
#                 model.load_state_dict(torch.load(checkpoint_path))
#                 print(f"Loaded weights from previous generation for architecture {arch}")

#             fine_tune_model(model, train_loader, test_loader, epochs=3, architecture_folder=architecture_folder)

#             accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
#             memory = count_parameters(model) * 4 / (1024 ** 2)  # MB
#             arch_performance[arch] = (accuracy, latency, memory)

#             del model
#             torch.cuda.empty_cache()

#         # Pareto selection
#         population = pareto_selection(arch_performance)

#         # Saving clearly ranked top models
#         save_top_ranked_models(population, arch_performance, generation)

#         # Check Pareto front convergence (stopping criteria)
#         current_best_accuracy = arch_performance[population[0]][0]
#         if current_best_accuracy - prev_best_accuracy < 1.0:
#             no_improvement_count += 1
#             print(f"Minimal improvement detected: {current_best_accuracy - prev_best_accuracy:.2f}%")
#             if no_improvement_count >= 2:
#                 print("Pareto front has converged. Stopping early.")
#                 break
#         else:
#             no_improvement_count = 0
#         prev_best_accuracy = current_best_accuracy

#         # Generate offspring
#         next_population = population[:len(population)//2]  # Only top half
#         offspring = []

#         for i in range(0, len(next_population)-1, 2):
#             parent1, parent2 = next_population[i], next_population[i+1]

#             if random.random() < crossover_rate:
#                 child1, child2 = one_point_crossover(parent1, parent2)
#                 print(f"Crossover parents: {parent1} & {parent2}")
#                 offspring.extend([child1, child2])
#             else:
#                 offspring.extend([parent1, parent2])

#         # Mutation with clear logging
#         mutated_offspring = []
#         for child in offspring:
#             if random.random() < mutation_rate:
#                 original_child = child
#                 child = mutate(child)
#                 print(f"Mutated from {original_child} to {child}")
#             mutated_offspring.append(child)

#         population = next_population + mutated_offspring

#     # Plot Pareto Front at end
#     plot_pareto_front(arch_performance)

#     return population

## 25 march

In [None]:
## after below run above code

In [None]:
import random
import os
import torch
import torch.nn as nn
from torch.optim import Adam
from timm import create_model
import time

# Path to save the models after fine-tuning
SAVE_PATH = '/home/pratibha/nas_vision/weights-img-evol2'
# SAVE_PATH = '/kaggle/working/'

# Set the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# First-time loading pretrained weights for initialization
def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
    pretrained_vit = create_model(pretrained_model_name, pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    # Match keys between pretrained and current model
    model_state_dict = model.state_dict()
    filtered_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}

    # Load pretrained weights
    model.load_state_dict(filtered_dict, strict=False)
    print(f"Pretrained weights loaded into {model.__class__.__name__} successfully.")

# Check if pretrained weights are loaded correctly
def check_pretrained_weights(model, generation=0, model_type="subnetwork"):
    pretrained_vit = create_model("vit_base_patch16_224", pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    model_state_dict = model.state_dict()
    matching_keys = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
    
    if len(matching_keys) > 0:
        print(f"Generation {generation + 1}: {model_type} model has loaded {len(matching_keys)} layers from pretrained weights.")
    else:
        print(f"Generation {generation + 1}: {model_type} model has NOT loaded any pretrained weights.")

# Sample Subnetwork - Randomly sample hyperparameters (depth, num_heads, etc.)
def sample_subnetwork(seen_architectures):
    while True:
        depth = random.choice([4, 6, 8, 10, 12])
        num_heads = random.choice([4, 8, 12, 16])
        mlp_ratio = random.choice([2.0, 4.0, 6.0])
        embed_dim = 768  # Fixed embedding dimension
        
        architecture = (depth, num_heads, mlp_ratio, embed_dim)
        
        # Skip if architecture has already been sampled
        if architecture not in seen_architectures:
            seen_architectures.add(architecture)
            print(f"Sampled architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
            
            # Create the model to calculate its number of parameters
            # sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000)
            sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
                                        depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, 
                                        num_classes=200  
                                    )
            num_params = count_parameters(sampled_model)
            print(f"Number of parameters in the sampled model: {num_params:,}")
            
            return architecture
        else:
            print(f"Repeated architecture found, resampling...")

# Count number of trainable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Evaluate architecture: accuracy, latency, and memory usage
def evaluate_architecture(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    y_true = []
    y_pred = []
    
    criterion = nn.CrossEntropyLoss()

    # Start measuring inference latency
    start_time = time.time()

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)  # Move to the same device
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    # Measure total time for inference (latency)
    latency = (time.time() - start_time) / len(test_loader.dataset)

    # Compute accuracy
    accuracy = 100 * correct / total

    # Compute memory usage (rough estimation)
    num_params = count_parameters(model)
    memory_usage = (num_params * 4) / (1024 ** 2)  # Convert bytes to MB (FP32)

    # Compute average loss
    test_loss = running_loss / len(test_loader)

    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.2f}%, Latency: {latency:.6f} seconds/image, Memory Usage: {memory_usage:.2f} MB")

    return accuracy, test_loss, latency, memory_usage




# Estimate memory usage of a model during inference (rough estimation)
def estimate_memory_usage(model):                                             ## this funtion is not needed
    # Create dummy input matching the expected shape of the input tensor
    dummy_input = torch.randn(1, 3, 224, 224).to(device)  # Example for ViT (3-channel image of size 224x224)
    
    # Use torch.utils.benchmark to measure memory usage during inference
    start_mem = torch.cuda.memory_allocated()
    
    # Run the model once with the dummy input
    with torch.no_grad():
        model(dummy_input)
    
    end_mem = torch.cuda.memory_allocated()
    memory_usage = (end_mem - start_mem) / (1024 ** 2)  # Convert bytes to MB
    return memory_usage


def calculate_crowding_distance(population, test_loader):
    crowding_distances = [0] * len(population)
    num_objectives = 3  # Accuracy, Latency, Memory

    # Evaluate each architecture once, then reuse the results
    evaluated_results = []
    for arch in population:
        # # model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
        #                    depth=arch[0], num_heads=arch[1],
        #                    mlp_ratio=arch[2], num_classes=10).to(device)
        model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
                            depth=arch[0], num_heads=arch[1], mlp_ratio=arch[2], 
                            num_classes=200).to(device)

        accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
        memory = count_parameters(model) * 4  # memory in bytes
        
        evaluated_results.append((accuracy, latency, memory))
        del model
        torch.cuda.empty_cache()

    for objective_index in range(num_objectives):
        sorted_indices = sorted(range(len(population)),
                                key=lambda idx: evaluated_results[idx][objective_index])
        
        crowding_distances[sorted_indices[0]] = crowding_distances[sorted_indices[-1]] = float('inf')

        for i in range(1, len(sorted_indices) - 1):
            prev_value = evaluated_results[sorted_indices[i - 1]][objective_index]
            next_value = evaluated_results[sorted_indices[i + 1]][objective_index]
            distance = next_value - prev_value
            crowding_distances[sorted_indices[i]] += distance

    return crowding_distances


def dominates(model1, model2, test_loader):
    # Evaluate both models on the test set
    accuracy1, latency1, _, _ = evaluate_architecture(model1, test_loader)
    accuracy2, latency2, _, _ = evaluate_architecture(model2, test_loader)
    
    # Calculate memory usage as the number of parameters * 4 bytes (FP32)
    memory1 = count_parameters(model1) * 4  # Memory in bytes
    memory2 = count_parameters(model2) * 4  # Memory in bytes
    
    # Compare performance metrics
    dominates_in_accuracy = accuracy1 >= accuracy2
    dominates_in_latency = latency1 <= latency2
    dominates_in_memory = memory1 <= memory2

    # Return True if model1 dominates model2 in all aspects
    return dominates_in_accuracy and dominates_in_latency and dominates_in_memory


# Mutation: Randomly mutate architecture's hyperparameters
def mutate(architecture):
    depth, num_heads, mlp_ratio, embed_dim = architecture
    if random.random() < 0.5: depth = random.choice([4, 6, 8, 10, 12])
    if random.random() < 0.5: num_heads = random.choice([4, 8, 12, 16])
    if random.random() < 0.5: mlp_ratio = random.choice([2.0, 4.0, 6.0])
    print(f"Mutated architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
    return depth, num_heads, mlp_ratio, embed_dim

# One-Point Crossover: Combine two parent architectures to create new architectures
def one_point_crossover(parent1, parent2):
    crossover_point = random.choice([0, 1, 2, 3])  # Crossover at depth, num_heads, etc.
    child1 = parent1[:crossover_point] + parent2[crossover_point:]
    child2 = parent2[:crossover_point] + parent1[crossover_point:]
    print(f"Crossover result: Child1={child1}, Child2={child2}")
    return child1, child2




# Optimized Pareto selection based on stored performance metrics
def pareto_selection(arch_performance):
    def dominates(perf1, perf2):
        acc1, lat1, mem1 = perf1
        acc2, lat2, mem2 = perf2
        return (acc1 >= acc2 and lat1 <= lat2 and mem1 <= mem2) and (acc1 > acc2 or lat1 < lat2 or mem1 < mem2)

    ranks = {}
    for arch1, perf1 in arch_performance.items():
        dominated_count = 0
        for arch2, perf2 in arch_performance.items():
            if arch1 != arch2 and dominates(perf2, perf1):
                dominated_count += 1
        ranks[arch1] = dominated_count

    # Sort architectures by rank (lower dominated_count = better)
    sorted_population = sorted(ranks.keys(), key=lambda arch: ranks[arch])
    return sorted_population

# Fine-tune model on dataset (train for a few epochs)
def fine_tune_model(sampled_model, train_loader, test_loader, epochs=3, architecture_folder=None):
    print(f"Fine-tuning model with architecture: Depth={sampled_model.depth}, Num Heads={sampled_model.num_heads}, MLP Ratio={sampled_model.mlp_ratio}")
    sampled_model.to(device)  # Ensure the model is on the correct device
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(sampled_model.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        sampled_model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)  # Ensure inputs are on the same device
            optimizer.zero_grad()
            outputs = sampled_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        test_accuracy, test_loss, test_latency, memory_usage = evaluate_architecture(sampled_model, test_loader)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%, Test Latency: {test_latency:.6f} seconds/image")

    # Save the model after fine-tuning
    if architecture_folder:
        os.makedirs(architecture_folder, exist_ok=True)
        torch.save(sampled_model.state_dict(), os.path.join(architecture_folder, 'checkpoint.pth'))
    return sampled_model





def save_top_ranked_models(population, arch_performance, generation):
    top_n = min(5, len(population))
    for idx, arch in enumerate(population[:top_n]):
        depth, num_heads, mlp_ratio, embed_dim = arch
        # model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
        #                    num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000).to(device)
        model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
                            num_heads=num_heads, mlp_ratio=mlp_ratio, 
                            num_classes=200).to(device)


        architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
        checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
        model.load_state_dict(torch.load(checkpoint_path))

        top_model_path = os.path.join(SAVE_PATH, f'top_ranked_model_gen{generation+1}_rank_{idx+1}.pth')
        torch.save(model.state_dict(), top_model_path)
        
        acc, lat, mem = arch_performance[arch]

        # with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
        #     f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
        #     f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")
        with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
            f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
            f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")


        print(f"Saved top-ranked model: Generation {generation+1}, Rank {idx+1} (Acc={acc:.2f}%, Lat={lat:.6f}, Mem={mem/(1024**2):.2f}MB)")
        
        


def plot_pareto_front(arch_performance):
    accuracies = [v[0] for v in arch_performance.values()]
    latencies = [v[1] for v in arch_performance.values()]
    memories = [v[2] / (1024**2) for v in arch_performance.values()]  # convert to MB

    # Accuracy vs Latency
    plt.figure(figsize=(8,6))
    plt.scatter(latencies, accuracies, c='blue')
    plt.xlabel('Latency (s/image)')
    plt.ylabel('Accuracy (%)')
    plt.title('Pareto Front (Accuracy vs Latency)')
    plt.grid()
    plt.show()

    # Accuracy vs Memory
    plt.figure(figsize=(8,6))
    plt.scatter(memories, accuracies, c='green')
    plt.xlabel('Memory (MB)')
    plt.ylabel('Accuracy (%)')
    plt.title('Pareto Front (Accuracy vs Memory)')
    plt.grid()
    plt.show()

#

def evolutionary_algorithm(population_size=10, generations=5, mutation_rate=0.1, crossover_rate=0.7, train_loader=None, test_loader=None):
    seen_architectures = set()
    population = [sample_subnetwork(seen_architectures) for _ in range(population_size)]
    arch_performance = {}

    prev_best_accuracy = 0
    no_improvement_count = 0

    for generation in range(generations):
        print(f"\n--- Generation {generation + 1}/{generations} ---")

        for arch in population:
            depth, num_heads, mlp_ratio, embed_dim = arch
            architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
            checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')

            model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
                               depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
                               num_classes=200).to(device)

            # Clearly load weights once per architecture
            if os.path.exists(checkpoint_path):
                model.load_state_dict(torch.load(checkpoint_path))
                print(f"Loaded weights from previous generation for architecture {arch}")
            else:
                load_pretrained_weights(model)

            fine_tune_model(model, train_loader, test_loader, epochs=3, architecture_folder=architecture_folder)

            accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
            memory = count_parameters(model) * 4 / (1024 ** 2)  # MB
            arch_performance[arch] = (accuracy, latency, memory)

            del model
            torch.cuda.empty_cache()

        # Pareto selection
        population = pareto_selection(arch_performance)

        # Saving top-ranked models
        save_top_ranked_models(population, arch_performance, generation)

        # Check for Pareto front convergence (early stopping criteria)
        current_best_accuracy = arch_performance[population[0]][0]
        if current_best_accuracy - prev_best_accuracy < 1.0:
            no_improvement_count += 1
            print(f"Minimal improvement detected: {current_best_accuracy - prev_best_accuracy:.2f}%")
            if no_improvement_count >= 2:
                print("Pareto front has converged. Stopping early.")
                break
        else:
            no_improvement_count = 0
        prev_best_accuracy = current_best_accuracy

        # Generate offspring
        next_population = population[:len(population)//2]  # Only top half
        offspring = []

        for i in range(0, len(next_population)-1, 2):
            parent1, parent2 = next_population[i], next_population[i+1]

            if random.random() < crossover_rate:
                child1, child2 = one_point_crossover(parent1, parent2)
                print(f"Crossover parents: {parent1} & {parent2}")
                offspring.extend([child1, child2])
            else:
                offspring.extend([parent1, parent2])

        # Mutation with clear logging
        mutated_offspring = []
        for child in offspring:
            if random.random() < mutation_rate:
                original_child = child
                child = mutate(child)
                print(f"Mutated from {original_child} to {child}")
            mutated_offspring.append(child)

        population = next_population + mutated_offspring

    # Plot Pareto Front at the end
    plot_pareto_front(arch_performance)

    return population

# Run the evolutionary algorithm
evolutionary_algorithm(population_size=10, generations=5, train_loader=train_loader, test_loader=test_loader)

# Call the algorithm
# evolutionary_algorithm(population_size=10, generations=5, train_loader=train_loader, test_loader=test_loader)

# Run the evolutionary algorithm
# evolutionary_algorithm(population_size=5, generations=2, train_loader=train_loader, test_loader=test_loader)


Sampled architecture: Depth=10, Num Heads=8, MLP Ratio=4.0, Embed Dim=768
Number of parameters in the sampled model: 71,776,712
Sampled architecture: Depth=12, Num Heads=4, MLP Ratio=6.0, Embed Dim=768
Number of parameters in the sampled model: 114,282,440
Sampled architecture: Depth=4, Num Heads=4, MLP Ratio=4.0, Embed Dim=768
Number of parameters in the sampled model: 29,249,480
Sampled architecture: Depth=10, Num Heads=4, MLP Ratio=6.0, Embed Dim=768
Number of parameters in the sampled model: 95,385,032
Sampled architecture: Depth=6, Num Heads=16, MLP Ratio=6.0, Embed Dim=768
Number of parameters in the sampled model: 57,590,216
Sampled architecture: Depth=4, Num Heads=4, MLP Ratio=2.0, Embed Dim=768
Number of parameters in the sampled model: 19,806,152
Sampled architecture: Depth=8, Num Heads=8, MLP Ratio=2.0, Embed Dim=768
Number of parameters in the sampled model: 38,714,312
Sampled architecture: Depth=12, Num Heads=8, MLP Ratio=4.0, Embed Dim=768
Number of parameters in the samp

RuntimeError: [enforce fail at inline_container.cc:626] . unexpected pos 50239104 vs 50238992

In [None]:
# # just defining model again here for easily avaliability
# import torch
# import torch.nn as nn

# class DynamicPatchEmbed(nn.Module):
#     def __init__(self, img_size=224, patch_size=16, embed_dim=768):
#         super().__init__()
#         self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
#         self.num_patches = (img_size // patch_size) ** 2

#     def forward(self, x):
#         x = self.proj(x)
#         return x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)


# class DynamicMultiHeadAttention(nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super().__init__()
#         self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
#         self.proj = nn.Linear(embed_dim, embed_dim)
#         self.scale = (embed_dim // num_heads) ** -0.5
#         self.num_heads = num_heads  # Store num_heads as a class attribute

#         # Ensure that the number of heads divides the embedding dimension
#         assert embed_dim % num_heads == 0, f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"

#     def forward(self, x):
#         B, N, C = x.shape
#         qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

#         q, k, v = qkv[0], qkv[1], qkv[2]
#         attn = (q @ k.transpose(-2, -1)) * self.scale
#         attn = attn.softmax(dim=-1)
#         x = (attn @ v).transpose(1, 2).reshape(B, N, C)
#         return self.proj(x)

# class MLPBlock(nn.Module):  
#     def __init__(self, embed_dim, mlp_ratio):
#         super().__init__()
#         hidden_dim = int(embed_dim * mlp_ratio)
#         self.fc1 = nn.Linear(embed_dim, hidden_dim)  # Matches `mlp.fc1`
#         self.act = nn.GELU()
#         self.fc2 = nn.Linear(hidden_dim, embed_dim)  # Matches `mlp.fc2`

#     def forward(self, x):
#         return self.fc2(self.act(self.fc1(x)))

# class DynamicTransformerBlock(nn.Module):
#     def __init__(self, embed_dim, num_heads, mlp_ratio=4.0):
#         super().__init__()
#         self.norm1 = nn.LayerNorm(embed_dim)
#         self.attn = DynamicMultiHeadAttention(embed_dim, num_heads)
#         self.norm2 = nn.LayerNorm(embed_dim)
        
#         #  Fix: Wrap MLP inside a separate module to match ViT
#         self.mlp = MLPBlock(embed_dim, mlp_ratio)  

#     def forward(self, x):
#         x = x + self.attn(self.norm1(x))
#         x = x + self.mlp(self.norm2(x))
#         return x
# class DynamicViT(nn.Module):
#     def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, num_classes=10):
#         super().__init__()
#         self.depth = depth  # Store depth as an instance variable
#         self.num_heads = num_heads  # Store num_heads as an instance variable
#         self.mlp_ratio = mlp_ratio  # Store mlp_ratio as an instance variable
        
#         self.patch_embed = DynamicPatchEmbed(img_size, patch_size, embed_dim)
        
#         # Fix: Correct positional embedding key
#         self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, embed_dim))
        
#         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
#         self.blocks = nn.ModuleList([DynamicTransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth)])
#         self.norm = nn.LayerNorm(embed_dim)
#         self.head = nn.Linear(embed_dim, num_classes)

#     def forward(self, x):
#         x = self.patch_embed(x)
#         B = x.shape[0]

#         # Add class token
#         cls_tokens = self.cls_token.expand(B, -1, -1)
#         x = torch.cat((cls_tokens, x), dim=1)
        
#         x = x + self.pos_embed

#         for block in self.blocks:
#             x = block(x)

#         x = self.norm(x[:, 0])
#         return self.head(x)


# # Path to save the models after fine-tuning
# SAVE_PATH = '/home/pratibha/nas_vision/weights-img-evol1'
# # SAVE_PATH = '/kaggle/working/'

# # Set the device (GPU if available, else CPU)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # First-time loading pretrained weights for initialization
# def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
#     pretrained_vit = create_model(pretrained_model_name, pretrained=True)
#     pretrained_state_dict = pretrained_vit.state_dict()
    
#     # Match keys between pretrained and current model
#     model_state_dict = model.state_dict()
#     filtered_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}

#     # Load pretrained weights
#     model.load_state_dict(filtered_dict, strict=False)
#     print(f"Pretrained weights loaded into {model.__class__.__name__} successfully.")

# # Check if pretrained weights are loaded correctly
# def check_pretrained_weights(model, generation=0, model_type="subnetwork"):
#     pretrained_vit = create_model("vit_base_patch16_224", pretrained=True)
#     pretrained_state_dict = pretrained_vit.state_dict()
    
#     model_state_dict = model.state_dict()
#     matching_keys = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
    
#     if len(matching_keys) > 0:
#         print(f"Generation {generation + 1}: {model_type} model has loaded {len(matching_keys)} layers from pretrained weights.")
#     else:
#         print(f"Generation {generation + 1}: {model_type} model has NOT loaded any pretrained weights.")

# # Sample Subnetwork - Randomly sample hyperparameters (depth, num_heads, etc.)
# def sample_subnetwork(seen_architectures):
#     while True:
#         depth = random.choice([4, 6, 8, 10, 12])
#         num_heads = random.choice([4, 8, 12, 16])
#         mlp_ratio = random.choice([2.0, 4.0, 6.0])
#         embed_dim = 768  # Fixed embedding dimension
        
#         architecture = (depth, num_heads, mlp_ratio, embed_dim)
        
#         # Skip if architecture has already been sampled
#         if architecture not in seen_architectures:
#             seen_architectures.add(architecture)
#             print(f"Sampled architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
            
#             # Create the model to calculate its number of parameters
#             # sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000)
#             sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
#                                         depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, 
#                                         num_classes=200  
#                                     )
#             num_params = count_parameters(sampled_model)
#             print(f"Number of parameters in the sampled model: {num_params:,}")
            
#             return architecture
#         else:
#             print(f"Repeated architecture found, resampling...")

# # Count number of trainable parameters
# def count_parameters(model):
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# # Evaluate architecture: accuracy, latency, and memory usage
# def evaluate_architecture(model, test_loader):
#     model.eval()
#     correct = 0
#     total = 0
#     running_loss = 0.0
#     y_true = []
#     y_pred = []
    
#     criterion = nn.CrossEntropyLoss()

#     # Start measuring inference latency
#     start_time = time.time()

#     with torch.no_grad():
#         for images, labels in test_loader:
#             images, labels = images.to(device), labels.to(device)  # Move to the same device
#             outputs = model(images)
#             loss = criterion(outputs, labels)
#             running_loss += loss.item()

#             _, predicted = torch.max(outputs, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#             y_true.extend(labels.cpu().numpy())
#             y_pred.extend(predicted.cpu().numpy())

#     # Measure total time for inference (latency)
#     latency = (time.time() - start_time) / len(test_loader.dataset)

#     # Compute accuracy
#     accuracy = 100 * correct / total

#     # Compute memory usage (rough estimation)
#     num_params = count_parameters(model)
#     memory_usage = (num_params * 4) / (1024 ** 2)  # Convert bytes to MB (FP32)

#     # Compute average loss
#     test_loss = running_loss / len(test_loader)

#     print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.2f}%, Latency: {latency:.6f} seconds/image, Memory Usage: {memory_usage:.2f} MB")

#     return accuracy, test_loss, latency, memory_usage

# def calculate_crowding_distance(population, test_loader):
#     crowding_distances = [0] * len(population)
#     num_objectives = 3  # Accuracy, Latency, Memory

#     # Evaluate each architecture once, then reuse the results
#     evaluated_results = []
#     for arch in population:
#         # # model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
#         #                    depth=arch[0], num_heads=arch[1],
#         #                    mlp_ratio=arch[2], num_classes=10).to(device)
#         model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
#                             depth=arch[0], num_heads=arch[1], mlp_ratio=arch[2], 
#                             num_classes=200).to(device)

#         accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
#         memory = count_parameters(model) * 4  # memory in bytes
        
#         evaluated_results.append((accuracy, latency, memory))
#         del model
#         torch.cuda.empty_cache()

#     for objective_index in range(num_objectives):
#         sorted_indices = sorted(range(len(population)),
#                                 key=lambda idx: evaluated_results[idx][objective_index])
        
#         crowding_distances[sorted_indices[0]] = crowding_distances[sorted_indices[-1]] = float('inf')

#         for i in range(1, len(sorted_indices) - 1):
#             prev_value = evaluated_results[sorted_indices[i - 1]][objective_index]
#             next_value = evaluated_results[sorted_indices[i + 1]][objective_index]
#             distance = next_value - prev_value
#             crowding_distances[sorted_indices[i]] += distance

#     return crowding_distances


# def dominates(model1, model2, test_loader):
#     # Evaluate both models on the test set
#     accuracy1, latency1, _, _ = evaluate_architecture(model1, test_loader)
#     accuracy2, latency2, _, _ = evaluate_architecture(model2, test_loader)
    
#     # Calculate memory usage as the number of parameters * 4 bytes (FP32)
#     memory1 = count_parameters(model1) * 4  # Memory in bytes
#     memory2 = count_parameters(model2) * 4  # Memory in bytes
    
#     # Compare performance metrics
#     dominates_in_accuracy = accuracy1 >= accuracy2
#     dominates_in_latency = latency1 <= latency2
#     dominates_in_memory = memory1 <= memory2

#     # Return True if model1 dominates model2 in all aspects
#     return dominates_in_accuracy and dominates_in_latency and dominates_in_memory


# # Mutation: Randomly mutate architecture's hyperparameters
# def mutate(architecture):
#     depth, num_heads, mlp_ratio, embed_dim = architecture
#     if random.random() < 0.5: depth = random.choice([4, 6, 8, 10, 12])
#     if random.random() < 0.5: num_heads = random.choice([4, 8, 12, 16])
#     if random.random() < 0.5: mlp_ratio = random.choice([2.0, 4.0, 6.0])
#     print(f"Mutated architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
#     return depth, num_heads, mlp_ratio, embed_dim

# # One-Point Crossover: Combine two parent architectures to create new architectures
# def one_point_crossover(parent1, parent2):
#     crossover_point = random.choice([0, 1, 2, 3])  # Crossover at depth, num_heads, etc.
#     child1 = parent1[:crossover_point] + parent2[crossover_point:]
#     child2 = parent2[:crossover_point] + parent1[crossover_point:]
#     print(f"Crossover result: Child1={child1}, Child2={child2}")
#     return child1, child2

# # Optimized Pareto selection based on stored performance metrics
# def pareto_selection(arch_performance):
#     def dominates(perf1, perf2):
#         acc1, lat1, mem1 = perf1
#         acc2, lat2, mem2 = perf2
#         return (acc1 >= acc2 and lat1 <= lat2 and mem1 <= mem2) and (acc1 > acc2 or lat1 < lat2 or mem1 < mem2)

#     ranks = {}
#     for arch1, perf1 in arch_performance.items():
#         dominated_count = 0
#         for arch2, perf2 in arch_performance.items():
#             if arch1 != arch2 and dominates(perf2, perf1):
#                 dominated_count += 1
#         ranks[arch1] = dominated_count

#     # Sort architectures by rank (lower dominated_count = better)
#     sorted_population = sorted(ranks.keys(), key=lambda arch: ranks[arch])
#     return sorted_population

# # Fine-tune model on dataset (train for a few epochs)
# def fine_tune_model(sampled_model, train_loader, test_loader, epochs=3, architecture_folder=None):
#     print(f"Fine-tuning model with architecture: Depth={sampled_model.depth}, Num Heads={sampled_model.num_heads}, MLP Ratio={sampled_model.mlp_ratio}")
#     sampled_model.to(device)  # Ensure the model is on the correct device
#     criterion = nn.CrossEntropyLoss()
#     optimizer = Adam(sampled_model.parameters(), lr=1e-4)
    
#     for epoch in range(epochs):
#         sampled_model.train()
#         running_loss = 0.0
#         for images, labels in train_loader:
#             images, labels = images.to(device), labels.to(device)  # Ensure inputs are on the same device
#             optimizer.zero_grad()
#             outputs = sampled_model(images)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#             running_loss += loss.item()
        
#         test_accuracy, test_loss, test_latency, memory_usage = evaluate_architecture(sampled_model, test_loader)
#         print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%, Test Latency: {test_latency:.6f} seconds/image")

#     # Save the model after fine-tuning
#     if architecture_folder:
#         os.makedirs(architecture_folder, exist_ok=True)
#         torch.save(sampled_model.state_dict(), os.path.join(architecture_folder, 'checkpoint.pth'))
#     return sampled_model

# def save_top_ranked_models(population, arch_performance, generation):
#     top_n = min(5, len(population))
#     for idx, arch in enumerate(population[:top_n]):
#         depth, num_heads, mlp_ratio, embed_dim = arch
       
#         model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
#                             num_heads=num_heads, mlp_ratio=mlp_ratio, 
#                             num_classes=200).to(device)


#         architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
#         checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
#         model.load_state_dict(torch.load(checkpoint_path))

#         top_model_path = os.path.join(SAVE_PATH, f'top_ranked_model_gen{generation+1}_rank_{idx+1}.pth')
#         torch.save(model.state_dict(), top_model_path)
        
#         acc, lat, mem = arch_performance[arch]

#         # with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
#         #     f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
#         #     f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")
#         with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
#             f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
#             f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")


#         print(f"Saved top-ranked model: Generation {generation+1}, Rank {idx+1} (Acc={acc:.2f}%, Lat={lat:.6f}, Mem={mem/(1024**2):.2f}MB)")
        

# def evolutionary_algorithm(population_size=5, generations=2, mutation_rate=0.1, crossover_rate=0.7, train_loader=None, test_loader=None):
#     seen_architectures = set()
#     population = [sample_subnetwork(seen_architectures) for _ in range(population_size)]

#     # Evaluate and store metrics only once per architecture per generation
#     arch_performance = {}

#     for generation in range(generations):
#         print(f"\n--- Generation {generation + 1}/{generations} ---")
        
#         for arch in population:
#             depth, num_heads, mlp_ratio, embed_dim = arch
#             architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")

           
#             model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
#                                 depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
#                                 num_classes=200).to(device)
            
#             if generation == 0:
#                 load_pretrained_weights(model)
#             else:
#                 checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
#                 if os.path.exists(checkpoint_path):
#                     model.load_state_dict(torch.load(checkpoint_path))
#                     print(f"Generation {generation + 1}: Loaded fine-tuned weights from previous generation for {arch}.")
#                 else:
#                     print(f"Generation {generation + 1}: Fine-tuned weights not found for {arch}. Loading pretrained ViT weights.")
#                     load_pretrained_weights(model)

#             check_pretrained_weights(model, generation=generation, model_type="subnetwork")

#             fine_tune_model(
#                 model, train_loader, test_loader, epochs=3, architecture_folder=architecture_folder
#             )

#             # Evaluate once per architecture
#             accuracy, _, latency, memory = evaluate_architecture(model, test_loader)
#             arch_performance[arch] = (accuracy, latency, memory)

#             del model
#             torch.cuda.empty_cache()

#         # Pareto selection
#         population = pareto_selection(arch_performance)

#         # Save top models clearly ranked (1 = best)
#         save_top_ranked_models(population, arch_performance, generation)

#         # Generate offspring using crossover and mutation
#         offspring = []
#         for i in range(0, len(population)-1, 2):
#             if random.random() < crossover_rate:
#                 child1, child2 = one_point_crossover(population[i], population[i + 1])
#                 offspring.extend([child1, child2])
#             else:
#                 offspring.extend([population[i], population[i + 1]])

#         offspring = [mutate(child) if random.random() < mutation_rate else child for child in offspring]

#         # Next-generation combines top parents and offspring
#         population = population[:len(population)//2] + offspring

#     return population

# # Run the evolutionary algorithm
# evolutionary_algorithm(population_size=5, generations=2, train_loader=train_loader, test_loader=test_loader)

# Saved top-ranked model: Generation 1, Rank 1 (Acc=37.52%, Lat=0.003791, Mem=0.00MB)
# Saved top-ranked model: Generation 1, Rank 2 (Acc=36.44%, Lat=0.003658, Mem=0.00MB)
# Saved top-ranked model: Generation 1, Rank 3 (Acc=68.70%, Lat=0.003857, Mem=0.00MB)
# Saved top-ranked model: Generation 1, Rank 4 (Acc=38.75%, Lat=0.004327, Mem=0.00MB)
# Saved top-ranked model: Generation 1, Rank 5 (Acc=37.01%, Lat=0.004950, Mem=0.00MB)
# this is first generation but you can see here in output ranking is not good i think this is because the dominates funtion in pareto selection also here in hthis code you are not considering crowding distance in pareto selection correct it and also whenever each model is finetuned for 5 epochs draw its graph of train loss test loss ans test accuracy vs epochs immediately. also print execution time of each architecuture after each epoch analyze all the code and do all these modifications.

In [None]:
## what to include 
crowding distance in pareto selection funtion also check ranking funtion how models are ranked after finetuned because ranking has some error
draw plot after every subnetwork is finetuned


In [None]:
############################################################################################################################
##################################################################################################################################

## may 28

In [None]:
import random
import os
import torch
import torch.nn as nn
from torch.optim import Adam
from timm import create_model
import time

# Path to save the models after fine-tuning
# SAVE_PATH = '/SN02DATA/nas_vision/evol_img1k-wts'
SAVE_PATH = '/home/pratibha/nas_vision/weights-img-evol28-may'

# SAVE_PATH = '/kaggle/working/'

# Set the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# First-time loading pretrained weights for initialization
def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
    pretrained_vit = create_model(pretrained_model_name, pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    # Match keys between pretrained and current model
    model_state_dict = model.state_dict()
    filtered_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}

    # Load pretrained weights
    model.load_state_dict(filtered_dict, strict=False)
    print(f"Pretrained weights loaded into {model.__class__.__name__} successfully.")

# Check if pretrained weights are loaded correctly
def check_pretrained_weights(model, generation=0, model_type="subnetwork"):
    pretrained_vit = create_model("vit_base_patch16_224", pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    model_state_dict = model.state_dict()
    matching_keys = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
    
    if len(matching_keys) > 0:
        print(f"Generation {generation + 1}: {model_type} model has loaded {len(matching_keys)} layers from pretrained weights.")
    else:
        print(f"Generation {generation + 1}: {model_type} model has NOT loaded any pretrained weights.")

# Sample Subnetwork - Randomly sample hyperparameters (depth, num_heads, etc.)
def sample_subnetwork(seen_architectures):
    while True:
        depth = random.choice([6, 8, 10, 12])
        num_heads = random.choice([4, 8, 12, 16])
        mlp_ratio = random.choice([2.0, 4.0, 6.0])
        embed_dim = 768  # Fixed embedding dimension
        
        architecture = (depth, num_heads, mlp_ratio, embed_dim)
        
        # Skip if architecture has already been sampled
        if architecture not in seen_architectures:
            seen_architectures.add(architecture)
            print(f"Sampled architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
            
            # Create the model to calculate its number of parameters
            # sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000)
            sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
                                        depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, 
                                        num_classes=200
                                    )
            num_params = count_parameters(sampled_model)
            print(f"Number of parameters in the sampled model: {num_params:,}")
            
            return architecture
        else:
            print(f"Repeated architecture found, resampling...")

# Count number of trainable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def topk_accuracy(output, target, topk=(1,5)):
    """Computes the top-k accuracy for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size).item())
    return res  # [top1, top5]



from ptflops import get_model_complexity_info

def get_macs(model):
    with torch.cuda.device(0):
        macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=False, print_per_layer_stat=False)
    return macs

def evaluate_architecture(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    top1_total = 0
    top5_total = 0
    criterion = nn.CrossEntropyLoss()
    start_time = time.time()

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            top1, top5 = topk_accuracy(outputs, labels, topk=(1,5))
            top1_total += top1 * labels.size(0) / 100.0
            top5_total += top5 * labels.size(0) / 100.0
            total += labels.size(0)

    latency = (time.time() - start_time) / total
    accuracy = 100 * top1_total / total
    top5_accuracy = 100 * top5_total / total
    num_params = count_parameters(model)
    memory_usage = (num_params * 4) / (1024 ** 2)
    test_loss = running_loss / len(test_loader)

    print(f"Test Loss: {test_loss:.4f}, Top-1 Acc: {accuracy:.2f}%, Top-5 Acc: {top5_accuracy:.2f}%, Latency: {latency:.6f}s/img, Mem: {memory_usage:.2f}MB")
    macs = get_macs(model)
    print(f"MACs: {macs / 1e6:.2f} M")
    return accuracy, top5_accuracy, test_loss, latency, memory_usage, macs





# Estimate memory usage of a model during inference (rough estimation)
def estimate_memory_usage(model):                                             ## this funtion is not needed
    # Create dummy input matching the expected shape of the input tensor
    dummy_input = torch.randn(1, 3, 224, 224).to(device)  # Example for ViT (3-channel image of size 224x224)
    
    # Use torch.utils.benchmark to measure memory usage during inference
    start_mem = torch.cuda.memory_allocated()
    
    # Run the model once with the dummy input
    with torch.no_grad():
        model(dummy_input)
    
    end_mem = torch.cuda.memory_allocated()
    memory_usage = (end_mem - start_mem) / (1024 ** 2)  # Convert bytes to MB
    return memory_usage


def calculate_crowding_distance(population, test_loader):
    crowding_distances = [0] * len(population)
    num_objectives = 3  # Accuracy, Latency, Memory

    # Evaluate each architecture once, then reuse the results
    evaluated_results = []
    for arch in population:
        # # model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
        #                    depth=arch[0], num_heads=arch[1],
        #                    mlp_ratio=arch[2], num_classes=10).to(device)
        model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
                            depth=arch[0], num_heads=arch[1], mlp_ratio=arch[2], 
                            num_classes=200).to(device)

        accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
        memory = count_parameters(model) * 4  # memory in bytes
        
        evaluated_results.append((accuracy, latency, memory))
        del model
        torch.cuda.empty_cache()

    for objective_index in range(num_objectives):
        sorted_indices = sorted(range(len(population)),
                                key=lambda idx: evaluated_results[idx][objective_index])
        
        crowding_distances[sorted_indices[0]] = crowding_distances[sorted_indices[-1]] = float('inf')

        for i in range(1, len(sorted_indices) - 1):
            prev_value = evaluated_results[sorted_indices[i - 1]][objective_index]
            next_value = evaluated_results[sorted_indices[i + 1]][objective_index]
            distance = next_value - prev_value
            crowding_distances[sorted_indices[i]] += distance

    return crowding_distances


def dominates(model1, model2, test_loader):
    # Evaluate both models on the test set
    accuracy1, _, _,latency1, _, _ = evaluate_architecture(model1, test_loader)
    accuracy2, _, _,latency2, _ , _= evaluate_architecture(model2, test_loader)
    
    # Calculate memory usage as the number of parameters * 4 bytes (FP32)
    memory1 = count_parameters(model1) * 4  # Memory in bytes
    memory2 = count_parameters(model2) * 4  # Memory in bytes
    
    # Compare performance metrics
    dominates_in_accuracy = accuracy1 >= accuracy2
    dominates_in_latency = latency1 <= latency2
    dominates_in_memory = memory1 <= memory2

    # Return True if model1 dominates model2 in all aspects
    return dominates_in_accuracy and dominates_in_latency and dominates_in_memory


# Mutation: Randomly mutate architecture's hyperparameters
def mutate(architecture):
    depth, num_heads, mlp_ratio, embed_dim = architecture
    if random.random() < 0.5: depth = random.choice([ 6, 8, 10, 12])
    if random.random() < 0.5: num_heads = random.choice([4, 8, 12, 16])
    if random.random() < 0.5: mlp_ratio = random.choice([2.0, 4.0, 6.0])
    print(f"Mutated architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
    return depth, num_heads, mlp_ratio, embed_dim

# One-Point Crossover: Combine two parent architectures to create new architectures
def one_point_crossover(parent1, parent2):
    crossover_point = random.choice([0, 1, 2, 3])  # Crossover at depth, num_heads, etc.
    child1 = parent1[:crossover_point] + parent2[crossover_point:]
    child2 = parent2[:crossover_point] + parent1[crossover_point:]
    print(f"Crossover result: Child1={child1}, Child2={child2}")
    return child1, child2



############################# this is not weight based instead it is pareto selection
# Optimized Pareto selection based on stored performance metrics
def pareto_selection(arch_performance):
    def dominates(perf1, perf2):
        acc1, lat1, mem1 = perf1
        acc2, lat2, mem2 = perf2
        return (acc1 >= acc2 and lat1 <= lat2 and mem1 <= mem2) and (acc1 > acc2 or lat1 < lat2 or mem1 < mem2)

    ranks = {}
    for arch1, perf1 in arch_performance.items():
        dominated_count = 0
        for arch2, perf2 in arch_performance.items():
            if arch1 != arch2 and dominates(perf2, perf1):
                dominated_count += 1
        ranks[arch1] = dominated_count

    # Sort architectures by rank (lower dominated_count = better)
    sorted_population = sorted(ranks.keys(), key=lambda arch: ranks[arch])
    return sorted_population



# def fine_tune_model(sampled_model, train_loader, test_loader, epochs=3, architecture_folder=None):
#     print(f"Fine-tuning model with architecture: Depth={sampled_model.depth}, Num Heads={sampled_model.num_heads}, MLP Ratio={sampled_model.mlp_ratio}")
#     sampled_model.to(device)
#     criterion = nn.CrossEntropyLoss()
#     optimizer = Adam(sampled_model.parameters(), lr=1e-4)
    
#     for epoch in range(epochs):
#         start_epoch = time.time()
#         sampled_model.train()
#         running_loss = 0.0
#         for images, labels in train_loader:
#             images, labels = images.to(device), labels.to(device)
#             optimizer.zero_grad()
#             outputs = sampled_model(images)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#             running_loss += loss.item()
#         epoch_time = time.time() - start_epoch
#         test_accuracy, test_top5, test_loss, test_latency, memory_usage = evaluate_architecture(sampled_model, test_loader)
#         print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss:.4f}, Top-1 Acc: {test_accuracy:.2f}%, Top-5 Acc: {test_top5:.2f}%, Latency: {test_latency:.6f}s/img, Time: {epoch_time:.2f}s")
#     # Save model code unchanged
#     if architecture_folder:
#         os.makedirs(architecture_folder, exist_ok=True)
#         torch.save(sampled_model.state_dict(), os.path.join(architecture_folder, 'checkpoint.pth'))
#     return sampled_model

def fine_tune_model(sampled_model, train_loader, test_loader, epochs=3, architecture_folder=None):
    print(f"Fine-tuning model with architecture: Depth={sampled_model.depth}, Num Heads={sampled_model.num_heads}, MLP Ratio={sampled_model.mlp_ratio}")
    sampled_model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(sampled_model.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        start_epoch = time.time()
        sampled_model.train()
        running_loss = 0.0
        
        # Training phase
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = sampled_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        # Evaluation phase
        epoch_time = time.time() - start_epoch
        test_accuracy, test_top5, test_loss, test_latency, memory_usage, macs = evaluate_architecture(sampled_model, test_loader)
        
        # Print epoch statistics
        print(f"\nEpoch {epoch + 1}/{epochs} Summary:")
        print(f"| Training Loss: {running_loss/len(train_loader):.4f}")
        print(f"| Test Loss: {test_loss:.4f}")
        print(f"| Top-1 Accuracy: {test_accuracy:.2f}%")
        print(f"| Top-5 Accuracy: {test_top5:.2f}%")
        print(f"| Latency: {test_latency:.6f}s/img")
        print(f"| Memory Usage: {memory_usage:.2f}MB")
        print(f"| MACs: {macs/1e6:.2f}M")
        print(f"| Epoch Time: {epoch_time:.2f}s\n")

    # Save model weights
    if architecture_folder:
        os.makedirs(architecture_folder, exist_ok=True)
        torch.save(sampled_model.state_dict(), os.path.join(architecture_folder, 'checkpoint.pth'))
    return sampled_model



def save_top_ranked_models(population, arch_performance, generation):
    top_n = min(5, len(population))
    for idx, arch in enumerate(population[:top_n]):
        depth, num_heads, mlp_ratio, embed_dim = arch
        # model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
        #                    num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000).to(device)
        model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
                            num_heads=num_heads, mlp_ratio=mlp_ratio, 
                            num_classes=200).to(device)


        architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
        checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
        model.load_state_dict(torch.load(checkpoint_path))

        top_model_path = os.path.join(SAVE_PATH, f'top_ranked_model_gen{generation+1}_rank_{idx+1}.pth')
        torch.save(model.state_dict(), top_model_path)
        
        acc, lat, mem = arch_performance[arch]

        # with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
        #     f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
        #     f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")
        with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
            f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
            f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")


        print(f"Saved top-ranked model: Generation {generation+1}, Rank {idx+1} (Acc={acc:.2f}%, Lat={lat:.6f}, Mem={mem/(1024**2):.2f}MB)")
        
        


def plot_pareto_front(arch_performance):
    accuracies = [v[0] for v in arch_performance.values()]
    latencies = [v[1] for v in arch_performance.values()]
    memories = [v[2] / (1024**2) for v in arch_performance.values()]  # convert to MB

    # Accuracy vs Latency
    plt.figure(figsize=(8,6))
    plt.scatter(latencies, accuracies, c='blue')
    plt.xlabel('Latency (s/image)')
    plt.ylabel('Accuracy (%)')
    plt.title('Pareto Front (Accuracy vs Latency)')
    plt.grid()
    plt.show()

    # Accuracy vs Memory
    plt.figure(figsize=(8,6))
    plt.scatter(memories, accuracies, c='green')
    plt.xlabel('Memory (MB)')
    plt.ylabel('Accuracy (%)')
    plt.title('Pareto Front (Accuracy vs Memory)')
    plt.grid()
    plt.show()

#

def evolutionary_algorithm(population_size=16, generations=5, mutation_rate=0.1, crossover_rate=0.7, train_loader=None, test_loader=None):
    seen_architectures = set()
    population = [sample_subnetwork(seen_architectures) for _ in range(population_size)]
    arch_performance = {}

    prev_best_accuracy = 0
    no_improvement_count = 0

    for generation in range(generations):
        print(f"\n--- Generation {generation + 1}/{generations} ---")

        for arch in population:
            depth, num_heads, mlp_ratio, embed_dim = arch
            architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
            checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')

            model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
                               depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
                               num_classes=200).to(device)

            # Clearly load weights once per architecture
            if os.path.exists(checkpoint_path):
                model.load_state_dict(torch.load(checkpoint_path))
                print(f"Loaded weights from previous generation for architecture {arch}")
            else:
                load_pretrained_weights(model)

            fine_tune_model(model, train_loader, test_loader, epochs=5, architecture_folder=architecture_folder)

            # accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
            accuracy, top5_accuracy, test_loss, latency, memory_usage, macs = evaluate_architecture(model, test_loader)
            memory = count_parameters(model) * 4 / (1024 ** 2)  # MB
            # arch_performance[arch] = (accuracy, latency, memory)
            arch_performance[arch] = (accuracy, top5_accuracy, latency, memory_usage, macs)

            del model
            torch.cuda.empty_cache()

        # Pareto selection
        population = pareto_selection(arch_performance)

        print("\nTop 5 Ranked Models of Generation", generation+1)
        for idx, arch in enumerate(population[:5]):
            acc, top5_acc, lat, mem, macs = arch_performance[arch]
            print(f"Rank {idx+1}: Model {arch} | Top-1 Acc: {acc:.2f}%, Top-5 Acc: {top5_acc:.2f}%, Latency: {lat:.6f}s/img, Mem: {mem:.2f}MB, MACs: {macs/1e6:.2f}M")
            # Saving top-ranked models
            save_top_ranked_models(population, arch_performance, generation)

        # Check for Pareto front convergence (early stopping criteria)
        current_best_accuracy = arch_performance[population[0]][0]
        if current_best_accuracy - prev_best_accuracy < 1.0:
            no_improvement_count += 1
            print(f"Minimal improvement detected: {current_best_accuracy - prev_best_accuracy:.2f}%")
            if no_improvement_count >= 2:
                print("Pareto front has converged. Stopping early.")
                break
        else:
            no_improvement_count = 0
        prev_best_accuracy = current_best_accuracy

        # Generate offspring
        next_population = population[:len(population)//2]  # Only top half
        offspring = []

        for i in range(0, len(next_population)-1, 2):
            parent1, parent2 = next_population[i], next_population[i+1]

            if random.random() < crossover_rate:
                child1, child2 = one_point_crossover(parent1, parent2)
                print(f"Crossover parents: {parent1} & {parent2}")
                offspring.extend([child1, child2])
            else:
                offspring.extend([parent1, parent2])

        # Mutation with clear logging
        mutated_offspring = []
        for child in offspring:
            if random.random() < mutation_rate:
                original_child = child
                child = mutate(child)
                print(f"Mutated from {original_child} to {child}")
            mutated_offspring.append(child)

        population = next_population + mutated_offspring

        print(f"\nAfter mutation and crossover, {len(mutated_offspring)} offspring models generated.")
        print("Only top 5 models will be used for the next generation.")

    # Plot Pareto Front at the end
    plot_pareto_front(arch_performance)

    return population

# Run the evolutionary algorithm
evolutionary_algorithm(population_size=10, generations=5, train_loader=train_loader, test_loader=test_loader)



### may 28 now we need to do pca on super net and then use evol algo

## bring everything here for easy avaliability

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
import random
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
# Path to your ImageNet data
data_dir = '/home/pratibha/nas_vision/vit_nas_imgnet/imagenet200'

# Load ImageNet dataset and filter only the first 200 classes
filtered_dataset = datasets.ImageFolder(root=data_dir, transform=transform)
# Use only the first 200 classes

In [2]:
train_size = int(0.8 * len(filtered_dataset))
test_size = len(filtered_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(filtered_dataset, [train_size, test_size])

# Create DataLoader for training and testing
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Check the number of samples in each set
print(f"Training set size: {len(train_dataset)}")
print(f"Test set size: {len(test_dataset)}")

Training set size: 99548
Test set size: 24888


In [None]:
# class DynamicViT(nn.Module):
#     def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, num_classes=10):
#         super().__init__()
#         self.patch_embed = DynamicPatchEmbed(img_size, patch_size, embed_dim)
        
#         #  Fix: Correct positional embedding key
#         self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, embed_dim))
        
#         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
#         self.blocks = nn.ModuleList([DynamicTransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth)])
#         self.norm = nn.LayerNorm(embed_dim)
#         self.head = nn.Linear(embed_dim, num_classes)

#     def forward(self, x):
#         x = self.patch_embed(x)
#         B = x.shape[0]

#         # Add class token
#         cls_tokens = self.cls_token.expand(B, -1, -1)
#         x = torch.cat((cls_tokens, x), dim=1)
        
#         x = x + self.pos_embed

#         for block in self.blocks:
#             x = block(x)

#         x = self.norm(x[:, 0])
#         return self.head(x)

In [None]:
## transformer super net
  ## implementing pca on this

In [3]:
# just defining model again here for easily avaliability
import torch
import torch.nn as nn

class DynamicPatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size // patch_size) ** 2

    def forward(self, x):
        x = self.proj(x)
        return x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)


class DynamicMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.scale = (embed_dim // num_heads) ** -0.5
        self.num_heads = num_heads  # Store num_heads as a class attribute

        # Ensure that the number of heads divides the embedding dimension
        assert embed_dim % num_heads == 0, f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)

class MLPBlock(nn.Module):  
    def __init__(self, embed_dim, mlp_ratio):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)  # Matches `mlp.fc1`
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)  # Matches `mlp.fc2`

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

class DynamicTransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = DynamicMultiHeadAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        #  Fix: Wrap MLP inside a separate module to match ViT
        self.mlp = MLPBlock(embed_dim, mlp_ratio)  

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x




class DynamicViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, num_classes=200):
        super().__init__()
        self.depth = depth  # Store depth as an instance variable
        self.num_heads = num_heads  # Store num_heads as an instance variable
        self.mlp_ratio = mlp_ratio  # Store mlp_ratio as an instance variable
        
        self.patch_embed = DynamicPatchEmbed(img_size, patch_size, embed_dim)
        
        # Fix: Correct positional embedding key
        self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, embed_dim))
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.blocks = nn.ModuleList([DynamicTransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        B = x.shape[0]

        # Add class token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        x = x + self.pos_embed

        for block in self.blocks:
            x = block(x)

        x = self.norm(x[:, 0])
        return self.head(x)


In [19]:
model = DynamicViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, num_classes=200).to(device)

## linear projecting layers

In [39]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
import random
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
# Path to your ImageNet data
data_dir = '/home/pratibha/nas_vision/vit_nas_imgnet/imagenet200'

# Load ImageNet dataset and filter only the first 200 classes
filtered_dataset = datasets.ImageFolder(root=data_dir, transform=transform)
# Use only the first 200 classes

In [40]:
train_size = int(0.8 * len(filtered_dataset))
test_size = len(filtered_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(filtered_dataset, [train_size, test_size])

# Create DataLoader for training and testing
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Check the number of samples in each set
print(f"Training set size: {len(train_dataset)}")
print(f"Test set size: {len(test_dataset)}")

Training set size: 99548
Test set size: 24888


In [41]:
# just defining model again here for easily avaliability
import torch
import torch.nn as nn

class DynamicPatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size // patch_size) ** 2
        self.embed_dim = embed_dim  # Track embedding dimension


    def forward(self, x):
        x = self.proj(x)
        return x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)


class DynamicMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.scale = (embed_dim // num_heads) ** -0.5
        self.num_heads = num_heads  # Store num_heads as a class attribute

        # Ensure that the number of heads divides the embedding dimension
        assert embed_dim % num_heads == 0, f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)

class MLPBlock(nn.Module):  
    def __init__(self, embed_dim, mlp_ratio):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)  # Matches `mlp.fc1`
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)  # Matches `mlp.fc2`

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

class DynamicTransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = DynamicMultiHeadAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        #  Fix: Wrap MLP inside a separate module to match ViT
        self.mlp = MLPBlock(embed_dim, mlp_ratio)  

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x




class DynamicViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, 
                 num_heads=12, mlp_ratio=4.0, num_classes=200):
        super().__init__()
        self.embed_dim = embed_dim  # Store embed_dim as instance variable
        self.depth = depth  # Store depth as an instance variable
        self.num_heads = num_heads  # Store num_heads as an instance variable
        self.mlp_ratio = mlp_ratio  # Store mlp_ratio as an instance variable
        
        self.patch_embed = DynamicPatchEmbed(img_size, patch_size, embed_dim)
        
        # Fix: Correct positional embedding key
        self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, embed_dim))
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.blocks = nn.ModuleList([DynamicTransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        B = x.shape[0]

        # Add class token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        x = x + self.pos_embed

        for block in self.blocks:
            x = block(x)

        x = self.norm(x[:, 0])
        return self.head(x)


In [None]:
## evolutionary

In [None]:
# First-time loading pretrained weights for initialization
# def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
#     pretrained_vit = create_model(pretrained_model_name, pretrained=True)
#     pretrained_state_dict = pretrained_vit.state_dict()
    
#     # Match keys between pretrained and current model
#     model_state_dict = model.state_dict()
#     filtered_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}

#     # Load pretrained weights
#     model.load_state_dict(filtered_dict, strict=False)
#     print(f"Pretrained weights loaded into {model.__class__.__name__} successfully.")
# def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
#     pretrained_vit = create_model(pretrained_model_name, pretrained=True)
#     pretrained_state_dict = pretrained_vit.state_dict()
    
#     # Get embedding dimensions
#     pretrained_embed_dim = pretrained_state_dict['pos_embed'].shape[-1]
#     current_embed_dim = model.embed_dim
    
#     # Create adaptation modules
#     adaptation_modules = nn.ModuleDict()
#     if pretrained_embed_dim != current_embed_dim:
#         adaptation_modules['pos_embed_proj'] = nn.Linear(pretrained_embed_dim, current_embed_dim)
#         adaptation_modules['cls_token_proj'] = nn.Linear(pretrained_embed_dim, current_embed_dim)
    
#     # Project pretrained weights
#     filtered_dict = {}
#     for k, v in pretrained_state_dict.items():
#         if k == 'pos_embed' and pretrained_embed_dim != current_embed_dim:
#             filtered_dict[k] = adaptation_modules['pos_embed_proj'](v)
#         elif k == 'cls_token' and pretrained_embed_dim != current_embed_dim:
#             filtered_dict[k] = adaptation_modules['cls_token_proj'](v)
#         elif k in model.state_dict() and v.shape == model.state_dict[k].shape:
#             filtered_dict[k] = v
    
#     model.load_state_dict(filtered_dict, strict=False)
#     print(f"Loaded pretrained weights with {'adaptation' if pretrained_embed_dim != current_embed_dim else 'no'} projection")



# Mutation: Randomly mutate architecture's hyperparameters
# def mutate(architecture):
#     depth, num_heads, mlp_ratio, embed_dim = architecture
#     if random.random() < 0.5: depth = random.choice([ 6, 8, 10, 12])
#     if random.random() < 0.5: num_heads = random.choice([4, 8, 12, 16])
#     if random.random() < 0.5: mlp_ratio = random.choice([2.0, 4.0, 6.0])
#     print(f"Mutated architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")

In [None]:
import random
import os
import torch
import torch.nn as nn
from torch.optim import Adam
from timm import create_model
import time

# Path to save the models after fine-tuning
# SAVE_PATH = '/SN02DATA/nas_vision/evol_img1k-wts'
SAVE_PATH = '/SN02DATA/nas_vision/evol_img200-wts'

# SAVE_PATH = '/kaggle/working/'

# Set the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")





def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
    pretrained_vit = create_model(pretrained_model_name, pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    # Get embedding dimensions
    pretrained_embed_dim = pretrained_state_dict['pos_embed'].shape[-1]
    current_embed_dim = model.embed_dim
    
    # Create adaptation modules
    adaptation_modules = nn.ModuleDict()
    if pretrained_embed_dim != current_embed_dim:
        adaptation_modules['pos_embed_proj'] = nn.Linear(pretrained_embed_dim, current_embed_dim)
        adaptation_modules['cls_token_proj'] = nn.Linear(pretrained_embed_dim, current_embed_dim)
    
    # Project pretrained weights
    filtered_dict = {}
    for k, v in pretrained_state_dict.items():
        if k == 'pos_embed' and pretrained_embed_dim != current_embed_dim:
            filtered_dict[k] = adaptation_modules['pos_embed_proj'](v)
        elif k == 'cls_token' and pretrained_embed_dim != current_embed_dim:
            filtered_dict[k] = adaptation_modules['cls_token_proj'](v)
        elif k in model.state_dict() and v.shape == model.state_dict()[k].shape:  # FIXED HERE
            filtered_dict[k] = v
    
    model.load_state_dict(filtered_dict, strict=False)
    print(f"Loaded pretrained weights with {'adaptation' if pretrained_embed_dim != current_embed_dim else 'no'} projection")
    
# Check if pretrained weights are loaded correctly
def check_pretrained_weights(model, generation=0, model_type="subnetwork"):
    pretrained_vit = create_model("vit_base_patch16_224", pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    model_state_dict = model.state_dict()
    matching_keys = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
    
    if len(matching_keys) > 0:
        print(f"Generation {generation + 1}: {model_type} model has loaded {len(matching_keys)} layers from pretrained weights.")
    else:
        print(f"Generation {generation + 1}: {model_type} model has NOT loaded any pretrained weights.")

# Sample Subnetwork - Randomly sample hyperparameters (depth, num_heads, etc.)
def sample_subnetwork(seen_architectures):
    while True:
        depth = random.choice([6, 8, 10, 12])
        num_heads = random.choice([4, 8, 12, 16])
        mlp_ratio = random.choice([2.0, 4.0, 6.0])
        # embed_dim = 768  # Fixed embedding dimension
        embed_dim = random.choice([384, 480, 768])  # Variable embedding dimension
        
        architecture = (depth, num_heads, mlp_ratio, embed_dim)
        
        # Skip if architecture has already been sampled
        if architecture not in seen_architectures:
            seen_architectures.add(architecture)
            print(f"Sampled architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
            
            # Create the model to calculate its number of parameters
            # sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000)
            sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
                                        depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, 
                                        num_classes=200
                                    )
            num_params = count_parameters(sampled_model)
            print(f"Number of parameters in the sampled model: {num_params:,}")
            
            return architecture
        else:
            print(f"Repeated architecture found, resampling...")

# Count number of trainable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def topk_accuracy(output, target, topk=(3,5)):
    """Computes the top-k accuracy for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size).item())
    return res  # [top1, top3]



from ptflops import get_model_complexity_info

def get_macs(model):
    with torch.cuda.device(0):
        macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=False, print_per_layer_stat=False)
    return macs

def evaluate_architecture(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    top1_total = 0
    top5_total = 0
    criterion = nn.CrossEntropyLoss()
    start_time = time.time()

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            top1, top5 = topk_accuracy(outputs, labels, topk=(3,5))
            top1_total += top1 * labels.size(0) / 100.0
            top5_total += top5 * labels.size(0) / 100.0
            total += labels.size(0)

    latency = (time.time() - start_time) / total
    accuracy = 100 * top1_total / total
    top3_accuracy = 100 * top5_total / total
    num_params = count_parameters(model)
    memory_usage = (num_params * 4) / (1024 ** 2)
    test_loss = running_loss / len(test_loader)

    print(f"Test Loss: {test_loss:.4f}, Top-1 Acc: {accuracy:.2f}%, Top-3 Acc: {top3_accuracy:.2f}%, Latency: {latency:.6f}s/img, Mem: {memory_usage:.2f}MB")
    macs = get_macs(model)
    print(f"MACs: {macs / 1e6:.2f} M")
    return accuracy, top3_accuracy, test_loss, latency, memory_usage, macs





# Estimate memory usage of a model during inference (rough estimation)
def estimate_memory_usage(model):                                             ## this funtion is not needed
    # Create dummy input matching the expected shape of the input tensor
    dummy_input = torch.randn(1, 3, 224, 224).to(device)  # Example for ViT (3-channel image of size 224x224)
    
    # Use torch.utils.benchmark to measure memory usage during inference
    start_mem = torch.cuda.memory_allocated()
    
    # Run the model once with the dummy input
    with torch.no_grad():
        model(dummy_input)
    
    end_mem = torch.cuda.memory_allocated()
    memory_usage = (end_mem - start_mem) / (1024 ** 2)  # Convert bytes to MB
    return memory_usage


def calculate_crowding_distance(population, test_loader):
    crowding_distances = [0] * len(population)
    num_objectives = 3  # Accuracy, Latency, Memory

    # Evaluate each architecture once, then reuse the results
    evaluated_results = []
    for arch in population:
        # # model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
        #                    depth=arch[0], num_heads=arch[1],
        #                    mlp_ratio=arch[2], num_classes=10).to(device)
        model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
                            depth=arch[0], num_heads=arch[1], mlp_ratio=arch[2], 
                            num_classes=200).to(device)

        accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
        memory = count_parameters(model) * 4  # memory in bytes
        
        evaluated_results.append((accuracy, latency, memory))
        del model
        torch.cuda.empty_cache()

    for objective_index in range(num_objectives):
        sorted_indices = sorted(range(len(population)),
                                key=lambda idx: evaluated_results[idx][objective_index])
        
        crowding_distances[sorted_indices[0]] = crowding_distances[sorted_indices[-1]] = float('inf')

        for i in range(1, len(sorted_indices) - 1):
            prev_value = evaluated_results[sorted_indices[i - 1]][objective_index]
            next_value = evaluated_results[sorted_indices[i + 1]][objective_index]
            distance = next_value - prev_value
            crowding_distances[sorted_indices[i]] += distance

    return crowding_distances


def dominates(model1, model2, test_loader):
    # Evaluate both models on the test set
    accuracy1, _, _,latency1, _, _ = evaluate_architecture(model1, test_loader)
    accuracy2, _, _,latency2, _ , _= evaluate_architecture(model2, test_loader)
    
    # Calculate memory usage as the number of parameters * 4 bytes (FP32)
    memory1 = count_parameters(model1) * 4  # Memory in bytes
    memory2 = count_parameters(model2) * 4  # Memory in bytes
    
    # Compare performance metrics
    dominates_in_accuracy = accuracy1 >= accuracy2
    dominates_in_latency = latency1 <= latency2
    dominates_in_memory = memory1 <= memory2

    # Return True if model1 dominates model2 in all aspects
    return dominates_in_accuracy and dominates_in_latency and dominates_in_memory



#     return depth, num_heads, mlp_ratio, embed_dim
def mutate(architecture):
    depth, num_heads, mlp_ratio, embed_dim = architecture
    mutation_choices = [
        (random.choice([6, 8, 10, 12]), 'depth'),
        (random.choice([4, 8, 12, 16]), 'num_heads'),
        (random.choice([2.0, 4.0, 6.0]), 'mlp_ratio'),
        (random.choice([384, 480, 768]), 'embed_dim')
    ]
    
    # Mutate at least one parameter
    while True:
        for new_val, param in mutation_choices:
            if random.random() < 0.5:
                if param == 'depth': depth = new_val
                elif param == 'num_heads': num_heads = new_val
                elif param == 'mlp_ratio': mlp_ratio = new_val
                elif param == 'embed_dim': embed_dim = new_val
        if (depth, num_heads, mlp_ratio, embed_dim) != architecture:
            break
            
    return (depth, num_heads, mlp_ratio, embed_dim)                            ## check whether tuple is returned or not

# One-Point Crossover: Combine two parent architectures to create new architectures
# def one_point_crossover(parent1, parent2):
#     crossover_point = random.choice([0, 1, 2, 3])  # Crossover at depth, num_heads, etc.
#     child1 = parent1[:crossover_point] + parent2[crossover_point:]
#     child2 = parent2[:crossover_point] + parent1[crossover_point:]
#     print(f"Crossover result: Child1={child1}, Child2={child2}")
#     return child1, child2
def one_point_crossover(parent1, parent2):
    crossover_point = random.randint(0, 3)
    child1 = parent1[:crossover_point] + parent2[crossover_point:]
    child2 = parent2[:crossover_point] + parent1[crossover_point:]
    return child1, child2


############################# this is not weight based instead it is pareto selection
# Optimized Pareto selection based on stored performance metrics
def pareto_selection(arch_performance):
    def dominates(perf1, perf2):
        acc1, lat1, mem1 = perf1
        acc2, lat2, mem2 = perf2
        return (acc1 >= acc2 and lat1 <= lat2 and mem1 <= mem2) and (acc1 > acc2 or lat1 < lat2 or mem1 < mem2)

    ranks = {}
    for arch1, perf1 in arch_performance.items():
        dominated_count = 0
        for arch2, perf2 in arch_performance.items():
            if arch1 != arch2 and dominates(perf2, perf1):
                dominated_count += 1
        ranks[arch1] = dominated_count

    # Sort architectures by rank (lower dominated_count = better)
    sorted_population = sorted(ranks.keys(), key=lambda arch: ranks[arch])
    return sorted_population



# def fine_tune_model(sampled_model, train_loader, test_loader, epochs=3, architecture_folder=None):
#     print(f"Fine-tuning model with architecture: Depth={sampled_model.depth}, Num Heads={sampled_model.num_heads}, MLP Ratio={sampled_model.mlp_ratio}")
#     sampled_model.to(device)
#     criterion = nn.CrossEntropyLoss()
#     optimizer = Adam(sampled_model.parameters(), lr=1e-4)
    
#     for epoch in range(epochs):
#         start_epoch = time.time()
#         sampled_model.train()
#         running_loss = 0.0
#         for images, labels in train_loader:
#             images, labels = images.to(device), labels.to(device)
#             optimizer.zero_grad()
#             outputs = sampled_model(images)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#             running_loss += loss.item()
#         epoch_time = time.time() - start_epoch
#         test_accuracy, test_top5, test_loss, test_latency, memory_usage = evaluate_architecture(sampled_model, test_loader)
#         print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss:.4f}, Top-1 Acc: {test_accuracy:.2f}%, Top-5 Acc: {test_top5:.2f}%, Latency: {test_latency:.6f}s/img, Time: {epoch_time:.2f}s")
#     # Save model code unchanged
#     if architecture_folder:
#         os.makedirs(architecture_folder, exist_ok=True)
#         torch.save(sampled_model.state_dict(), os.path.join(architecture_folder, 'checkpoint.pth'))
#     return sampled_model

def fine_tune_model(sampled_model, train_loader, test_loader, epochs=3, architecture_folder=None):
    print(f"Fine-tuning model with architecture: Depth={sampled_model.depth}, Num Heads={sampled_model.num_heads}, MLP Ratio={sampled_model.mlp_ratio}")
    sampled_model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(sampled_model.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        start_epoch = time.time()
        sampled_model.train()
        running_loss = 0.0
        
        # Training phase
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = sampled_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        # Evaluation phase
        epoch_time = time.time() - start_epoch
        test_accuracy, test_top5, test_loss, test_latency, memory_usage, macs = evaluate_architecture(sampled_model, test_loader)
        
        # Print epoch statistics
        print(f"\nEpoch {epoch + 1}/{epochs} Summary:")
        print(f"| Training Loss: {running_loss/len(train_loader):.4f}")
        print(f"| Test Loss: {test_loss:.4f}")
        print(f"| Top-1 Accuracy: {test_accuracy:.2f}%")
        print(f"| Top-3 Accuracy: {test_top5:.2f}%")
        print(f"| Latency: {test_latency:.6f}s/img")
        print(f"| Memory Usage: {memory_usage:.2f}MB")
        print(f"| MACs: {macs/1e6:.2f}M")
        print(f"| Epoch Time: {epoch_time:.2f}s\n")

    # Save model weights
    if architecture_folder:
        os.makedirs(architecture_folder, exist_ok=True)
        torch.save(sampled_model.state_dict(), os.path.join(architecture_folder, 'checkpoint.pth'))
    return sampled_model



def save_top_ranked_models(population, arch_performance, generation):
    top_n = min(5, len(population))                                             ## how top n is taken ??????????
    for idx, arch in enumerate(population[:top_n]):
        depth, num_heads, mlp_ratio, embed_dim = arch
        # model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
        #                    num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000).to(device)
        model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
                            num_heads=num_heads, mlp_ratio=mlp_ratio, 
                            num_classes=200).to(device)


        architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
        checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
        model.load_state_dict(torch.load(checkpoint_path))

        top_model_path = os.path.join(SAVE_PATH, f'top_ranked_model_gen{generation+1}_rank_{idx+1}.pth')
        torch.save(model.state_dict(), top_model_path)
        
        acc, lat, mem = arch_performance[arch]

        # with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
        #     f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
        #     f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")
        with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
            f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
            f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")


        print(f"Saved top-ranked model: Generation {generation+1}, Rank {idx+1} (Acc={acc:.2f}%, Lat={lat:.6f}, Mem={mem/(1024**2):.2f}MB)")
        
        


def plot_pareto_front(arch_performance):
    accuracies = [v[0] for v in arch_performance.values()]
    latencies = [v[1] for v in arch_performance.values()]
    memories = [v[2] / (1024**2) for v in arch_performance.values()]  # convert to MB

    # Accuracy vs Latency
    plt.figure(figsize=(8,6))
    plt.scatter(latencies, accuracies, c='blue')
    plt.xlabel('Latency (s/image)')
    plt.ylabel('Accuracy (%)')
    plt.title('Pareto Front (Accuracy vs Latency)')
    plt.grid()
    plt.show()

    # Accuracy vs Memory
    plt.figure(figsize=(8,6))
    plt.scatter(memories, accuracies, c='green')
    plt.xlabel('Memory (MB)')
    plt.ylabel('Accuracy (%)')
    plt.title('Pareto Front (Accuracy vs Memory)')
    plt.grid()
    plt.show()

#

def evolutionary_algorithm(population_size=16, generations=5, mutation_rate=0.1, crossover_rate=0.7, train_loader=None, test_loader=None):
    seen_architectures = set()
    population = [sample_subnetwork(seen_architectures) for _ in range(population_size)]
    arch_performance = {}

    prev_best_accuracy = 0
    no_improvement_count = 0

    for generation in range(generations):
        print(f"\n--- Generation {generation + 1}/{generations} ---")

        for arch in population:
            depth, num_heads, mlp_ratio, embed_dim = arch
            architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
            checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')

            model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
                               depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
                               num_classes=200).to(device)

            # Clearly load weights once per architecture
            if os.path.exists(checkpoint_path):
                model.load_state_dict(torch.load(checkpoint_path))
                print(f"Loaded weights from previous generation for architecture {arch}")
            else:
                load_pretrained_weights(model)

            fine_tune_model(model, train_loader, test_loader, epochs=5, architecture_folder=architecture_folder)

            # accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
            accuracy, top5_accuracy, test_loss, latency, memory_usage, macs = evaluate_architecture(model, test_loader)
            memory = count_parameters(model) * 4 / (1024 ** 2)  # MB
            # arch_performance[arch] = (accuracy, latency, memory)
            arch_performance[arch] = (accuracy, top5_accuracy, latency, memory_usage, macs)

            del model
            torch.cuda.empty_cache()

        # Pareto selection
        population = pareto_selection(arch_performance)

        print("\nTop 5 Ranked Models of Generation", generation+1)
        for idx, arch in enumerate(population[:5]):
            acc, top5_acc, lat, mem, macs = arch_performance[arch]
            print(f"Rank {idx+1}: Model {arch} | Top-1 Acc: {acc:.2f}%, Top-3 Acc: {top5_acc:.2f}%, Latency: {lat:.6f}s/img, Mem: {mem:.2f}MB, MACs: {macs/1e6:.2f}M")
            # Saving top-ranked models
            save_top_ranked_models(population, arch_performance, generation)

        # Check for Pareto front convergence (early stopping criteria)
        current_best_accuracy = arch_performance[population[0]][0]
        if current_best_accuracy - prev_best_accuracy < 1.0:
            no_improvement_count += 1
            print(f"Minimal improvement detected: {current_best_accuracy - prev_best_accuracy:.2f}%")
            if no_improvement_count >= 2:
                print("Pareto front has converged. Stopping early.")
                break
        else:
            no_improvement_count = 0
        prev_best_accuracy = current_best_accuracy

        # Generate offspring
        next_population = population[:len(population)//2]  # Only top half
        offspring = []

        for i in range(0, len(next_population)-1, 2):
            parent1, parent2 = next_population[i], next_population[i+1]

            if random.random() < crossover_rate:
                child1, child2 = one_point_crossover(parent1, parent2)
                print(f"Crossover parents: {parent1} & {parent2}")
                offspring.extend([child1, child2])
            else:
                offspring.extend([parent1, parent2])

        # Mutation with clear logging
        mutated_offspring = []
        for child in offspring:
            if random.random() < mutation_rate:
                original_child = child
                child = mutate(child)
                print(f"Mutated from {original_child} to {child}")
            mutated_offspring.append(child)

        population = next_population + mutated_offspring

        print(f"\nAfter mutation and crossover, {len(mutated_offspring)} offspring models generated.")
        print("Only top 5 models will be used for the next generation.")

    # Plot Pareto Front at the end
    plot_pareto_front(arch_performance)

    return population

# Run the evolutionary algorithm
evolutionary_algorithm(population_size=16, generations=5, train_loader=train_loader, test_loader=test_loader)



Sampled architecture: Depth=6, Num Heads=16, MLP Ratio=4.0, Embed Dim=768
Number of parameters in the sampled model: 43,425,224
Sampled architecture: Depth=10, Num Heads=12, MLP Ratio=6.0, Embed Dim=480
Number of parameters in the sampled model: 37,497,320
Repeated architecture found, resampling...
Sampled architecture: Depth=12, Num Heads=4, MLP Ratio=4.0, Embed Dim=384
Number of parameters in the sampled model: 21,742,664
Sampled architecture: Depth=6, Num Heads=4, MLP Ratio=4.0, Embed Dim=384
Number of parameters in the sampled model: 11,095,880
Sampled architecture: Depth=12, Num Heads=8, MLP Ratio=6.0, Embed Dim=480
Number of parameters in the sampled model: 44,884,520
Sampled architecture: Depth=10, Num Heads=8, MLP Ratio=2.0, Embed Dim=480
Number of parameters in the sampled model: 19,046,120
Sampled architecture: Depth=6, Num Heads=16, MLP Ratio=6.0, Embed Dim=384
Number of parameters in the sampled model: 14,639,432
Sampled architecture: Depth=12, Num Heads=8, MLP Ratio=2.0, E

RuntimeError: [enforce fail at inline_container.cc:626] . unexpected pos 68828736 vs 68828624

## evolutionary

In [None]:
import random
import os
import torch
import torch.nn as nn
from torch.optim import Adam
from timm import create_model
import time

# Path to save the models after fine-tuning
# SAVE_PATH = '/SN02DATA/nas_vision/evol_img1k-wts'
SAVE_PATH = '/SN02DATA/nas_vision/evol_img200-wts'

# SAVE_PATH = '/kaggle/working/'

# Set the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")





def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
    pretrained_vit = create_model(pretrained_model_name, pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    # Get embedding dimensions
    pretrained_embed_dim = pretrained_state_dict['pos_embed'].shape[-1]
    current_embed_dim = model.embed_dim
    
    # Create adaptation modules
    adaptation_modules = nn.ModuleDict()
    if pretrained_embed_dim != current_embed_dim:
        adaptation_modules['pos_embed_proj'] = nn.Linear(pretrained_embed_dim, current_embed_dim)
        adaptation_modules['cls_token_proj'] = nn.Linear(pretrained_embed_dim, current_embed_dim)
    
    # Project pretrained weights
    filtered_dict = {}
    for k, v in pretrained_state_dict.items():
        if k == 'pos_embed' and pretrained_embed_dim != current_embed_dim:
            filtered_dict[k] = adaptation_modules['pos_embed_proj'](v)
        elif k == 'cls_token' and pretrained_embed_dim != current_embed_dim:
            filtered_dict[k] = adaptation_modules['cls_token_proj'](v)
        elif k in model.state_dict() and v.shape == model.state_dict()[k].shape:  # FIXED HERE
            filtered_dict[k] = v
    
    model.load_state_dict(filtered_dict, strict=False)
    print(f"Loaded pretrained weights with {'adaptation' if pretrained_embed_dim != current_embed_dim else 'no'} projection")
    
# Check if pretrained weights are loaded correctly
def check_pretrained_weights(model, generation=0, model_type="subnetwork"):
    pretrained_vit = create_model("vit_base_patch16_224", pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    model_state_dict = model.state_dict()
    matching_keys = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
    
    if len(matching_keys) > 0:
        print(f"Generation {generation + 1}: {model_type} model has loaded {len(matching_keys)} layers from pretrained weights.")
    else:
        print(f"Generation {generation + 1}: {model_type} model has NOT loaded any pretrained weights.")

# Sample Subnetwork - Randomly sample hyperparameters (depth, num_heads, etc.)
def sample_subnetwork(seen_architectures):
    while True:
        depth = random.choice([6, 8, 10, 12])
        num_heads = random.choice([4, 8, 12, 16])
        mlp_ratio = random.choice([2.0, 4.0, 6.0])
        # embed_dim = 768  # Fixed embedding dimension
        embed_dim = random.choice([384, 480, 768])  # Variable embedding dimension
        
        architecture = (depth, num_heads, mlp_ratio, embed_dim)
        
        # Skip if architecture has already been sampled
        if architecture not in seen_architectures:
            seen_architectures.add(architecture)
            print(f"Sampled architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
            
            # Create the model to calculate its number of parameters
            # sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000)
            sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
                                        depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, 
                                        num_classes=200
                                    )
            num_params = count_parameters(sampled_model)
            print(f"Number of parameters in the sampled model: {num_params:,}")
            
            return architecture
        else:
            print(f"Repeated architecture found, resampling...")

# Count number of trainable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def topk_accuracy(output, target, topk=(3,5)):
    """Computes the top-k accuracy for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size).item())
    return res  # [top1, top3]



from ptflops import get_model_complexity_info

def get_macs(model):
    with torch.cuda.device(0):
        macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=False, print_per_layer_stat=False)
    return macs

def evaluate_architecture(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    top1_total = 0
    top5_total = 0
    criterion = nn.CrossEntropyLoss()
    start_time = time.time()

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            top1, top5 = topk_accuracy(outputs, labels, topk=(3,5))
            top1_total += top1 * labels.size(0) / 100.0
            top5_total += top5 * labels.size(0) / 100.0
            total += labels.size(0)

    latency = (time.time() - start_time) / total
    accuracy = 100 * top1_total / total
    top3_accuracy = 100 * top5_total / total
    num_params = count_parameters(model)
    memory_usage = (num_params * 4) / (1024 ** 2)
    test_loss = running_loss / len(test_loader)

    print(f"Test Loss: {test_loss:.4f}, Top-1 Acc: {accuracy:.2f}%, Top-3 Acc: {top3_accuracy:.2f}%, Latency: {latency:.6f}s/img, Mem: {memory_usage:.2f}MB")
    macs = get_macs(model)
    print(f"MACs: {macs / 1e6:.2f} M")
    return accuracy, top3_accuracy, test_loss, latency, memory_usage, macs





# Estimate memory usage of a model during inference (rough estimation)
def estimate_memory_usage(model):                                             ## this funtion is not needed
    # Create dummy input matching the expected shape of the input tensor
    dummy_input = torch.randn(1, 3, 224, 224).to(device)  # Example for ViT (3-channel image of size 224x224)
    
    # Use torch.utils.benchmark to measure memory usage during inference
    start_mem = torch.cuda.memory_allocated()
    
    # Run the model once with the dummy input
    with torch.no_grad():
        model(dummy_input)
    
    end_mem = torch.cuda.memory_allocated()
    memory_usage = (end_mem - start_mem) / (1024 ** 2)  # Convert bytes to MB
    return memory_usage


def calculate_crowding_distance(population, test_loader):
    crowding_distances = [0] * len(population)
    num_objectives = 3  # Accuracy, Latency, Memory

    # Evaluate each architecture once, then reuse the results
    evaluated_results = []
    for arch in population:
        # # model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
        #                    depth=arch[0], num_heads=arch[1],
        #                    mlp_ratio=arch[2], num_classes=10).to(device)
        model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
                            depth=arch[0], num_heads=arch[1], mlp_ratio=arch[2], 
                            num_classes=200).to(device)

        accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
        memory = count_parameters(model) * 4  # memory in bytes
        
        evaluated_results.append((accuracy, latency, memory))
        del model
        torch.cuda.empty_cache()

    for objective_index in range(num_objectives):
        sorted_indices = sorted(range(len(population)),
                                key=lambda idx: evaluated_results[idx][objective_index])
        
        crowding_distances[sorted_indices[0]] = crowding_distances[sorted_indices[-1]] = float('inf')

        for i in range(1, len(sorted_indices) - 1):
            prev_value = evaluated_results[sorted_indices[i - 1]][objective_index]
            next_value = evaluated_results[sorted_indices[i + 1]][objective_index]
            distance = next_value - prev_value
            crowding_distances[sorted_indices[i]] += distance

    return crowding_distances


def dominates(model1, model2, test_loader):
    # Evaluate both models on the test set
    accuracy1, _, _,latency1, _, _ = evaluate_architecture(model1, test_loader)
    accuracy2, _, _,latency2, _ , _= evaluate_architecture(model2, test_loader)
    
    # Calculate memory usage as the number of parameters * 4 bytes (FP32)
    memory1 = count_parameters(model1) * 4  # Memory in bytes
    memory2 = count_parameters(model2) * 4  # Memory in bytes
    
    # Compare performance metrics
    dominates_in_accuracy = accuracy1 >= accuracy2
    dominates_in_latency = latency1 <= latency2
    dominates_in_memory = memory1 <= memory2

    # Return True if model1 dominates model2 in all aspects
    return dominates_in_accuracy and dominates_in_latency and dominates_in_memory



#     return depth, num_heads, mlp_ratio, embed_dim
def mutate(architecture):
    depth, num_heads, mlp_ratio, embed_dim = architecture
    mutation_choices = [
        (random.choice([6, 8, 10, 12]), 'depth'),
        (random.choice([4, 8, 12, 16]), 'num_heads'),
        (random.choice([2.0, 4.0, 6.0]), 'mlp_ratio'),
        (random.choice([384, 480, 768]), 'embed_dim')
    ]
    
    # Mutate at least one parameter
    while True:
        for new_val, param in mutation_choices:
            if random.random() < 0.5:
                if param == 'depth': depth = new_val
                elif param == 'num_heads': num_heads = new_val
                elif param == 'mlp_ratio': mlp_ratio = new_val
                elif param == 'embed_dim': embed_dim = new_val
        if (depth, num_heads, mlp_ratio, embed_dim) != architecture:
            break
            
    return (depth, num_heads, mlp_ratio, embed_dim)                            ## check whether tuple is returned or not

# One-Point Crossover: Combine two parent architectures to create new architectures
# def one_point_crossover(parent1, parent2):
#     crossover_point = random.choice([0, 1, 2, 3])  # Crossover at depth, num_heads, etc.
#     child1 = parent1[:crossover_point] + parent2[crossover_point:]
#     child2 = parent2[:crossover_point] + parent1[crossover_point:]
#     print(f"Crossover result: Child1={child1}, Child2={child2}")
#     return child1, child2
def one_point_crossover(parent1, parent2):
    crossover_point = random.randint(0, 3)
    child1 = parent1[:crossover_point] + parent2[crossover_point:]
    child2 = parent2[:crossover_point] + parent1[crossover_point:]
    return child1, child2


############################# this is not weight based instead it is pareto selection
# Optimized Pareto selection based on stored performance metrics
def pareto_selection(arch_performance):
    def dominates(perf1, perf2):
        acc1, lat1, mem1 = perf1
        acc2, lat2, mem2 = perf2
        return (acc1 >= acc2 and lat1 <= lat2 and mem1 <= mem2) and (acc1 > acc2 or lat1 < lat2 or mem1 < mem2)

    ranks = {}
    for arch1, perf1 in arch_performance.items():
        dominated_count = 0
        for arch2, perf2 in arch_performance.items():
            if arch1 != arch2 and dominates(perf2, perf1):
                dominated_count += 1
        ranks[arch1] = dominated_count

    # Sort architectures by rank (lower dominated_count = better)
    sorted_population = sorted(ranks.keys(), key=lambda arch: ranks[arch])
    return sorted_population



# def fine_tune_model(sampled_model, train_loader, test_loader, epochs=3, architecture_folder=None):
#     print(f"Fine-tuning model with architecture: Depth={sampled_model.depth}, Num Heads={sampled_model.num_heads}, MLP Ratio={sampled_model.mlp_ratio}")
#     sampled_model.to(device)
#     criterion = nn.CrossEntropyLoss()
#     optimizer = Adam(sampled_model.parameters(), lr=1e-4)
    
#     for epoch in range(epochs):
#         start_epoch = time.time()
#         sampled_model.train()
#         running_loss = 0.0
#         for images, labels in train_loader:
#             images, labels = images.to(device), labels.to(device)
#             optimizer.zero_grad()
#             outputs = sampled_model(images)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#             running_loss += loss.item()
#         epoch_time = time.time() - start_epoch
#         test_accuracy, test_top5, test_loss, test_latency, memory_usage = evaluate_architecture(sampled_model, test_loader)
#         print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss:.4f}, Top-1 Acc: {test_accuracy:.2f}%, Top-5 Acc: {test_top5:.2f}%, Latency: {test_latency:.6f}s/img, Time: {epoch_time:.2f}s")
#     # Save model code unchanged
#     if architecture_folder:
#         os.makedirs(architecture_folder, exist_ok=True)
#         torch.save(sampled_model.state_dict(), os.path.join(architecture_folder, 'checkpoint.pth'))
#     return sampled_model

def fine_tune_model(sampled_model, train_loader, test_loader, epochs=3, architecture_folder=None):
    print(f"Fine-tuning model with architecture: Depth={sampled_model.depth}, Num Heads={sampled_model.num_heads}, MLP Ratio={sampled_model.mlp_ratio}")
    sampled_model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(sampled_model.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        start_epoch = time.time()
        sampled_model.train()
        running_loss = 0.0
        
        # Training phase
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = sampled_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        # Evaluation phase
        epoch_time = time.time() - start_epoch
        test_accuracy, test_top5, test_loss, test_latency, memory_usage, macs = evaluate_architecture(sampled_model, test_loader)
        
        # Print epoch statistics
        print(f"\nEpoch {epoch + 1}/{epochs} Summary:")
        print(f"| Training Loss: {running_loss/len(train_loader):.4f}")
        print(f"| Test Loss: {test_loss:.4f}")
        print(f"| Top-1 Accuracy: {test_accuracy:.2f}%")
        print(f"| Top-3 Accuracy: {test_top5:.2f}%")
        print(f"| Latency: {test_latency:.6f}s/img")
        print(f"| Memory Usage: {memory_usage:.2f}MB")
        print(f"| MACs: {macs/1e6:.2f}M")
        print(f"| Epoch Time: {epoch_time:.2f}s\n")

    # Save model weights
    if architecture_folder:
        os.makedirs(architecture_folder, exist_ok=True)
        torch.save(sampled_model.state_dict(), os.path.join(architecture_folder, 'checkpoint.pth'))
    return sampled_model



def save_top_ranked_models(population, arch_performance, generation):
    top_n = min(5, len(population))                                             ## how top n is taken ??????????
    for idx, arch in enumerate(population[:top_n]):
        depth, num_heads, mlp_ratio, embed_dim = arch
        # model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
        #                    num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000).to(device)
        model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
                            num_heads=num_heads, mlp_ratio=mlp_ratio, 
                            num_classes=200).to(device)


        architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
        checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
        model.load_state_dict(torch.load(checkpoint_path))

        top_model_path = os.path.join(SAVE_PATH, f'top_ranked_model_gen{generation+1}_rank_{idx+1}.pth')
        torch.save(model.state_dict(), top_model_path)
        
        acc, lat, mem = arch_performance[arch]

        # with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
        #     f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
        #     f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")
        with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
            f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
            f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")


        print(f"Saved top-ranked model: Generation {generation+1}, Rank {idx+1} (Acc={acc:.2f}%, Lat={lat:.6f}, Mem={mem/(1024**2):.2f}MB)")
        
        


def plot_pareto_front(arch_performance):
    accuracies = [v[0] for v in arch_performance.values()]
    latencies = [v[1] for v in arch_performance.values()]
    memories = [v[2] / (1024**2) for v in arch_performance.values()]  # convert to MB

    # Accuracy vs Latency
    plt.figure(figsize=(8,6))
    plt.scatter(latencies, accuracies, c='blue')
    plt.xlabel('Latency (s/image)')
    plt.ylabel('Accuracy (%)')
    plt.title('Pareto Front (Accuracy vs Latency)')
    plt.grid()
    plt.show()

    # Accuracy vs Memory
    plt.figure(figsize=(8,6))
    plt.scatter(memories, accuracies, c='green')
    plt.xlabel('Memory (MB)')
    plt.ylabel('Accuracy (%)')
    plt.title('Pareto Front (Accuracy vs Memory)')
    plt.grid()
    plt.show()

#

def evolutionary_algorithm(population_size=16, generations=5, mutation_rate=0.1, crossover_rate=0.7, train_loader=None, test_loader=None):
    seen_architectures = set()
    population = [sample_subnetwork(seen_architectures) for _ in range(population_size)]
    arch_performance = {}

    prev_best_accuracy = 0
    no_improvement_count = 0

    for generation in range(generations):
        print(f"\n--- Generation {generation + 1}/{generations} ---")

        # for arch in population:
        #     depth, num_heads, mlp_ratio, embed_dim = arch
        #     architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
        #     checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')

        #     model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
        #                        depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
        #                        num_classes=200).to(device)

        for arch in population:
            depth, num_heads, mlp_ratio, embed_dim = arch
            architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
            checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')

            model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
                               depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
                               num_classes=200).to(device)

            # Determine fine-tuning epochs based on embedding dimension
            fine_tune_epochs = 16 if embed_dim != 768 else 3  # 768 is pretrained model's embed_dim
            
            # Clearly load weights once per architecture
            if os.path.exists(checkpoint_path):
                model.load_state_dict(torch.load(checkpoint_path))
                print(f"Loaded weights from previous generation for architecture {arch}")
            else:
                load_pretrained_weights(model)

            # fine_tune_model(model, train_loader, test_loader, epochs=5, architecture_folder=architecture_folder)
            fine_tune_model(model, train_loader, test_loader, epochs=fine_tune_epochs, 
                           architecture_folder=architecture_folder)

            # accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
            accuracy, top5_accuracy, test_loss, latency, memory_usage, macs = evaluate_architecture(model, test_loader)
            memory = count_parameters(model) * 4 / (1024 ** 2)  # MB
            # arch_performance[arch] = (accuracy, latency, memory)
            arch_performance[arch] = (accuracy, top5_accuracy, latency, memory_usage, macs)

            del model
            torch.cuda.empty_cache()

        # Pareto selection
        population = pareto_selection(arch_performance)

        print("\nTop 5 Ranked Models of Generation", generation+1)
        for idx, arch in enumerate(population[:5]):
            acc, top5_acc, lat, mem, macs = arch_performance[arch]
            print(f"Rank {idx+1}: Model {arch} | Top-1 Acc: {acc:.2f}%, Top-3 Acc: {top5_acc:.2f}%, Latency: {lat:.6f}s/img, Mem: {mem:.2f}MB, MACs: {macs/1e6:.2f}M")
            # Saving top-ranked models
            save_top_ranked_models(population, arch_performance, generation)

        # Check for Pareto front convergence (early stopping criteria)
        current_best_accuracy = arch_performance[population[0]][0]
        if current_best_accuracy - prev_best_accuracy < 1.0:
            no_improvement_count += 1
            print(f"Minimal improvement detected: {current_best_accuracy - prev_best_accuracy:.2f}%")
            if no_improvement_count >= 2:
                print("Pareto front has converged. Stopping early.")
                break
        else:
            no_improvement_count = 0
        prev_best_accuracy = current_best_accuracy

        # Generate offspring
        next_population = population[:len(population)//2]  # Only top half
        offspring = []

        for i in range(0, len(next_population)-1, 2):
            parent1, parent2 = next_population[i], next_population[i+1]

            if random.random() < crossover_rate:
                child1, child2 = one_point_crossover(parent1, parent2)
                print(f"Crossover parents: {parent1} & {parent2}")
                offspring.extend([child1, child2])
            else:
                offspring.extend([parent1, parent2])

        # Mutation with clear logging
        mutated_offspring = []
        for child in offspring:
            if random.random() < mutation_rate:
                original_child = child
                child = mutate(child)
                print(f"Mutated from {original_child} to {child}")
            mutated_offspring.append(child)

        population = next_population + mutated_offspring

        print(f"\nAfter mutation and crossover, {len(mutated_offspring)} offspring models generated.")
        print("Only top 5 models will be used for the next generation.")

    # Plot Pareto Front at the end
    plot_pareto_front(arch_performance)

    return population

# Run the evolutionary algorithm
evolutionary_algorithm(population_size=16, generations=5, train_loader=train_loader, test_loader=test_loader)



Sampled architecture: Depth=6, Num Heads=12, MLP Ratio=6.0, Embed Dim=768
Number of parameters in the sampled model: 57,590,216
Sampled architecture: Depth=10, Num Heads=16, MLP Ratio=4.0, Embed Dim=768
Number of parameters in the sampled model: 71,776,712
Sampled architecture: Depth=6, Num Heads=8, MLP Ratio=2.0, Embed Dim=768
Number of parameters in the sampled model: 29,260,232
Sampled architecture: Depth=6, Num Heads=12, MLP Ratio=4.0, Embed Dim=384
Number of parameters in the sampled model: 11,095,880
Sampled architecture: Depth=8, Num Heads=8, MLP Ratio=4.0, Embed Dim=384
Number of parameters in the sampled model: 14,644,808
Sampled architecture: Depth=6, Num Heads=8, MLP Ratio=6.0, Embed Dim=384
Number of parameters in the sampled model: 14,639,432
Sampled architecture: Depth=8, Num Heads=16, MLP Ratio=4.0, Embed Dim=768
Number of parameters in the sampled model: 57,600,968
Sampled architecture: Depth=12, Num Heads=12, MLP Ratio=2.0, Embed Dim=768
Number of parameters in the sam

## knowledge distillation

## pca

## pca

Step 1: Extract Patch Embeddings from Supernetwork
Before applying PCA, pass a representative dataset through the DynamicPatchEmbed layer to get the high-dimensional (768-dim) embeddings.

In [12]:
from torch.utils.data import DataLoader
from tqdm import tqdm

def extract_patch_embeddings(model, dataloader, num_samples=1000):
    model.eval()
    embeddings = []
    with torch.no_grad():
        for i, (images, _) in enumerate(tqdm(dataloader)):
            if i * images.size(0) >= num_samples:
                break
            images = images.to(device)
            x = model.patch_embed(images)  # Shape: (B, num_patches, 768)
            x = x.flatten(2).transpose(1, 2)  # (B, num_patches, 768)
            embeddings.append(x.reshape(-1, x.shape[-1]))  # (B*num_patches, 768)

    return torch.cat(embeddings, dim=0).cpu()  # (N, 768)



In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# First-time loading pretrained weights for initialization
def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
    pretrained_vit = create_model(pretrained_model_name, pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    # Match keys between pretrained and current model
    model_state_dict = model.state_dict()
    filtered_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}

    # Load pretrained weights
    model.load_state_dict(filtered_dict, strict=False)
    print(f"Pretrained weights loaded into {model.__class__.__name__} successfully.")

In [14]:
from sklearn.decomposition import PCA
import numpy as np

def fit_pca(embeddings, target_dim=384):
    pca = PCA(n_components=target_dim)
    pca.fit(embeddings.numpy())
    return pca



In [None]:
# def apply_pca_to_patch_embed(model, pca):
#     # Original weights shape: (embed_dim, in_channels, kernel_size, kernel_size)
#     old_proj_weight = model.patch_embed.proj.weight.data  # (768, 3, 16, 16)
#     old_proj_bias = model.patch_embed.proj.bias.data      # (768,)

#     # Reshape to (768, -1) for PCA
#     weight_2d = old_proj_weight.view(768, -1).cpu().numpy()
#     bias_1d = old_proj_bias.cpu().numpy()

#     # Apply PCA transform
#     new_weight = pca.transform(weight_2d)  # (768, 384)
#     new_weight = torch.tensor(new_weight).view(384, 3, 16, 16)
#     new_bias = torch.tensor(pca.transform(bias_1d.reshape(1, -1))[0])  # (384,)

#     # Update model patch embedding to 384-dim
#     model.patch_embed.proj = nn.Conv2d(3, 384, kernel_size=16, stride=16)
#     model.patch_embed.num_patches = (model.patch_embed.proj.kernel_size[0] ** 2)

#     # Assign new weights
#     model.patch_embed.proj.weight.data.copy_(new_weight)
#     model.patch_embed.proj.bias.data.copy_(new_bias)


In [15]:
def transform_patch_embed_weights(model, pca):
    # Original weights
    old_weight = model.patch_embed.proj.weight.data.view(768, -1).cpu().numpy()  # (768, 3*16*16)
    old_bias = model.patch_embed.proj.bias.data.cpu().numpy()  # (768,)

    # PCA transform
    new_weight = pca.transform(old_weight)  # (768, 384)
    new_weight = torch.tensor(new_weight).view(384, 3, 16, 16)

    # Bias transform
    new_bias = pca.transform(old_bias.reshape(1, -1))[0]  # (384,)
    new_bias = torch.tensor(new_bias)

    return new_weight, new_bias


In [None]:
def create_vit_384(new_weight, new_bias, num_classes=200):
    model_384 = DynamicViT(
        img_size=224,
        patch_size=16,
        embed_dim=384,
        depth=12,
        num_heads=6,  # Must divide 384
        mlp_ratio=4.0,
        num_classes=num_classes
    ).to(device)

    # Replace patch_embed weights
    model_384.patch_embed.proj.weight.data.copy_(new_weight)
    model_384.patch_embed.proj.bias.data.copy_(new_bias)

    return model_384


In [17]:
def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            total_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    avg_loss = total_loss / total
    accuracy = 100. * correct / total
    print(f"Evaluation - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
    return avg_loss, accuracy


In [21]:
from timm import create_model

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Load pretrained weights
load_pretrained_weights(model)

# Extract patch embeddings from multiple images
# Resulting shape: (num_images, num_patches=196, embed_dim=768)
embeddings = extract_patch_embeddings(model, train_loader, num_samples=1000)  # (1000, 196, 768)

# Reshape to (num_images * num_patches, 768)
embeddings = embeddings.reshape(-1, embeddings.shape[-1])  # (1000 * 196, 768)

# Fit PCA to reduce 768 ‚Üí 128
pca = PCA(n_components=384)
pca.fit(embeddings.numpy())


Pretrained weights loaded into DynamicViT successfully.


  1%|          | 16/1556 [00:03<06:12,  4.14it/s]


In [29]:
# Load pretrained weights
# load_pretrained_weights(model)

# # Extract embeddings
# embeddings = extract_patch_embeddings(model, train_loader, num_samples=1000)


# # Fit PCA
# pca = fit_pca(embeddings, target_dim=128)

# Transform weights
new_weight, new_bias = transform_patch_embed_weights(model, pca)

# Create new model
model_384 = create_vit_384(new_weight, new_bias)

# Evaluate
criterion = nn.CrossEntropyLoss()
evaluate_model(model_384, test_loader, criterion, device)


ValueError: X has 768 features, but PCA is expecting 196 features as input.

In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import DataLoader
# from torchvision import datasets, transforms
# from tqdm import tqdm
# from sklearn.decomposition import PCA
# import numpy as np
# import time
# import random
# from timm import create_model  # For loading pretrained ViT

# # Set device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # === Dataset ===
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
# ])
# data_dir = '/home/pratibha/nas_vision/vit_nas_imgnet/imagenet200'
# filtered_dataset = datasets.ImageFolder(root=data_dir, transform=transform)

# train_size = int(0.8 * len(filtered_dataset))
# test_size = len(filtered_dataset) - train_size
# train_dataset, test_dataset = torch.utils.data.random_split(filtered_dataset, [train_size, test_size])

# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# print(f"Training set size: {len(train_dataset)}")
# print(f"Test set size: {len(test_dataset)}")

# # === ViT Components ===
# class DynamicPatchEmbed(nn.Module):
#     def __init__(self, img_size=224, patch_size=16, embed_dim=768):
#         super().__init__()
#         self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
#         self.num_patches = (img_size // patch_size) ** 2

#     def forward(self, x):
#         x = self.proj(x)
#         return x.flatten(2).transpose(1, 2)

# class DynamicMultiHeadAttention(nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super().__init__()
#         self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
#         self.proj = nn.Linear(embed_dim, embed_dim)
#         self.scale = (embed_dim // num_heads) ** -0.5
#         self.num_heads = num_heads
#         assert embed_dim % num_heads == 0

#     def forward(self, x):
#         B, N, C = x.shape
#         qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
#         q, k, v = qkv[0], qkv[1], qkv[2]
#         attn = (q @ k.transpose(-2, -1)) * self.scale
#         attn = attn.softmax(dim=-1)
#         x = (attn @ v).transpose(1, 2).reshape(B, N, C)
#         return self.proj(x)

# class MLPBlock(nn.Module):  
#     def __init__(self, embed_dim, mlp_ratio):
#         super().__init__()
#         hidden_dim = int(embed_dim * mlp_ratio)
#         self.fc1 = nn.Linear(embed_dim, hidden_dim)
#         self.act = nn.GELU()
#         self.fc2 = nn.Linear(hidden_dim, embed_dim)

#     def forward(self, x):
#         return self.fc2(self.act(self.fc1(x)))

# class DynamicTransformerBlock(nn.Module):
#     def __init__(self, embed_dim, num_heads, mlp_ratio=4.0):
#         super().__init__()
#         self.norm1 = nn.LayerNorm(embed_dim)
#         self.attn = DynamicMultiHeadAttention(embed_dim, num_heads)
#         self.norm2 = nn.LayerNorm(embed_dim)
#         self.mlp = MLPBlock(embed_dim, mlp_ratio)

#     def forward(self, x):
#         x = x + self.attn(self.norm1(x))
#         x = x + self.mlp(self.norm2(x))
#         return x

# class DynamicViT(nn.Module):
#     def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, num_classes=200):
#         super().__init__()
#         self.patch_embed = DynamicPatchEmbed(img_size, patch_size, embed_dim)
#         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
#         self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, embed_dim))
#         self.blocks = nn.ModuleList([DynamicTransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth)])
#         self.norm = nn.LayerNorm(embed_dim)
#         self.head = nn.Linear(embed_dim, num_classes)

#     def forward(self, x):
#         x = self.patch_embed(x)
#         B = x.shape[0]
#         cls_tokens = self.cls_token.expand(B, -1, -1)
#         x = torch.cat((cls_tokens, x), dim=1)
#         x = x + self.pos_embed
#         for block in self.blocks:
#             x = block(x)
#         x = self.norm(x[:, 0])
#         return self.head(x)

# # === Pretrained weight loader ===
# def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
#     pretrained_vit = create_model(pretrained_model_name, pretrained=True)
#     pretrained_state_dict = pretrained_vit.state_dict()
#     model_state_dict = model.state_dict()
#     filtered_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
#     model.load_state_dict(filtered_dict, strict=False)
#     print("‚úÖ Pretrained weights loaded.")

# # === PCA Embedding Extraction ===
# # def extract_patch_embeddings(model, dataloader, num_samples=1000):
# #     model.eval()
# #     embeddings = []
# #     with torch.no_grad():
# #         for i, (images, _) in enumerate(tqdm(dataloader)):
# #             if i * images.size(0) >= num_samples:
# #                 break
# #             images = images.to(device)
# #             x = model.patch_embed(images)
# #             x = x.flatten(2).transpose(1, 2)
# #             embeddings.append(x.reshape(-1, x.shape[-1]))
# #     return torch.cat(embeddings, dim=0).cpu()

# # def extract_patch_embeddings(model, dataloader, max_samples=10000):
# #     model.eval()
# #     embeddings = []
# #     collected = 0
# #     with torch.no_grad():
# #         for images, _ in tqdm(dataloader):
# #             images = images.to(device)
# #             patches = model.patch_embed(images)  # [B, C, H, W] -> [B, embed_dim, H/patch, W/patch]
# #             B, C, H, W = patches.shape
# #             patches = patches.permute(0, 2, 3, 1).reshape(-1, C)  # Flatten all patches
# #             embeddings.append(patches.cpu())
# #             collected += patches.shape[0]
# #             if collected >= max_samples:
# #                 break
# #     return torch.cat(embeddings, dim=0)[:max_samples]
# def extract_patch_embeddings(model, dataloader, max_samples=10000):
#     model.eval()
#     embeddings = []
#     collected = 0
#     with torch.no_grad():
#         for images, _ in tqdm(dataloader):
#             images = images.to(device)
#             patches = model.patch_embed(images)  # [B, num_patches, embed_dim]
#             B, N, D = patches.shape
#             patches = patches.reshape(-1, D)  # Flatten to [B*N, embed_dim]
#             embeddings.append(patches.cpu())
#             collected += patches.shape[0]
#             if collected >= max_samples:
#                 break
#     return torch.cat(embeddings, dim=0)[:max_samples]


# # === PCA weight transformation ===
# def transform_patch_embed_weights(model, pca):
#     old_weight = model.patch_embed.proj.weight.data.view(768, -1).cpu().numpy()
#     old_bias = model.patch_embed.proj.bias.data.cpu().numpy()
#     new_weight = pca.transform(old_weight)  # (384,)
#     new_weight = torch.tensor(new_weight).view(384, 3, 16, 16)
#     new_bias = pca.transform(old_bias.reshape(1, -1))[0]
#     new_bias = torch.tensor(new_bias)
#     return new_weight, new_bias

# # === Create 384-dim ViT ===
# def create_vit_384(new_weight, new_bias, num_classes=200):
#     model_384 = DynamicViT(
#         img_size=224,
#         patch_size=16,
#         embed_dim=384,
#         depth=12,
#         num_heads=6,
#         mlp_ratio=4.0,
#         num_classes=num_classes
#     ).to(device)
#     model_384.patch_embed.proj.weight.data.copy_(new_weight)
#     model_384.patch_embed.proj.bias.data.copy_(new_bias)
#     return model_384

# # === Evaluate Model ===
# def evaluate_model(model, dataloader, criterion, device):
#     model.eval()
#     total_loss, correct, total = 0.0, 0, 0
#     with torch.no_grad():
#         for inputs, targets in dataloader:
#             inputs, targets = inputs.to(device), targets.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, targets)
#             total_loss += loss.item() * inputs.size(0)
#             _, predicted = outputs.max(1)
#             total += targets.size(0)
#             correct += predicted.eq(targets).sum().item()
#     avg_loss = total_loss / total
#     accuracy = 100. * correct / total
#     print(f"üìä Eval - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
#     return avg_loss, accuracy

# # === Main Execution ===
# model = DynamicViT(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, num_classes=200).to(device)
# load_pretrained_weights(model)
# # embeddings = extract_patch_embeddings(model, train_loader, num_samples=1000)
# # embeddings = embeddings.reshape(-1, embeddings.shape[-1])  # (1000 * 196, 768)

# # pca = PCA(n_components=384)
# # pca.fit(embeddings.numpy())

# embeddings = extract_patch_embeddings(model, train_loader, max_samples=10000)
# pca = PCA(n_components=384)
# pca.fit(embeddings.numpy())


# new_weight, new_bias = transform_patch_embed_weights(model, pca)
# model_384 = create_vit_384(new_weight, new_bias)

# criterion = nn.CrossEntropyLoss()
# evaluate_model(model_384, test_loader, criterion, device)


Training set size: 99548
Test set size: 24888
‚úÖ Pretrained weights loaded.


  0%|          | 0/1556 [00:00<?, ?it/s]


üìä Eval - Loss: 5.4428, Accuracy: 0.51%


(5.442845274037159, 0.5143040822886532)

In [35]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from sklearn.decomposition import PCA
from tqdm import tqdm
import numpy as np
import timm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Data Loading ===
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

data_dir = '/home/pratibha/nas_vision/vit_nas_imgnet/imagenet200'
full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(f"Training set size: {len(train_dataset)}")
print(f"Test set size: {len(test_dataset)}")

# === Dynamic ViT Model Definitions ===
class DynamicPatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size // patch_size) ** 2

    def forward(self, x):
        return self.proj(x)  # [B, embed_dim, H', W']

class DynamicMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.scale = (embed_dim // num_heads) ** -0.5

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(out)

class MLPBlock(nn.Module):
    def __init__(self, embed_dim, mlp_ratio=4.0):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

class DynamicTransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = DynamicMultiHeadAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLPBlock(embed_dim, mlp_ratio)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class DynamicViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, num_classes=200):
        super().__init__()
        self.patch_embed = DynamicPatchEmbed(img_size, patch_size, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, (img_size // patch_size) ** 2 + 1, embed_dim))
        self.blocks = nn.ModuleList([
            DynamicTransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)  # [B, C, H', W']
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # [B, N, C]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x[:, 0])
        return self.head(x)

# === Pretrained Weight Loading ===
def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
    pretrained = timm.create_model(pretrained_model_name, pretrained=True)
    model.patch_embed.proj.weight.data.copy_(pretrained.patch_embed.proj.weight.data)
    model.patch_embed.proj.bias.data.copy_(pretrained.patch_embed.proj.bias.data)
    print("‚úÖ Pretrained patch_embed weights loaded.")

# === Patch Embedding Extraction ===
def extract_patch_embeddings(model, dataloader, max_samples=10000):
    model.eval()
    embeddings = []
    seen = 0

    with torch.no_grad():
        for images, _ in tqdm(dataloader, desc="Extracting Patch Embeddings"):
            if seen >= max_samples:
                break
            images = images.to(device)
            patches = model.patch_embed(images)  # [B, C, H', W']
            B, C, H, W = patches.shape
            patches = patches.permute(0, 2, 3, 1).reshape(-1, C)  # [B*H'*W', C]
            embeddings.append(patches.cpu())
            seen += images.size(0)

    return torch.cat(embeddings, dim=0)

# === PCA + Weight Transform ===
def apply_pca_to_conv_weights(conv_layer, pca):
    old_weight = conv_layer.weight.data.view(768, -1).cpu().numpy()
    new_weight = pca.transform(old_weight)
    new_weight = torch.tensor(new_weight).view(384, 3, 16, 16)
    old_bias = conv_layer.bias.data.cpu().numpy()
    new_bias = pca.transform(old_bias.reshape(1, -1))[0]
    new_bias = torch.tensor(new_bias)
    return new_weight, new_bias

# === New 384-Dim Model ===
def create_vit_384(new_weight, new_bias, num_classes=200):
    model_384 = DynamicViT(embed_dim=384, depth=12, num_heads=6, num_classes=num_classes)
    model_384.patch_embed.proj.weight.data.copy_(new_weight)
    model_384.patch_embed.proj.bias.data.copy_(new_bias)
    return model_384.to(device)

# === Evaluation ===
def evaluate_model(model, dataloader, criterion):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)
            _, preds = outputs.max(1)
            total += labels.size(0)
            correct += preds.eq(labels).sum().item()
    acc = 100 * correct / total
    print(f"üìä Eval - Loss: {total_loss / total:.4f}, Accuracy: {acc:.2f}%")
    return total_loss / total, acc

# === Run Full Flow ===
model = DynamicViT(embed_dim=768, num_classes=200).to(device)
load_pretrained_weights(model)

embeddings = extract_patch_embeddings(model, train_loader, max_samples=10000)
print("‚úÖ Embedding shape for PCA:", embeddings.shape)  # Should be [N, 768]

pca = PCA(n_components=384)
pca.fit(embeddings.numpy())

new_weight, new_bias = apply_pca_to_conv_weights(model.patch_embed.proj, pca)
model_384 = create_vit_384(new_weight, new_bias)
evaluate_model(model_384, test_loader, nn.CrossEntropyLoss())


Training set size: 99548
Test set size: 24888
‚úÖ Pretrained patch_embed weights loaded.


Extracting Patch Embeddings:  10%|‚ñà         | 157/1556 [00:38<05:44,  4.06it/s]


‚úÖ Embedding shape for PCA: torch.Size([1969408, 768])
üìä Eval - Loss: 5.4709, Accuracy: 0.54%


(5.47090251757684, 0.5384120861459338)

In [38]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.decomposition import PCA
from timm import create_model
from tqdm import tqdm

# --- SETTINGS ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
target_embed_dim = 384
num_classes = 200
data_dir = '/home/pratibha/nas_vision/vit_nas_imgnet/imagenet200'

# --- DATA LOADERS ---
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
dataset = datasets.ImageFolder(root=data_dir, transform=transform)
train_size = int(0.8 * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# --- PATCH EMBEDDING EXTRACTOR ---
# def extract_patch_embeddings(model, dataloader, max_samples=10000):
#     model.eval()
#     collected = 0
#     embeddings = []

#     with torch.no_grad():
#         for images, _ in tqdm(dataloader, desc="Extracting embeddings"):
#             if collected >= max_samples:
#                 break
#             images = images.to(device)
#             x = model.patch_embed(images)  # (B, C, H, W)
#             B, C, H, W = x.shape
#             patches = x.permute(0, 2, 3, 1).reshape(-1, C)
#             embeddings.append(patches.cpu())
#             collected += images.size(0)

#     return torch.cat(embeddings, dim=0)[:max_samples * 196]  # shape: (N, 768)

def extract_patch_embeddings(model, dataloader, max_samples=10000):
    model.eval()
    collected = 0
    embeddings = []

    with torch.no_grad():
        for images, _ in tqdm(dataloader, desc="Extracting embeddings"):
            if collected >= max_samples:
                break
            images = images.to(device)
            x = model.patch_embed(images)  # [B, N, C] directly
            if x.ndim == 3:
                B, N, C = x.shape
                patches = x.reshape(-1, C)  # Flatten all patches
                embeddings.append(patches.cpu())
            collected += images.size(0)

    return torch.cat(embeddings, dim=0)[:max_samples * 196]  # (N * 196, 768)

# --- PCA FITTER ---
def fit_pca(embeddings, target_dim):
    print(f"‚úÖ Embedding shape for PCA: {embeddings.shape}")
    pca = PCA(n_components=target_dim)
    pca.fit(embeddings.numpy())
    return pca

# --- PATCH EMBED WEIGHT TRANSFORMER ---
def transform_patch_embed_weights(orig_conv: nn.Conv2d, pca: PCA):
    old_weight = orig_conv.weight.data.view(768, -1).cpu().numpy()  # (768, 3*16*16)
    new_weight = torch.tensor(pca.transform(old_weight)).view(target_embed_dim, 3, 16, 16)

    old_bias = orig_conv.bias.data.cpu().numpy().reshape(1, -1)  # (1, 768)
    new_bias = torch.tensor(pca.transform(old_bias)[0])

    return new_weight, new_bias

# --- NEW VIT MODEL WITH 384 EMBEDDING ---
# def create_vit_384(pretrained_vit, pca):
#     model_384 = create_model(
#         'vit_base_patch16_224',
#         pretrained=False,
#         num_classes=num_classes,
#         embed_dim=target_embed_dim
#     ).to(device)

#     # Transform patch embedding weights
#     new_weight, new_bias = transform_patch_embed_weights(pretrained_vit.patch_embed.proj, pca)
#     model_384.patch_embed.proj = nn.Conv2d(3, target_embed_dim, kernel_size=16, stride=16)
#     model_384.patch_embed.proj.weight.data.copy_(new_weight)
#     model_384.patch_embed.proj.bias.data.copy_(new_bias)

#     # Positional embedding PCA projection
#     pos_embed = pretrained_vit.pos_embed.data[:, 1:]  # Exclude class token
#     pos_embed_reduced = torch.tensor(pca.transform(pos_embed[0].cpu().numpy())).unsqueeze(0)
#     cls_token = torch.tensor(pca.transform(pretrained_vit.cls_token.data.cpu().numpy())[0]).unsqueeze(0).unsqueeze(0)
#     model_384.pos_embed = nn.Parameter(torch.cat([cls_token, pos_embed_reduced], dim=1).to(device))
#     model_384.cls_token = nn.Parameter(cls_token.to(device))

#     # Transfer transformer weights block-wise
#     for i in range(len(model_384.blocks)):
#         # Skip weights that depend on embed dim
#         for name, param in pretrained_vit.blocks[i].named_parameters():
#             if param.shape == getattr(model_384.blocks[i], name.split('.')[0]).__getattr__(name.split('.')[1]).shape:
#                 getattr(model_384.blocks[i], name.split('.')[0]).__getattr__(name.split('.')[1]).data.copy_(param.data)

#     # Normalize layer
#     if model_384.norm.weight.shape == pretrained_vit.norm.weight.shape:
#         model_384.norm.load_state_dict(pretrained_vit.norm.state_dict(), strict=False)

#     return model_384

def create_vit_384(pretrained_vit, pca):
    model_384 = timm.create_model('vit_base_patch16_224', pretrained=False)
    model_384.patch_embed.proj = nn.Linear(16*16*3, 384)
    model_384.pos_embed = nn.Parameter(torch.zeros(1, 197, 384))
    model_384.cls_token = nn.Parameter(torch.zeros(1, 1, 384))
    model_384.head = nn.Linear(384, 1000)

    # Transform patch embedding weights
    W = pretrained_vit.patch_embed.proj.weight.data.reshape(768, -1)  # (768, 16*16*3)
    W_reduced = torch.tensor(pca.components_ @ W.cpu().numpy()).float()
    model_384.patch_embed.proj.weight.data = W_reduced.reshape(384, 3, 16, 16).to(device)
    model_384.patch_embed.proj.bias.data = pretrained_vit.patch_embed.proj.bias.data[:384].to(device)

    # Transform positional embeddings
    pos_embed = pretrained_vit.pos_embed.data[:, 1:, :]  # Exclude CLS token: shape [1, 196, 768]
    pos_embed_np = pos_embed.cpu().numpy().reshape(-1, pos_embed.shape[-1])  # (196, 768)
    pos_embed_reduced_np = pca.transform(pos_embed_np)  # (196, 384)
    pos_embed_reduced = torch.tensor(pos_embed_reduced_np).reshape(1, pos_embed.shape[1], 384).to(device)

    # Transform CLS token embedding
    cls_token_np = pretrained_vit.cls_token.data.cpu().numpy().reshape(-1, pretrained_vit.cls_token.shape[-1])  # (1, 768)
    cls_token_reduced_np = pca.transform(cls_token_np)  # (1, 384)
    cls_token_reduced = torch.tensor(cls_token_reduced_np).unsqueeze(0).to(device)  # (1, 1, 384)

    # Set embeddings
    model_384.pos_embed.data = torch.cat([cls_token_reduced, pos_embed_reduced], dim=1)
    model_384.cls_token.data = cls_token_reduced.squeeze(0)

    # Copy transformer weights except patch embedding and head
    model_384.blocks.load_state_dict(pretrained_vit.blocks.state_dict())
    model_384.norm.load_state_dict(pretrained_vit.norm.state_dict())

    return model_384


# --- EVALUATION FUNCTION ---
def evaluate_model(model, dataloader, criterion):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    avg_loss = total_loss / total
    accuracy = 100. * correct / total
    print(f"üìä Eval - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
    return avg_loss, accuracy

# === MAIN EXECUTION ===
if __name__ == "__main__":
    # Load pretrained ViT-Base
    pretrained_vit = create_model('vit_base_patch16_224', pretrained=True)
    pretrained_vit.to(device)
    pretrained_vit.eval()
    print("‚úÖ Pretrained ViT-Base loaded.")

    # Extract patch embeddings and fit PCA
    embeddings = extract_patch_embeddings(pretrained_vit, train_loader, max_samples=10000)
    pca = fit_pca(embeddings, target_embed_dim)

    # Create new ViT model with 384-dim embedding
    model_384 = create_vit_384(pretrained_vit, pca)

    # Final evaluation
    criterion = nn.CrossEntropyLoss()
    evaluate_model(model_384, test_loader, criterion)


‚úÖ Pretrained ViT-Base loaded.


Extracting embeddings:  10%|‚ñà         | 157/1556 [00:37<05:33,  4.20it/s]


‚úÖ Embedding shape for PCA: torch.Size([1960000, 768])


RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D

In [None]:
# # Create a new model with updated embedding dimension
# model_384 = DynamicViT(
#     img_size=224,
#     patch_size=16,
#     embed_dim=384,
#     depth=12,
#     num_heads=6,  # must divide 384 evenly
#     mlp_ratio=4.0,
#     num_classes=200
# ).to(device)


In [None]:
# def evaluate_model(model, dataloader, criterion, device):
#     model.eval()
#     total_loss, correct, total = 0.0, 0, 0

#     with torch.no_grad():
#         for inputs, targets in dataloader:
#             inputs, targets = inputs.to(device), targets.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, targets)

#             total_loss += loss.item() * inputs.size(0)
#             _, predicted = outputs.max(1)
#             total += targets.size(0)
#             correct += predicted.eq(targets).sum().item()

#     avg_loss = total_loss / total
#     accuracy = 100. * correct / total
#     print(f"Evaluation - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
#     return avg_loss, accuracy



In [None]:
# # Define criterion
# criterion = nn.CrossEntropyLoss()

# # Call evaluation function
# evaluate_model(model_384, test_loader, criterion, device)


Evaluation - Loss: 5.4482, Accuracy: 0.64%


(5.448154718459745, 0.6428801028608164)

In [None]:
## evolutioonary

In [None]:
import random
import os
import torch
import torch.nn as nn
from torch.optim import Adam
from timm import create_model
import time

# Path to save the models after fine-tuning
# SAVE_PATH = '/SN02DATA/nas_vision/evol_img1k-wts'
SAVE_PATH = '/home/pratibha/nas_vision/weights-img-evol28-may'

# SAVE_PATH = '/kaggle/working/'

# Set the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# First-time loading pretrained weights for initialization
def load_pretrained_weights(model, pretrained_model_name="vit_base_patch16_224"):
    pretrained_vit = create_model(pretrained_model_name, pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    # Match keys between pretrained and current model
    model_state_dict = model.state_dict()
    filtered_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}

    # Load pretrained weights
    model.load_state_dict(filtered_dict, strict=False)
    print(f"Pretrained weights loaded into {model.__class__.__name__} successfully.")

# Check if pretrained weights are loaded correctly
def check_pretrained_weights(model, generation=0, model_type="subnetwork"):
    pretrained_vit = create_model("vit_base_patch16_224", pretrained=True)
    pretrained_state_dict = pretrained_vit.state_dict()
    
    model_state_dict = model.state_dict()
    matching_keys = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
    
    if len(matching_keys) > 0:
        print(f"Generation {generation + 1}: {model_type} model has loaded {len(matching_keys)} layers from pretrained weights.")
    else:
        print(f"Generation {generation + 1}: {model_type} model has NOT loaded any pretrained weights.")

# Sample Subnetwork - Randomly sample hyperparameters (depth, num_heads, etc.)
def sample_subnetwork(seen_architectures):
    while True:
        depth = random.choice([6, 8, 10, 12])
        num_heads = random.choice([4, 8, 12, 16])
        mlp_ratio = random.choice([2.0, 4.0, 6.0])
        embed_dim = 768  # Fixed embedding dimension
        
        architecture = (depth, num_heads, mlp_ratio, embed_dim)
        
        # Skip if architecture has already been sampled
        if architecture not in seen_architectures:
            seen_architectures.add(architecture)
            print(f"Sampled architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
            
            # Create the model to calculate its number of parameters
            # sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000)
            sampled_model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
                                        depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, 
                                        num_classes=200
                                    )
            num_params = count_parameters(sampled_model)
            print(f"Number of parameters in the sampled model: {num_params:,}")
            
            return architecture
        else:
            print(f"Repeated architecture found, resampling...")

# Count number of trainable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def topk_accuracy(output, target, topk=(1,5)):
    """Computes the top-k accuracy for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size).item())
    return res  # [top1, top5]



from ptflops import get_model_complexity_info

def get_macs(model):
    with torch.cuda.device(0):
        macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=False, print_per_layer_stat=False)
    return macs

def evaluate_architecture(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    top1_total = 0
    top5_total = 0
    criterion = nn.CrossEntropyLoss()
    start_time = time.time()

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            top1, top5 = topk_accuracy(outputs, labels, topk=(1,5))
            top1_total += top1 * labels.size(0) / 100.0
            top5_total += top5 * labels.size(0) / 100.0
            total += labels.size(0)

    latency = (time.time() - start_time) / total
    accuracy = 100 * top1_total / total
    top5_accuracy = 100 * top5_total / total
    num_params = count_parameters(model)
    memory_usage = (num_params * 4) / (1024 ** 2)
    test_loss = running_loss / len(test_loader)

    print(f"Test Loss: {test_loss:.4f}, Top-1 Acc: {accuracy:.2f}%, Top-5 Acc: {top5_accuracy:.2f}%, Latency: {latency:.6f}s/img, Mem: {memory_usage:.2f}MB")
    macs = get_macs(model)
    print(f"MACs: {macs / 1e6:.2f} M")
    return accuracy, top5_accuracy, test_loss, latency, memory_usage, macs





# Estimate memory usage of a model during inference (rough estimation)
def estimate_memory_usage(model):                                             ## this funtion is not needed
    # Create dummy input matching the expected shape of the input tensor
    dummy_input = torch.randn(1, 3, 224, 224).to(device)  # Example for ViT (3-channel image of size 224x224)
    
    # Use torch.utils.benchmark to measure memory usage during inference
    start_mem = torch.cuda.memory_allocated()
    
    # Run the model once with the dummy input
    with torch.no_grad():
        model(dummy_input)
    
    end_mem = torch.cuda.memory_allocated()
    memory_usage = (end_mem - start_mem) / (1024 ** 2)  # Convert bytes to MB
    return memory_usage


def calculate_crowding_distance(population, test_loader):
    crowding_distances = [0] * len(population)
    num_objectives = 3  # Accuracy, Latency, Memory

    # Evaluate each architecture once, then reuse the results
    evaluated_results = []
    for arch in population:
        # # model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
        #                    depth=arch[0], num_heads=arch[1],
        #                    mlp_ratio=arch[2], num_classes=10).to(device)
        model = DynamicViT(img_size=224, patch_size=16, embed_dim=arch[3],
                            depth=arch[0], num_heads=arch[1], mlp_ratio=arch[2], 
                            num_classes=200).to(device)

        accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
        memory = count_parameters(model) * 4  # memory in bytes
        
        evaluated_results.append((accuracy, latency, memory))
        del model
        torch.cuda.empty_cache()

    for objective_index in range(num_objectives):
        sorted_indices = sorted(range(len(population)),
                                key=lambda idx: evaluated_results[idx][objective_index])
        
        crowding_distances[sorted_indices[0]] = crowding_distances[sorted_indices[-1]] = float('inf')

        for i in range(1, len(sorted_indices) - 1):
            prev_value = evaluated_results[sorted_indices[i - 1]][objective_index]
            next_value = evaluated_results[sorted_indices[i + 1]][objective_index]
            distance = next_value - prev_value
            crowding_distances[sorted_indices[i]] += distance

    return crowding_distances


def dominates(model1, model2, test_loader):
    # Evaluate both models on the test set
    accuracy1, _, _,latency1, _, _ = evaluate_architecture(model1, test_loader)
    accuracy2, _, _,latency2, _ , _= evaluate_architecture(model2, test_loader)
    
    # Calculate memory usage as the number of parameters * 4 bytes (FP32)
    memory1 = count_parameters(model1) * 4  # Memory in bytes
    memory2 = count_parameters(model2) * 4  # Memory in bytes
    
    # Compare performance metrics
    dominates_in_accuracy = accuracy1 >= accuracy2
    dominates_in_latency = latency1 <= latency2
    dominates_in_memory = memory1 <= memory2

    # Return True if model1 dominates model2 in all aspects
    return dominates_in_accuracy and dominates_in_latency and dominates_in_memory


# Mutation: Randomly mutate architecture's hyperparameters
def mutate(architecture):
    depth, num_heads, mlp_ratio, embed_dim = architecture
    if random.random() < 0.5: depth = random.choice([ 6, 8, 10, 12])
    if random.random() < 0.5: num_heads = random.choice([4, 8, 12, 16])
    if random.random() < 0.5: mlp_ratio = random.choice([2.0, 4.0, 6.0])
    print(f"Mutated architecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}")
    return depth, num_heads, mlp_ratio, embed_dim

# One-Point Crossover: Combine two parent architectures to create new architectures
def one_point_crossover(parent1, parent2):
    crossover_point = random.choice([0, 1, 2, 3])  # Crossover at depth, num_heads, etc.
    child1 = parent1[:crossover_point] + parent2[crossover_point:]
    child2 = parent2[:crossover_point] + parent1[crossover_point:]
    print(f"Crossover result: Child1={child1}, Child2={child2}")
    return child1, child2



############################# this is not weight based instead it is pareto selection
# Optimized Pareto selection based on stored performance metrics
def pareto_selection(arch_performance):
    def dominates(perf1, perf2):
        acc1, lat1, mem1 = perf1
        acc2, lat2, mem2 = perf2
        return (acc1 >= acc2 and lat1 <= lat2 and mem1 <= mem2) and (acc1 > acc2 or lat1 < lat2 or mem1 < mem2)

    ranks = {}
    for arch1, perf1 in arch_performance.items():
        dominated_count = 0
        for arch2, perf2 in arch_performance.items():
            if arch1 != arch2 and dominates(perf2, perf1):
                dominated_count += 1
        ranks[arch1] = dominated_count

    # Sort architectures by rank (lower dominated_count = better)
    sorted_population = sorted(ranks.keys(), key=lambda arch: ranks[arch])
    return sorted_population



# def fine_tune_model(sampled_model, train_loader, test_loader, epochs=3, architecture_folder=None):
#     print(f"Fine-tuning model with architecture: Depth={sampled_model.depth}, Num Heads={sampled_model.num_heads}, MLP Ratio={sampled_model.mlp_ratio}")
#     sampled_model.to(device)
#     criterion = nn.CrossEntropyLoss()
#     optimizer = Adam(sampled_model.parameters(), lr=1e-4)
    
#     for epoch in range(epochs):
#         start_epoch = time.time()
#         sampled_model.train()
#         running_loss = 0.0
#         for images, labels in train_loader:
#             images, labels = images.to(device), labels.to(device)
#             optimizer.zero_grad()
#             outputs = sampled_model(images)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#             running_loss += loss.item()
#         epoch_time = time.time() - start_epoch
#         test_accuracy, test_top5, test_loss, test_latency, memory_usage = evaluate_architecture(sampled_model, test_loader)
#         print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss:.4f}, Top-1 Acc: {test_accuracy:.2f}%, Top-5 Acc: {test_top5:.2f}%, Latency: {test_latency:.6f}s/img, Time: {epoch_time:.2f}s")
#     # Save model code unchanged
#     if architecture_folder:
#         os.makedirs(architecture_folder, exist_ok=True)
#         torch.save(sampled_model.state_dict(), os.path.join(architecture_folder, 'checkpoint.pth'))
#     return sampled_model

def fine_tune_model(sampled_model, train_loader, test_loader, epochs=3, architecture_folder=None):
    print(f"Fine-tuning model with architecture: Depth={sampled_model.depth}, Num Heads={sampled_model.num_heads}, MLP Ratio={sampled_model.mlp_ratio}")
    sampled_model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(sampled_model.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        start_epoch = time.time()
        sampled_model.train()
        running_loss = 0.0
        
        # Training phase
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = sampled_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        # Evaluation phase
        epoch_time = time.time() - start_epoch
        test_accuracy, test_top5, test_loss, test_latency, memory_usage, macs = evaluate_architecture(sampled_model, test_loader)
        
        # Print epoch statistics
        print(f"\nEpoch {epoch + 1}/{epochs} Summary:")
        print(f"| Training Loss: {running_loss/len(train_loader):.4f}")
        print(f"| Test Loss: {test_loss:.4f}")
        print(f"| Top-1 Accuracy: {test_accuracy:.2f}%")
        print(f"| Top-5 Accuracy: {test_top5:.2f}%")
        print(f"| Latency: {test_latency:.6f}s/img")
        print(f"| Memory Usage: {memory_usage:.2f}MB")
        print(f"| MACs: {macs/1e6:.2f}M")
        print(f"| Epoch Time: {epoch_time:.2f}s\n")

    # Save model weights
    if architecture_folder:
        os.makedirs(architecture_folder, exist_ok=True)
        torch.save(sampled_model.state_dict(), os.path.join(architecture_folder, 'checkpoint.pth'))
    return sampled_model



def save_top_ranked_models(population, arch_performance, generation):
    top_n = min(5, len(population))
    for idx, arch in enumerate(population[:top_n]):
        depth, num_heads, mlp_ratio, embed_dim = arch
        # model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
        #                    num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=1000).to(device)
        model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim, depth=depth,
                            num_heads=num_heads, mlp_ratio=mlp_ratio, 
                            num_classes=200).to(device)


        architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
        checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')
        model.load_state_dict(torch.load(checkpoint_path))

        top_model_path = os.path.join(SAVE_PATH, f'top_ranked_model_gen{generation+1}_rank_{idx+1}.pth')
        torch.save(model.state_dict(), top_model_path)
        
        acc, lat, mem = arch_performance[arch]

        # with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
        #     f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
        #     f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")
        with open(top_model_path.replace('.pth', '.txt'), 'w') as f:
            f.write(f"Rank: {idx+1}\nArchitecture: Depth={depth}, Num Heads={num_heads}, MLP Ratio={mlp_ratio}, Embed Dim={embed_dim}\n")
            f.write(f"Accuracy: {acc:.2f}%, Latency: {lat:.6f}s/image, Memory: {mem / (1024 ** 2):.2f}MB\n")


        print(f"Saved top-ranked model: Generation {generation+1}, Rank {idx+1} (Acc={acc:.2f}%, Lat={lat:.6f}, Mem={mem/(1024**2):.2f}MB)")
        
        


def plot_pareto_front(arch_performance):
    accuracies = [v[0] for v in arch_performance.values()]
    latencies = [v[1] for v in arch_performance.values()]
    memories = [v[2] / (1024**2) for v in arch_performance.values()]  # convert to MB

    # Accuracy vs Latency
    plt.figure(figsize=(8,6))
    plt.scatter(latencies, accuracies, c='blue')
    plt.xlabel('Latency (s/image)')
    plt.ylabel('Accuracy (%)')
    plt.title('Pareto Front (Accuracy vs Latency)')
    plt.grid()
    plt.show()

    # Accuracy vs Memory
    plt.figure(figsize=(8,6))
    plt.scatter(memories, accuracies, c='green')
    plt.xlabel('Memory (MB)')
    plt.ylabel('Accuracy (%)')
    plt.title('Pareto Front (Accuracy vs Memory)')
    plt.grid()
    plt.show()

#

def evolutionary_algorithm(population_size=16, generations=5, mutation_rate=0.1, crossover_rate=0.7, train_loader=None, test_loader=None):
    seen_architectures = set()
    population = [sample_subnetwork(seen_architectures) for _ in range(population_size)]
    arch_performance = {}

    prev_best_accuracy = 0
    no_improvement_count = 0

    for generation in range(generations):
        print(f"\n--- Generation {generation + 1}/{generations} ---")

        for arch in population:
            depth, num_heads, mlp_ratio, embed_dim = arch
            architecture_folder = os.path.join(SAVE_PATH, f"arch_{depth}_{num_heads}_{mlp_ratio}_{embed_dim}")
            checkpoint_path = os.path.join(architecture_folder, 'checkpoint.pth')

            model = DynamicViT(img_size=224, patch_size=16, embed_dim=embed_dim,
                               depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
                               num_classes=200).to(device)

            # Clearly load weights once per architecture
            if os.path.exists(checkpoint_path):
                model.load_state_dict(torch.load(checkpoint_path))
                print(f"Loaded weights from previous generation for architecture {arch}")
            else:
                load_pretrained_weights(model)

            fine_tune_model(model, train_loader, test_loader, epochs=5, architecture_folder=architecture_folder)

            # accuracy, _, latency, _ = evaluate_architecture(model, test_loader)
            accuracy, top5_accuracy, test_loss, latency, memory_usage, macs = evaluate_architecture(model, test_loader)
            memory = count_parameters(model) * 4 / (1024 ** 2)  # MB
            # arch_performance[arch] = (accuracy, latency, memory)
            arch_performance[arch] = (accuracy, top5_accuracy, latency, memory_usage, macs)

            del model
            torch.cuda.empty_cache()

        # Pareto selection
        population = pareto_selection(arch_performance)

        print("\nTop 5 Ranked Models of Generation", generation+1)
        for idx, arch in enumerate(population[:5]):
            acc, top5_acc, lat, mem, macs = arch_performance[arch]
            print(f"Rank {idx+1}: Model {arch} | Top-1 Acc: {acc:.2f}%, Top-5 Acc: {top5_acc:.2f}%, Latency: {lat:.6f}s/img, Mem: {mem:.2f}MB, MACs: {macs/1e6:.2f}M")
            # Saving top-ranked models
            save_top_ranked_models(population, arch_performance, generation)

        # Check for Pareto front convergence (early stopping criteria)
        current_best_accuracy = arch_performance[population[0]][0]
        if current_best_accuracy - prev_best_accuracy < 1.0:
            no_improvement_count += 1
            print(f"Minimal improvement detected: {current_best_accuracy - prev_best_accuracy:.2f}%")
            if no_improvement_count >= 2:
                print("Pareto front has converged. Stopping early.")
                break
        else:
            no_improvement_count = 0
        prev_best_accuracy = current_best_accuracy

        # Generate offspring
        next_population = population[:len(population)//2]  # Only top half
        offspring = []

        for i in range(0, len(next_population)-1, 2):
            parent1, parent2 = next_population[i], next_population[i+1]

            if random.random() < crossover_rate:
                child1, child2 = one_point_crossover(parent1, parent2)
                print(f"Crossover parents: {parent1} & {parent2}")
                offspring.extend([child1, child2])
            else:
                offspring.extend([parent1, parent2])

        # Mutation with clear logging
        mutated_offspring = []
        for child in offspring:
            if random.random() < mutation_rate:
                original_child = child
                child = mutate(child)
                print(f"Mutated from {original_child} to {child}")
            mutated_offspring.append(child)

        population = next_population + mutated_offspring

        print(f"\nAfter mutation and crossover, {len(mutated_offspring)} offspring models generated.")
        print("Only top 5 models will be used for the next generation.")

    # Plot Pareto Front at the end
    plot_pareto_front(arch_performance)

    return population

# Run the evolutionary algorithm
evolutionary_algorithm(population_size=10, generations=5, train_loader=train_loader, test_loader=test_loader)

