In [1]:
# @title Cell 1: Setup and Imports

# **Important:** Run this cell first to install necessary libraries.

!pip install torch torchvision opencv-python scikit-image albumentations tqdm matplotlib

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
import os
from PIL import Image
import numpy as np
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
from skimage.restoration import denoise_nl_means, denoise_bilateral, denoise_wavelet
from skimage import img_as_float, img_as_ubyte
from skimage.transform import rotate
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

print("Setup and imports complete.")

Setup and imports complete.


In [22]:
!pip install 'git+https://github.com/CASIA-IVA-Lab/FastSAM.git'
!pip install ultralytics

from fastsam import FastSAM, FastSAMPrompt
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os

def fastsam_segmentation(image_path, device="cuda", model_path="FastSAM.pt"):  # Or "FastSAM-x.pt"
    # Load the image
    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError("Could not open the image.")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Load the FastSAM model
    model = FastSAM(model_path)  # Initialize every time
    model.to(device)

    # Preprocessing (required by FastSAM)
    everything_results = model(image, device=device, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
    # Create a prompt object (this handles the different prompt types)
    prompt_process = FastSAMPrompt(image, everything_results, device=device)
    # Everything prompt
    ann = prompt_process.everything_prompt()
    # # Bbox prompt.  Example:  [x1, y1, x2, y2] (top-left, bottom-right)
    # ann = prompt_process.box_prompt(bbox=[[200, 200, 400, 400]])
    # # Point prompt. Example: points=[[620, 360]], point_label=[1] (1 for foreground, 0 for background)
    # ann = prompt_process.point_prompt(points=[[620, 360]], point_label=[1])

    if len(ann) == 0:
      print("No objects detected.")
      # Return a full black image
      return np.zeros(image.shape[:2], dtype=np.uint8)

    # Find the largest mask (assuming the largest mask is the leaf)
    largest_mask = None
    largest_area = 0
    for mask in ann:
      mask_np = mask.cpu().numpy()
      area = np.sum(mask_np) # Count the number of True pixels
      if area > largest_area:
          largest_area = area
          largest_mask = mask_np


    if largest_mask is not None:
         # Convert to uint8 (required for saving as an image)
        binary_mask = (largest_mask > 0).astype(np.uint8)
    else:
      binary_mask = np.zeros(image.shape[:2], dtype=np.uint8)


    # # --- Optional: Visualization ---
    # plt.imshow(binary_mask, cmap='gray')
    # plt.title("Segmented Mask")
    # plt.show()

    return binary_mask
# --- Example Usage ---
if __name__ == "__main__":
    # Download FastSAM model (if you haven't already)
    if not os.path.exists("FastSAM.pt"):
        print("Downloading FastSAM.pt...")
        os.system("wget -q https://github.com/CASIA-IVA-Lab/FastSAM/releases/download/v1.0/FastSAM.pt") #wget command
    image_path = "/teamspace/studios/this_studio/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Augmented Dataset/Bacterial Blight/augmented_Bacterial Blight0001.jpg"  # Replace with your image path.
    output_mask = fastsam_segmentation(image_path) # Run FastSAM
    # Save the mask, or use it for further processing
    cv2.imwrite("fastsam_output_mask.png", output_mask * 255)  # Save as grayscale image (0 and 255)
    print("Segmentation mask saved to fastsam_output_mask.png")
    plt.imshow(output_mask) #Show mask
    plt.show()

Collecting git+https://github.com/CASIA-IVA-Lab/FastSAM.git
  Cloning https://github.com/CASIA-IVA-Lab/FastSAM.git to /tmp/pip-req-build-q7uy3weg
  Running command git clone --filter=blob:none --quiet https://github.com/CASIA-IVA-Lab/FastSAM.git /tmp/pip-req-build-q7uy3weg

  Resolved https://github.com/CASIA-IVA-Lab/FastSAM.git to commit b4ed20c2fed75eadc5aa7d8b09fedd137b873b52
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting CLIP@ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33#egg=CLIP (from fastsam==0.1.1)
  Using cached clip-1.0-py3-none-any.whl


ModuleNotFoundError: No module named 'ultralytics.yolo.cfg'

In [12]:
print(model.encoder.layers[-1])

EncoderBlock(
  (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (self_attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (mlp): MLPBlock(
    (0): Linear(in_features=768, out_features=3072, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.0, inplace=False)
    (3): Linear(in_features=3072, out_features=768, bias=True)
    (4): Dropout(p=0.0, inplace=False)
  )
)


In [5]:
# @title Cell 2: Leaf Segmentation (U-Net) - Model and Dataset Definition

# --- Simplified U-Net Model ---
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNet
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNet
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)



# --- Dataset ---
class CottonLeafSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        # Modified: List all image files, handling subdirectories
        self.image_paths = []
        self.labels = []

        # Iterate through class directories
        for class_idx, class_dir in enumerate([item for item in os.listdir(image_dir) if os.path.isdir(os.path.join(image_dir,item))]):
            class_path = os.path.join(image_dir, class_dir)
            for filename in os.listdir(class_path):
                if filename.endswith(('.jpg', '.jpeg', '.png')):
                    self.image_paths.append(os.path.join(class_dir, filename))
                    self.labels.append(class_idx)


    def __len__(self):
        return len(self.image_paths)


    def __getitem__(self, index):
        img_filename = self.image_paths[index]
        img_path = os.path.join(self.image_dir, img_filename)

        # Extract filename without extension
        base_filename = os.path.splitext(img_filename)[0]

        # Construct the mask path assuming masks are in subdirectories
        class_name = [item for item in os.listdir(self.image_dir) if os.path.isdir(os.path.join(self.image_dir,item))][self.labels[index]] #gets class name
        mask_path = os.path.join(self.mask_dir, class_name, base_filename + "_mask.png")


        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32) # L: grayscale
        mask[mask > 0.0] = 1.0  # Ensure binary mask (0 and 1)


        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask




# --- Training Loop ---
def train_segmentation_model(image_dir, mask_dir, epochs=25, batch_size=8, learning_rate=1e-4):
    # Use albumentations for transforms
    transform = A.Compose([
        A.Resize(height=256, width=256),
        A.Rotate(limit=35, p=0.7),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0,
        ),
        ToTensorV2(), # Convert to tensor
    ])
    # Modified: Use the updated dataset class
    dataset = CottonLeafSegmentationDataset(image_dir, mask_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(in_channels=3, out_channels=1).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy with Logits (more stable)


    for epoch in range(epochs):
        model.train()
        loop = tqdm(dataloader, total = len(dataloader), leave=False)
        for batch_idx, (data, targets) in enumerate(loop):
            data = data.to(device=device)
            targets = targets.float().unsqueeze(1).to(device=device) # Add channel dimension

            # Forward pass
            predictions = model(data)
            loss = loss_fn(predictions, targets)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # update tqdm loop
            loop.set_description(f"Epoch [{epoch}/{epochs}]")
            loop.set_postfix(loss=loss.item())


        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

    torch.save(model.state_dict(), "unet_segmentation_model.pth")  # Save the trained model
    return model


# --- Inference (using the trained model) ---
def segment_leaf(image_path, model, device):
    model.eval()
    image = Image.open(image_path).convert("RGB")

    transform = A.Compose([  # Inference transform (no augmentation)
            A.Resize(height=256, width=256),
            A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0,
            ),
            ToTensorV2(),
    ])

    image_np = np.array(image)
    augmented = transform(image=image_np)
    image_tensor = augmented["image"].unsqueeze(0).to(device) # Add batch dimension


    with torch.no_grad():
        predicted_mask = model(image_tensor)
        predicted_mask = torch.sigmoid(predicted_mask)  # Apply sigmoid for probability
        predicted_mask = (predicted_mask > 0.5).float()  # Threshold to get binary mask

    return predicted_mask.squeeze().cpu().numpy()  # Remove batch and channel dimensions, move to CPU

In [6]:
# @title Cell 3: U-Net Training and Inference - Example Usage

# --- VERY IMPORTANT: Set up paths correctly! ---

# 1. Image Directory: Points to your 'Original Dataset' folder.
image_directory = "../SAR-CLD-2024 A Comprehensive Review of Current Research, Challenges, and Future Directions/Original Dataset"

# 2. Mask Directory:  Points to a 'masks' folder *inside* 'Original Dataset'.
#    YOU MUST MANUALLY CREATE THIS 'masks' FOLDER AND ITS SUBFOLDERS.
mask_directory = "../SAR-CLD-2024 A Comprehensive Review of Current Research, Challenges, and Future Directions/Original Dataset/masks"

# --- Create the 'masks' directory and its subfolders if they don't exist ---
if not os.path.exists(mask_directory):
    os.makedirs(mask_directory)
    # Create subfolders corresponding to each class in your original dataset
    for class_name in os.listdir(image_directory):
      class_path = os.path.join(image_directory, class_name)
      if os.path.isdir(class_path): # Check to avoid files.
          mask_class_path = os.path.join(mask_directory, class_name)
          os.makedirs(mask_class_path, exist_ok=True)  # exist_ok=True prevents errors if it already exists
    print(f"Created mask directory and subdirectories: {mask_directory}")
    print("YOU MUST NOW MANUALLY CREATE SEGMENTATION MASKS.")
    print("Place the masks inside the corresponding class subfolders within the 'masks' directory.")
    print("For example, the mask for 'Original Dataset/Bacterial Blight/image1.jpg' should be at:")
    print("  'Original Dataset/masks/Bacterial Blight/image1_mask.png'")
    # Exit early if the mask directory was just created.  This prevents errors.
    #  The user needs to create the masks before proceeding.
    import sys
    sys.exit("Please create masks and then re-run this cell and the following cells.")


# --- Create dummy images (ONLY for demonstration if the directory is empty) ---
# In a REAL scenario, you would already have your images and you would create
# the masks manually.  This section is just to make the notebook runnable
# for demonstration purposes.
for class_name in os.listdir(image_directory):
    class_path = os.path.join(image_directory, class_name)
    if os.path.isdir(class_path) and not os.listdir(class_path): #check if it is a directory and if it is empty
        print(f"Creating dummy images in: {class_path}")
        for i in range(5):  # Create a few dummy images
            img = Image.new('RGB', (256, 256))  # Create a blank image
            img.save(os.path.join(class_path, f"image_{i}.jpg"))


# --- Train the U-Net Model (if masks exist) ---
if os.path.exists(mask_directory) and any(os.scandir(mask_directory)): # Check if masks dir exists and is not empty
    print("Training U-Net model...")
    trained_model = train_segmentation_model(image_directory, mask_directory)

    # --- Example Inference ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    trained_model.to(device)

    # Get an example image path (you'll need to adapt this if your images are named differently)
    example_class = os.listdir(image_directory)[0]  # Get the first class directory
    example_image_path = os.path.join(image_directory, example_class, os.listdir(os.path.join(image_directory,example_class))[0])
    if os.path.exists(example_image_path):
        segmented_mask = segment_leaf(example_image_path, trained_model, device)


        # Display the result
        original_image = Image.open(example_image_path)
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(original_image)
        plt.title("Original Image")

        plt.subplot(1, 2, 2)
        plt.imshow(segmented_mask, cmap='gray')
        plt.title("Segmented Mask")
        plt.show()

        # --- Applying the mask ---
        # 1. Masking (setting background to black)
        original_image_np = np.array(original_image)
        segmented_mask_expanded = np.expand_dims(segmented_mask, axis=2)
        masked_image = original_image_np * segmented_mask_expanded
        masked_image = masked_image.astype(np.uint8)

        # 2. Cropping (finding bounding box)
        coords = np.argwhere(segmented_mask > 0)
        if coords.size > 0:
            y_min, x_min = coords.min(axis=0)
            y_max, x_max = coords.max(axis=0)
            cropped_image = original_image_np[y_min:y_max+1, x_min:x_max+1]
            cropped_image = cropped_image.astype(np.uint8)

            plt.figure(figsize=(10, 5))
            plt.subplot(1, 2, 1)
            plt.imshow(masked_image)
            plt.title("Masked Image")

            plt.subplot(1, 2, 2)
            plt.imshow(cropped_image)
            plt.title("Cropped Image")
            plt.show()
        else:
            print("No leaf detected in the image.")
    else:
      print("Example Image not found.")
else:
    print("Mask directory is empty or does not exist.  Skipping U-Net training and inference.")
    print("Please create segmentation masks and place them in the 'masks' directory.")

Created mask directory and subdirectories: ../SAR-CLD-2024 A Comprehensive Review of Current Research, Challenges, and Future Directions/Original Dataset/masks
YOU MUST NOW MANUALLY CREATE SEGMENTATION MASKS.
Place the masks inside the corresponding class subfolders within the 'masks' directory.
For example, the mask for 'Original Dataset/Bacterial Blight/image1.jpg' should be at:
  'Original Dataset/masks/Bacterial Blight/image1_mask.png'


SystemExit: Please create masks and then re-run this cell and the following cells.

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
# @title Cell 4: Noise Reduction

def apply_noise_reduction(image_path):
    """Applies and visualizes different noise reduction techniques."""

    original_image = cv2.imread(image_path)
    if original_image is None:
        raise FileNotFoundError(f"Could not open or find the image at {image_path}")

    original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)  # Convert to RGB
    image = img_as_float(original_image) # Convert to float for skimage functions

    # --- Non-Local Means Denoising ---
    denoised_nl_means = denoise_nl_means(image, patch_size=7, patch_distance=11, h=0.1, multichannel=True)

    # --- Bilateral Filtering ---
    denoised_bilateral = denoise_bilateral(image, sigma_color=0.05, sigma_spatial=15, multichannel=True)

    # --- Wavelet Denoising ---
    denoised_wavelet = denoise_wavelet(image, multichannel=True, convert2ycbcr=True, method='BayesShrink')

    # --- OpenCV FastNlMeansDenoisingColored (for comparison) ---
    denoised_cv2 = cv2.fastNlMeansDenoisingColored(original_image, None, 10, 10, 7, 21)
    denoised_cv2 = cv2.cvtColor(denoised_cv2, cv2.COLOR_BGR2RGB) # Convert to RGB


    # --- Visualization ---
    fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(20, 5))

    axes[0].imshow(original_image)
    axes[0].set_title('Original')

    axes[1].imshow(img_as_ubyte(denoised_nl_means)) # Convert back to uint8 for display
    axes[1].set_title('NL Means (skimage)')

    axes[2].imshow(img_as_ubyte(denoised_bilateral))
    axes[2].set_title('Bilateral (skimage)')

    axes[3].imshow(img_as_ubyte(denoised_wavelet))
    axes[3].set_title('Wavelet (skimage)')

    axes[4].imshow(denoised_cv2)  # Already uint8
    axes[4].set_title('NL Means (OpenCV)')

    for ax in axes:
        ax.axis('off')

    plt.tight_layout()
    plt.show()

    return original_image, img_as_ubyte(denoised_nl_means), img_as_ubyte(denoised_bilateral), img_as_ubyte(denoised_wavelet), denoised_cv2


# --- Example Usage ---

# Create a dummy noisy image (for demonstration)
image_path = "noisy_image.jpg"
if not os.path.exists(image_path):
  img = np.zeros((256, 256, 3), dtype=np.uint8)
  img[:] = (50, 100, 150)  # Fill with a solid color
  noise = np.random.randint(0, 50, img.shape, dtype=np.uint8)  # Add random noise
  noisy_img = cv2.add(img, noise)
  cv2.imwrite(image_path, cv2.cvtColor(noisy_img, cv2.COLOR_RGB2BGR))  # Save as BGR for OpenCV


original, nl_means, bilateral, wavelet, cv2_nl_means = apply_noise_reduction(image_path)

# Example of integrating into a PyTorch data pipeline:
# Convert the denoised image (e.g., nl_means) to a PIL Image
# denoised_pil = Image.fromarray(nl_means)
# Then, you can use torchvision.transforms to convert it to a tensor:
# transform = transforms.Compose([transforms.ToTensor()])
# denoised_tensor = transform(denoised_pil)

In [None]:
# @title Cell 5: Leaf Orientation Normalization

def normalize_leaf_orientation(image_path, segmented_mask=None):
    """Normalizes leaf orientation using PCA on the segmentation mask."""

    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError(f"Could not open or find image at {image_path}")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    if segmented_mask is None:
        # If no mask provided, do simple thresholding (fallback)
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        _, thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
        mask = thresh.astype(np.uint8)
    else:
        mask = (segmented_mask * 255).astype(np.uint8)  # Ensure 0-255 range

    coords = np.argwhere(mask > 0)
    if coords.size == 0:
        print("Warning: No leaf pixels found. Returning original image.")
        return image, 0, mask

    # PCA
    mean = coords.mean(axis=0)
    centered_coords = coords - mean
    covariance_matrix = np.cov(centered_coords, rowvar=False)
    eigenvalues, eigenvectors = np.linalg.eig(covariance_matrix)

    principal_eigenvector = eigenvectors[:, 0]
    angle_rad = np.arctan2(principal_eigenvector[0], principal_eigenvector[1])
    angle_deg = np.degrees(angle_rad)

    # Rotate
    rotated_image = rotate(image, angle_deg, resize=True, preserve_range=True)
    rotated_mask = rotate(mask, angle_deg, resize=True, preserve_range=True, order=0)

    rotated_image = img_as_ubyte(rotated_image)
    rotated_mask = (rotated_mask > 0.5).astype(np.uint8)

    # Visualization
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
    axes[0].imshow(image)
    axes[0].set_title('Original Image')
    axes[1].imshow(mask, cmap='gray')
    axes[1].set_title('Segmentation Mask')
    center_y, center_x = mean
    axes[1].arrow(center_x, center_y, 50 * np.cos(angle_rad), 50*np.sin(angle_rad),
                 color='red', width=2, head_width=8)
    axes[2].imshow(rotated_image)
    axes[2].set_title(f'Rotated Image ({angle_deg:.2f}°)')

    for ax in axes:
        ax.axis('off')
    plt.tight_layout()
    plt.show()
    return rotated_image, angle_deg, rotated_mask

In [None]:
# @title Cell 6: Orientation Normalization - Example Usage

# --- Dummy Image and Mask Creation (for demonstration) ---
image_path = "orientation_image.jpg"
mask_path = "orientation_mask.png"

if not os.path.exists(image_path) or not os.path.exists(mask_path):
    img = np.zeros((256, 256, 3), dtype=np.uint8)
    for i in range(256):
        img[i, i] = (255, 255, 255)  # White diagonal line
    cv2.imwrite(image_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
    mask = np.zeros((256, 256), dtype=np.uint8)
    for i in range(256):
        mask[i, i] = 255
    cv2.imwrite(mask_path, mask)

# --- Example without a provided mask (using simple thresholding) ---
rotated_image, angle, _ = normalize_leaf_orientation(image_path)
print(f"Detected angle (no mask provided): {angle:.2f} degrees")

# --- Example *with* a mask (using U-Net output) ---
# This assumes you've run the U-Net training and inference (Cells 2 and 3).
if os.path.exists("unet_segmentation_model.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    seg_model = UNet(in_channels=3, out_channels=1)  # Instantiate U-Net
    seg_model.load_state_dict(torch.load("unet_segmentation_model.pth"))
    seg_model.to(device)

    #  Use a *real* image from your dataset for a proper test.
    #  This assumes your 'Original Dataset' has at least one image.
    if os.path.exists(image_directory) and any(os.scandir(image_directory)):
        example_class_dir = os.listdir(image_directory)[0]
        example_image_name = os.listdir(os.path.join(image_directory,example_class_dir))[0]
        example_image_path = os.path.join(image_directory, example_class_dir, example_image_name ) #Use the image that used in segmentation

        segmented_mask_example = segment_leaf(example_image_path, seg_model, device)
        rotated_image_with_mask, angle_with_mask, rotated_mask = normalize_leaf_orientation(example_image_path, segmented_mask_example)
        print(f"Detected angle (with mask provided): {angle_with_mask:.2f} degrees")

         # Example of integrating into PyTorch dataset:
        if rotated_mask is not None:
            # Convert rotated image and mask to PIL Images:
            rotated_image_pil = Image.fromarray(rotated_image)
            rotated_mask_pil = Image.fromarray((rotated_mask * 255).astype(np.uint8))

            # Now use torchvision.transforms:
            transform = A.Compose([
                    A.Resize(height=256, width=256),
                    A.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                    max_pixel_value=255.0,
                    ),
                    ToTensorV2(),
                ])
            augmented = transform(image=np.array(rotated_image_pil), mask=np.array(rotated_mask_pil))
            rotated_image_tensor = augmented['image']
            rotated_mask_tensor = augmented['mask']
    else:
        print("Skipping mask example: 'Original Dataset' is empty or does not exist.")

else:
    print("Skipping mask example: 'unet_segmentation_model.pth' not found. Run U-Net example first.")

In [None]:
# @title Cell 7: GAN-Based Synthetic Data Generation (cGAN)

# --- Simplified cGAN Generator ---
class Generator(nn.Module):
    def __init__(self, z_dim, img_channels, num_classes, features_g, img_size):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.embed = nn.Embedding(num_classes, img_size * img_size)
        self.net = nn.Sequential(
            # Input: N x (z_dim + num_classes) x 1 x 1  (noise + embedded label)
            self._block(z_dim + img_size*img_size, features_g * 16, 4, 1, 0),  # N x f_g*16 x 4 x 4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # N x f_g*8 x 8 x 8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # N x f_g*4 x 16 x 16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # N x f_g*2 x 32 x 32
            nn.ConvTranspose2d(
                features_g * 2, img_channels, kernel_size=4, stride=2, padding=1
            ),  # N x img_channels x 64 x 64
            nn.Tanh(), # [-1, 1] range
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x, labels):
      # Latent vector x: [N, z_dim, 1, 1]
      # labels: [N]
      embedding = self.embed(labels).view(labels.shape[0], self.img_size, self.img_size)
      x = torch.cat([x, embedding], dim=1)  # Concatenate noise with embedded label
      return self.net(x)


# --- Simplified cGAN Discriminator ---
class Discriminator(nn.Module):
    def __init__(self, img_channels, num_classes, features_d, img_size):
        super(Discriminator, self).__init__()
        self.img_size = img_size
        self.embed = nn.Embedding(num_classes, img_size * img_size)
        self.net = nn.Sequential(
            # Input: N x (img_channels + num_classes) x 64 x 64
            nn.Conv2d(img_channels + 1, features_d, kernel_size=4, stride=2, padding=1), # +1 for label
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),  # Output a single probability (real/fake)
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x, labels):
        # x: [N, img_channels, img_size, img_size]
        # labels: [N]
        embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size)
        x = torch.cat([x, embedding], dim=1) # Concatenate image with embedded label
        return self.net(x)

# --- Dataset ---
class CottonLeafDiseaseDataset(Dataset):
    def __init__(self, root_dir, transform=None, num_classes = 4): #Assuming there is 4 disease classes
        self.root_dir = root_dir
        self.transform = transform
        self.num_classes = num_classes
        self.image_paths = []
        self.labels = []

        # Iterate through class directories (assuming each disease has its own folder)
        for class_idx in range(num_classes):
          class_dir = os.path.join(root_dir, str(class_idx))
          if not os.path.isdir(class_dir):
              #Skip files.
              continue
          for filename in os.listdir(class_dir):
              if filename.endswith(('.jpg', '.jpeg', '.png')):  # Add other extensions if needed
                  self.image_paths.append(os.path.join(class_dir, filename))
                  self.labels.append(class_idx)


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            image = self.transform(image=np.array(image))['image']

        return image, label

# --- Training Loop ---
def train_cgan(root_dir, epochs=100, batch_size=64, z_dim=100, lr=2e-4, img_size = 64, num_classes=4):
    # Transforms
    transform = A.Compose(
        [
            A.Resize(height=img_size, width=img_size),
            A.HorizontalFlip(p=0.5),
            A.Normalize(
                mean=[0.5, 0.5, 0.5],
                std=[0.5, 0.5, 0.5],
                max_pixel_value=255,
            ),
            ToTensorV2(),
        ]
    )

    dataset = CottonLeafDiseaseDataset(root_dir, transform=transform, num_classes=num_classes)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) # Added num_workers

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize generator and discriminator
    netG = Generator(z_dim, 3, num_classes, 64, img_size).to(device)  # 3 for RGB channels
    netD = Discriminator(3, num_classes, 64, img_size).to(device)

    # Optimizers
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

    # Loss function
    criterion = nn.BCELoss()

    # Create a directory for saving generated images
    output_dir = "synthetic_images"
    os.makedirs(output_dir, exist_ok=True)

    print("Starting Training...")
    for epoch in range(epochs):
        netG.train()
        netD.train()
        loop = tqdm(dataloader, total=len(dataloader), leave=False)
        for batch_idx, (real_images, labels) in enumerate(loop):
            real_images = real_images.to(device)
            labels = labels.to(device)
            batch_size_current = real_images.size(0) # Get actual batch size (last batch might be smaller)

            # --- Train Discriminator: max log(D(x)) + log(1 - D(G(z))) ---
            netD.zero_grad()
            # 1. Train with real images
            real_labels = torch.ones(batch_size_current, 1, 1, 1).to(device) # Real label: 1
            output_real = netD(real_images, labels).view(-1, 1, 1, 1) # Get the prediction
            lossD_real = criterion(output_real, real_labels)
            lossD_real.backward()


            # 2. Train with fake images
            noise = torch.randn(batch_size_current, z_dim, 1, 1).to(device)
            fake_labels = torch.randint(0, num_classes, (batch_size_current,)).to(device)  # Random class labels
            fake_images = netG(noise, fake_labels)
            fake_labels_tensor = torch.ones(batch_size_current, 1, 1, 1).to(device) * 0  # Fake label: 0
            output_fake = netD(fake_images.detach(), fake_labels).view(-1, 1, 1, 1) # Detach to avoid training G
            lossD_fake = criterion(output_fake, fake_labels_tensor)
            lossD_fake.backward()

            lossD = lossD_real + lossD_fake
            optimizerD.step()

            # --- Train Generator: min log(1 - D(G(z)))  <-> max log(D(G(z)) ---
            netG.zero_grad()
            output = netD(fake_images, fake_labels).view(-1, 1, 1, 1)
            lossG = criterion(output, real_labels) # We want G to fool D (predict 1)
            lossG.backward()
            optimizerG.step()


            # update tqdm loop
            loop.set_description(f"Epoch [{epoch}/{epochs}]")
            loop.set_postfix(lossD=lossD.item(), lossG = lossG.item())

        print(f"Epoch {epoch+1}/{epochs}, Loss D: {lossD.item():.4f}, Loss G: {lossG.item():.4f}")


        # --- Generate and Save Synthetic Images (after each epoch) ---
        netG.eval()  # Set generator to evaluation mode
        with torch.no_grad():
            for class_idx in range(num_classes):
                fixed_noise = torch.randn(16, z_dim, 1, 1).to(device)  # Generate 16 images per class
                fixed_labels = torch.full((16,), class_idx, dtype=torch.long).to(device)
                fake_images = netG(fixed_noise, fixed_labels)

                # Denormalize images (from [-1, 1] to [0, 1])
                fake_images = fake_images * 0.5 + 0.5

                # Save images
                for i in range(16):
                    img = fake_images[i].cpu().permute(1, 2, 0).numpy()  # CHW to HWC
                    img = (img * 255).astype(np.uint8) # Scale to 0-255
                    img_pil = Image.fromarray(img)
                    img_pil.save(os.path.join(output_dir, f"epoch_{epoch}_class_{class_idx}_img_{i}.png"))
    return netG, netD

In [None]:
# @title Cell 8: GAN Training and Image Generation

# --- Create a dummy dataset structure (for demonstration) ---
root_directory = "../SAR-CLD-2024 A Comprehensive Review of Current Research, Challenges, and Future Directions/Original Dataset"
num_classes = 4  # Example: 4 disease classes
img_size = 64

# Check if the directory exists and contains subdirectories for each class.
# If not, create dummy data.
if not os.path.exists(root_directory) or not any(os.scandir(root_directory)):
    os.makedirs(root_directory, exist_ok=True)
    for i in range(num_classes):
        class_dir = os.path.join(root_directory, str(i))
        os.makedirs(class_dir, exist_ok=True)
        for j in range(20):
            img = np.random.randint(0, 256, size=(img_size, img_size, 3), dtype=np.uint8)
            img_pil = Image.fromarray(img)
            img_pil.save(os.path.join(class_dir, f"img_{j}.jpg"))
    print(f"Created dummy dataset at: {root_directory}")


# --- Train the cGAN ---
# Use a smaller number of epochs for demonstration.  Increase for real training.
trained_generator, trained_discriminator = train_cgan(root_directory, epochs=5, num_classes=num_classes)


# --- Example of using the trained generator ---
def generate_synthetic_images(generator, z_dim, num_classes, num_samples_per_class, device):
    generator.eval()
    synthetic_images = []
    synthetic_labels = []
    with torch.no_grad():
        for class_idx in range(num_classes):
            noise = torch.randn(num_samples_per_class, z_dim, 1, 1).to(device)
            labels = torch.full((num_samples_per_class,), class_idx, dtype=torch.long).to(device)
            fake_images = generator(noise, labels)
            fake_images = fake_images * 0.5 + 0.5  # Denormalize
            synthetic_images.append(fake_images)
            synthetic_labels.extend([class_idx] * num_samples_per_class)
    return torch.cat(synthetic_images, dim=0), synthetic_labels

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
z_dim = 100
num_samples_per_class = 10
synthetic_images, synthetic_labels = generate_synthetic_images(trained_generator, z_dim, num_classes, num_samples_per_class, device)

# --- Visualize Synthetic Images ---
fig, axes = plt.subplots(nrows=num_classes, ncols=num_samples_per_class, figsize=(12, 6))
for i in range(len(synthetic_images)):
    img = synthetic_images[i].cpu().permute(1, 2, 0).numpy()
    img = (img * 255).astype(np.uint8)
    row = i // num_samples_per_class
    col = i % num_samples_per_class
    axes[row, col].imshow(img)
    axes[row, col].axis('off')
    if col == 0:
        axes[row, col].set_ylabel(f"Class {synthetic_labels[i]}", rotation=0, labelpad=30)
plt.tight_layout()
plt.show()

In [None]:
# @title Cell 9: Datasets for Real and Synthetic Data

# --- Dataset for REAL Cotton Leaf Images (for training and validation) ---
class RealCottonLeafDataset(Dataset):
    def __init__(self, root_dir, transform=None, train=True, validation_split=0.2):
        self.root_dir = root_dir
        self.transform = transform
        self.train = train
        self.image_paths = []
        self.labels = []

        # Assuming same directory structure as before (class folders)
        for class_idx in range(4):  # Adjust num_classes if needed.
            class_dir = os.path.join(root_dir, str(class_idx))
            if not os.path.isdir(class_dir):
                continue  # Skip if the class directory doesn't exist
            for filename in os.listdir(class_dir):
                if filename.endswith(('.jpg', '.jpeg', '.png')):
                    self.image_paths.append(os.path.join(class_dir, filename))
                    self.labels.append(class_idx)

        # Split into training and validation sets
        train_paths, val_paths, train_labels, val_labels = train_test_split(
            self.image_paths, self.labels, test_size=validation_split, random_state=42, stratify=self.labels
        )

        if self.train:
            self.image_paths = train_paths
            self.labels = train_labels
        else:
            self.image_paths = val_paths
            self.labels = val_labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
           image = self.transform(image=np.array(image))['image']
        return image, label


# --- Dataset for SYNTHETIC Cotton Leaf Images ---
class SyntheticCottonLeafDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Assuming images are saved as "epoch_{epoch}_class_{class_idx}_img_{i}.png"
        for filename in os.listdir(self.root_dir):
            if filename.endswith(".png") and filename.startswith("epoch_"):
                parts = filename.split("_")
                try:
                    class_idx = int(parts[3])  # Extract class index
                    self.image_paths.append(os.path.join(self.root_dir, filename))
                    self.labels.append(class_idx)
                except (ValueError, IndexError):
                    print(f"Warning: Skipping malformed filename: {filename}")
                    continue

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image=np.array(image))['image']

        return image, label

In [None]:
# @title Cell 10: Combined Training Function

# --- Combined Training Function (Example with a simple CNN) ---
def train_combined(real_data_dir, synthetic_data_dir, epochs=20, batch_size=32, learning_rate=1e-4, synthetic_ratio=0.5):
    """
    Trains a model using a combination of real and synthetic data.

    Args:
        real_data_dir: Path to the directory containing real cotton leaf images.
        synthetic_data_dir: Path to the directory containing synthetic images.
        epochs: Number of training epochs.
        batch_size: Batch size.
        learning_rate: Learning rate.
        synthetic_ratio:  The proportion of synthetic data in each batch (0.0 to 1.0).
    """

    # Transforms (adjust as needed for your main model)
    train_transform = A.Compose([
        A.Resize(height=224, width=224), # Example size
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.RandomRotate90(p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    val_transform = A.Compose([ # Usually less augmentation for validation
        A.Resize(height=224, width=224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])


    # Create datasets
    real_train_dataset = RealCottonLeafDataset(real_data_dir, transform=train_transform, train=True)
    real_val_dataset = RealCottonLeafDataset(real_data_dir, transform=val_transform, train=False)  # Validation set (only real data)
    synthetic_train_dataset = SyntheticCottonLeafDataset(synthetic_data_dir, transform=train_transform)

    # Create a combined dataset using ConcatDataset
    combined_dataset = ConcatDataset([real_train_dataset, synthetic_train_dataset])

    # Create DataLoaders.  We'll use a custom sampler to control the synthetic ratio.
    #  The built-in sampler doesn't guarantee a specific ratio *within each batch*.

    class MixedBatchSampler(torch.utils.data.Sampler):
        def __init__(self, real_dataset, synthetic_dataset, batch_size, synthetic_ratio):
            self.real_indices = list(range(len(real_dataset)))
            self.synthetic_indices = list(range(len(real_dataset), len(real_dataset) + len(synthetic_dataset)))
            self.batch_size = batch_size
            self.synthetic_ratio = synthetic_ratio
            self.num_synthetic = int(batch_size * synthetic_ratio)
            self.num_real = batch_size - self.num_synthetic
            self.length = len(real_dataset) // self.num_real  # Determine how many batches we can create

        def __iter__(self):
            np.random.shuffle(self.real_indices)
            np.random.shuffle(self.synthetic_indices)
            real_iter = iter(self.real_indices)
            synth_iter = iter(self.synthetic_indices)

            for _ in range(self.length):
                batch = []
                # Get real samples for this batch
                for _ in range(self.num_real):
                    try:
                        batch.append(next(real_iter))
                    except StopIteration:  # Handle edge case if real data runs out
                        real_iter = iter(self.real_indices)
                        batch.append(next(real_iter))
                # Get synthetic samples for this batch
                for _ in range(self.num_synthetic):
                    try:
                        batch.append(next(synth_iter))
                    except StopIteration:  # Handle edge case
                        synth_iter = iter(self.synthetic_indices)
                        batch.append(next(synth_iter))
                np.random.shuffle(batch)  # Shuffle the combined batch
                yield batch

        def __len__(self):
            return self.length # Number of *batches*


    sampler = MixedBatchSampler(real_train_dataset, synthetic_train_dataset, batch_size, synthetic_ratio)
    train_loader = DataLoader(combined_dataset, batch_sampler=sampler, num_workers=4)
    val_loader = DataLoader(real_val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)


    # --- Model (Replace with your ViT or Ensemble) ---
    # For demonstration, use a simple CNN.  *You should replace this with your actual model.*
    model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 4)  # 4 classes (adjust as needed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss() #Use cross entropy for multi class classification


    # --- Training Loop ---
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        correct_train = 0
        total_train = 0
        loop = tqdm(train_loader, total=len(train_loader), leave=False)
        for batch_idx, (inputs, labels) in enumerate(loop):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
            loop.set_description(f"Epoch [{epoch+1}/{epochs}]")
            loop.set_postfix(loss=loss.item(), acc=(100 * correct_train / total_train))


        train_accuracy = 100 * correct_train / total_train
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%")

        # --- Validation (on REAL data) ---
        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        val_accuracy = 100 * correct_val / total_val
        print(f"Epoch {epoch+1}/{epochs}, Val Loss: {val_loss/len(val_loader):.4f}, Val Accuracy: {val_accuracy:.2f}%")


    print("Finished Training")
    return model

In [None]:
# @title Cell 11: Example Usage of Combined Training

# --- Set up paths ---
real_data_root = "../SAR-CLD-2024 A Comprehensive Review of Current Research, Challenges, and Future Directions/Original Dataset"
synthetic_data_root = "synthetic_images"  # Should have been created by the GAN

# --- Create Dummy Real Data (if needed) ---
# This is for demonstration purposes.  In a real scenario, you would already have your real data.
num_classes = 4
img_size = 224
if not os.path.exists(real_data_root) or not any(os.scandir(os.path.join(real_data_root, str(0)))): #Check if the path or the subfolders exists
    os.makedirs(real_data_root, exist_ok=True)
    for i in range(num_classes):
        class_dir = os.path.join(real_data_root, str(i))
        os.makedirs(class_dir, exist_ok=True)
        for j in range(30):  # Create more real images than synthetic
            img = np.random.randint(0, 256, size=(img_size, img_size, 3), dtype=np.uint8)
            img_pil = Image.fromarray(img)
            img_pil.save(os.path.join(class_dir, f"real_img_{j}.jpg"))
    print(f"Created dummy REAL dataset at: {real_data_root}")

# --- Check if synthetic data exists ---
if not os.path.exists(synthetic_data_root):
    print("ERROR: Synthetic data directory not found. Run GAN training first.")
else:
    # --- Train the combined model ---
    trained_model = train_combined(real_data_root, synthetic_data_root, epochs=10, synthetic_ratio=0.3)

    # --- Save the trained model ---
    torch.save(trained_model.state_dict(), "combined_model.pth")
    print("Trained model saved to 'combined_model.pth'")

In [None]:
# @title 4.3 Example Usage and Model Saving (Cell 12)
# Create Dummy Data (replace with your actual paths)
real_data_root = "dummy_real_dataset"
synthetic_data_root = "synthetic_images"  # From the GAN training output

# Create dummy real data (similar to the GAN dummy data)
if not os.path.exists(real_data_root):
    os.makedirs(real_data_root)
    num_classes = 4
    img_size = 224 #Should match with the training configuration
    for i in range(num_classes):
        class_dir = os.path.join(real_data_root, str(i))
        os.makedirs(class_dir)
        for j in range(30): # More images per class for real data
            img = np.random.randint(0, 256, size=(img_size, img_size, 3), dtype=np.uint8)
            img_pil = Image.fromarray(img)
            img_pil.save(os.path.join(class_dir, f"real_img_{j}.jpg"))

    print(f"Created dummy REAL dataset at: {real_data_root}")

# Check if synthetic data exists (assuming you ran the GAN training)
if not os.path.exists(synthetic_data_root):
    print("ERROR: Synthetic data directory not found.  Run the GAN training code first.")
else:
    # Train the combined model
    trained_model = train_combined(real_data_root, synthetic_data_root, epochs=10, synthetic_ratio=0.3)  # Example: 30% synthetic data

    # After training, you can save and evaluate your model as usual.
    torch.save(trained_model.state_dict(), "combined_model.pth")

In [None]:
# @title Cell 12 (Optional): Loading and Evaluating the Combined Model

# --- Example of loading and evaluating the model ---
# This is optional and needs to be adapted to your specific evaluation setup.
# It assumes you have a separate *test set* of real images.

def evaluate_model(model_path, test_data_dir):
    """Loads a saved model and evaluates it on a test dataset."""

    # --- Model (must match the architecture used during training) ---
    model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 4)  # 4 classes
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # --- Load the saved model weights ---
    model.load_state_dict(torch.load(model_path))
    model.eval()  # Set to evaluation mode

    # --- Transforms for the test set (usually no augmentation) ---
    test_transform = A.Compose([
        A.Resize(height=224, width=224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    # --- Create a test dataset (assuming same directory structure as training) ---
    test_dataset = RealCottonLeafDataset(test_data_dir, transform=test_transform, train=False) # Use Real dataset and set train to False to get all
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # --- Evaluation Loop ---
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_predictions.extend(predicted.cpu().numpy())  # For confusion matrix
            all_labels.extend(labels.cpu().numpy())

    accuracy = 100 * correct / total
    print(f"Accuracy on the test set: {accuracy:.2f}%")

    # --- Optional: Confusion Matrix ---
    from sklearn.metrics import confusion_matrix, classification_report
    import seaborn as sns
    cm = confusion_matrix(all_labels, all_predictions)
    print("Confusion Matrix:\n", cm)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=range(4), yticklabels=range(4)) #Modify xticklabels, yticklabels based on dataset
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title("Confusion Matrix")
    plt.show()

    # --- Optional: Classification Report ---
    print(classification_report(all_labels, all_predictions))
# --- Example Usage (uncomment and modify paths if you have a test set) ---
# test_data_directory = "path/to/your/test/data"  # Replace with your test data path
# evaluate_model("combined_model.pth", test_data_directory)