In [None]:
import sys
import os
from pathlib import Path

# Get absolute path to project root
project_root = Path(os.path.abspath('')).parent.parent
sys.path.append(str(project_root))

from dotenv import load_dotenv
load_dotenv()

nih_dataset_root_dir = os.getenv("NIH_CXR14_DATASET_DIR")

main_output_dir = "../data"
os.makedirs(main_output_dir, exist_ok=True)


In [None]:
from src.datasets import NIHFindingLabels


nih_finding_labels = NIHFindingLabels.load_from_processed(main_output_dir)


#print sample of the labels

sample = nih_finding_labels[0]

print(sample)


In [None]:
from torch.utils.data import Dataset, DataLoader


class CustomDataset(Dataset):
    def __init__(self, label_dataset):
        self.label_dataset = label_dataset
        # Convert dict_keys to list for indexing
        self.str_labels = list(label_dataset.label_counts.keys())

    def __len__(self):
        return len(self.label_dataset)

    def __getitem__(self, idx):
        binary_labels, image_id, _ = self.label_dataset[idx]

        # Build sentence from positive labels
        sentences = []
        for i, valid in enumerate(binary_labels):
            if valid:
                sentences.append(self.str_labels[i])

        return image_id, ", ".join(sentences)


    

In [None]:
device = "cuda"
batch_size = 64
num_workers = 24
shuffle = False
pin_memory = True

dataset = CustomDataset(nih_finding_labels)

dataloader = DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=shuffle,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

In [None]:
from src.pipelines import CLIPTextProcessor


clip_text_processor = CLIPTextProcessor(device=device)

In [None]:
import pickle
import os

def update_pickle(pickle_file, data):
    
    try:
        # If file exists, load and update
        if os.path.exists(pickle_file):
            with open(pickle_file, 'rb') as f:
                old_data = pickle.load(f)
            old_data.update(data)
        else:
            # If file doesn't exist, use new data directly
            old_data = data
        
        # Save updated data
        with open(pickle_file, 'wb') as f:
            pickle.dump(old_data, f)
            
        return True
        
    except Exception as e:
        raise Exception(f"Error updating pickle file: {str(e)}")

In [None]:
from tqdm.notebook import tqdm
import os
import torch

output_file_path = os.path.join(main_output_dir, "clip_text_embeds.pkl")
text_embeds = {}

try:
    for image_ids, sentences in tqdm(dataloader):
        # Generate embeddings
        embeds = clip_text_processor.encode_text(sentences)
        embeds = embeds.detach().cpu().numpy()
        
        # Store embeddings for each image
        for i, image_id in enumerate(image_ids):
            text_embeds[image_id] = embeds[i]
            
        # Save periodically to avoid data loss
        if len(text_embeds) % 1000 == 0:
            print(f"\nSaving {len(text_embeds)} embeddings...")
            update_pickle(output_file_path, text_embeds)
            text_embeds = {}  # Clear memory

    # Save any remaining embeddings
    if text_embeds:
        print(f"\nSaving final {len(text_embeds)} embeddings...")
        update_pickle(output_file_path, text_embeds)

except Exception as e:
    print(f"Error in generating text embeddings: {str(e)}")
    
    # Emergency save of any processed embeddings
    if text_embeds:
        emergency_path = os.path.join(main_output_dir, "clip_text_embeds_emergency.pkl")
        print(f"Attempting emergency save to {emergency_path}")
        try:
            update_pickle(emergency_path, text_embeds)
        except Exception as save_error:
            print(f"Emergency save failed: {str(save_error)}")
    raise e