# Optimized MoS2 Flake Detection Pipeline

## Improvements in This Version:
- **Multiple color space analysis** (HSV, LAB, RGB)
- **Adaptive thresholding** for varying lighting conditions
- **Contrast enhancement** specifically for microscopy images
- **Multiple detection methods** with result combination
- **Interactive parameter tuning** for your specific images

Upload your images to `/content/images/` and run the cells below.

In [None]:
# Install required packages
!pip install opencv-python-headless matplotlib numpy scipy scikit-image
!pip install plotly ipywidgets

import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
from scipy import ndimage
from skimage import measure, morphology, filters, segmentation, feature
from skimage.color import rgb2lab, lab2rgb
import os
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create directories
os.makedirs('/content/images', exist_ok=True)
os.makedirs('/content/results', exist_ok=True)
os.makedirs('/content/debug', exist_ok=True)

print("✓ Environment setup complete")

In [None]:
def analyze_image_colors(image_path, sample_points=None):
    """Analyze color characteristics of the image to optimize detection parameters"""
    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Convert to different color spaces
    hsv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2HSV)
    lab = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2LAB)
    
    # Create interactive plot for color analysis
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Original image
    axes[0,0].imshow(img_rgb)
    axes[0,0].set_title('Original Image')
    axes[0,0].axis('off')
    
    # HSV channels
    axes[0,1].imshow(hsv[:,:,0], cmap='hsv')
    axes[0,1].set_title('HSV - Hue')
    axes[0,1].axis('off')
    
    axes[0,2].imshow(hsv[:,:,1], cmap='gray')
    axes[0,2].set_title('HSV - Saturation')
    axes[0,2].axis('off')
    
    axes[1,0].imshow(hsv[:,:,2], cmap='gray')
    axes[1,0].set_title('HSV - Value')
    axes[1,0].axis('off')
    
    # LAB channels
    axes[1,1].imshow(lab[:,:,1], cmap='RdYlGn_r')
    axes[1,1].set_title('LAB - A channel (Green-Red)')
    axes[1,1].axis('off')
    
    axes[1,2].imshow(lab[:,:,2], cmap='YlOrBr_r')
    axes[1,2].set_title('LAB - B channel (Blue-Yellow)')
    axes[1,2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print color statistics
    print("=== COLOR ANALYSIS ===")
    print(f"HSV ranges:")
    print(f"  Hue: {hsv[:,:,0].min()}-{hsv[:,:,0].max()}")
    print(f"  Saturation: {hsv[:,:,1].min()}-{hsv[:,:,1].max()}")
    print(f"  Value: {hsv[:,:,2].min()}-{hsv[:,:,2].max()}")
    print(f"LAB ranges:")
    print(f"  L: {lab[:,:,0].min()}-{lab[:,:,0].max()}")
    print(f"  A: {lab[:,:,1].min()}-{lab[:,:,1].max()}")
    print(f"  B: {lab[:,:,2].min()}-{lab[:,:,2].max()}")
    
    return img_rgb, hsv, lab

# Analyze the first image if available
image_files = list(Path('/content/images').glob('*.png')) + list(Path('/content/images').glob('*.jpg'))
if image_files:
    print(f"Analyzing colors in: {image_files[0].name}")
    img_rgb, hsv, lab = analyze_image_colors(str(image_files[0]))
else:
    print("No images found. Please upload images to /content/images/ first.")

In [None]:
class OptimizedMoS2Analyzer:
    def __init__(self):
        self.debug_mode = True
        self.results = {}
    
    def preprocess_image(self, image_path):
        """Enhanced preprocessing for microscopy images"""
        img = cv2.imread(image_path)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Multiple preprocessing approaches
        processed_images = {}
        
        # 1. CLAHE enhancement
        lab = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
        l_enhanced = clahe.apply(l)
        enhanced_lab = cv2.merge([l_enhanced, a, b])
        processed_images['clahe'] = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2RGB)
        
        # 2. Histogram equalization
        yuv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2YUV)
        yuv[:,:,0] = cv2.equalizeHist(yuv[:,:,0])
        processed_images['hist_eq'] = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB)
        
        # 3. Gamma correction for better contrast
        gamma = 1.2
        processed_images['gamma'] = np.power(img_rgb / 255.0, gamma)
        processed_images['gamma'] = (processed_images['gamma'] * 255).astype(np.uint8)
        
        return img_rgb, processed_images
    
    def detect_flakes_multiple_methods(self, image, enhanced_images):
        """Use multiple detection methods and combine results"""
        all_flakes = []
        debug_masks = {}
        
        # Method 1: HSV color-based detection (improved)
        flakes1, mask1 = self.detect_hsv_method(enhanced_images['clahe'])
        all_flakes.extend([(f, 'hsv') for f in flakes1])
        debug_masks['hsv'] = mask1
        
        # Method 2: LAB color space detection
        flakes2, mask2 = self.detect_lab_method(enhanced_images['clahe'])
        all_flakes.extend([(f, 'lab') for f in flakes2])
        debug_masks['lab'] = mask2
        
        # Method 3: Edge-based detection
        flakes3, mask3 = self.detect_edge_method(enhanced_images['gamma'])
        all_flakes.extend([(f, 'edge') for f in flakes3])
        debug_masks['edge'] = mask3
        
        # Method 4: Adaptive threshold method
        flakes4, mask4 = self.detect_adaptive_method(enhanced_images['hist_eq'])
        all_flakes.extend([(f, 'adaptive') for f in flakes4])
        debug_masks['adaptive'] = mask4
        
        # Combine and filter results
        combined_flakes = self.combine_detections(all_flakes)
        
        return combined_flakes, debug_masks
    
    def detect_hsv_method(self, image):
        """Improved HSV-based detection with multiple color ranges"""
        hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
        
        # Multiple color ranges for different MoS2 appearances
        color_ranges = [
            # Blue range (primary)
            ([100, 30, 30], [140, 255, 255]),
            # Purple range
            ([140, 30, 30], [170, 255, 255]),
            # Dark blue range
            ([90, 50, 20], [120, 255, 200]),
            # Light purple range
            ([120, 20, 50], [160, 150, 255])
        ]
        
        combined_mask = np.zeros(hsv.shape[:2], dtype=np.uint8)
        
        for lower, upper in color_ranges:
            mask = cv2.inRange(hsv, np.array(lower), np.array(upper))
            combined_mask = cv2.bitwise_or(combined_mask, mask)
        
        # Morphological operations
        kernel = np.ones((2,2), np.uint8)
        combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel)
        combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_OPEN, kernel)
        
        return self.extract_triangular_contours(combined_mask), combined_mask
    
    def detect_lab_method(self, image):
        """LAB color space detection focusing on blue/purple regions"""
        lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
        
        # Focus on B channel (blue-yellow axis) for blue/purple detection
        b_channel = lab[:,:,2]
        
        # Create mask for blue regions (low B values indicate blue)
        blue_mask = b_channel < 120  # Adjust this threshold
        
        # Also check A channel for purple regions
        a_channel = lab[:,:,1]
        purple_mask = a_channel > 135  # Higher A values for magenta/purple
        
        # Combine masks
        combined_mask = np.logical_or(blue_mask, purple_mask).astype(np.uint8) * 255
        
        # Morphological operations
        kernel = np.ones((3,3), np.uint8)
        combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel)
        combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_OPEN, kernel)
        
        return self.extract_triangular_contours(combined_mask), combined_mask
    
    def detect_edge_method(self, image):
        """Edge-based detection for flake boundaries"""
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Multiple edge detection approaches
        edges1 = cv2.Canny(gray, 30, 100)
        edges2 = cv2.Canny(gray, 50, 150)
        
        # Combine edge maps
        edges = cv2.bitwise_or(edges1, edges2)
        
        # Dilate to connect nearby edges
        kernel = np.ones((3,3), np.uint8)
        edges = cv2.dilate(edges, kernel, iterations=1)
        
        # Fill closed regions
        contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        filled_mask = np.zeros_like(edges)
        
        for contour in contours:
            area = cv2.contourArea(contour)
            if area > 50:  # Filter small areas
                cv2.fillPoly(filled_mask, [contour], 255)
        
        return self.extract_triangular_contours(filled_mask), filled_mask
    
    def detect_adaptive_method(self, image):
        """Adaptive thresholding method"""
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Apply Gaussian blur
        blurred = cv2.GaussianBlur(gray, (5, 5), 0)
        
        # Adaptive threshold
        adaptive_mask = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                                            cv2.THRESH_BINARY_INV, 11, 2)
        
        # Morphological operations to clean up
        kernel = np.ones((3,3), np.uint8)
        adaptive_mask = cv2.morphologyEx(adaptive_mask, cv2.MORPH_CLOSE, kernel)
        adaptive_mask = cv2.morphologyEx(adaptive_mask, cv2.MORPH_OPEN, kernel)
        
        return self.extract_triangular_contours(adaptive_mask), adaptive_mask
    
    def extract_triangular_contours(self, mask):
        """Extract triangular contours from binary mask"""
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        triangular_flakes = []
        for contour in contours:
            area = cv2.contourArea(contour)
            if area < 50:  # Filter very small areas
                continue
            
            # Calculate contour properties
            perimeter = cv2.arcLength(contour, True)
            if perimeter == 0:
                continue
            
            # Approximate contour
            epsilon = 0.01 * perimeter  # More sensitive approximation
            approx = cv2.approxPolyDP(contour, epsilon, True)
            
            # Check shape characteristics
            aspect_ratio = self.calculate_aspect_ratio(contour)
            solidity = area / cv2.contourArea(cv2.convexHull(contour))
            
            # More flexible shape criteria
            is_triangular = (
                (3 <= len(approx) <= 6) or  # Allow some flexibility in vertex count
                (0.6 < solidity < 1.0 and aspect_ratio < 3.0)  # Alternative criteria
            )
            
            if is_triangular:
                triangular_flakes.append({
                    'contour': contour,
                    'approx': approx,
                    'area': area,
                    'vertices': len(approx),
                    'solidity': solidity,
                    'aspect_ratio': aspect_ratio
                })
        
        return triangular_flakes
    
    def calculate_aspect_ratio(self, contour):
        """Calculate aspect ratio of contour bounding box"""
        x, y, w, h = cv2.boundingRect(contour)
        return max(w, h) / min(w, h) if min(w, h) > 0 else 1
    
    def combine_detections(self, all_flakes):
        """Combine results from multiple detection methods"""
        if not all_flakes:
            return []
        
        # Simple combination - remove duplicates based on centroid distance
        unique_flakes = []
        
        for flake_data, method in all_flakes:
            flake_data['detection_method'] = method
            
            # Calculate centroid
            M = cv2.moments(flake_data['contour'])
            if M['m00'] == 0:
                continue
                
            cx = M['m10'] / M['m00']
            cy = M['m01'] / M['m00']
            
            # Check if similar flake already exists
            is_duplicate = False
            for existing_flake in unique_flakes:
                existing_M = cv2.moments(existing_flake['contour'])
                if existing_M['m00'] == 0:
                    continue
                    
                existing_cx = existing_M['m10'] / existing_M['m00']
                existing_cy = existing_M['m01'] / existing_M['m00']
                
                distance = np.sqrt((cx - existing_cx)**2 + (cy - existing_cy)**2)
                if distance < 30:  # Threshold for considering duplicates
                    # Keep the larger flake
                    if flake_data['area'] > existing_flake['area']:
                        unique_flakes.remove(existing_flake)
                        unique_flakes.append(flake_data)
                    is_duplicate = True
                    break
            
            if not is_duplicate:
                unique_flakes.append(flake_data)
        
        return unique_flakes
    
    def visualize_debug_results(self, image, flakes, debug_masks):
        """Visualize debug information for all detection methods"""
        fig, axes = plt.subplots(3, 3, figsize=(18, 15))
        
        # Original image
        axes[0,0].imshow(image)
        axes[0,0].set_title('Original Image')
        axes[0,0].axis('off')
        
        # Detection masks
        mask_titles = ['HSV Method', 'LAB Method', 'Edge Method', 'Adaptive Method']
        mask_keys = ['hsv', 'lab', 'edge', 'adaptive']
        
        for i, (key, title) in enumerate(zip(mask_keys, mask_titles)):
            row = (i + 1) // 3
            col = (i + 1) % 3
            if key in debug_masks:
                axes[row, col].imshow(debug_masks[key], cmap='gray')
            axes[row, col].set_title(title)
            axes[row, col].axis('off')
        
        # Combined result
        result_img = image.copy()
        colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255)]
        
        for i, flake in enumerate(flakes):
            color = colors[i % len(colors)]
            cv2.drawContours(result_img, [flake['contour']], -1, color, 2)
            
            # Add flake number and method
            M = cv2.moments(flake['contour'])
            if M['m00'] != 0:
                cx = int(M['m10']/M['m00'])
                cy = int(M['m01']/M['m00'])
                cv2.putText(result_img, f"{i+1}({flake['detection_method']})", 
                          (cx-20, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1)
        
        axes[2,2].imshow(result_img)
        axes[2,2].set_title(f'Combined Results ({len(flakes)} flakes)')
        axes[2,2].axis('off')
        
        # Fill empty subplot
        axes[2,1].axis('off')
        
        plt.tight_layout()
        return fig

# Initialize optimized analyzer
analyzer = OptimizedMoS2Analyzer()
print("✓ Optimized MoS2 Analyzer initialized")

In [None]:
# Run optimized detection on all images
image_dir = Path('/content/images')
image_extensions = ['.jpg', '.jpeg', '.png', '.tiff', '.bmp']

image_files = []
for ext in image_extensions:
    image_files.extend(image_dir.glob(f'*{ext}'))
    image_files.extend(image_dir.glob(f'*{ext.upper()}'))

print(f"Found {len(image_files)} images to process with optimized detection\n")

for image_path in image_files:
    print(f"Processing: {image_path.name}")
    
    try:
        # Preprocess with multiple enhancement methods
        original_img, enhanced_images = analyzer.preprocess_image(str(image_path))
        
        # Apply multiple detection methods
        flakes, debug_masks = analyzer.detect_flakes_multiple_methods(original_img, enhanced_images)
        
        print(f"  ✓ Found {len(flakes)} flakes using combined methods")
        
        # Show method breakdown
        method_counts = {}
        for flake in flakes:
            method = flake['detection_method']
            method_counts[method] = method_counts.get(method, 0) + 1
        
        print(f"  Detection breakdown: {method_counts}")
        
        # Visualize debug results
        debug_fig = analyzer.visualize_debug_results(original_img, flakes, debug_masks)
        plt.savefig(f'/content/debug/{image_path.stem}_debug.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # Save detailed flake information
        flake_data = []
        for i, flake in enumerate(flakes):
            flake_info = {
                'id': i + 1,
                'area': float(flake['area']),
                'vertices': int(flake['vertices']),
                'solidity': float(flake['solidity']),
                'aspect_ratio': float(flake['aspect_ratio']),
                'detection_method': flake['detection_method']
            }
            flake_data.append(flake_info)
        
        # Save results
        results = {
            'filename': image_path.name,
            'total_flakes': len(flakes),
            'method_breakdown': method_counts,
            'flake_details': flake_data
        }
        
        with open(f'/content/results/{image_path.stem}_optimized_results.json', 'w') as f:
            json.dump(results, f, indent=2)
        
        print(f"  ✓ Results saved to /content/results/{image_path.stem}_optimized_results.json\n")
        
    except Exception as e:
        print(f"  ❌ Error processing {image_path.name}: {str(e)}\n")
        continue

print("=== OPTIMIZED DETECTION COMPLETE ===")
print("Check the debug visualizations above to see how each method performed.")
print("Results are saved in /content/results/ and debug images in /content/debug/")

In [None]:
# Interactive parameter tuning (run this if results still need improvement)
def test_hsv_parameters(image_path, h_min=90, h_max=170, s_min=20, s_max=255, v_min=20, v_max=255):
    """Test different HSV parameters interactively"""
    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    hsv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2HSV)
    
    # Create mask with custom parameters
    lower = np.array([h_min, s_min, v_min])
    upper = np.array([h_max, s_max, v_max])
    mask = cv2.inRange(hsv, lower, upper)
    
    # Apply morphological operations
    kernel = np.ones((3,3), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
    
    # Find and count contours
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    valid_contours = [c for c in contours if cv2.contourArea(c) > 50]
    
    # Visualize results
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(img_rgb)
    axes[0].set_title('Original')
    axes[0].axis('off')
    
    axes[1].imshow(mask, cmap='gray')
    axes[1].set_title(f'Mask (H:{h_min}-{h_max}, S:{s_min}-{s_max}, V:{v_min}-{v_max})')
    axes[1].axis('off')
    
    result_img = img_rgb.copy()
    cv2.drawContours(result_img, valid_contours, -1, (255, 0, 0), 2)
    axes[2].imshow(result_img)
    axes[2].set_title(f'Detected: {len(valid_contours)} objects')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return len(valid_contours)

# Test with your image (adjust path as needed)
if image_files:
    test_image = str(image_files[0])
    print("Testing different HSV parameters...")
    print("Try adjusting these values to improve detection:")
    
    # Test different parameter combinations
    test_params = [
        (90, 170, 20, 255, 20, 255),  # Wide range
        (100, 140, 30, 255, 30, 200),  # Blue focus
        (140, 170, 30, 255, 30, 200),  # Purple focus
        (90, 160, 40, 200, 40, 180),   # Mid-tone focus
    ]
    
    for i, (h_min, h_max, s_min, s_max, v_min, v_max) in enumerate(test_params):
        print(f"\nTest {i+1}: H({h_min}-{h_max}), S({s_min}-{s_max}), V({v_min}-{v_max})")
        count = test_hsv_parameters(test_image, h_min, h_max, s_min, s_max, v_min, v_max)
        
    print("\n=== TUNING RECOMMENDATIONS ===")
    print("1. If too many false positives: increase minimum saturation/value")
    print("2. If missing flakes: expand hue range or decrease saturation/value minimums")
    print("3. Update the parameters in the detect_hsv_method() function above")
else:
    print("No images available for parameter tuning. Please upload images first.")

## Optimization Guide

### If Detection is Still Not Accurate:

1. **Analyze the debug visualizations above** to see which method works best
2. **Use the interactive parameter tuning** to find better HSV ranges
3. **Modify detection parameters** in the code:
   - Adjust color ranges in `detect_hsv_method()`
   - Change area thresholds in `extract_triangular_contours()`
   - Modify morphological operations kernel sizes

### Key Parameters to Adjust:
- **HSV color ranges**: `([h_min, s_min, v_min], [h_max, s_max, v_max])`
- **Minimum area**: Currently 50, increase to filter noise
- **Epsilon for contour approximation**: Currently 0.01 * perimeter
- **Morphological kernel size**: Currently (2,2) or (3,3)

### Next Steps:
Once Stage 1 detection is optimized, we can proceed to:
- **Stage 2**: Multilayer structure identification
- **Stage 3**: Twist angle calculations
