In [1]:
import os
import numpy as np
import cv2
from PIL import Image, ImageEnhance, ImageFilter
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage import exposure, filters, morphology, segmentation, color
from skimage.restoration import denoise_bilateral, denoise_nl_means
from skimage.transform import rotate
from scipy import ndimage
import random
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set matplotlib parameters for Times New Roman font
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 24
plt.rcParams['figure.dpi'] = 1000

class HistopathologyPreprocessor:
    def __init__(self, input_dir, output_dir):
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.classes = ['normal', 'oscc']
        
        # Create output directories for each preprocessing technique
        self.techniques = [
            'original',
            'gaussian_blur',
            'bilateral_filter',
            'nlm_denoising',
            'histogram_equalization',
            'clahe',
            'gamma_correction',
            'contrast_enhancement',
            'brightness_adjustment',
            'color_normalization',
            'edge_enhancement',
            'morphological_opening',
            'morphological_closing',
            'gaussian_noise',
            'salt_pepper_noise',
            'rotation_augmentation',
            'stain_normalization',
            'median_filter',
            'unsharp_masking',
            'adaptive_threshold'
        ]
        
        self.create_output_directories()
    
    def create_output_directories(self):
        """Create output directories for each technique and class"""
        for technique in self.techniques:
            for class_name in self.classes:
                dir_path = os.path.join(self.output_dir, technique, class_name)
                os.makedirs(dir_path, exist_ok=True)
    
    def load_image(self, image_path):
        """Load image as RGB"""
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return image
    
    def save_image(self, image, technique, class_name, filename):
        """Save processed image"""
        output_path = os.path.join(self.output_dir, technique, class_name, filename)
        cv2.imwrite(output_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
        return output_path
    
    def gaussian_blur(self, image, kernel_size=5):
        """Apply Gaussian blur"""
        return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
    
    def bilateral_filter(self, image, d=9, sigma_color=75, sigma_space=75):
        """Apply bilateral filter for noise reduction while preserving edges"""
        return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
    
    def nlm_denoising(self, image):
        """Apply Non-Local Means denoising"""
        # Convert to float for skimage processing
        image_float = image.astype(np.float32) / 255.0
        
        # Check scikit-image version and use appropriate parameter
        try:
            # For newer versions of scikit-image (>= 0.19)
            denoised = denoise_nl_means(image_float, h=0.1, fast_mode=True, channel_axis=-1)
        except TypeError:
            try:
                # For older versions that still support multichannel
                denoised = denoise_nl_means(image_float, h=0.1, fast_mode=True, multichannel=True)
            except TypeError:
                # Fallback: process each channel separately
                result = np.zeros_like(image_float)
                for i in range(3):
                    result[:, :, i] = denoise_nl_means(image_float[:, :, i], h=0.1, fast_mode=True)
                denoised = result
        
        return (denoised * 255).astype(np.uint8)
    
    def histogram_equalization(self, image):
        """Apply histogram equalization to each channel"""
        result = np.zeros_like(image)
        for i in range(3):
            result[:, :, i] = cv2.equalizeHist(image[:, :, i])
        return result
    
    def clahe(self, image, clip_limit=2.0, tile_grid_size=(8, 8)):
        """Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)"""
        clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
        result = np.zeros_like(image)
        for i in range(3):
            result[:, :, i] = clahe.apply(image[:, :, i])
        return result
    
    def gamma_correction(self, image, gamma=1.2):
        """Apply gamma correction"""
        gamma_corrected = np.power(image / 255.0, gamma)
        return (gamma_corrected * 255).astype(np.uint8)
    
    def contrast_enhancement(self, image, alpha=1.3):
        """Enhance contrast"""
        enhanced = cv2.convertScaleAbs(image, alpha=alpha, beta=0)
        return enhanced
    
    def brightness_adjustment(self, image, beta=20):
        """Adjust brightness"""
        adjusted = cv2.convertScaleAbs(image, alpha=1.0, beta=beta)
        return adjusted
    
    def color_normalization(self, image):
        """Normalize color channels"""
        result = np.zeros_like(image, dtype=np.float32)
        for i in range(3):
            channel = image[:, :, i].astype(np.float32)
            mean = np.mean(channel)
            std = np.std(channel)
            result[:, :, i] = (channel - mean) / (std + 1e-8)
        
        # Rescale to 0-255
        result = ((result - result.min()) / (result.max() - result.min()) * 255).astype(np.uint8)
        return result
    
    def edge_enhancement(self, image):
        """Enhance edges using unsharp masking"""
        # Convert to grayscale for edge detection, then apply to all channels
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        edges = cv2.Laplacian(gray, cv2.CV_64F)
        edges = np.uint8(np.absolute(edges))
        
        # Apply edge enhancement to each channel
        result = np.zeros_like(image)
        for i in range(3):
            result[:, :, i] = cv2.addWeighted(image[:, :, i], 1.0, edges, 0.3, 0)
        
        return result
    
    def morphological_opening(self, image, kernel_size=5):
        """Apply morphological opening"""
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
        result = np.zeros_like(image)
        for i in range(3):
            result[:, :, i] = cv2.morphologyEx(image[:, :, i], cv2.MORPH_OPEN, kernel)
        return result
    
    def morphological_closing(self, image, kernel_size=5):
        """Apply morphological closing"""
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
        result = np.zeros_like(image)
        for i in range(3):
            result[:, :, i] = cv2.morphologyEx(image[:, :, i], cv2.MORPH_CLOSE, kernel)
        return result
    
    def add_gaussian_noise(self, image, std=25):
        """Add Gaussian noise"""
        noise = np.random.normal(0, std, image.shape).astype(np.int16)
        noisy_image = image.astype(np.int16) + noise
        noisy_image = np.clip(noisy_image, 0, 255).astype(np.uint8)
        return noisy_image
    
    def add_salt_pepper_noise(self, image, noise_ratio=0.05):
        """Add salt and pepper noise"""
        result = image.copy()
        h, w, c = image.shape
        
        # Salt noise
        num_salt = np.ceil(noise_ratio * image.size * 0.5)
        coords = [np.random.randint(0, i - 1, int(num_salt)) for i in image.shape[:2]]
        result[coords[0], coords[1], :] = 255
        
        # Pepper noise
        num_pepper = np.ceil(noise_ratio * image.size * 0.5)
        coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in image.shape[:2]]
        result[coords[0], coords[1], :] = 0
        
        return result
    
    def rotation_augmentation(self, image, angle=15):
        """Apply rotation augmentation"""
        h, w = image.shape[:2]
        center = (w // 2, h // 2)
        rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
        rotated = cv2.warpAffine(image, rotation_matrix, (w, h), borderMode=cv2.BORDER_REFLECT)
        return rotated
    
    def stain_normalization(self, image):
        """Simple stain normalization using histogram matching"""
        # Convert to LAB color space
        lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
        
        # Normalize each channel
        result = np.zeros_like(lab, dtype=np.float32)
        for i in range(3):
            channel = lab[:, :, i].astype(np.float32)
            # Normalize to mean=128, std=30 for L channel, mean=128, std=20 for a,b channels
            target_mean = 128
            target_std = 30 if i == 0 else 20
            
            current_mean = np.mean(channel)
            current_std = np.std(channel)
            
            if current_std > 0:
                normalized = (channel - current_mean) / current_std * target_std + target_mean
                result[:, :, i] = np.clip(normalized, 0, 255)
            else:
                result[:, :, i] = channel
        
        # Convert back to RGB
        result = result.astype(np.uint8)
        rgb_result = cv2.cvtColor(result, cv2.COLOR_LAB2RGB)
        return rgb_result
    
    def median_filter(self, image, kernel_size=5):
        """Apply median filter"""
        result = np.zeros_like(image)
        for i in range(3):
            result[:, :, i] = cv2.medianBlur(image[:, :, i], kernel_size)
        return result
    
    def unsharp_masking(self, image, strength=1.5, radius=1, threshold=0):
        """Apply unsharp masking"""
        # Create Gaussian blur
        blurred = cv2.GaussianBlur(image, (0, 0), radius)
        
        # Create unsharp mask
        unsharp_mask = cv2.addWeighted(image, 1 + strength, blurred, -strength, 0)
        
        return unsharp_mask
    
    def adaptive_threshold(self, image):
        """Apply adaptive thresholding (convert to grayscale first, then back to RGB)"""
        # Convert to grayscale
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Apply adaptive threshold
        thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                                     cv2.THRESH_BINARY, 11, 2)
        
        # Convert back to RGB
        result = cv2.cvtColor(thresh, cv2.COLOR_GRAY2RGB)
        return result
    
    def process_image(self, image_path, filename):
        """Process a single image with all techniques"""
        image = self.load_image(image_path)
        class_name = os.path.basename(os.path.dirname(image_path))
        
        # Dictionary to store processed images for visualization
        processed_images = {}
        
        # Original image
        self.save_image(image, 'original', class_name, filename)
        processed_images['original'] = image.copy()
        
        # Apply all preprocessing techniques
        techniques_functions = {
            'gaussian_blur': lambda img: self.gaussian_blur(img),
            'bilateral_filter': lambda img: self.bilateral_filter(img),
            'nlm_denoising': lambda img: self.nlm_denoising(img),
            'histogram_equalization': lambda img: self.histogram_equalization(img),
            'clahe': lambda img: self.clahe(img),
            'gamma_correction': lambda img: self.gamma_correction(img),
            'contrast_enhancement': lambda img: self.contrast_enhancement(img),
            'brightness_adjustment': lambda img: self.brightness_adjustment(img),
            'color_normalization': lambda img: self.color_normalization(img),
            'edge_enhancement': lambda img: self.edge_enhancement(img),
            'morphological_opening': lambda img: self.morphological_opening(img),
            'morphological_closing': lambda img: self.morphological_closing(img),
            'gaussian_noise': lambda img: self.add_gaussian_noise(img),
            'salt_pepper_noise': lambda img: self.add_salt_pepper_noise(img),
            'rotation_augmentation': lambda img: self.rotation_augmentation(img),
            'stain_normalization': lambda img: self.stain_normalization(img),
            'median_filter': lambda img: self.median_filter(img),
            'unsharp_masking': lambda img: self.unsharp_masking(img),
            'adaptive_threshold': lambda img: self.adaptive_threshold(img)
        }
        
        for technique_name, technique_func in techniques_functions.items():
            try:
                processed_img = technique_func(image)
                self.save_image(processed_img, technique_name, class_name, filename)
                processed_images[technique_name] = processed_img.copy()
            except Exception as e:
                print(f"Error applying {technique_name} to {filename}: {str(e)}")
        
        return processed_images
    
    def process_dataset(self):
        """Process the entire dataset"""
        sample_images = {}  # Store sample images for visualization
        
        for class_name in self.classes:
            class_dir = os.path.join(self.input_dir, class_name)
            if not os.path.exists(class_dir):
                print(f"Warning: Directory {class_dir} does not exist!")
                continue
            
            image_files = [f for f in os.listdir(class_dir) 
                          if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp'))]
            
            print(f"Processing {len(image_files)} images in {class_name} class...")
            
            # Process each image
            for filename in tqdm(image_files, desc=f"Processing {class_name}"):
                image_path = os.path.join(class_dir, filename)
                processed_images = self.process_image(image_path, filename)
                
                # Store the first image as sample for visualization
                if class_name not in sample_images:
                    sample_images[class_name] = {
                        'filename': filename,
                        'images': processed_images
                    }
        
        return sample_images
    
    def create_visualization_figure(self, sample_images, output_path):
        """Create a figure showing different preprocessing techniques"""
        # Select a subset of techniques for visualization (to fit in the figure)
        selected_techniques = [
            'original', 'gaussian_blur', 'bilateral_filter', 'histogram_equalization',
            'clahe', 'gamma_correction', 'contrast_enhancement', 'edge_enhancement',
            'stain_normalization', 'unsharp_masking'
        ]
        
        n_techniques = len(selected_techniques)
        n_classes = len(self.classes)
        
        # Create figure
        fig, axes = plt.subplots(n_classes, n_techniques, figsize=(n_techniques * 4, n_classes * 4))
        fig.patch.set_facecolor('white')  # Set figure background to white
        
        # Remove any grid settings globally for this figure
        plt.rcParams['axes.grid'] = False
        
        if n_classes == 1:
            axes = axes.reshape(1, -1)
        
        for class_idx, class_name in enumerate(self.classes):
            if class_name in sample_images:
                for tech_idx, technique in enumerate(selected_techniques):
                    if technique in sample_images[class_name]['images']:
                        img = sample_images[class_name]['images'][technique]
                        axes[class_idx, tech_idx].imshow(img)
                        axes[class_idx, tech_idx].set_title(
                            f"{technique.replace('_', ' ').title()}", 
                            fontsize=20, fontname='Times New Roman'
                        )
                        axes[class_idx, tech_idx].axis('off')
                        axes[class_idx, tech_idx].grid(False)
                
                # Add class label on the left
                axes[class_idx, 0].text(-0.1, 0.5, f"{class_name.upper()}", 
                                       rotation=90, fontsize=24, fontname='Times New Roman',
                                       ha='center', va='center', transform=axes[class_idx, 0].transAxes)
        
        plt.tight_layout()
        
        # Save as PNG and PDF with high DPI and transparency
        plt.savefig(f"{output_path}.png", dpi=1000, bbox_inches='tight', 
                   transparent=True, facecolor='white')
        plt.savefig(f"{output_path}.pdf", dpi=1000, bbox_inches='tight', 
                   transparent=True, facecolor='white')
        
        plt.show()
        plt.close()
        
        print(f"Visualization figure saved as {output_path}.png and {output_path}.pdf")

# Define paths
input_dir = "dataset/original"
output_dir = "dataset"

# Initialize preprocessor
preprocessor = HistopathologyPreprocessor(input_dir, output_dir)

# Process the dataset
print("Starting preprocessing of oral cancer histopathology dataset...")
sample_images = preprocessor.process_dataset()

# Create visualization figure
print("Creating visualization figure...")
preprocessor.create_visualization_figure(
    sample_images, 
    os.path.join(output_dir, "preprocessing_techniques_comparison")
)

print("\nPreprocessing completed successfully!")
print(f"Processed images saved in: {output_dir}")
print(f"Applied {len(preprocessor.techniques)} different preprocessing techniques")

Starting preprocessing of oral cancer histopathology dataset...
Processing 89 images in normal class...


Processing normal: 100%|██████████| 89/89 [12:48<00:00,  8.63s/it]


Processing 439 images in oscc class...


Processing oscc: 100%|██████████| 439/439 [1:09:48<00:00,  9.54s/it]


Creating visualization figure...
