In [1]:
import sys
import os
sys.path.append(os.path.abspath("..")) 

from dotenv import load_dotenv
load_dotenv()

IMAGE_SIZE = 224


# DATA ANALYZE
This section aims to analyze the NIH-CXR14 dataset through these operations:

* **NIHDataset**: Loads tabular data and image paths from NIH-CXR14 dataset 

* **Field Filtering**: Selects essential columns `["Image Index", "Finding Labels", "Image Path"]`

* **Balance Adjustment**: Handles class imbalance using the `Finding Labels` field

* **Label Exclusion**: Sets aside specific labels for validation using unseen data

In [2]:
from src.datasets import NIHDataset

In [3]:
nih_dataset = NIHDataset(root_dir = os.getenv("NIH_CXR14_DATASET_DIR"), img_size = 224)

In [4]:
len(nih_dataset)

112120

In [5]:
fields = nih_dataset.get_fields()
print(fields)

['Image Index', 'Finding Labels', 'Follow-up #', 'Patient ID', 'Patient Age', 'Patient Gender', 'View Position', 'OriginalImage[Width', 'Height]', 'OriginalImagePixelSpacing[x', 'y]', 'Unnamed: 11', 'Image Path']


In [6]:
field_filtered_dataset = nih_dataset.select_columns(["Image Index", "Finding Labels", "Image Path"])

In [7]:
fields = field_filtered_dataset.get_fields()
print(fields)

['Image Index', 'Finding Labels', 'Image Path']


In [8]:
label_counts = nih_dataset.get_label_counts()
print(label_counts)

{'Effusion': 13317, 'Emphysema': 2516, 'Infiltration': 19894, 'Pleural_Thickening': 3385, 'No Finding': 60361, 'Fibrosis': 1686, 'Mass': 5782, 'Atelectasis': 11559, 'Pneumothorax': 5302, 'Consolidation': 4667, 'Nodule': 6331, 'Cardiomegaly': 2776, 'Edema': 2303, 'Pneumonia': 1431, 'Hernia': 227}


In [9]:
# apply limit for "No Finding" label
dataset = field_filtered_dataset.limit_samples(label="No Finding", max_samples=20000)
label_counts = dataset.get_label_counts()
print(label_counts)

{'No Finding': 20000, 'Effusion': 13317, 'Emphysema': 2516, 'Infiltration': 19894, 'Pleural_Thickening': 3385, 'Fibrosis': 1686, 'Mass': 5782, 'Atelectasis': 11559, 'Pneumothorax': 5302, 'Consolidation': 4667, 'Nodule': 6331, 'Cardiomegaly': 2776, 'Edema': 2303, 'Pneumonia': 1431, 'Hernia': 227}


In [10]:
# drop some diseases labels to try the model can predict not trained labels
exclude_labels = ["Hernia", 'Atelectasis']

for label in exclude_labels:
    dataset = dataset.filter_by_label(label, exclude=True)

label_counts = dataset.get_label_counts()
print(label_counts)

{'No Finding': 20000, 'Effusion': 10027, 'Emphysema': 2088, 'Infiltration': 16602, 'Pleural_Thickening': 2882, 'Fibrosis': 1459, 'Mass': 5020, 'Pneumothorax': 4525, 'Consolidation': 3440, 'Nodule': 5734, 'Cardiomegaly': 2399, 'Edema': 2080, 'Pneumonia': 1166}


## Creating Image-Text Pair Dataset
**these will be used for feeding `vae` and `clip-text`**

In [11]:
from PIL import Image

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np

image_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])


class IMGTEXTDataset(Dataset):
    def __init__(self, dataset, transforms=None):
        """
        Args:
            dataset: The NIH dataset instance
            transforms: torchvision transforms to be applied to images
        """
        self.dataset = dataset
        self.transforms = transforms

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        
        # Get image path and labels
        image_path = sample["Image Path"]
        labels = sample["Finding Labels"]
        
        # Open image and keep it as PIL Image
        image = Image.open(image_path).convert("RGB")
        
        # Apply transforms if any
        if self.transforms:
            image = self.transforms(image)  # Apply transforms directly to PIL Image
            
        # Join labels into a sentence
        sentence = " ".join(labels)
        
        return image, sentence




# Creating Latent-TextEmbed Pair Dataset

**These will be used when training diffision model**

In [12]:
from src.pipelines import VaeProcessor, CLIPTextProcessor



device = "cuda:2"


vae_processor = VaeProcessor(device=device) # optional vae model can be passed as argument
clip_processor = CLIPTextProcessor(device=device)






In [13]:
import pickle
import torch
from torch.utils.data import DataLoader
from pathlib import Path

# Use batch processing for efficiency
dataloader = DataLoader(IMGTEXTDataset(dataset, image_transform), batch_size=16, shuffle=False)

# Create directory if it doesn't exist
save_dir = Path("data")
save_dir.mkdir(exist_ok=True)

# Initialize storage dictionary
embeddings = {
    'latents': {},
    'texts': {}
}

current_idx = 0
for images, texts in dataloader:
    # Process batch

    images = images.to(vae_processor.device)

    # Get text embeddings
    text_embeddings = clip_processor.encode_text(texts)


    # Get image latents
    image_latents = vae_processor.prepare_latent(images)

    
    # Store individual samples from the batch
    for i in range(len(images)):
        embeddings['latents'][current_idx] = image_latents[i].detach().cpu().numpy()
        embeddings['texts'][current_idx] = text_embeddings[i].detach().cpu().numpy()
        current_idx += 1
    
    if current_idx % 100 == 0:
        print(f"Processed {current_idx} samples")
        # Periodically save to avoid memory issues
        with open(save_dir / "embeddings.pkl", "wb") as f:
            pickle.dump(embeddings, f)

# Final save
with open(save_dir / "embeddings.pkl", "wb") as f:
    pickle.dump(embeddings, f)

print(f"Total samples processed: {current_idx}")

Processed 400 samples
Processed 800 samples
Processed 1200 samples
Processed 1600 samples
Processed 2000 samples
Processed 2400 samples
Processed 2800 samples
Processed 3200 samples
Processed 3600 samples
Processed 4000 samples
Processed 4400 samples
Processed 4800 samples
Processed 5200 samples
Processed 5600 samples
Processed 6000 samples
Processed 6400 samples
Processed 6800 samples
Processed 7200 samples
Processed 7600 samples
Processed 8000 samples
Processed 8400 samples
Processed 8800 samples
Processed 9200 samples
Processed 9600 samples
Processed 10000 samples
Processed 10400 samples
Processed 10800 samples
Processed 11200 samples
Processed 11600 samples
Processed 12000 samples
Processed 12400 samples
Processed 12800 samples
Processed 13200 samples
Processed 13600 samples
Processed 14000 samples
Processed 14400 samples
Processed 14800 samples
Processed 15200 samples
Processed 15600 samples
Processed 16000 samples
Processed 16400 samples
Processed 16800 samples
Processed 17200 sa

: 