In [None]:
import sys
import os
from pathlib import Path

# Get absolute path to project root
project_root = Path(os.path.abspath('')).parent.parent
sys.path.append(str(project_root))

from dotenv import load_dotenv
load_dotenv()

nih_dataset_root_dir = os.getenv("NIH_CXR14_DATASET_DIR")

main_output_dir = "../data"
os.makedirs(main_output_dir, exist_ok=True)

In [2]:
from src.datasets import NIHImageDataset, NIHFindingLabels


nih_finding_labels = NIHFindingLabels.load_from_processed(main_output_dir)


#print sample of the labels

sample = nih_finding_labels[0]

print(sample)


(tensor([0., 0., 1., 0., 0.], dtype=torch.float64), '00000001_002.png', ['Infiltration', 'No Finding', 'Effusion', 'Nodule/Mass', 'Atelectasis'])


In [3]:
import torchvision.transforms as transforms


img_size = 224

# Image transformations
image_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

nih_image_dataset = NIHImageDataset(root_dir=nih_dataset_root_dir,
                              img_size = img_size,
                              transform=image_transform
)



In [4]:
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, image_dataset, finding_labels):
        self.image_dataset = image_dataset
        self.finding_labels = finding_labels

    def __len__(self):
        return len(self.finding_labels)
    
    def __getitem__(self, idx):
        # Get labels, image_id and label names from finding_labels
        _, image_id, _ = self.finding_labels[idx]
        
        # Get image using string-based lookup (image_id is already a string)
        image, _ = self.image_dataset[image_id]
        
        return image_id, image

In [5]:

device = "cuda"
batch_size = 32
num_workers = 4


custom_dataset = CustomDataset(nih_image_dataset, nih_finding_labels)
custom_dataloader = DataLoader(
    custom_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

In [6]:
from src.pipelines import VaeProcessor

vae_processor = VaeProcessor(device= device)




In [7]:
import pickle
import os

def update_pickle(pickle_file, data):
    
    try:
        # If file exists, load and update
        if os.path.exists(pickle_file):
            with open(pickle_file, 'rb') as f:
                old_data = pickle.load(f)
            old_data.update(data)
        else:
            # If file doesn't exist, use new data directly
            old_data = data
        
        # Save updated data
        with open(pickle_file, 'wb') as f:
            pickle.dump(old_data, f)
            
        return True
        
    except Exception as e:
        raise Exception(f"Error updating pickle file: {str(e)}")

In [8]:
from tqdm.notebook import tqdm
import os
import torch

# Dictionary to store all latents
latents_dict = {}

try:
    for image_ids, input_tensors in tqdm(custom_dataloader):

        input_tensors = input_tensors.to(device)
        
        with torch.no_grad():
            batch_latents = vae_processor.prepare_latent(image=input_tensors)
        
        for i, image_id in enumerate(image_ids):
            latents_dict[image_id] = batch_latents[i].detach().cpu()
        
        if len(latents_dict) >= 1000:
            update_pickle("latents.pkl", latents_dict)
            latents_dict = {}
    if len(latents_dict) > 0:
        update_pickle("latents.pkl", latents_dict)
        latents_dict = {}

except Exception as e:
    print(f"Error: {str(e)}")
    update_pickle("latents.pkl", latents_dict)
    latents_dict = {}


  0%|          | 0/1510 [00:00<?, ?it/s]