# AI-Based MoS2 Flake Analysis Pipeline

## Project Overview
This notebook implements a complete pipeline for analyzing twisted bilayer MoS2 flakes from optical microscopy images:

**Stage 1**: Detect and outline MoS2 flakes  
**Stage 2**: Identify multilayer structures  
**Stage 3**: Measure twist angles in bilayer MoS2

## Setup Instructions
1. Upload your microscopy images to `/content/images/`
2. Run all cells sequentially
3. Results will be saved to `/content/results/`

In [None]:
# Install required packages
!pip install ultralytics opencv-python-headless matplotlib numpy scipy scikit-image
!pip install roboflow supervision

import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
from scipy import ndimage
from skimage import measure, morphology, filters
from ultralytics import YOLO
import os

# Create directories
os.makedirs('/content/images', exist_ok=True)
os.makedirs('/content/results', exist_ok=True)
os.makedirs('/content/models', exist_ok=True)

print("✓ Environment setup complete")

In [None]:
class MoS2Analyzer:
    def __init__(self):
        self.results = {}
    
    def preprocess_image(self, image_path):
        """Load and preprocess microscopy image"""
        img = cv2.imread(image_path)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Enhance contrast for better detection
        lab = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        l = clahe.apply(l)
        enhanced = cv2.merge([l, a, b])
        enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2RGB)
        
        return img_rgb, enhanced
    
    def detect_triangular_flakes(self, image):
        """Stage 1: Detect triangular MoS2 flakes using color and shape analysis"""
        # Convert to HSV for better color segmentation
        hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
        
        # Define color range for MoS2 flakes (adjust based on your images)
        # MoS2 typically appears as blue/purple triangles
        lower_blue = np.array([100, 50, 50])
        upper_blue = np.array([130, 255, 255])
        
        mask = cv2.inRange(hsv, lower_blue, upper_blue)
        
        # Morphological operations to clean up mask
        kernel = np.ones((3,3), np.uint8)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
        
        # Find contours
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        triangular_flakes = []
        for contour in contours:
            area = cv2.contourArea(contour)
            if area > 100:  # Filter small noise
                # Approximate contour to check if triangular
                epsilon = 0.02 * cv2.arcLength(contour, True)
                approx = cv2.approxPolyDP(contour, epsilon, True)
                
                # Check if approximately triangular (3-4 vertices)
                if 3 <= len(approx) <= 4:
                    triangular_flakes.append({
                        'contour': contour,
                        'approx': approx,
                        'area': area,
                        'vertices': len(approx)
                    })
        
        return triangular_flakes, mask
    
    def find_multilayer_structures(self, image, flakes):
        """Stage 2: Identify flakes with internal structures (multilayer)"""
        multilayer_flakes = []
        
        for flake in flakes:
            # Create mask for this flake
            mask = np.zeros(image.shape[:2], dtype=np.uint8)
            cv2.fillPoly(mask, [flake['contour']], 255)
            
            # Extract region of interest
            x, y, w, h = cv2.boundingRect(flake['contour'])
            roi = image[y:y+h, x:x+w]
            roi_mask = mask[y:y+h, x:x+w]
            
            # Look for internal structures using edge detection
            gray_roi = cv2.cvtColor(roi, cv2.COLOR_RGB2GRAY)
            edges = cv2.Canny(gray_roi, 50, 150)
            edges = cv2.bitwise_and(edges, roi_mask)
            
            # Find internal contours
            internal_contours, _ = cv2.findContours(edges, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
            
            # Check for significant internal structures
            internal_structures = []
            for contour in internal_contours:
                area = cv2.contourArea(contour)
                if area > flake['area'] * 0.1:  # Internal structure > 10% of main flake
                    # Adjust coordinates back to full image
                    adjusted_contour = contour + [x, y]
                    internal_structures.append(adjusted_contour)
            
            if internal_structures:
                flake['internal_structures'] = internal_structures
                flake['is_multilayer'] = True
                multilayer_flakes.append(flake)
            else:
                flake['is_multilayer'] = False
        
        return multilayer_flakes
    
    def calculate_twist_angles(self, multilayer_flakes):
        """Stage 3: Calculate twist angles for bilayer structures"""
        results = []
        
        for flake in multilayer_flakes:
            if not flake['is_multilayer']:
                continue
            
            # Get main flake orientation
            main_vertices = flake['approx'].reshape(-1, 2)
            main_angle = self.get_triangle_orientation(main_vertices)
            
            # Calculate angles for internal structures
            internal_angles = []
            for internal_contour in flake['internal_structures']:
                # Approximate internal contour
                epsilon = 0.02 * cv2.arcLength(internal_contour, True)
                internal_approx = cv2.approxPolyDP(internal_contour, epsilon, True)
                
                if len(internal_approx) >= 3:
                    internal_vertices = internal_approx.reshape(-1, 2)
                    internal_angle = self.get_triangle_orientation(internal_vertices)
                    internal_angles.append(internal_angle)
            
            # Calculate twist angles
            twist_angles = []
            for internal_angle in internal_angles:
                twist_angle = abs(main_angle - internal_angle)
                # Normalize to 0-60 degrees (considering 3-fold symmetry)
                twist_angle = min(twist_angle, 180 - twist_angle, 120 - twist_angle if twist_angle > 60 else twist_angle)
                twist_angles.append(twist_angle)
            
            results.append({
                'flake_id': len(results) + 1,
                'main_angle': main_angle,
                'internal_angles': internal_angles,
                'twist_angles': twist_angles,
                'contour': flake['contour'],
                'area': flake['area']
            })
        
        return results
    
    def get_triangle_orientation(self, vertices):
        """Calculate orientation angle of triangle"""
        if len(vertices) < 3:
            return 0
        
        # Find the centroid
        centroid = np.mean(vertices, axis=0)
        
        # Find the vertex furthest from centroid (apex)
        distances = np.linalg.norm(vertices - centroid, axis=1)
        apex_idx = np.argmax(distances)
        apex = vertices[apex_idx]
        
        # Calculate angle from centroid to apex
        angle = np.arctan2(apex[1] - centroid[1], apex[0] - centroid[0])
        return np.degrees(angle)
    
    def visualize_results(self, image, flakes, twist_results):
        """Create visualization of all analysis stages"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Original image
        axes[0,0].imshow(image)
        axes[0,0].set_title('Original Image')
        axes[0,0].axis('off')
        
        # Stage 1: All detected flakes
        img_stage1 = image.copy()
        for i, flake in enumerate(flakes):
            cv2.drawContours(img_stage1, [flake['contour']], -1, (255, 0, 0), 2)
            # Add flake number
            M = cv2.moments(flake['contour'])
            if M['m00'] != 0:
                cx = int(M['m10']/M['m00'])
                cy = int(M['m01']/M['m00'])
                cv2.putText(img_stage1, str(i+1), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
        
        axes[0,1].imshow(img_stage1)
        axes[0,1].set_title(f'Stage 1: Detected Flakes ({len(flakes)})')
        axes[0,1].axis('off')
        
        # Stage 2: Multilayer structures
        img_stage2 = image.copy()
        multilayer_count = 0
        for flake in flakes:
            if flake['is_multilayer']:
                multilayer_count += 1
                cv2.drawContours(img_stage2, [flake['contour']], -1, (0, 255, 0), 2)
                for internal in flake['internal_structures']:
                    cv2.drawContours(img_stage2, [internal], -1, (0, 255, 255), 1)
        
        axes[1,0].imshow(img_stage2)
        axes[1,0].set_title(f'Stage 2: Multilayer Structures ({multilayer_count})')
        axes[1,0].axis('off')
        
        # Stage 3: Twist angles
        img_stage3 = image.copy()
        for result in twist_results:
            cv2.drawContours(img_stage3, [result['contour']], -1, (255, 165, 0), 2)
            
            # Add angle annotations
            M = cv2.moments(result['contour'])
            if M['m00'] != 0:
                cx = int(M['m10']/M['m00'])
                cy = int(M['m01']/M['m00'])
                
                for i, angle in enumerate(result['twist_angles']):
                    text = f"{angle:.1f}°"
                    cv2.putText(img_stage3, text, (cx, cy + i*20), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 0), 1)
        
        axes[1,1].imshow(img_stage3)
        axes[1,1].set_title(f'Stage 3: Twist Angles ({len(twist_results)} bilayers)')
        axes[1,1].axis('off')
        
        plt.tight_layout()
        return fig
    
    def save_results(self, filename, flakes, twist_results):
        """Save analysis results to JSON file"""
        results_data = {
            'total_flakes': len(flakes),
            'multilayer_flakes': len([f for f in flakes if f['is_multilayer']]),
            'bilayer_structures': len(twist_results),
            'flake_data': [],
            'twist_angle_data': []
        }
        
        # Save flake information
        for i, flake in enumerate(flakes):
            flake_info = {
                'id': i + 1,
                'area': float(flake['area']),
                'vertices': int(flake['vertices']),
                'is_multilayer': bool(flake['is_multilayer'])
            }
            results_data['flake_data'].append(flake_info)
        
        # Save twist angle information
        for result in twist_results:
            twist_info = {
                'flake_id': int(result['flake_id']),
                'main_angle': float(result['main_angle']),
                'twist_angles': [float(angle) for angle in result['twist_angles']],
                'area': float(result['area'])
            }
            results_data['twist_angle_data'].append(twist_info)
        
        with open(f'/content/results/{filename}_results.json', 'w') as f:
            json.dump(results_data, f, indent=2)
        
        return results_data

# Initialize analyzer
analyzer = MoS2Analyzer()
print("✓ MoS2 Analyzer initialized")

In [None]:
# Process all images in the images directory
image_dir = Path('/content/images')
results_summary = []

# Supported image formats
image_extensions = ['.jpg', '.jpeg', '.png', '.tiff', '.bmp']

image_files = []
for ext in image_extensions:
    image_files.extend(image_dir.glob(f'*{ext}'))
    image_files.extend(image_dir.glob(f'*{ext.upper()}'))

print(f"Found {len(image_files)} images to process")

for image_path in image_files:
    print(f"\nProcessing: {image_path.name}")
    
    try:
        # Load and preprocess image
        original_img, enhanced_img = analyzer.preprocess_image(str(image_path))
        
        # Stage 1: Detect triangular flakes
        flakes, detection_mask = analyzer.detect_triangular_flakes(enhanced_img)
        print(f"  Stage 1: Found {len(flakes)} potential flakes")
        
        # Stage 2: Find multilayer structures
        multilayer_flakes = analyzer.find_multilayer_structures(enhanced_img, flakes)
        print(f"  Stage 2: Found {len(multilayer_flakes)} multilayer structures")
        
        # Stage 3: Calculate twist angles
        twist_results = analyzer.calculate_twist_angles(flakes)
        print(f"  Stage 3: Calculated angles for {len(twist_results)} bilayer structures")
        
        # Visualize results
        fig = analyzer.visualize_results(original_img, flakes, twist_results)
        plt.savefig(f'/content/results/{image_path.stem}_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # Save detailed results
        results_data = analyzer.save_results(image_path.stem, flakes, twist_results)
        results_summary.append({
            'filename': image_path.name,
            'total_flakes': results_data['total_flakes'],
            'multilayer_flakes': results_data['multilayer_flakes'],
            'bilayer_structures': results_data['bilayer_structures']
        })
        
        # Print twist angles found
        if twist_results:
            print("  Twist angles found:")
            for result in twist_results:
                for angle in result['twist_angles']:
                    print(f"    Flake {result['flake_id']}: {angle:.1f}°")
        
    except Exception as e:
        print(f"  Error processing {image_path.name}: {str(e)}")
        continue

# Save overall summary
with open('/content/results/analysis_summary.json', 'w') as f:
    json.dump(results_summary, f, indent=2)

print(f"\n✓ Analysis complete! Processed {len(results_summary)} images successfully.")
print("Results saved to /content/results/")

In [None]:
# Generate summary statistics and plots
import pandas as pd

if results_summary:
    # Create summary DataFrame
    df_summary = pd.DataFrame(results_summary)
    
    print("=== ANALYSIS SUMMARY ===")
    print(f"Total images processed: {len(df_summary)}")
    print(f"Total flakes detected: {df_summary['total_flakes'].sum()}")
    print(f"Total multilayer structures: {df_summary['multilayer_flakes'].sum()}")
    print(f"Total bilayer structures with angles: {df_summary['bilayer_structures'].sum()}")
    print(f"Average flakes per image: {df_summary['total_flakes'].mean():.1f}")
    
    # Display summary table
    print("\n=== PER-IMAGE RESULTS ===")
    print(df_summary.to_string(index=False))
    
    # Collect all twist angles for distribution analysis
    all_angles = []
    for result_file in Path('/content/results').glob('*_results.json'):
        with open(result_file, 'r') as f:
            data = json.load(f)
            for twist_data in data['twist_angle_data']:
                all_angles.extend(twist_data['twist_angles'])
    
    if all_angles:
        # Plot twist angle distribution
        plt.figure(figsize=(10, 6))
        plt.hist(all_angles, bins=20, edgecolor='black', alpha=0.7)
        plt.xlabel('Twist Angle (degrees)')
        plt.ylabel('Frequency')
        plt.title(f'Distribution of Twist Angles (n={len(all_angles)})')
        plt.grid(True, alpha=0.3)
        
        # Add statistics
        plt.axvline(np.mean(all_angles), color='red', linestyle='--', label=f'Mean: {np.mean(all_angles):.1f}°')
        plt.axvline(np.median(all_angles), color='green', linestyle='--', label=f'Median: {np.median(all_angles):.1f}°')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig('/content/results/twist_angle_distribution.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        print(f"\n=== TWIST ANGLE STATISTICS ===")
        print(f"Total measurements: {len(all_angles)}")
        print(f"Mean angle: {np.mean(all_angles):.2f}°")
        print(f"Median angle: {np.median(all_angles):.2f}°")
        print(f"Standard deviation: {np.std(all_angles):.2f}°")
        print(f"Range: {np.min(all_angles):.1f}° - {np.max(all_angles):.1f}°")
        
        # Save angle data to CSV
        angle_df = pd.DataFrame({'twist_angle': all_angles})
        angle_df.to_csv('/content/results/all_twist_angles.csv', index=False)
        
    else:
        print("\nNo twist angles found in the processed images.")
        print("This might indicate:")
        print("1. No multilayer structures were detected")
        print("2. Color thresholds need adjustment")
        print("3. Images need preprocessing")

else:
    print("No images were successfully processed.")
    print("Please check:")
    print("1. Images are uploaded to /content/images/")
    print("2. Images are in supported formats (jpg, png, tiff, bmp)")
    print("3. Image quality and contrast are sufficient")

## Usage Instructions

### 1. Upload Your Images
- Click the folder icon in the left sidebar
- Navigate to `/content/images/`
- Upload your microscopy images (jpg, png, tiff, bmp formats supported)

### 2. Run the Analysis
- Execute all cells from top to bottom
- The analysis will process all images automatically

### 3. Results Location
- **Visualizations**: `/content/results/[filename]_analysis.png`
- **Detailed data**: `/content/results/[filename]_results.json`
- **Summary**: `/content/results/analysis_summary.json`
- **Angle data**: `/content/results/all_twist_angles.csv`

### 4. Customization
If results are not accurate, you can adjust:
- Color thresholds in `detect_triangular_flakes()` method
- Size filters (minimum area requirements)
- Shape detection parameters (epsilon values)

### 5. Download Results
- Right-click on result files and select "Download"
- All results are automatically saved to `/content/results/`

In [None]:
# Parameter tuning section - adjust these values if detection is not accurate

print("=== PARAMETER TUNING GUIDE ===")
print("If the detection is not working well, try adjusting these parameters:\n")

print("1. COLOR DETECTION (in detect_triangular_flakes method):")
print("   Current: lower_blue = [100, 50, 50], upper_blue = [130, 255, 255]")
print("   - Increase lower bound if detecting too much background")
print("   - Decrease lower bound if missing flakes")
print("   - Adjust hue range (first values) for different colors\n")

print("2. SIZE FILTERING:")
print("   Current: minimum area = 100 pixels")
print("   - Increase to filter out smaller objects")
print("   - Decrease to detect smaller flakes\n")

print("3. SHAPE DETECTION:")
print("   Current: epsilon = 0.02 * perimeter")
print("   - Decrease for more precise shape detection")
print("   - Increase for more flexible shape matching\n")

print("4. MULTILAYER DETECTION:")
print("   Current: internal structure > 10% of main flake")
print("   - Decrease percentage to detect smaller internal features")
print("   - Increase to focus on larger internal structures\n")

print("To modify parameters, edit the corresponding values in the MoS2Analyzer class above and re-run the analysis.")