In [1]:
import os
import pandas as pd
import torch
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.models import efficientnet_b3
import torch.nn as nn
import numpy as np
from tqdm import tqdm

data_path = "/home/sebastian/codes/data/quantum/MNIST"
output_dir = "/home/sebastian/codes/data/quantum/embeddings"
batch_size = 64
embedding_size = 512

os.makedirs(output_dir, exist_ok=True)
train_output_dir = os.path.join(output_dir, "train_embeddings")
test_output_dir = os.path.join(output_dir, "test_embeddings")

os.makedirs(train_output_dir, exist_ok=True)
os.makedirs(test_output_dir, exist_ok=True)

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((300, 300)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.MNIST(root=data_path, train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root=data_path, train=False, download=True, transform=transform)

dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

model = efficientnet_b3(weights="IMAGENET1K_V1")
model.classifier[1] = nn.Linear(model.classifier[1].in_features, out_features=embedding_size)
model.eval()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

def extract_and_save_embeddings(model, dataloader, device, output_dir, name):
    all_embeddings = []
    labels = []
    index = 0

    with torch.no_grad():
        for images, targets in tqdm(dataloader, desc="Extracting embeddings"):
            images = images.to(device)
            embeddings = model(images)

            for embedding, target in zip(embeddings.cpu(), targets):
                all_embeddings.append(embedding.numpy())
                labels.append(target.item())
                embedding_df = pd.DataFrame(embedding.numpy()).T
                embedding_df['label'] = target.item()
                embedding_csv_path = os.path.join(output_dir, f"{target.item()}_{index}.csv")
                embedding_df.to_csv(embedding_csv_path, index=False)
                index += 1

    consolidated_df = pd.DataFrame(all_embeddings)
    consolidated_df['label'] = labels
    consolidated_csv_path = os.path.join(output_dir, f"{name}_all_embeddings.csv")
    consolidated_df.to_csv(consolidated_csv_path, index=False)
    print(len(consolidated_df), index)
    
extract_and_save_embeddings(model, dataloader, device, train_output_dir, "train")
extract_and_save_embeddings(model, dataloader_test, device, test_output_dir, "test")

print("Embeddings extracted and saved successfully!")

Extracting embeddings: 100%|██████████████████████████████| 938/938 [03:19<00:00,  4.71it/s]


60000 60000


Extracting embeddings: 100%|██████████████████████████████| 157/157 [00:28<00:00,  5.43it/s]


10000 10000
Embeddings extracted and saved successfully!
