# Vector Embedding Extraction


`data_path` is assumed to have the splitted data named as "train/", "val/" and "test/"

In [27]:
# Define the model name
model_name = "mobilenetv4_r448_pretrained"#"efficientnet_b3" #EfficientNet_B7_Weights.IMAGENET1K_V1
weights_path = '/home/sebastian/codes/QuantumVE/q_Net/pretrain/mobilenetv4_r448/checkpoint-99.pth'
feat_space = 16
batch_size = 64
data_path = '../data/ABGQI_mel_spectrograms'
device = 'cuda'
output_dir = "embeddings"

# Create experiment directory
EXPERIMENT_NAME = f"./{output_dir}/{model_name}_{feat_space}_bs{batch_size}"
import os
os.makedirs(EXPERIMENT_NAME, exist_ok=True)

# Load pre-trained model with Timm

In [44]:
import timm
import torch
import torch.nn as nn

def count_parameters(model, message=""):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"{message} Trainable params: {trainable_params} of {total_params}")

# Define extract_embeddings class
class extract_embeddings(nn.Module):
    def __init__(self, base_model, feat_space, model_name):
        super(extract_embeddings, self).__init__()
        self.base_model = base_model
        self.model_name = model_name
        self.feat_space = feat_space
        
        # Example: Adding a new classifier layer
        if model_name == "mobilenetv4_r448_pretrained":
            self.new_classifier = nn.Linear(self.base_model.conv_head.out_channels, out_features=self.feat_space)

    def forward(self, x):
        x = self.base_model(x)
        x = self.new_classifier(x)
        return x
    
def load_and_initialize_model(model_name, weights_path, feat_space):
    model = timm.create_model('mobilenetv4_conv_aa_large.e230_r448_in12k_ft_in1k', pretrained=False, num_classes=0)

    # Count parameters before loading the checkpoint
    count_parameters(model, message="Before loading checkpoint")

    checkpoint = torch.load(weights_path, map_location='cpu')
    checkpoint_model = checkpoint['model']

    # Count parameters after loading the checkpoint
    count_parameters(model, message="After loading checkpoint")

    # Initialize the extract_embeddings with the base model and new classifier
    model = extract_embeddings(base_model=model, feat_space=feat_space, model_name=model_name)
    # Load updated checkpoint into the model
    model.load_state_dict(checkpoint_model, strict=False)

    # Count parameters of the custom model
    count_parameters(model, message="Custom model parameters")
    
    return model

# Pre-trained model with 5 classes so need to be initialized with with the same
model = load_and_initialize_model(model_name, weights_path, 5)

# Create a dummy input tensor (e.g., batch size 1, 3 channels, 448x448 image size)
dummy_input = torch.randn(1, 3, 448, 448)  # Adjust size if necessary

# Test the model
model.eval()
with torch.no_grad():
    output_features = model(dummy_input)
print("Output features shape:", output_features.shape)

[16:35:52.249011] [16:35:52.249008] [16:35:52.249067] [16:35:52.248985] [16:35:52.249076] [16:35:52.249074] [16:35:52.249082] Before loading checkpoint Trainable params: 31309864 of 31309864
[16:35:52.384900] [16:35:52.384897] [16:35:52.384919] [16:35:52.384882] [16:35:52.384929] [16:35:52.384927] [16:35:52.384935] After loading checkpoint Trainable params: 31309864 of 31309864
[16:35:52.402295] [16:35:52.402293] [16:35:52.402315] [16:35:52.402277] [16:35:52.402325] [16:35:52.402323] [16:35:52.402331] Custom model parameters Trainable params: 31316269 of 31316269
[16:35:52.477369] [16:35:52.477366] [16:35:52.477414] [16:35:52.477352] [16:35:52.477423] [16:35:52.477421] [16:35:52.477428] Output features shape: torch.Size([1, 5])


Adapt model for N embedding-sizes

In [29]:
# Initialize the extract_embeddings with the base model and new classifier
model = extract_embeddings(base_model=model.base_model, feat_space=feat_space, model_name=model_name)
# Create a dummy input tensor (e.g., batch size 1, 3 channels, 448x448 image size)
dummy_input = torch.randn(1, 3, 448, 448)  # Adjust size if necessary

model.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Disable gradient calculation for inference
    output_features = model(dummy_input)

# Print the shape of the output features
print("Output features shape:", output_features.shape)

[16:07:21.856092] [16:07:21.856076] [16:07:21.856151] Output features shape: torch.Size([1, 16])


## Utility function

In [30]:
import torch
import timm
import torch.nn as nn
import argparse
from util.datasets import build_dataset
import sys
sys.path.insert(0,'../') 
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import random_split
from torch.utils.data import Subset, DataLoader, random_split
from torchvision import datasets, transforms
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt

import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import pandas as pd 
# from MAE code
from util.datasets import build_dataset
import argparse
import util.misc as misc
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter

import timm

from timm.models.layers import trunc_normal_
from timm.data.mixup import Mixup
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy

import util.lr_decay as lrd
import util.misc as misc
from util.datasets import build_dataset
from util.pos_embed import interpolate_pos_embed
from util.misc import NativeScalerWithGradNormCount as NativeScaler

# import models_vit
import sys
import os
import torch
import numpy as np

import matplotlib.pyplot as plt
from PIL import Image
import torch; print(f'numpy version: {np.__version__}\nCUDA version: {torch.version.cuda} - Torch versteion: {torch.__version__} - device count: {torch.cuda.device_count()}')

from timm.data import Mixup
from timm.utils import accuracy
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
from itertools import cycle
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score
import torch.optim as optim
import torchvision.models as models
import torch.nn as nn
import torch
import pandas as pd
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, fbeta_score
from sklearn.metrics import precision_score, recall_score, f1_score, fbeta_score
import numpy as np

imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])

def count_parameters(model, message=""):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"{message} Trainable params: {trainable_params} of {total_params}")

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
    # build model
    model = getattr(models_mae, arch)()
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

def plot_multiclass_roc_curve(all_labels, all_predictions, EXPERIMENT_NAME="."):
    # Step 1: Label Binarization
    label_binarizer = LabelBinarizer()
    y_onehot = label_binarizer.fit_transform(all_labels)
    all_predictions_hot = label_binarizer.transform(all_predictions)

    # Step 2: Calculate ROC curves
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    unique_classes = range(y_onehot.shape[1])
    for i in unique_classes:
        fpr[i], tpr[i], _ = roc_curve(y_onehot[:, i], all_predictions_hot[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Step 3: Plot ROC curves
    fig, ax = plt.subplots(figsize=(8, 8))

    # Micro-average ROC curve
    fpr_micro, tpr_micro, _ = roc_curve(y_onehot.ravel(), all_predictions_hot.ravel())
    roc_auc_micro = auc(fpr_micro, tpr_micro)
    plt.plot(
        fpr_micro,
        tpr_micro,
        label=f"micro-average ROC curve (AUC = {roc_auc_micro:.2f})",
        color="deeppink",
        linestyle=":",
        linewidth=4,
    )

    # Macro-average ROC curve
    all_fpr = np.unique(np.concatenate([fpr[i] for i in unique_classes]))
    mean_tpr = np.zeros_like(all_fpr)
    for i in unique_classes:
        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
    mean_tpr /= len(unique_classes)
    fpr_macro = all_fpr
    tpr_macro = mean_tpr
    roc_auc_macro = auc(fpr_macro, tpr_macro)
    plt.plot(
        fpr_macro,
        tpr_macro,
        label=f"macro-average ROC curve (AUC = {roc_auc_macro:.2f})",
        color="navy",
        linestyle=":",
        linewidth=4,
    )

    # Individual class ROC curves with unique colors
    colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_classes)))
    for class_id, color in zip(unique_classes, colors):
        plt.plot(
            fpr[class_id],
            tpr[class_id],
            color=color,
            label=f"ROC curve for Class {class_id} (AUC = {roc_auc[class_id]:.2f})",
            linewidth=2,
        )

    plt.plot([0, 1], [0, 1], color='gray', linestyle='--', linewidth=2)  # Add diagonal line for reference
    plt.axis("equal")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Extension of Receiver Operating Characteristic\n to One-vs-Rest multiclass")
    plt.legend()
    plt.savefig(f'{EXPERIMENT_NAME}/roc_curve.png')
    plt.show()
    

[16:07:23.217299] [16:07:23.217291] [16:07:23.217324] numpy version: 1.26.4
CUDA version: 11.8 - Torch versteion: 2.0.0+cu118 - device count: 2


## Parametrize and initialize seed

In [31]:
parser = argparse.ArgumentParser('VE extraction', add_help=False)
parser.add_argument('--batch_size', default=batch_size, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--accum_iter', default=4, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters
parser.add_argument('--model', default=model_name, type=str, metavar='MODEL',
                        help='Name of model to train')

parser.add_argument('--input_size', default=224, type=int,
                        help='images input size')

parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
                        help='Drop path rate (default: 0.1)')

    # Optimizer parameters
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
                        help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--weight_decay', type=float, default=0.05,
                        help='weight decay (default: 0.05)')

parser.add_argument('--lr', type=float, default=None, metavar='LR',
                        help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=5e-4, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--layer_decay', type=float, default=0.65,
                        help='layer-wise lr decay from ELECTRA/BEiT')

parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')

parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
                        help='epochs to warmup LR')

    # Augmentation parameters
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
                        help='Color jitter factor (enabled only when not using Auto/RandAug)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                        help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.1,
                        help='Label smoothing (default: 0.1)')

    # * Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                        help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
                        help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
                        help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
                        help='Do not random erase first (clean) augmentation split')
    # * Mixup params
parser.add_argument('--mixup', type=float, default=0.8,
                        help='mixup alpha, mixup enabled if > 0.')
parser.add_argument('--cutmix', type=float, default=1.0,
                        help='cutmix alpha, cutmix enabled if > 0.')
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup_prob', type=float, default=1.0,
                        help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
                        help='Probability of switching to cutmix when both mixup and cutmix enabled')

parser.add_argument('--mixup_mode', type=str, default='batch',
                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')

    # * Finetuning params
parser.add_argument('--finetune', default='mae_pretrain_vit_base.pth',
                        help='finetune from checkpoint')
parser.add_argument('--global_pool', action='store_true')
parser.set_defaults(global_pool=True)
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
                        help='Use class token instead of global pool for classification')
# Dataset parameters
parser.add_argument('--data_path', default=data_path, type=str,
                        help='dataset path')
parser.add_argument('--nb_classes', default=5, type=int,
                        help='number of the classification types')
parser.add_argument('--output_dir', default=EXPERIMENT_NAME,
                        help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_dir',
                        help='path where to tensorboard log')

parser.add_argument('--device', default=device,
                        help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default=".",
                        help='resume from checkpoint')

parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
parser.add_argument('--eval',default=True, action='store_true',
                        help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False,
                        help='Enabling distributed evaluation (recommended during training for faster monitor')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)

    # distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')

args, unknown = parser.parse_known_args()
print("{}".format(args).replace(', ', ',\n'))
os.makedirs(args.output_dir, exist_ok=True)
device = torch.device(args.device)


# set seeds
misc.init_distributed_mode(args)
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True

[16:07:24.254396] [16:07:24.254382] [16:07:24.254453] Namespace(batch_size=64,
epochs=50,
accum_iter=4,
model='mobilenetv4_r448_pretrained',
input_size=224,
drop_path=0.1,
clip_grad=None,
weight_decay=0.05,
lr=None,
blr=0.0005,
layer_decay=0.65,
min_lr=1e-06,
warmup_epochs=5,
color_jitter=None,
aa='rand-m9-mstd0.5-inc1',
smoothing=0.1,
reprob=0.25,
remode='pixel',
recount=1,
resplit=False,
mixup=0.8,
cutmix=1.0,
cutmix_minmax=None,
mixup_prob=1.0,
mixup_switch_prob=0.5,
mixup_mode='batch',
finetune='mae_pretrain_vit_base.pth',
global_pool=True,
data_path='../data/ABGQI_mel_spectrograms',
nb_classes=5,
output_dir='./embeddings/mobilenetv4_r448_pretrained_16_bs64',
log_dir='./output_dir',
device='cuda',
seed=0,
resume='.',
start_epoch=0,
eval=True,
dist_eval=False,
num_workers=10,
pin_mem=True,
world_size=1,
local_rank=-1,
dist_on_itp=False,
dist_url='env://')
[16:07:24.254680] [16:07:24.254671] [16:07:24.254699] Not using distributed mode


## Adapt dataloaders

In [32]:
dataset_train = build_dataset(is_train=True, args=args)
dataset_val = build_dataset(is_train=False, args=args)

if True:  # args.distributed:
        num_tasks = misc.get_world_size()
        global_rank = misc.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_train = %s" % str(sampler_train))
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                      'This will slightly alter validation results as extra duplicate entries are added to achieve '
                      'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True)  # shuffle=True to reduce monitor bias
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

if global_rank == 0 and args.log_dir is not None and not args.eval:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
else:
        log_writer = None

data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
)

data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=False
)

[16:07:24.814370] [16:07:24.814365] [16:07:24.814423] [16:07:24.814348] [16:07:24.814443] [16:07:24.814439] [16:07:24.814455] Dataset ImageFolder
    Number of datapoints: 7814
    Root location: ../data/ABGQI_mel_spectrograms/train
    StandardTransform
Transform: Compose(
               RandomResizedCropAndInterpolation(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bicubic)
               RandomHorizontalFlip(p=0.5)
               ColorJitter(brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=None)
               MaybeToTensor()
               Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
           )
[16:07:24.818972] [16:07:24.818968] [16:07:24.818993] [16:07:24.818955] [16:07:24.819009] [16:07:24.819006] [16:07:24.819021] Dataset ImageFolder
    Number of datapoints: 850
    Root location: ../data/ABGQI_mel_spectrograms/val
    StandardTransform
Transform: Compose(
               Resize(size=256

## Extract embeddings

In [33]:
def extract_embeddings(model, data_loader, save_path, device, preprocess=None,data_config=None, transforms=None):
    embeddings_list = []
    targets_list = []
    total_batches = len(data_loader)
    with torch.no_grad(), tqdm(total=total_batches) as pbar:
        model.eval()  # Set the model to evaluation mode
        model.to(device)
        for images, targets in data_loader:
            if preprocess:
                print("required processing")
                images = preprocess(images).squeeze()
                images = images.to(device)
                embeddings = model(images)
            if transforms: # for timm models
                
                # get model specific transforms (normalization, resize)
                data_config = timm.data.resolve_model_data_config(model)
                transforms = timm.data.create_transform(**data_config, is_training=False)
                images = images.to(device)                
                embeddings = model(transforms(images))# output is (batch_size, num_features) shaped tensor

            embeddings_list.append(embeddings.cpu().detach().numpy())  # Move to CPU and convert to NumPy
            targets_list.append(targets.numpy())  # Convert targets to NumPy
            pbar.update(1)

    # Concatenate embeddings and targets from all batches
    embeddings = np.concatenate(embeddings_list).squeeze()
    targets = np.concatenate(targets_list)
    num_embeddings = embeddings.shape[1]
    column_names = [f"feat_{i}" for i in range(num_embeddings)]
    column_names.append("label")

    embeddings_with_targets = np.hstack((embeddings, np.expand_dims(targets, axis=1)))

    # Create a DataFrame with column names
    df = pd.DataFrame(embeddings_with_targets, columns=column_names)
    
    df.to_csv(save_path, index=False)
    
preprocess=None
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
# Extract embeddings for training data
extract_embeddings(model, data_loader_train, f'{EXPERIMENT_NAME}/train_embeddings.csv', device, preprocess, data_config, transforms)
    
# Extract embeddings for validation data
extract_embeddings(model, data_loader_val, f'{EXPERIMENT_NAME}/val_embeddings.csv', device, preprocess,data_config, transforms)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 122/122 [00:05<00:00, 23.74it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00, 11.48it/s]
