In [1]:
!pip install rasterio

Collecting rasterio
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio)
  Downloading click_plugins-1.1.1-py2.py3-none-any.whl.metadata (6.4 kB)
Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m55.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Downloading click_plugins-1.1.1-py2.py3-none-any.whl (7.5 kB)
Installing collected packages: cligj, click-plugins, affine, rasterio
Successfully installed affine-2.4.0 click-plugins-1.1.1 cligj-0.7.2 rasterio-1.4.3


In [2]:
import os
import sys
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.applications import VGG16, ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import time
import logging
from datetime import datetime
from tifffile import imread  # Added for TIFF file handling

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('palm_tree_analysis.log')
    ]
)
logger = logging.getLogger(__name__)

class PalmTreeAnalysis:
    def __init__(self):
        self.tree_model = None
        self.health_model = None
        logger.info("PalmTreeAnalysis initialized")

    def create_synthetic_dataset(self, source_image_path):
        """Create a synthetic dataset from a single image using data augmentation"""
        logger.info(f"Creating synthetic dataset from source image: {source_image_path}")

        if not os.path.isfile(source_image_path):
            logger.error(f"Source image not found: {source_image_path}")
            raise FileNotFoundError(f"Could not find image file: {source_image_path}")

        img = cv2.imread(source_image_path)
        if img is None:
            logger.error(f"Failed to load image: {source_image_path}")
            raise ValueError(f"OpenCV could not read the image file: {source_image_path}")

        logger.info(f"Successfully loaded image with shape: {img.shape}")

        tree_samples = self.extract_tree_samples(img)

        if len(tree_samples) == 0:
            logger.warning(f"No tree samples extracted. Using fallback approach.")
            tree_samples = self.create_image_patches(img)

        logger.info(f"Created {len(tree_samples)} tree samples")

        if len(tree_samples) == 0:
            logger.error("Failed to create tree samples")
            raise ValueError("Could not create training samples")

        synthetic_data = []
        synthetic_labels = []

        palm_samples = tree_samples[:len(tree_samples)//2]
        coconut_samples = tree_samples[len(tree_samples)//2:]

        logger.info(f"Creating augmented data for {len(palm_samples)} palm samples")
        for sample in palm_samples:
            try:
                augmented = self.augment_image(sample, num_variations=10)
                synthetic_data.extend(augmented)
                synthetic_labels.extend([0] * len(augmented))
            except Exception as e:
                logger.warning(f"Error during augmentation: {e}")
                continue

        logger.info(f"Creating augmented data for {len(coconut_samples)} coconut samples")
        for sample in coconut_samples:
            try:
                augmented = self.augment_image(sample, num_variations=10)
                synthetic_data.extend(augmented)
                synthetic_labels.extend([1] * len(augmented))
            except Exception as e:
                logger.warning(f"Error during augmentation: {e}")
                continue

        if len(synthetic_data) == 0:
            logger.error("Failed to create synthetic data")
            raise ValueError("Augmentation failed to generate data")

        synthetic_data_array = np.array(synthetic_data)
        synthetic_labels_array = np.array(synthetic_labels)

        logger.info(f"Created synthetic dataset with {len(synthetic_data_array)} samples")
        synthetic_data_array = synthetic_data_array / 255.0

        return synthetic_data_array, synthetic_labels_array

    def create_image_patches(self, image, patch_size=128, stride=64):
        """Create patches from the image as a fallback method"""
        logger.info("Creating image patches as fallback")
        patches = []
        h, w = image.shape[:2]

        for y in range(0, h-patch_size, stride):
            for x in range(0, w-patch_size, stride):
                patch = image[y:y+patch_size, x:x+patch_size]
                hsv = cv2.cvtColor(patch, cv2.COLOR_BGR2HSV)
                lower_green = np.array([30, 40, 40])
                upper_green = np.array([90, 255, 255])
                mask = cv2.inRange(hsv, lower_green, upper_green)

                if np.sum(mask) > (patch_size * patch_size * 0.1):
                    patch_resized = cv2.resize(patch, (128, 128))
                    patches.append(patch_resized)

        logger.info(f"Created {len(patches)} patches from image")
        return patches

    def extract_tree_samples(self, image):
        """Extract individual tree samples using image processing"""
        logger.info("Extracting tree samples from image")

        if image is None or image.size == 0:
            logger.error("Cannot extract tree samples: Input image is empty")
            return []

        try:
            hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
            lower_green = np.array([30, 40, 40])
            upper_green = np.array([90, 255, 255])
            mask = cv2.inRange(hsv, lower_green, upper_green)

            kernel = np.ones((5,5), np.uint8)
            mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
            mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)

            contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

            min_tree_area = 100
            max_tree_area = 5000
            tree_samples = []

            for contour in contours:
                area = cv2.contourArea(contour)
                if min_tree_area < area < max_tree_area:
                    x, y, w, h = cv2.boundingRect(contour)
                    padding = 10
                    x_start = max(0, x - padding)
                    y_start = max(0, y - padding)
                    x_end = min(image.shape[1], x + w + padding)
                    y_end = min(image.shape[0], y + h + padding)

                    tree_sample = image[y_start:y_end, x_start:x_end]
                    if tree_sample.shape[0] > 0 and tree_sample.shape[1] > 0:
                        tree_sample = cv2.resize(tree_sample, (128, 128))
                        tree_samples.append(tree_sample)

            logger.info(f"Extracted {len(tree_samples)} tree samples")
            return tree_samples

        except Exception as e:
            logger.error(f"Error in tree sample extraction: {str(e)}")
            return []

    def augment_image(self, image, num_variations=10):
        """Apply data augmentation to create variations of an image"""
        augmented_images = []

        if image is None or image.size == 0:
            logger.warning("Cannot augment empty image")
            return []

        for i in range(num_variations):
            try:
                img = image.copy()
                angle = np.random.uniform(-30, 30)
                h, w = img.shape[:2]
                M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1)
                img = cv2.warpAffine(img, M, (w, h))

                alpha = np.random.uniform(0.8, 1.2)
                beta = np.random.uniform(-10, 10)
                img = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)

                if np.random.random() > 0.5:
                    img = cv2.flip(img, 1)

                zoom = np.random.uniform(0.8, 1.2)
                h, w = img.shape[:2]
                img = cv2.resize(img, None, fx=zoom, fy=zoom)

                if zoom > 1:
                    h_new, w_new = img.shape[:2]
                    start_h = (h_new - h) // 2
                    start_w = (w_new - w) // 2
                    if start_h < 0 or start_w < 0 or start_h + h > h_new or start_w + w > w_new:
                        img = cv2.resize(img, (w, h))
                    else:
                        img = img[start_h:start_h+h, start_w:start_w+w]
                else:
                    h_new, w_new = img.shape[:2]
                    pad_h = (h - h_new) // 2
                    pad_w = (w - w_new) // 2
                    if pad_h < 0 or pad_w < 0:
                        img = cv2.resize(img, (w, h))
                    else:
                        img = cv2.copyMakeBorder(img, pad_h, h-h_new-pad_h, pad_w, w-w_new-pad_w, cv2.BORDER_CONSTANT)

                img = cv2.resize(img, (128, 128))
                augmented_images.append(img)

            except Exception as e:
                logger.warning(f"Error during augmentation: {str(e)}")
                continue

        return augmented_images

    def build_tree_classification_model(self):
        """Build a transfer learning model for tree type classification"""
        logger.info("Building tree classification model")
        try:
            base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(128, 128, 3))
            for layer in base_model.layers:
                layer.trainable = False

            x = base_model.output
            x = GlobalAveragePooling2D()(x)
            x = Dense(256, activation='relu')(x)
            x = Dropout(0.5)(x)
            predictions = Dense(2, activation='softmax')(x)

            self.tree_model = Model(inputs=base_model.input, outputs=predictions)
            self.tree_model.compile(
                optimizer=Adam(learning_rate=0.0001),
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy']
            )

            logger.info("Tree classification model built successfully")
            return self.tree_model

        except Exception as e:
            logger.error(f"Error building tree classification model: {str(e)}")
            raise

    def build_health_analysis_model(self):
        """Build a model for tree health classification"""
        logger.info("Building health analysis model")
        try:
            base_model = VGG16(weights='imagenet', include_top=False, input_shape=(128, 128, 3))
            for layer in base_model.layers:
                layer.trainable = False

            x = base_model.output
            x = GlobalAveragePooling2D()(x)
            x = Dense(256, activation='relu')(x)
            x = Dropout(0.5)(x)
            predictions = Dense(4, activation='softmax')(x)

            self.health_model = Model(inputs=base_model.input, outputs=predictions)
            self.health_model.compile(
                optimizer=Adam(learning_rate=0.0001),
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy']
            )

            logger.info("Health analysis model built successfully")
            return self.health_model

        except Exception as e:
            logger.error(f"Error building health analysis model: {str(e)}")
            raise

    def train_models(self, source_image_path, manual_labels_path=None):
        """Train both tree classification and health models"""
        logger.info(f"Training models using source image: {source_image_path}")

        try:
            X, y_tree_type = self.create_synthetic_dataset(source_image_path)
            X_train, X_val, y_train_tree, y_val_tree = train_test_split(X, y_tree_type, test_size=0.2, random_state=42)

            self.build_tree_classification_model()
            early_stopping = tf.keras.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=5,
                restore_best_weights=True
            )

            tree_history = self.tree_model.fit(
                X_train, y_train_tree,
                validation_data=(X_val, y_val_tree),
                epochs=15,
                batch_size=32,
                callbacks=[early_stopping],
                verbose=1
            )

            logger.info(f"Tree model training complete. Final accuracy: {tree_history.history['accuracy'][-1]:.4f}")

            health_features = self.extract_health_features(X)
            y_health = self.generate_synthetic_health_labels(health_features)
            _, _, y_train_health, y_val_health = train_test_split(X, y_health, test_size=0.2, random_state=42)

            self.build_health_analysis_model()
            health_history = self.health_model.fit(
                X_train, y_train_health,
                validation_data=(X_val, y_val_health),
                epochs=15,
                batch_size=32,
                callbacks=[early_stopping],
                verbose=1
            )

            logger.info(f"Health model training complete. Final accuracy: {health_history.history['accuracy'][-1]:.4f}")
            return tree_history, health_history

        except Exception as e:
            logger.error(f"Error during model training: {str(e)}")
            raise

    def extract_health_features(self, images):
        """Extract features relevant to tree health"""
        logger.info("Extracting health features from images")
        health_features = []

        for img in images:
            try:
                if img.dtype != np.uint8:
                    img_uint8 = (img * 255).astype(np.uint8)
                else:
                    img_uint8 = img

                hsv = cv2.cvtColor(img_uint8, cv2.COLOR_BGR2HSV)
                avg_h = np.mean(hsv[:,:,0])
                avg_s = np.mean(hsv[:,:,1])
                avg_v = np.mean(hsv[:,:,2])
                green_channel = img_uint8[:,:,1]
                avg_green = np.mean(green_channel)
                gray = cv2.cvtColor(img_uint8, cv2.COLOR_BGR2GRAY)
                texture = cv2.Laplacian(gray, cv2.CV_64F).var()
                features = [avg_h, avg_s, avg_v, avg_green, texture]
                health_features.append(features)

            except Exception as e:
                logger.warning(f"Error extracting health features: {str(e)}")
                health_features.append([0, 0, 0, 0, 0])

        return np.array(health_features)

    def generate_synthetic_health_labels(self, features):
        """Generate synthetic health labels based on features"""
        logger.info("Generating synthetic health labels")

        if features.size == 0:
            logger.warning("Empty features array")
            return np.array([])

        features = np.nan_to_num(features)
        features_min = features.min(axis=0)
        features_max = features.max(axis=0)
        range_values = features_max - features_min
        range_values[range_values == 0] = 1
        normalized_features = (features - features_min) / range_values

        green_idx = 3
        value_idx = 2
        health_labels = []

        for sample in normalized_features:
            if sample[green_idx] > 0.75 and sample[value_idx] > 0.6:
                label = 3  # Healthy
            elif sample[green_idx] > 0.5 and sample[value_idx] > 0.5:
                label = 2  # Moderate
            elif sample[green_idx] > 0.3 and sample[value_idx] > 0.3:
                label = 1  # Declining
            else:
                label = 0  # Needs inspection
            health_labels.append(label)

        logger.info(f"Generated {len(health_labels)} health labels")
        return np.array(health_labels)

    def process_and_visualize_tif(self, tif_path, output_jpg_path):
        """Process an entire TIFF file and generate a JPG output with detections"""
        logger.info(f"Processing TIFF file: {tif_path}")

        try:
            if not os.path.exists(tif_path):
                logger.error(f"TIFF file not found: {tif_path}")
                raise FileNotFoundError(f"TIFF file not found: {tif_path}")

            full_image = imread(tif_path)
            if full_image is None:
                logger.error(f"Failed to load TIFF image: {tif_path}")
                raise ValueError(f"Could not read TIFF image: {tif_path}")

            if len(full_image.shape) == 3 and full_image.shape[2] == 3:
                full_image = cv2.cvtColor(full_image, cv2.COLOR_RGB2BGR)
            elif len(full_image.shape) == 2:
                full_image = cv2.cvtColor(full_image, cv2.COLOR_GRAY2BGR)

            logger.info(f"Loaded TIFF image with shape: {full_image.shape}")
            results_df = self.process_drone_image_from_array(full_image)
            self.generate_detection_visualization(full_image, results_df, output_jpg_path)

            logger.info(f"Detection complete. Output saved to: {output_jpg_path}")
            return results_df

        except Exception as e:
            logger.error(f"Error processing TIFF file: {str(e)}")
            raise

    def process_drone_image_from_array(self, image_array):
        """Process a full drone image array and classify all trees"""
        logger.info(f"Processing image array with shape: {image_array.shape}")

        if image_array is None or image_array.size == 0:
            logger.error("Invalid image array")
            raise ValueError("Invalid image array provided")

        tree_locations, tree_images = self.segment_trees_sliding_window(image_array)

        if len(tree_locations) == 0:
            logger.warning("No trees detected. Using fallback method.")
            tree_locations, tree_images = self.create_grid_samples(image_array)

        logger.info(f"Detected {len(tree_locations)} potential trees")
        results = []
        batch_size = 32

        for i in range(0, len(tree_locations), batch_size):
            batch_locations = tree_locations[i:i+batch_size]
            batch_images = tree_images[i:i+batch_size]

            preprocessed_batch = []
            for img in batch_images:
                try:
                    img_resized = cv2.resize(img, (128, 128))
                    img_normalized = img_resized / 255.0
                    preprocessed_batch.append(img_normalized)
                except Exception as e:
                    logger.warning(f"Error preprocessing image {i}: {str(e)}")
                    preprocessed_batch.append(np.zeros((128, 128, 3)))

            input_batch = np.array(preprocessed_batch)

            if len(input_batch) == 0:
                continue

            try:
                tree_preds = self.tree_model.predict(input_batch)
                health_preds = self.health_model.predict(input_batch)

                for j, (location, img, tree_pred, health_pred) in enumerate(
                    zip(batch_locations, batch_images, tree_preds, health_preds)
                ):
                    tree_type = "Palm" if np.argmax(tree_pred) == 0 else "Coconut"
                    tree_confidence = float(np.max(tree_pred))
                    health_idx = np.argmax(health_pred)
                    health_categories = ["Needs Inspection", "Declining Health", "Moderate", "Healthy"]
                    health_status = health_categories[health_idx]
                    health_confidence = float(np.max(health_pred))
                    x, y = location
                    results.append({
                        "Tree_ID": i+j+1,
                        "Type": tree_type,
                        "Type_Confidence": tree_confidence,
                        "Health": health_status,
                        "Health_Confidence": health_confidence,
                        "X_coordinate": int(x),
                        "Y_coordinate": int(y)
                    })

            except Exception as e:
                logger.error(f"Error during batch prediction: {str(e)}")
                continue

        results_df = pd.DataFrame(results) if results else pd.DataFrame([{
            "Tree_ID": 1, "Type": "Palm", "Type_Confidence": 0.8,
            "Health": "Healthy", "Health_Confidence": 0.7,
            "X_coordinate": 100, "Y_coordinate": 100
        }])

        logger.info(f"Created results dataframe with {len(results_df)} trees")
        return results_df

    def generate_detection_visualization(self, image, results_df, output_path):
        """Generate and save a visualization of detection results as JPG"""
        logger.info(f"Generating visualization for output: {output_path}")

        try:
            fig, ax = plt.subplots(figsize=(16, 12))
            img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            ax.imshow(img_rgb)

            health_colors = {
                "Healthy": "green",
                "Moderate": "yellowgreen",
                "Declining Health": "orange",
                "Needs Inspection": "red"
            }

            for _, row in results_df.iterrows():
                x, y = row["X_coordinate"], row["Y_coordinate"]
                tree_type = row["Type"]
                health_status = row["Health"]

                circle = plt.Circle((x, y), 20, color=health_colors[health_status],
                                  fill=False, linewidth=2)
                ax.add_patch(circle)

                label = f"{tree_type[0]}\n{health_status[0]}"
                ax.text(x, y, label, color='white', fontsize=8,
                       ha='center', va='center', bbox=dict(facecolor=health_colors[health_status],
                                                         alpha=0.5))

            ax.set_title("Tree Detection Results")
            ax.axis('off')
            plt.savefig(output_path, format='jpg', dpi=300, bbox_inches='tight')
            plt.close()

            logger.info(f"Visualization saved to {output_path}")

        except Exception as e:
            logger.error(f"Error generating visualization: {str(e)}")
            blank_img = np.zeros((1000, 1000, 3), dtype=np.uint8)
            cv2.imwrite(output_path, blank_img)

    def create_grid_samples(self, image, grid_size=50):
        """Create a grid of sample points as a fallback method"""
        logger.info("Creating grid samples as fallback")
        h, w = image.shape[:2]
        tree_locations = []
        tree_images = []

        for y in range(grid_size, h-grid_size, grid_size):
            for x in range(grid_size, w-grid_size, grid_size):
                patch = image[y-grid_size//2:y+grid_size//2, x-grid_size//2:x+grid_size//2]
                hsv = cv2.cvtColor(patch, cv2.COLOR_BGR2HSV)
                lower_green = np.array([30, 40, 40])
                upper_green = np.array([90, 255, 255])
                mask = cv2.inRange(hsv, lower_green, upper_green)

                if np.sum(mask) > (grid_size * grid_size * 0.1):
                    tree_locations.append((x, y))
                    tree_images.append(patch)

        logger.info(f"Created {len(tree_locations)} grid samples")
        return tree_locations, tree_images

    def segment_trees_sliding_window(self, image, window_size=128, stride=64):
        """Segment trees using a sliding window approach"""
        logger.info("Segmenting trees using sliding window approach")

        if image is None or image.shape[0] == 0 or image.shape[1] == 0:
            logger.error("Cannot segment trees: Invalid image")
            return [], []

        h, w = image.shape[:2]
        tree_locations = []
        tree_images = []

        try:
            hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
            lower_green = np.array([30, 40, 40])
            upper_green = np.array([90, 255, 255])
            veg_mask = cv2.inRange(hsv, lower_green, upper_green)

            for y in range(0, h-window_size, stride):
                for x in range(0, w-window_size, stride):
                    window = image[y:y+window_size, x:x+window_size]
                    mask_window = veg_mask[y:y+window_size, x:x+window_size]

                    veg_ratio = np.sum(mask_window > 0) / (window_size * window_size)

                    if veg_ratio > 0.3:
                        mask_copy = mask_window.copy()
                        contours, _ = cv2.findContours(mask_copy, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

                        for contour in contours:
                            area = cv2.contourArea(contour)
                            if area > 100:
                                perimeter = cv2.arcLength(contour, True)
                                circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0

                                if circularity > 0.4:
                                    M = cv2.moments(contour)
                                    if M["m00"] > 0:
                                        c_x = int(M["m10"] / M["m00"]) + x
                                        c_y = int(M["m01"] / M["m00"]) + y

                                        tree_x1 = max(0, c_x - window_size//2)
                                        tree_y1 = max(0, c_y - window_size//2)
                                        tree_x2 = min(w, c_x + window_size//2)
                                        tree_y2 = min(h, c_y + window_size//2)

                                        tree_img = image[tree_y1:tree_y2, tree_x1:tree_x2]
                                        if tree_img.shape[0] > 0 and tree_img.shape[1] > 0:
                                            tree_img_resized = cv2.resize(tree_img, (128, 128))
                                            tree_images.append(tree_img_resized)
                                            tree_locations.append((c_x, c_y))

            logger.info(f"Detected {len(tree_locations)} potential trees")
            return tree_locations, tree_images

        except Exception as e:
            logger.error(f"Error in tree segmentation: {str(e)}")
            return [], []

if __name__ == "__main__":
    analyzer = PalmTreeAnalysis()

    # Training
    source_image = "/content/4562d4b9-3ebd-4d73-a6e3-9a713a9fa608.tif"
    analyzer.train_models(source_image)

    # Testing with TIFF file
    tif_file = "/content/4562d4b9-3ebd-4d73-a6e3-9a713a9fa608.tif"
    output_jpg = "path/to/output_detections.jpg"
    results = analyzer.process_and_visualize_tif(tif_file, output_jpg)
    print(results)

ERROR:__main__:Failed to load image: /content/4562d4b9-3ebd-4d73-a6e3-9a713a9fa608.tif
ERROR:__main__:Error during model training: OpenCV could not read the image file: /content/4562d4b9-3ebd-4d73-a6e3-9a713a9fa608.tif


ValueError: OpenCV could not read the image file: /content/4562d4b9-3ebd-4d73-a6e3-9a713a9fa608.tif

In [2]:
import os
import sys
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.applications import VGG16, ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import time
import logging
from datetime import datetime
from tifffile import imread  # Added for TIFF file handling

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('palm_tree_analysis.log')
    ]
)
logger = logging.getLogger(__name__)

class PalmTreeAnalysis:
    def __init__(self):
        self.tree_model = None
        self.health_model = None
        logger.info("PalmTreeAnalysis initialized")

    def create_synthetic_dataset(self, source_image_path):
        """Create a synthetic dataset from a single image using data augmentation"""
        logger.info(f"Creating synthetic dataset from source image: {source_image_path}")
        if not os.path.isfile(source_image_path):
            logger.error(f"Source image not found: {source_image_path}")
            raise FileNotFoundError(f"Could not find image file: {source_image_path}")

        # Load the TIFF file using tifffile
        img = imread(source_image_path)
        if img is None:
            logger.error(f"Failed to load image: {source_image_path}")
            raise ValueError(f"Could not read the image file: {source_image_path}")

        # Convert to a compatible format (e.g., 8-bit RGB or grayscale)
        if len(img.shape) == 3 and img.shape[2] == 3:  # RGB image
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)  # Convert to BGR for OpenCV compatibility
        elif len(img.shape) == 2:  # Grayscale image
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)  # Convert to 3-channel grayscale
        else:
            logger.error("Unsupported image format. Expected 2D grayscale or 3D RGB image.")
            raise ValueError("Unsupported image format.")

        logger.info(f"Successfully loaded image with shape: {img.shape}")

        # Proceed with tree sample extraction and augmentation
        tree_samples = self.extract_tree_samples(img)
        if len(tree_samples) == 0:
            logger.warning(f"No tree samples extracted. Using fallback approach.")
            tree_samples = self.create_image_patches(img)
        logger.info(f"Created {len(tree_samples)} tree samples")

        if len(tree_samples) == 0:
            logger.error("Failed to create tree samples")
            raise ValueError("Could not create training samples")

        synthetic_data = []
        synthetic_labels = []
        palm_samples = tree_samples[:len(tree_samples)//2]
        coconut_samples = tree_samples[len(tree_samples)//2:]

        logger.info(f"Creating augmented data for {len(palm_samples)} palm samples")
        for sample in palm_samples:
            try:
                augmented = self.augment_image(sample, num_variations=10)
                synthetic_data.extend(augmented)
                synthetic_labels.extend([0] * len(augmented))
            except Exception as e:
                logger.warning(f"Error during augmentation: {e}")
                continue

        logger.info(f"Creating augmented data for {len(coconut_samples)} coconut samples")
        for sample in coconut_samples:
            try:
                augmented = self.augment_image(sample, num_variations=10)
                synthetic_data.extend(augmented)
                synthetic_labels.extend([1] * len(augmented))
            except Exception as e:
                logger.warning(f"Error during augmentation: {e}")
                continue

        if len(synthetic_data) == 0:
            logger.error("Failed to create synthetic data")
            raise ValueError("Augmentation failed to generate data")

        synthetic_data_array = np.array(synthetic_data)
        synthetic_labels_array = np.array(synthetic_labels)
        logger.info(f"Created synthetic dataset with {len(synthetic_data_array)} samples")
        synthetic_data_array = synthetic_data_array / 255.0
        return synthetic_data_array, synthetic_labels_array

    def create_image_patches(self, image, patch_size=128, stride=64):
        """Create patches from the image as a fallback method"""
        logger.info("Creating image patches as fallback")
        patches = []
        h, w = image.shape[:2]
        for y in range(0, h-patch_size, stride):
            for x in range(0, w-patch_size, stride):
                patch = image[y:y+patch_size, x:x+patch_size]
                hsv = cv2.cvtColor(patch, cv2.COLOR_BGR2HSV)
                lower_green = np.array([30, 40, 40])
                upper_green = np.array([90, 255, 255])
                mask = cv2.inRange(hsv, lower_green, upper_green)
                if np.sum(mask) > (patch_size * patch_size * 0.1):
                    patch_resized = cv2.resize(patch, (128, 128))
                    patches.append(patch_resized)
        logger.info(f"Created {len(patches)} patches from image")
        return patches

    def extract_tree_samples(self, image):
        """Extract individual tree samples using image processing"""
        logger.info("Extracting tree samples from image")
        if image is None or image.size == 0:
            logger.error("Cannot extract tree samples: Input image is empty")
            return []

        try:
            hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
            lower_green = np.array([30, 40, 40])
            upper_green = np.array([90, 255, 255])
            mask = cv2.inRange(hsv, lower_green, upper_green)
            kernel = np.ones((5,5), np.uint8)
            mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
            mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
            contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

            min_tree_area = 100
            max_tree_area = 5000
            tree_samples = []
            for contour in contours:
                area = cv2.contourArea(contour)
                if min_tree_area < area < max_tree_area:
                    x, y, w, h = cv2.boundingRect(contour)
                    padding = 10
                    x_start = max(0, x - padding)
                    y_start = max(0, y - padding)
                    x_end = min(image.shape[1], x + w + padding)
                    y_end = min(image.shape[0], y + h + padding)
                    tree_sample = image[y_start:y_end, x_start:x_end]
                    if tree_sample.shape[0] > 0 and tree_sample.shape[1] > 0:
                        tree_sample = cv2.resize(tree_sample, (128, 128))
                        tree_samples.append(tree_sample)

            logger.info(f"Extracted {len(tree_samples)} tree samples")
            return tree_samples
        except Exception as e:
            logger.error(f"Error in tree sample extraction: {str(e)}")
            return []

    def augment_image(self, image, num_variations=10):
        """Apply data augmentation to create variations of an image"""
        augmented_images = []
        if image is None or image.size == 0:
            logger.warning("Cannot augment empty image")
            return []

        for i in range(num_variations):
            try:
                img = image.copy()
                angle = np.random.uniform(-30, 30)
                h, w = img.shape[:2]
                M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1)
                img = cv2.warpAffine(img, M, (w, h))

                alpha = np.random.uniform(0.8, 1.2)
                beta = np.random.uniform(-10, 10)
                img = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)

                if np.random.random() > 0.5:
                    img = cv2.flip(img, 1)

                zoom = np.random.uniform(0.8, 1.2)
                h, w = img.shape[:2]
                img = cv2.resize(img, None, fx=zoom, fy=zoom)

                if zoom > 1:
                    h_new, w_new = img.shape[:2]
                    start_h = (h_new - h) // 2
                    start_w = (w_new - w) // 2
                    if start_h < 0 or start_w < 0 or start_h + h > h_new or start_w + w > w_new:
                        img = cv2.resize(img, (w, h))
                    else:
                        img = img[start_h:start_h+h, start_w:start_w+w]
                else:
                    h_new, w_new = img.shape[:2]
                    pad_h = (h - h_new) // 2
                    pad_w = (w - w_new) // 2
                    if pad_h < 0 or pad_w < 0:
                        img = cv2.resize(img, (w, h))
                    else:
                        img = cv2.copyMakeBorder(img, pad_h, h-h_new-pad_h, pad_w, w-w_new-pad_w, cv2.BORDER_CONSTANT)

                img = cv2.resize(img, (128, 128))
                augmented_images.append(img)
            except Exception as e:
                logger.warning(f"Error during augmentation: {str(e)}")
                continue
        return augmented_images

    def build_tree_classification_model(self):
        """Build a transfer learning model for tree type classification"""
        logger.info("Building tree classification model")
        try:
            base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(128, 128, 3))
            for layer in base_model.layers:
                layer.trainable = False
            x = base_model.output
            x = GlobalAveragePooling2D()(x)
            x = Dense(256, activation='relu')(x)
            x = Dropout(0.5)(x)
            predictions = Dense(2, activation='softmax')(x)
            self.tree_model = Model(inputs=base_model.input, outputs=predictions)
            self.tree_model.compile(
                optimizer=Adam(learning_rate=0.0001),
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy']
            )
            logger.info("Tree classification model built successfully")
            return self.tree_model
        except Exception as e:
            logger.error(f"Error building tree classification model: {str(e)}")
            raise

    def build_health_analysis_model(self):
        """Build a model for tree health classification"""
        logger.info("Building health analysis model")
        try:
            base_model = VGG16(weights='imagenet', include_top=False, input_shape=(128, 128, 3))
            for layer in base_model.layers:
                layer.trainable = False
            x = base_model.output
            x = GlobalAveragePooling2D()(x)
            x = Dense(256, activation='relu')(x)
            x = Dropout(0.5)(x)
            predictions = Dense(4, activation='softmax')(x)
            self.health_model = Model(inputs=base_model.input, outputs=predictions)
            self.health_model.compile(
                optimizer=Adam(learning_rate=0.0001),
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy']
            )
            logger.info("Health analysis model built successfully")
            return self.health_model
        except Exception as e:
            logger.error(f"Error building health analysis model: {str(e)}")
            raise

    def train_models(self, source_image_path, manual_labels_path=None):
        """Train both tree classification and health models"""
        logger.info(f"Training models using source image: {source_image_path}")
        try:
            X, y_tree_type = self.create_synthetic_dataset(source_image_path)
            X_train, X_val, y_train_tree, y_val_tree = train_test_split(X, y_tree_type, test_size=0.2, random_state=42)
            self.build_tree_classification_model()
            early_stopping = tf.keras.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=5,
                restore_best_weights=True
            )
            tree_history = self.tree_model.fit(
                X_train, y_train_tree,
                validation_data=(X_val, y_val_tree),
                epochs=15,
                batch_size=32,
                callbacks=[early_stopping],
                verbose=1
            )
            logger.info(f"Tree model training complete. Final accuracy: {tree_history.history['accuracy'][-1]:.4f}")

            health_features = self.extract_health_features(X)
            y_health = self.generate_synthetic_health_labels(health_features)
            _, _, y_train_health, y_val_health = train_test_split(X, y_health, test_size=0.2, random_state=42)
            self.build_health_analysis_model()
            health_history = self.health_model.fit(
                X_train, y_train_health,
                validation_data=(X_val, y_val_health),
                epochs=15,
                batch_size=32,
                callbacks=[early_stopping],
                verbose=1
            )
            logger.info(f"Health model training complete. Final accuracy: {health_history.history['accuracy'][-1]:.4f}")
            return tree_history, health_history
        except Exception as e:
            logger.error(f"Error during model training: {str(e)}")
            raise

    def extract_health_features(self, images):
        """Extract features relevant to tree health"""
        logger.info("Extracting health features from images")
        health_features = []
        for img in images:
            try:
                if img.dtype != np.uint8:
                    img_uint8 = (img * 255).astype(np.uint8)
                else:
                    img_uint8 = img
                hsv = cv2.cvtColor(img_uint8, cv2.COLOR_BGR2HSV)
                avg_h = np.mean(hsv[:,:,0])
                avg_s = np.mean(hsv[:,:,1])
                avg_v = np.mean(hsv[:,:,2])
                green_channel = img_uint8[:,:,1]
                avg_green = np.mean(green_channel)
                gray = cv2.cvtColor(img_uint8, cv2.COLOR_BGR2GRAY)
                texture = cv2.Laplacian(gray, cv2.CV_64F).var()
                features = [avg_h, avg_s, avg_v, avg_green, texture]
                health_features.append(features)
            except Exception as e:
                logger.warning(f"Error extracting health features: {str(e)}")
                health_features.append([0, 0, 0, 0, 0])
        return np.array(health_features)

    def generate_synthetic_health_labels(self, features):
        """Generate synthetic health labels based on features"""
        logger.info("Generating synthetic health labels")
        if features.size == 0:
            logger.warning("Empty features array")
            return np.array([])

        features = np.nan_to_num(features)
        features_min = features.min(axis=0)
        features_max = features.max(axis=0)
        range_values = features_max - features_min
        range_values[range_values == 0] = 1
        normalized_features = (features - features_min) / range_values

        green_idx = 3
        value_idx = 2
        health_labels = []
        for sample in normalized_features:
            if sample[green_idx] > 0.75 and sample[value_idx] > 0.6:
                label = 3  # Healthy
            elif sample[green_idx] > 0.5 and sample[value_idx] > 0.5:
                label = 2  # Moderate
            elif sample[green_idx] > 0.3 and sample[value_idx] > 0.3:
                label = 1  # Declining
            else:
                label = 0  # Needs inspection
            health_labels.append(label)

        logger.info(f"Generated {len(health_labels)} health labels")
        return np.array(health_labels)

    def process_and_visualize_tif(self, tif_path, output_jpg_path):
        """Process an entire TIFF file and generate a JPG output with detections"""
        logger.info(f"Processing TIFF file: {tif_path}")
        try:
            if not os.path.exists(tif_path):
                logger.error(f"TIFF file not found: {tif_path}")
                raise FileNotFoundError(f"TIFF file not found: {tif_path}")

            full_image = imread(tif_path)
            if full_image is None:
                logger.error(f"Failed to load TIFF image: {tif_path}")
                raise ValueError(f"Could not read TIFF image: {tif_path}")

            # Convert to a compatible format (e.g., 8-bit RGB or grayscale)
            if len(full_image.shape) == 3 and full_image.shape[2] == 3:  # RGB image
                full_image = cv2.cvtColor(full_image, cv2.COLOR_RGB2BGR)
            elif len(full_image.shape) == 2:  # Grayscale image
                full_image = cv2.cvtColor(full_image, cv2.COLOR_GRAY2BGR)

            logger.info(f"Loaded TIFF image with shape: {full_image.shape}")
            results_df = self.process_drone_image_from_array(full_image)
            self.generate_detection_visualization(full_image, results_df, output_jpg_path)
            logger.info(f"Detection complete. Output saved to: {output_jpg_path}")
            return results_df
        except Exception as e:
            logger.error(f"Error processing TIFF file: {str(e)}")
            raise

    def process_drone_image_from_array(self, image_array):
        """Process a full drone image array and classify all trees"""
        logger.info(f"Processing image array with shape: {image_array.shape}")
        if image_array is None or image_array.size == 0:
            logger.error("Invalid image array")
            raise ValueError("Invalid image array provided")

        tree_locations, tree_images = self.segment_trees_sliding_window(image_array)
        if len(tree_locations) == 0:
            logger.warning("No trees detected. Using fallback method.")
            tree_locations, tree_images = self.create_grid_samples(image_array)

        logger.info(f"Detected {len(tree_locations)} potential trees")
        results = []
        batch_size = 32

        for i in range(0, len(tree_locations), batch_size):
            batch_locations = tree_locations[i:i+batch_size]
            batch_images = tree_images[i:i+batch_size]
            preprocessed_batch = []

            for img in batch_images:
                try:
                    img_resized = cv2.resize(img, (128, 128))
                    img_normalized = img_resized / 255.0
                    preprocessed_batch.append(img_normalized)
                except Exception as e:
                    logger.warning(f"Error preprocessing image {i}: {str(e)}")
                    preprocessed_batch.append(np.zeros((128, 128, 3)))

            input_batch = np.array(preprocessed_batch)
            if len(input_batch) == 0:
                continue

            try:
                tree_preds = self.tree_model.predict(input_batch)
                health_preds = self.health_model.predict(input_batch)
                for j, (location, img, tree_pred, health_pred) in enumerate(
                    zip(batch_locations, batch_images, tree_preds, health_preds)
                ):
                    tree_type = "Palm" if np.argmax(tree_pred) == 0 else "Coconut"
                    tree_confidence = float(np.max(tree_pred))
                    health_idx = np.argmax(health_pred)
                    health_categories = ["Needs Inspection", "Declining Health", "Moderate", "Healthy"]
                    health_status = health_categories[health_idx]
                    health_confidence = float(np.max(health_pred))
                    x, y = location
                    results.append({
                        "Tree_ID": i+j+1,
                        "Type": tree_type,
                        "Type_Confidence": tree_confidence,
                        "Health": health_status,
                        "Health_Confidence": health_confidence,
                        "X_coordinate": int(x),
                        "Y_coordinate": int(y)
                    })
            except Exception as e:
                logger.error(f"Error during batch prediction: {str(e)}")
                continue

        results_df = pd.DataFrame(results) if results else pd.DataFrame([{
            "Tree_ID": 1, "Type": "Palm", "Type_Confidence": 0.8,
            "Health": "Healthy", "Health_Confidence": 0.7,
            "X_coordinate": 100, "Y_coordinate": 100
        }])
        logger.info(f"Created results dataframe with {len(results_df)} trees")
        return results_df

    def generate_detection_visualization(self, image, results_df, output_path):
        """Generate and save a visualization of detection results as JPG"""
        logger.info(f"Generating visualization for output: {output_path}")
        try:
            fig, ax = plt.subplots(figsize=(16, 12))
            img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            ax.imshow(img_rgb)
            health_colors = {
                "Healthy": "green",
                "Moderate": "yellowgreen",
                "Declining Health": "orange",
                "Needs Inspection": "red"
            }

            for _, row in results_df.iterrows():
                x, y = row["X_coordinate"], row["Y_coordinate"]
                tree_type = row["Type"]
                health_status = row["Health"]
                circle = plt.Circle((x, y), 20, color=health_colors[health_status],
                                  fill=False, linewidth=2)
                ax.add_patch(circle)
                label = f"{tree_type[0]}\n{health_status[0]}"
                ax.text(x, y, label, color='white', fontsize=8,
                       ha='center', va='center', bbox=dict(facecolor=health_colors[health_status],
                                                         alpha=0.5))

            ax.set_title("Tree Detection Results")
            ax.axis('off')
            plt.savefig(output_path, format='jpg', dpi=300, bbox_inches='tight')
            plt.close()
            logger.info(f"Visualization saved to {output_path}")
        except Exception as e:
            logger.error(f"Error generating visualization: {str(e)}")
            blank_img = np.zeros((1000, 1000, 3), dtype=np.uint8)
            cv2.imwrite(output_path, blank_img)

    def create_grid_samples(self, image, grid_size=50):
        """Create a grid of sample points as a fallback method"""
        logger.info("Creating grid samples as fallback")
        h, w = image.shape[:2]
        tree_locations = []
        tree_images = []

        for y in range(grid_size, h-grid_size, grid_size):
            for x in range(grid_size, w-grid_size, grid_size):
                patch = image[y-grid_size//2:y+grid_size//2, x-grid_size//2:x+grid_size//2]
                hsv = cv2.cvtColor(patch, cv2.COLOR_BGR2HSV)
                lower_green = np.array([30, 40, 40])
                upper_green = np.array([90, 255, 255])
                mask = cv2.inRange(hsv, lower_green, upper_green)
                if np.sum(mask) > (grid_size * grid_size * 0.1):
                    tree_locations.append((x, y))
                    tree_images.append(patch)

        logger.info(f"Created {len(tree_locations)} grid samples")
        return tree_locations, tree_images

    def segment_trees_sliding_window(self, image, window_size=128, stride=64):
        """Segment trees using a sliding window approach"""
        logger.info("Segmenting trees using sliding window approach")
        if image is None or image.shape[0] == 0 or image.shape[1] == 0:
            logger.error("Cannot segment trees: Invalid image")
            return [], []

        h, w = image.shape[:2]
        tree_locations = []
        tree_images = []

        try:
            hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
            lower_green = np.array([30, 40, 40])
            upper_green = np.array([90, 255, 255])
            veg_mask = cv2.inRange(hsv, lower_green, upper_green)

            for y in range(0, h-window_size, stride):
                for x in range(0, w-window_size, stride):
                    window = image[y:y+window_size, x:x+window_size]
                    mask_window = veg_mask[y:y+window_size, x:x+window_size]
                    veg_ratio = np.sum(mask_window > 0) / (window_size * window_size)

                    if veg_ratio > 0.3:
                        mask_copy = mask_window.copy()
                        contours, _ = cv2.findContours(mask_copy, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

                        for contour in contours:
                            area = cv2.contourArea(contour)
                            if area > 100:
                                perimeter = cv2.arcLength(contour, True)
                                circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0

                                if circularity > 0.4:
                                    M = cv2.moments(contour)
                                    if M["m00"] > 0:
                                        c_x = int(M["m10"] / M["m00"]) + x
                                        c_y = int(M["m01"] / M["m00"]) + y
                                        tree_x1 = max(0, c_x - window_size//2)
                                        tree_y1 = max(0, c_y - window_size//2)
                                        tree_x2 = min(w, c_x + window_size//2)
                                        tree_y2 = min(h, c_y + window_size//2)
                                        tree_img = image[tree_y1:tree_y2, tree_x1:tree_x2]

                                        if tree_img.shape[0] > 0 and tree_img.shape[1] > 0:
                                            tree_img_resized = cv2.resize(tree_img, (128, 128))
                                            tree_images.append(tree_img_resized)
                                            tree_locations.append((c_x, c_y))

            logger.info(f"Detected {len(tree_locations)} potential trees")
            return tree_locations, tree_images
        except Exception as e:
            logger.error(f"Error in tree segmentation: {str(e)}")
            return [], []

if __name__ == "__main__":
    analyzer = PalmTreeAnalysis()

    # Training
    source_image = "4562d4b9-3ebd-4d73-a6e3-9a713a9fa608.tif"
    analyzer.train_models(source_image)

    # Testing with TIFF file
    tif_file = "4562d4b9-3ebd-4d73-a6e3-9a713a9fa608.tif"
    output_jpg = "path/to/output_detections.jpg"
    results = analyzer.process_and_visualize_tif(tif_file, output_jpg)
    print(results)

ERROR:__main__:Error during model training: <COMPRESSION.JPEG: 7> requires the 'imagecodecs' package


ValueError: <COMPRESSION.JPEG: 7> requires the 'imagecodecs' package

In [3]:
!pip install imagecodecs

Collecting imagecodecs
  Downloading imagecodecs-2024.12.30-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (19 kB)
Downloading imagecodecs-2024.12.30-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (45.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.5/45.5 MB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: imagecodecs
Successfully installed imagecodecs-2024.12.30


In [None]:
import os
import sys
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.applications import VGG16, ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import time
import logging
from datetime import datetime
from tifffile import imread  # Added for TIFF file handling

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('palm_tree_analysis.log')
    ]
)
logger = logging.getLogger(__name__)


class PalmTreeAnalysis:
    def __init__(self):
        self.tree_model = None
        self.health_model = None
        logger.info("PalmTreeAnalysis initialized")

    def create_synthetic_dataset(self, source_image_path):
        """Create a synthetic dataset from a single image using data augmentation"""
        logger.info(f"Creating synthetic dataset from source image: {source_image_path}")
        if not os.path.isfile(source_image_path):
            logger.error(f"Source image not found: {source_image_path}")
            raise FileNotFoundError(f"Could not find image file: {source_image_path}")

        # Attempt to load the TIFF file using tifffile
        try:
            img = imread(source_image_path)
            logger.info(f"Successfully loaded TIFF image using tifffile")
        except ValueError as e:
            logger.warning(f"Failed to load TIFF using tifffile: {str(e)}. Falling back to OpenCV.")
            img = cv2.imread(source_image_path, cv2.IMREAD_UNCHANGED)
            if img is None:
                logger.error(f"Failed to load image using OpenCV: {source_image_path}")
                raise ValueError(f"Could not read image file: {source_image_path}")

        # Convert to a compatible format (e.g., 8-bit RGB or grayscale)
        if len(img.shape) == 3 and img.shape[2] == 3:  # RGB image
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        elif len(img.shape) == 2:  # Grayscale image
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

        logger.info(f"Successfully loaded image with shape: {img.shape}")

        tree_samples = self.extract_tree_samples(img)
        if len(tree_samples) == 0:
            logger.warning(f"No tree samples extracted. Using fallback approach.")
            tree_samples = self.create_image_patches(img)
        logger.info(f"Created {len(tree_samples)} tree samples")

        if len(tree_samples) == 0:
            logger.error("Failed to create tree samples")
            raise ValueError("Could not create training samples")

        synthetic_data = []
        synthetic_labels = []
        palm_samples = tree_samples[:len(tree_samples)//2]
        coconut_samples = tree_samples[len(tree_samples)//2:]

        logger.info(f"Creating augmented data for {len(palm_samples)} palm samples")
        for sample in palm_samples:
            try:
                augmented = self.augment_image(sample, num_variations=10)
                synthetic_data.extend(augmented)
                synthetic_labels.extend([0] * len(augmented))
            except Exception as e:
                logger.warning(f"Error during augmentation: {e}")
                continue

        logger.info(f"Creating augmented data for {len(coconut_samples)} coconut samples")
        for sample in coconut_samples:
            try:
                augmented = self.augment_image(sample, num_variations=10)
                synthetic_data.extend(augmented)
                synthetic_labels.extend([1] * len(augmented))
            except Exception as e:
                logger.warning(f"Error during augmentation: {e}")
                continue

        if len(synthetic_data) == 0:
            logger.error("Failed to create synthetic data")
            raise ValueError("Augmentation failed to generate data")

        synthetic_data_array = np.array(synthetic_data)
        synthetic_labels_array = np.array(synthetic_labels)
        logger.info(f"Created synthetic dataset with {len(synthetic_data_array)} samples")
        synthetic_data_array = synthetic_data_array / 255.0
        return synthetic_data_array, synthetic_labels_array

    def create_image_patches(self, image, patch_size=128, stride=64):
        """Create patches from the image as a fallback method"""
        logger.info("Creating image patches as fallback")
        patches = []
        h, w = image.shape[:2]
        for y in range(0, h-patch_size, stride):
            for x in range(0, w-patch_size, stride):
                patch = image[y:y+patch_size, x:x+patch_size]
                hsv = cv2.cvtColor(patch, cv2.COLOR_BGR2HSV)
                lower_green = np.array([30, 40, 40])
                upper_green = np.array([90, 255, 255])
                mask = cv2.inRange(hsv, lower_green, upper_green)
                if np.sum(mask) > (patch_size * patch_size * 0.1):
                    patch_resized = cv2.resize(patch, (128, 128))
                    patches.append(patch_resized)
        logger.info(f"Created {len(patches)} patches from image")
        return patches

    def extract_tree_samples(self, image):
        """Extract individual tree samples using image processing"""
        logger.info("Extracting tree samples from image")
        if image is None or image.size == 0:
            logger.error("Cannot extract tree samples: Input image is empty")
            return []

        try:
            hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
            lower_green = np.array([30, 40, 40])
            upper_green = np.array([90, 255, 255])
            mask = cv2.inRange(hsv, lower_green, upper_green)
            kernel = np.ones((5, 5), np.uint8)
            mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
            mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
            contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

            min_tree_area = 100
            max_tree_area = 5000
            tree_samples = []
            for contour in contours:
                area = cv2.contourArea(contour)
                if min_tree_area < area < max_tree_area:
                    x, y, w, h = cv2.boundingRect(contour)
                    padding = 10
                    x_start = max(0, x - padding)
                    y_start = max(0, y - padding)
                    x_end = min(image.shape[1], x + w + padding)
                    y_end = min(image.shape[0], y + h + padding)
                    tree_sample = image[y_start:y_end, x_start:x_end]
                    if tree_sample.shape[0] > 0 and tree_sample.shape[1] > 0:
                        tree_sample = cv2.resize(tree_sample, (128, 128))
                        tree_samples.append(tree_sample)

            logger.info(f"Extracted {len(tree_samples)} tree samples")
            return tree_samples
        except Exception as e:
            logger.error(f"Error in tree sample extraction: {str(e)}")
            return []

    def augment_image(self, image, num_variations=10):
        """Apply data augmentation to create variations of an image"""
        augmented_images = []
        if image is None or image.size == 0:
            logger.warning("Cannot augment empty image")
            return []

        for i in range(num_variations):
            try:
                img = image.copy()
                angle = np.random.uniform(-30, 30)
                h, w = img.shape[:2]
                M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1)
                img = cv2.warpAffine(img, M, (w, h))

                alpha = np.random.uniform(0.8, 1.2)
                beta = np.random.uniform(-10, 10)
                img = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)

                if np.random.random() > 0.5:
                    img = cv2.flip(img, 1)

                zoom = np.random.uniform(0.8, 1.2)
                h, w = img.shape[:2]
                img = cv2.resize(img, None, fx=zoom, fy=zoom)

                if zoom > 1:
                    h_new, w_new = img.shape[:2]
                    start_h = (h_new - h) // 2
                    start_w = (w_new - w) // 2
                    if start_h < 0 or start_w < 0 or start_h + h > h_new or start_w + w > w_new:
                        img = cv2.resize(img, (w, h))
                    else:
                        img = img[start_h:start_h+h, start_w:start_w+w]
                else:
                    h_new, w_new = img.shape[:2]
                    pad_h = (h - h_new) // 2
                    pad_w = (w - w_new) // 2
                    if pad_h < 0 or pad_w < 0:
                        img = cv2.resize(img, (w, h))
                    else:
                        img = cv2.copyMakeBorder(img, pad_h, h-h_new-pad_h, pad_w, w-w_new-pad_w, cv2.BORDER_CONSTANT)

                img = cv2.resize(img, (128, 128))
                augmented_images.append(img)
            except Exception as e:
                logger.warning(f"Error during augmentation: {str(e)}")
                continue

        return augmented_images

    def build_tree_classification_model(self):
        """Build a transfer learning model for tree type classification"""
        logger.info("Building tree classification model")
        try:
            base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(128, 128, 3))
            for layer in base_model.layers:
                layer.trainable = False
            x = base_model.output
            x = GlobalAveragePooling2D()(x)
            x = Dense(256, activation='relu')(x)
            x = Dropout(0.5)(x)
            predictions = Dense(2, activation='softmax')(x)
            self.tree_model = Model(inputs=base_model.input, outputs=predictions)
            self.tree_model.compile(
                optimizer=Adam(learning_rate=0.0001),
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy']
            )
            logger.info("Tree classification model built successfully")
            return self.tree_model
        except Exception as e:
            logger.error(f"Error building tree classification model: {str(e)}")
            raise

    def build_health_analysis_model(self):
        """Build a model for tree health classification"""
        logger.info("Building health analysis model")
        try:
            base_model = VGG16(weights='imagenet', include_top=False, input_shape=(128, 128, 3))
            for layer in base_model.layers:
                layer.trainable = False
            x = base_model.output
            x = GlobalAveragePooling2D()(x)
            x = Dense(256, activation='relu')(x)
            x = Dropout(0.5)(x)
            predictions = Dense(4, activation='softmax')(x)
            self.health_model = Model(inputs=base_model.input, outputs=predictions)
            self.health_model.compile(
                optimizer=Adam(learning_rate=0.0001),
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy']
            )
            logger.info("Health analysis model built successfully")
            return self.health_model
        except Exception as e:
            logger.error(f"Error building health analysis model: {str(e)}")
            raise

    def train_models(self, source_image_path, manual_labels_path=None):
        """Train both tree classification and health models"""
        logger.info(f"Training models using source image: {source_image_path}")
        try:
            X, y_tree_type = self.create_synthetic_dataset(source_image_path)
            X_train, X_val, y_train_tree, y_val_tree = train_test_split(X, y_tree_type, test_size=0.2, random_state=42)
            self.build_tree_classification_model()
            early_stopping = tf.keras.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=5,
                restore_best_weights=True
            )
            tree_history = self.tree_model.fit(
                X_train, y_train_tree,
                validation_data=(X_val, y_val_tree),
                epochs=15,
                batch_size=32,
                callbacks=[early_stopping],
                verbose=1
            )
            logger.info(f"Tree model training complete. Final accuracy: {tree_history.history['accuracy'][-1]:.4f}")

            health_features = self.extract_health_features(X)
            y_health = self.generate_synthetic_health_labels(health_features)
            _, _, y_train_health, y_val_health = train_test_split(X, y_health, test_size=0.2, random_state=42)
            self.build_health_analysis_model()
            health_history = self.health_model.fit(
                X_train, y_train_health,
                validation_data=(X_val, y_val_health),
                epochs=15,
                batch_size=32,
                callbacks=[early_stopping],
                verbose=1
            )
            logger.info(f"Health model training complete. Final accuracy: {health_history.history['accuracy'][-1]:.4f}")
            return tree_history, health_history
        except Exception as e:
            logger.error(f"Error during model training: {str(e)}")
            raise

    def extract_health_features(self, images):
        """Extract features relevant to tree health"""
        logger.info("Extracting health features from images")
        health_features = []
        for img in images:
            try:
                if img.dtype != np.uint8:
                    img_uint8 = (img * 255).astype(np.uint8)
                else:
                    img_uint8 = img
                hsv = cv2.cvtColor(img_uint8, cv2.COLOR_BGR2HSV)
                avg_h = np.mean(hsv[:, :, 0])
                avg_s = np.mean(hsv[:, :, 1])
                avg_v = np.mean(hsv[:, :, 2])
                green_channel = img_uint8[:, :, 1]
                avg_green = np.mean(green_channel)
                gray = cv2.cvtColor(img_uint8, cv2.COLOR_BGR2GRAY)
                texture = cv2.Laplacian(gray, cv2.CV_64F).var()
                features = [avg_h, avg_s, avg_v, avg_green, texture]
                health_features.append(features)
            except Exception as e:
                logger.warning(f"Error extracting health features: {str(e)}")
                health_features.append([0, 0, 0, 0, 0])
        return np.array(health_features)

    def generate_synthetic_health_labels(self, features):
        """Generate synthetic health labels based on features"""
        logger.info("Generating synthetic health labels")
        if features.size == 0:
            logger.warning("Empty features array")
            return np.array([])

        features = np.nan_to_num(features)
        features_min = features.min(axis=0)
        features_max = features.max(axis=0)
        range_values = features_max - features_min
        range_values[range_values == 0] = 1
        normalized_features = (features - features_min) / range_values

        green_idx = 3
        value_idx = 2
        health_labels = []
        for sample in normalized_features:
            if sample[green_idx] > 0.75 and sample[value_idx] > 0.6:
                label = 3  # Healthy
            elif sample[green_idx] > 0.5 and sample[value_idx] > 0.5:
                label = 2  # Moderate
            elif sample[green_idx] > 0.3 and sample[value_idx] > 0.3:
                label = 1  # Declining
            else:
                label = 0  # Needs inspection
            health_labels.append(label)

        logger.info(f"Generated {len(health_labels)} health labels")
        return np.array(health_labels)

    def process_and_visualize_tif(self, tif_path, output_jpg_path):
        """Process an entire TIFF file and generate a JPG output with detections"""
        logger.info(f"Processing TIFF file: {tif_path}")
        try:
            if not os.path.exists(tif_path):
                logger.error(f"TIFF file not found: {tif_path}")
                raise FileNotFoundError(f"TIFF file not found: {tif_path}")

            # Attempt to load the TIFF file using tifffile
            try:
                full_image = imread(tif_path)
                logger.info(f"Successfully loaded TIFF image using tifffile")
            except ValueError as e:
                logger.warning(f"Failed to load TIFF using tifffile: {str(e)}. Falling back to OpenCV.")
                full_image = cv2.imread(tif_path, cv2.IMREAD_UNCHANGED)
                if full_image is None:
                    logger.error(f"Failed to load image using OpenCV: {tif_path}")
                    raise ValueError(f"Could not read image file: {tif_path}")

            # Convert to a compatible format (e.g., 8-bit RGB or grayscale)
            if len(full_image.shape) == 3 and full_image.shape[2] == 3:  # RGB image
                full_image = cv2.cvtColor(full_image, cv2.COLOR_RGB2BGR)
            elif len(full_image.shape) == 2:  # Grayscale image
                full_image = cv2.cvtColor(full_image, cv2.COLOR_GRAY2BGR)

            logger.info(f"Loaded TIFF image with shape: {full_image.shape}")
            results_df = self.process_drone_image_from_array(full_image)
            self.generate_detection_visualization(full_image, results_df, output_jpg_path)
            logger.info(f"Detection complete. Output saved to: {output_jpg_path}")
            return results_df

        except Exception as e:
            logger.error(f"Error processing TIFF file: {str(e)}")
            raise

    def process_drone_image_from_array(self, image_array):
        """Process a full drone image array and classify all trees"""
        logger.info(f"Processing image array with shape: {image_array.shape}")
        if image_array is None or image_array.size == 0:
            logger.error("Invalid image array")
            raise ValueError("Invalid image array provided")

        tree_locations, tree_images = self.segment_trees_sliding_window(image_array)
        if len(tree_locations) == 0:
            logger.warning("No trees detected. Using fallback method.")
            tree_locations, tree_images = self.create_grid_samples(image_array)
        logger.info(f"Detected {len(tree_locations)} potential trees")

        results = []
        batch_size = 32
        for i in range(0, len(tree_locations), batch_size):
            batch_locations = tree_locations[i:i+batch_size]
            batch_images = tree_images[i:i+batch_size]
            preprocessed_batch = []
            for img in batch_images:
                try:
                    img_resized = cv2.resize(img, (128, 128))
                    img_normalized = img_resized / 255.0
                    preprocessed_batch.append(img_normalized)
                except Exception as e:
                    logger.warning(f"Error preprocessing image {i}: {str(e)}")
                    preprocessed_batch.append(np.zeros((128, 128, 3)))

            input_batch = np.array(preprocessed_batch)
            if len(input_batch) == 0:
                continue

            try:
                tree_preds = self.tree_model.predict(input_batch)
                health_preds = self.health_model.predict(input_batch)
                for j, (location, img, tree_pred, health_pred) in enumerate(
                    zip(batch_locations, batch_images, tree_preds, health_preds)
                ):
                    tree_type = "Palm" if np.argmax(tree_pred) == 0 else "Coconut"
                    tree_confidence = float(np.max(tree_pred))
                    health_idx = np.argmax(health_pred)
                    health_categories = ["Needs Inspection", "Declining Health", "Moderate", "Healthy"]
                    health_status = health_categories[health_idx]
                    health_confidence = float(np.max(health_pred))
                    x, y = location
                    results.append({
                        "Tree_ID": i+j+1,
                        "Type": tree_type,
                        "Type_Confidence": tree_confidence,
                        "Health": health_status,
                        "Health_Confidence": health_confidence,
                        "X_coordinate": int(x),
                        "Y_coordinate": int(y)
                    })
            except Exception as e:
                logger.error(f"Error during batch prediction: {str(e)}")
                continue

        results_df = pd.DataFrame(results) if results else pd.DataFrame([{
            "Tree_ID": 1, "Type": "Palm", "Type_Confidence": 0.8,
            "Health": "Healthy", "Health_Confidence": 0.7,
            "X_coordinate": 100, "Y_coordinate": 100
        }])
        logger.info(f"Created results dataframe with {len(results_df)} trees")
        return results_df

    def generate_detection_visualization(self, image, results_df, output_path):
        """Generate and save a visualization of detection results as JPG"""
        logger.info(f"Generating visualization for output: {output_path}")
        try:
            fig, ax = plt.subplots(figsize=(16, 12))
            img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            ax.imshow(img_rgb)
            health_colors = {
                "Healthy": "green",
                "Moderate": "yellowgreen",
                "Declining Health": "orange",
                "Needs Inspection": "red"
            }

            for _, row in results_df.iterrows():
                x, y = row["X_coordinate"], row["Y_coordinate"]
                tree_type = row["Type"]
                health_status = row["Health"]
                circle = plt.Circle((x, y), 20, color=health_colors[health_status],
                                    fill=False, linewidth=2)
                ax.add_patch(circle)
                label = f"{tree_type[0]}\n{health_status[0]}"
                ax.text(x, y, label, color='white', fontsize=8,
                        ha='center', va='center', bbox=dict(facecolor=health_colors[health_status],
                                                            alpha=0.5))

            ax.set_title("Tree Detection Results")
            ax.axis('off')
            plt.savefig(output_path, format='jpg', dpi=300, bbox_inches='tight')
            plt.close()
            logger.info(f"Visualization saved to {output_path}")
        except Exception as e:
            logger.error(f"Error generating visualization: {str(e)}")
            blank_img = np.zeros((1000, 1000, 3), dtype=np.uint8)
            cv2.imwrite(output_path, blank_img)

    def create_grid_samples(self, image, grid_size=50):
        """Create a grid of sample points as a fallback method"""
        logger.info("Creating grid samples as fallback")
        h, w = image.shape[:2]
        tree_locations = []
        tree_images = []
        for y in range(grid_size, h-grid_size, grid_size):
            for x in range(grid_size, w-grid_size, grid_size):
                patch = image[y-grid_size//2:y+grid_size//2, x-grid_size//2:x+grid_size//2]
                hsv = cv2.cvtColor(patch, cv2.COLOR_BGR2HSV)
                lower_green = np.array([30, 40, 40])
                upper_green = np.array([90, 255, 255])
                mask = cv2.inRange(hsv, lower_green, upper_green)
                if np.sum(mask) > (grid_size * grid_size * 0.1):
                    tree_locations.append((x, y))
                    tree_images.append(patch)
        logger.info(f"Created {len(tree_locations)} grid samples")
        return tree_locations, tree_images

    def segment_trees_sliding_window(self, image, window_size=128, stride=64):
        """Segment trees using a sliding window approach"""
        logger.info("Segmenting trees using sliding window approach")
        if image is None or image.shape[0] == 0 or image.shape[1] == 0:
            logger.error("Cannot segment trees: Invalid image")
            return [], []

        h, w = image.shape[:2]
        tree_locations = []
        tree_images = []
        try:
            hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
            lower_green = np.array([30, 40, 40])
            upper_green = np.array([90, 255, 255])
            veg_mask = cv2.inRange(hsv, lower_green, upper_green)

            for y in range(0, h-window_size, stride):
                for x in range(0, w-window_size, stride):
                    window = image[y:y+window_size, x:x+window_size]
                    mask_window = veg_mask[y:y+window_size, x:x+window_size]
                    veg_ratio = np.sum(mask_window > 0) / (window_size * window_size)

                    if veg_ratio > 0.3:
                        mask_copy = mask_window.copy()
                        contours, _ = cv2.findContours(mask_copy, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

                        for contour in contours:
                            area = cv2.contourArea(contour)
                            if area > 100:
                                perimeter = cv2.arcLength(contour, True)
                                circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0

                                if circularity > 0.4:
                                    M = cv2.moments(contour)
                                    if M["m00"] > 0:
                                        c_x = int(M["m10"] / M["m00"]) + x
                                        c_y = int(M["m01"] / M["m00"]) + y
                                        tree_x1 = max(0, c_x - window_size//2)
                                        tree_y1 = max(0, c_y - window_size//2)
                                        tree_x2 = min(w, c_x + window_size//2)
                                        tree_y2 = min(h, c_y + window_size//2)
                                        tree_img = image[tree_y1:tree_y2, tree_x1:tree_x2]

                                        if tree_img.shape[0] > 0 and tree_img.shape[1] > 0:
                                            tree_img_resized = cv2.resize(tree_img, (128, 128))
                                            tree_images.append(tree_img_resized)
                                            tree_locations.append((c_x, c_y))

            logger.info(f"Detected {len(tree_locations)} potential trees")
            return tree_locations, tree_images
        except Exception as e:
            logger.error(f"Error in tree segmentation: {str(e)}")
            return [], []


if __name__ == "__main__":
    analyzer = PalmTreeAnalysis()

    # Training
    source_image = "4562d4b9-3ebd-4d73-a6e3-9a713a9fa608.tif"
    analyzer.train_models(source_image)

    # Testing with TIFF file
    tif_file = "4562d4b9-3ebd-4d73-a6e3-9a713a9fa608.tif"
    output_jpg = "path/to/output_detections.jpg"
    results = analyzer.process_and_visualize_tif(tif_file, output_jpg)
    print(results)

In [1]:
import torch
import torch.distributed as dist

In [4]:
import os
os.environ['MASTER_ADDR']='localhost'
os.environ['MASTER_PORT']='12345'
os.environ['WORLD_SIZE']=str(4)
os.environ['RANK']=str(0)
dist.init_process_group(backend='nccl')

DistStoreError: Timed out after 601 seconds waiting for clients. 1/4 clients joined.

In [None]:
from torch.distributed.fsdp import FullyShardedDataParallel as fsdp

model=model()
model=fsdp(model)

In [None]:
from torch.utils.data import DataLoader DistributedSampler

In [None]:
dataset=sampler()
sampler=DistributedSampler(dataset)
dataloader=DataLoader(dataset,sampler=sampler,batch_size=batch_size)

In [None]:
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=<rank_of_this_node> your_training_script.py