In [17]:
# Cell 1: Imports and Setup
import pandas as pd
from sentence_transformers import SentenceTransformer
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import pickle
from PIL import Image
import faiss  # Library for efficient similarity search
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define the folder path for images if needed
image_base_folder = ""  # Update this to your actual base folder path if needed

# Cell 2: Define Model Class and Preprocessing
class ImageTextSimilarityModel:
    def __init__(self, text_model_name='all-MiniLM-L6-v2', image_model_name='resnet50'):
        # Set up device
        self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

        # Load text and image models
        self.text_model = SentenceTransformer(text_model_name)
        self.image_model = models.resnet50(pretrained=True)
        self.image_model = torch.nn.Sequential(*(list(self.image_model.children())[:-1]))  # Remove last layer
        self.image_model.to(self.device)  # Move image model to the specified device

        # Other initializations
        self.index = None
        self.ecg_ids = None
        self.average_text_embedding = None  # To use as a proxy for text embeddings in inference

    def preprocess_image(self, image_path):
        preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        image = Image.open(image_path).convert("RGB")
        return preprocess(image).unsqueeze(0).to(self.device)  # Ensure image tensor is on the same device

    # Cell 3: Load Data and Compute Embeddings with Full Relative Path and Logging
    def compute_embeddings(self, df):
        logging.info("Computing text embeddings")
        text_embeddings = self.text_model.encode(df['report'].tolist(), convert_to_tensor=True)
        
        # Calculate average text embedding for inference proxy
        self.average_text_embedding = torch.mean(text_embeddings, dim=0).to(self.device)  # Move to device

        logging.info("Computing image embeddings")
        image_embeddings = []
        for relative_path in df['filename_lr']:  # Use full relative path to construct the image path
            image_path = f"{image_base_folder}{relative_path}.png"  # Construct full image path with .png
            
            # Log the image path being processed
            logging.info(f"Processing image: {image_path}")
            
            # Process and embed image
            image_tensor = self.preprocess_image(image_path)
            image_embedding = self.image_model(image_tensor).squeeze().detach()
            image_embeddings.append(image_embedding)

        # Concatenate image and text embeddings
        image_embeddings = torch.stack(image_embeddings)
        combined_embeddings = torch.cat((image_embeddings, text_embeddings.to(self.device)), dim=1).cpu().numpy()
        
        # Set ECG IDs and prepare FAISS index
        self.ecg_ids = df['ecg_id'].tolist()
        embedding_dim = combined_embeddings.shape[1]
        self.index = faiss.IndexFlatL2(embedding_dim)
        self.index.add(combined_embeddings)

    # Cell 4: Save and Load Model
    def save_model(self, embeddings_path='faiss_embeddings.pkl', faiss_index_path='faiss_index.bin'):
        with open(embeddings_path, 'wb') as f:
            pickle.dump({'ecg_ids': self.ecg_ids, 'average_text_embedding': self.average_text_embedding}, f)
        faiss.write_index(self.index, faiss_index_path)
        logging.info("Model and FAISS index saved successfully.")

    def load_model(self, embeddings_path='faiss_embeddings.pkl', faiss_index_path='faiss_index.bin'):
        with open(embeddings_path, 'rb') as f:
            data = pickle.load(f)
            self.ecg_ids = data['ecg_ids']
            self.average_text_embedding = data['average_text_embedding'].to(self.device)  # Move to device
        self.index = faiss.read_index(faiss_index_path)
        logging.info("Model and FAISS index loaded successfully.")

    # Cell 5: Inference using FAISS for L2 Distance
    def find_similar_text_reports_faiss(self, image_relative_path, top_k=5):
        image_path = f"{image_base_folder}{image_relative_path}.png"  # Construct full image path
        logging.info(f"Finding similar text reports for image (FAISS): {image_path}")
        
        # Process and embed the input image
        image_tensor = self.preprocess_image(image_path)  # Image tensor already on device
        image_embedding = self.image_model(image_tensor).squeeze().detach()
        
        # Concatenate with average text embedding for inference
        input_embedding = torch.cat((image_embedding, self.average_text_embedding.detach())).cpu().numpy()
        
        # Perform FAISS similarity search
        _, indices = self.index.search(np.array([input_embedding]), top_k)
        
        # Retrieve corresponding ECG IDs
        similar_reports = [(self.ecg_ids[i], i) for i in indices[0]]
        
        return similar_reports

    # Cell 6: Inference using Cosine Similarity
    def find_similar_text_reports_cosine(self, image_relative_path, top_k=5):
        image_path = f"{image_base_folder}{image_relative_path}.png"  # Construct full image path
        logging.info(f"Finding similar text reports for image (Cosine Similarity): {image_path}")
        
        # Process and embed the input image
        image_tensor = self.preprocess_image(image_path)  # Image tensor already on device
        image_embedding = self.image_model(image_tensor).squeeze().detach()
        
        # Concatenate with average text embedding for inference
        input_embedding = torch.cat((image_embedding, self.average_text_embedding.detach())).cpu().numpy()
        
        # Calculate cosine similarity for each embedding in the index
        similar_reports = []
        for i in range(self.index.ntotal):
            candidate_embedding = self.index.reconstruct(i)
            similarity_score = cosine_similarity([input_embedding], [candidate_embedding])[0][0]
            similar_reports.append((self.ecg_ids[i], similarity_score))
        
        # Sort by similarity score in descending order and return top_k results
        similar_reports = sorted(similar_reports, key=lambda x: x[1], reverse=True)[:top_k]
        
        return similar_reports

# Example Usage
if __name__ == "__main__":
    model = ImageTextSimilarityModel()
    df = pd.read_csv('ecg_reports_outpu100_filtered.csv')  # Load your CSV file with data
    
    # Compute and save embeddings
    # model.compute_embeddings(df)
    # model.save_model()

    # Load model for inference
    model.load_model()
    image_relative_path = "records100/00000/00004_lr"  # Input image relative path without .png extension
    
    # Find similar reports using FAISS
    print("Top 5 similar text reports using FAISS (L2 Distance):")
    similar_reports_faiss = model.find_similar_text_reports_faiss(image_relative_path)
    for ecg_id, score in similar_reports_faiss:
        print(f"ECG ID: {ecg_id}, Index: {score}")
    
    # Find similar reports using Cosine Similarity
    print("\nTop 5 similar text reports using Cosine Similarity:")
    similar_reports_cosine = model.find_similar_text_reports_cosine(image_relative_path)
    for ecg_id, similarity_score in similar_reports_cosine:
        print(f"ECG ID: {ecg_id}, Cosine Similarity Score: {similarity_score:.4f}")


2024-10-31 18:05:58,290 - INFO - Use pytorch device_name: mps
2024-10-31 18:05:58,291 - INFO - Load pretrained SentenceTransformer: all-MiniLM-L6-v2
2024-10-31 18:06:02,069 - INFO - Model and FAISS index loaded successfully.
2024-10-31 18:06:02,070 - INFO - Finding similar text reports for image (FAISS): records100/00000/00004_lr.png
2024-10-31 18:06:02,190 - INFO - Finding similar text reports for image (Cosine Similarity): records100/00000/00004_lr.png


Top 5 similar text reports using FAISS (L2 Distance):
ECG ID: 4, Index: 3
ECG ID: 624, Index: 611
ECG ID: 353, Index: 345
ECG ID: 79, Index: 78
ECG ID: 317, Index: 309

Top 5 similar text reports using Cosine Similarity:
ECG ID: 4, Cosine Similarity Score: 0.9998
ECG ID: 624, Cosine Similarity Score: 0.9970
ECG ID: 353, Cosine Similarity Score: 0.9970
ECG ID: 317, Cosine Similarity Score: 0.9969
ECG ID: 79, Cosine Similarity Score: 0.9969
