
---

# **Tomato Segmentation with U-Net**  

This notebook provides an **end-to-end implementation** of tomato image segmentation using a U-Net convolutional neural network. Designed to run on **Google Colab with free GPU acceleration**

### 🚀 **How to Use This Notebook**  
1. **Enable GPU**:  
   Go to `Runtime` → `Change runtime type` → Select `GPU`  
   *(Verify with `torch.cuda.is_available()` in Cell 1)*  

2. **Upload Your Data**:  
   - Run Cell 2 to upload `Tomato_dataset.zip`  
   - Expected structure:  
     ```
     Tomato_dataset/  
     ├── Train/      # Training images  
     ├── Mask/       # Training masks  
     ├── Test2/       # Test images  
     └── Mask2/ # Ground truth masks (for evaluation)  
     ```  

3. **Execute Cells Sequentially**:  
   - Cells 3-7: Build and train the U-Net model  
   - Cells 8-9: Evaluate on test images and visualize results  
   - Cell 10-11: Download predictions  



### 📈 **Example Output**  
After training, you'll get:  
- Trained model (`unet_tomato.pth`)  
- Predicted masks for test images  
- Performance metrics like:  
  ```
  IoU: 0.82 | Dice: 0.90  
  Precision: 0.88 | Recall: 0.85  
  ```

📂 **Output Folder Contents**  
After processing, `/content/Tomato_dataset/output_masks/` contains:  

- `[image_name]_mask.png` → Predicted segmentation mask (binary image)  
- `vis_[image_name].jpg` → **Visual comparison** with three panels:  
  - **Left**: Original input image  
  - **Middle**: Predicted mask (white=tomato, black=background)  
  - **Right**: Ground truth mask (if provided in `Test_Masks/`)  

---




In [1]:
import torch
print(torch.cuda.is_available())  # Should return True

True


## Step 1/2: Dataset loading and preparation

In [2]:
from google.colab import files
import zipfile
import os

# Upload the zip file (click the upload button or use this code)
uploaded = files.upload()  # Select Tomato_dataset.zip

# Extract the zip file
with zipfile.ZipFile('Tomato_dataset.zip', 'r') as zip_ref:
    zip_ref.extractall('/content/')  # Extracts to /content/Tomato_dataset/

# Verify extraction
print(os.listdir('/content/Tomato_dataset/'))  # Should show ['Train', 'Mask']

Saving Tomato_dataset.zip to Tomato_dataset.zip
['Mask2', 'Mask', 'Test', 'Train2', 'Test2', 'Train']


## Step 3/6: Dataset class with preprocessing


In [3]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class TomatoDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

        # Only include image files (exclude hidden files/folders)
        self.images = [
            f for f in os.listdir(image_dir)
            if f.endswith(('.jpg'))  # Add your image formats
            and not f.startswith('.')  # Exclude hidden files
        ]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_name = img_name
        mask_path = os.path.join(self.mask_dir, mask_name)

        try:
            image = Image.open(img_path).convert('RGB')
            mask = Image.open(mask_path).convert('L')

            if self.transform:
                image = self.transform(image)
                mask = transforms.Resize((256, 256), interpolation=Image.NEAREST)(mask)
                mask = transforms.ToTensor()(mask)
                mask = (mask > 0).float()

            return image, mask

        except Exception as e:
            print(f"Error loading {img_name}: {str(e)}")
            # Return a placeholder or skip this sample
            return None

## Step 4/5: U-Net implementation

In [4]:
import torch.nn as nn

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super(UNet, self).__init__()

        # Encoder (Downsampling)
        self.down1 = DoubleConv(n_channels, 64)
        self.down2 = DoubleConv(64, 128)
        self.down3 = DoubleConv(128, 256)
        self.down4 = DoubleConv(256, 512)
        self.pool = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = DoubleConv(512, 1024)

        # Decoder (Upsampling)
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up_conv1 = DoubleConv(1024, 512)

        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_conv2 = DoubleConv(512, 256)

        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv3 = DoubleConv(256, 128)

        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv4 = DoubleConv(128, 64)

        # Output layer
        self.out = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.down1(x)
        x2 = self.pool(x1)
        x2 = self.down2(x2)
        x3 = self.pool(x2)
        x3 = self.down3(x3)
        x4 = self.pool(x3)
        x4 = self.down4(x4)
        x5 = self.pool(x4)

        # Bottleneck
        x5 = self.bottleneck(x5)

        # Decoder with skip connections
        x = self.up1(x5)
        x = torch.cat([x, x4], dim=1)
        x = self.up_conv1(x)

        x = self.up2(x)
        x = torch.cat([x, x3], dim=1)
        x = self.up_conv2(x)

        x = self.up3(x)
        x = torch.cat([x, x2], dim=1)
        x = self.up_conv3(x)

        x = self.up4(x)
        x = torch.cat([x, x1], dim=1)
        x = self.up_conv4(x)

        # Output
        return torch.sigmoid(self.out(x))

## Step 12: Evaluation metrics

In [5]:
# ==================== EVALUATION METRICS ====================
def calculate_metrics(pred, target, threshold=0.5):
    """Computes IoU, Dice, Accuracy, Precision, Recall"""
    pred_bin = (pred > threshold).float()
    target_bin = target.float()

    tp = torch.sum(pred_bin * target_bin).item()
    fp = torch.sum(pred_bin * (1 - target_bin)).item()
    fn = torch.sum((1 - pred_bin) * target_bin).item()
    tn = torch.sum((1 - pred_bin) * (1 - target_bin)).item()

    eps = 1e-6
    return {
        'iou': tp / (tp + fp + fn + eps),
        'dice': (2 * tp) / (2 * tp + fp + fn + eps),
        'accuracy': (tp + tn) / (tp + tn + fp + fn + eps),
        'precision': tp / (tp + fp + eps),
        'recall': tp / (tp + fn + eps)
    }

def evaluate_model(model, dataloader, device):
    """Evaluates model on a DataLoader"""
    model.eval()
    metrics = {'iou':0, 'dice':0, 'accuracy':0, 'precision':0, 'recall':0}

    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            batch_metrics = calculate_metrics(outputs, masks)

            for key in metrics:
                metrics[key] += batch_metrics[key]

    for key in metrics:
        metrics[key] /= len(dataloader)

    return metrics

## Step 7/8/9: Training setup

In [6]:
import torch.optim as optim


# Step 6: Preprocessing
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Step 2-3: Dataset preparation
# Use absolute paths in Colab
dataset = TomatoDataset(
    image_dir='/content/Tomato_dataset/Train',  # Full path
    mask_dir='/content/Tomato_dataset/Mask',    # Full path
    transform=transform
)

# Split into train and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False)



# Step 5: Model implementation
# Initialize model, loss, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet().to(device)
criterion = nn.BCELoss()   # Step 8: Loss function : Binary Cross Entropy Loss
optimizer = optim.Adam(model.parameters(), lr=0.001) # Step 9: Optimization

## Step 7/10: Training loop

In [7]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}')

    # Validation
    model.eval()
    val_loss = 0.0
    val_metrics = {'iou':0, 'dice':0, 'accuracy':0, 'precision':0, 'recall':0}

    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)

            val_loss += criterion(outputs, masks).item()
            batch_metrics = calculate_metrics(outputs, masks)

            for key in val_metrics:
                val_metrics[key] += batch_metrics[key]

    # Average metrics
    for key in val_metrics:
        val_metrics[key] /= len(val_loader)

    print(f'Validation Loss: {val_loss/len(val_loader):.4f}')
    print(f'IoU: {val_metrics["iou"]:.4f} | Dice: {val_metrics["dice"]:.4f} | Acc: {val_metrics["accuracy"]:.4f}')
    print(f'Precision: {val_metrics["precision"]:.4f} | Recall: {val_metrics["recall"]:.4f}')

Epoch 1, Loss: 0.3615
Validation Loss: 0.2918
IoU: 0.7463 | Dice: 0.8480 | Acc: 0.9594
Precision: 0.8775 | Recall: 0.8347
Epoch 2, Loss: 0.2444
Validation Loss: 0.2101
IoU: 0.8038 | Dice: 0.8828 | Acc: 0.9641
Precision: 0.8392 | Recall: 0.9526
Epoch 3, Loss: 0.2012
Validation Loss: 0.1756
IoU: 0.7951 | Dice: 0.8767 | Acc: 0.9610
Precision: 0.8179 | Recall: 0.9677
Epoch 4, Loss: 0.1792
Validation Loss: 0.1488
IoU: 0.8002 | Dice: 0.8796 | Acc: 0.9626
Precision: 0.8321 | Recall: 0.9557
Epoch 5, Loss: 0.1546
Validation Loss: 0.1359
IoU: 0.8076 | Dice: 0.8856 | Acc: 0.9658
Precision: 0.8514 | Recall: 0.9432
Epoch 6, Loss: 0.1460
Validation Loss: 0.1339
IoU: 0.7829 | Dice: 0.8682 | Acc: 0.9576
Precision: 0.8088 | Recall: 0.9613
Epoch 7, Loss: 0.1322
Validation Loss: 0.1084
IoU: 0.8086 | Dice: 0.8861 | Acc: 0.9656
Precision: 0.8416 | Recall: 0.9560
Epoch 8, Loss: 0.1305
Validation Loss: 0.1233
IoU: 0.7958 | Dice: 0.8769 | Acc: 0.9612
Precision: 0.8186 | Recall: 0.9670
Epoch 9, Loss: 0.1220
Va

In [8]:
torch.save(model.state_dict(), '/content/Tomato_dataset/unet_tomato.pth')

In [9]:
import os
import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# Load the trained model
def load_model(model_path, device):
    model = UNet().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

# Preprocess the image (same as during training)
def preprocess_image(image_path, transform):
    image = Image.open(image_path).convert('RGB')
    original_size = image.size  # Store original size
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image, original_size

# Postprocess the mask (convert to binary and resize to original)
def postprocess_mask(mask_tensor, original_size, threshold=0.5):
    mask = mask_tensor.squeeze().cpu().numpy()  # Remove batch dim and convert to numpy
    mask = (mask > threshold).astype(np.uint8) * 255  # Threshold and scale to 0-255
    mask = Image.fromarray(mask).resize(original_size, Image.NEAREST)
    return mask

def process_images(model, image_dir, mask_dir, output_dir, transform, device):
    """
    Process images, generate masks, and evaluate against ground truth

    Args:
        model: Trained U-Net model
        image_dir: Directory with input images
        mask_dir: Directory with ground truth masks
        output_dir: Where to save predictions
        transform: Image transformations
        device: CUDA/CPU device
    """
    os.makedirs(output_dir, exist_ok=True)
    image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg') and not f.startswith('.')]

    # Initialize metrics
    metrics = {
        'iou': [],
        'dice': [],
        'accuracy': [],
        'precision': [],
        'recall': []
    }

    for img_file in image_files:
        try:
            # Load and preprocess image
            img_path = os.path.join(image_dir, img_file)
            image_tensor, original_size = preprocess_image(img_path, transform)
            image_tensor = image_tensor.to(device)

            # Generate prediction
            with torch.no_grad():
                mask_tensor = model(image_tensor)

            # Postprocess predicted mask
            pred_mask = postprocess_mask(mask_tensor, original_size)
            pred_mask_np = np.array(pred_mask)

            # Save predicted mask
            mask_filename = os.path.splitext(img_file)[0] + '_mask.png'
            pred_mask.save(os.path.join(output_dir, mask_filename))

            # Load ground truth mask
            true_mask_path = os.path.join(mask_dir, img_file)
            if os.path.exists(true_mask_path):
                true_mask = np.array(Image.open(true_mask_path).convert('L'))

                # Calculate metrics
                pred_bin = (pred_mask_np > 128).astype(np.uint8)  # Threshold at 128
                true_bin = (true_mask > 128).astype(np.uint8)     # Threshold at 128

                intersection = np.logical_and(pred_bin, true_bin)
                union = np.logical_or(pred_bin, true_bin)

                tp = np.sum(intersection)
                fp = np.sum(pred_bin) - tp
                fn = np.sum(true_bin) - tp
                tn = np.sum(np.logical_not(union))

                eps = 1e-6  # Avoid division by zero

                # Store metrics
                metrics['iou'].append(tp / (tp + fp + fn + eps))
                metrics['dice'].append((2 * tp) / (2 * tp + fp + fn + eps))
                metrics['accuracy'].append((tp + tn) / (tp + tn + fp + fn + eps))
                metrics['precision'].append(tp / (tp + fp + eps))
                metrics['recall'].append(tp / (tp + fn + eps))

                # Visualize with metrics
                plt.figure(figsize=(18, 6))

                plt.subplot(1, 3, 1)
                plt.imshow(Image.open(img_path))
                plt.title('Original Image')
                plt.axis('off')

                plt.subplot(1, 3, 2)
                plt.imshow(pred_mask, cmap='gray')
                plt.title('Predicted Mask')
                plt.axis('off')

                plt.subplot(1, 3, 3)
                plt.imshow(true_mask, cmap='gray')
                plt.title('Ground Truth')
                plt.axis('off')

                plt.savefig(os.path.join(output_dir, f'vis_{img_file}'), bbox_inches='tight')
                plt.close()

            print(f"Processed {img_file}")

        except Exception as e:
            print(f"Error processing {img_file}: {str(e)}")

    # Print average metrics if ground truth was available
    if metrics['iou']:
        print("\nAverage Evaluation Metrics:")
        print(f"IoU: {np.mean(metrics['iou']):.4f}")
        print(f"Dice Coefficient: {np.mean(metrics['dice']):.4f}")
        print(f"Accuracy: {np.mean(metrics['accuracy']):.4f}")
        print(f"Precision: {np.mean(metrics['precision']):.4f}")
        print(f"Recall: {np.mean(metrics['recall']):.4f}")

# Visualization function
def visualize_result(image_path, mask, save_path=None):
    image = Image.open(image_path)

    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(mask, cmap='gray')
    plt.title('Predicted Mask')
    plt.axis('off')

    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        plt.close()
    else:
        plt.show()



# Main execution
if __name__ == "__main__":
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define transforms (same as training)
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Paths
    model_path = '/content/Tomato_dataset/unet_tomato.pth'  # Update with your saved model path
    test_image_dir = '/content/Tomato_dataset/Test2'  # Directory with new images to test
    output_dir = '/content/Tomato_dataset/output_masks'    # Where to save results
    mask_dir = '/content/Tomato_dataset/Mask2'
    # Load model
    model = load_model(model_path, device)

    # Process images
    process_images(model, test_image_dir, mask_dir, output_dir, transform, device)

Processed Img10.jpg
Processed Img6.jpg
Processed Img20.jpg
Processed Img12.jpg
Processed Img11.jpg
Processed Img18.jpg
Processed Img13.jpg
Processed Img3.jpg
Processed Img15.jpg
Processed Img9.jpg
Processed Img4.jpg
Processed Img1.jpg
Processed Img16.jpg
Processed Img14.jpg
Processed Img8.jpg
Processed Img17.jpg
Processed Img2.jpg
Processed Img5.jpg
Processed Img7.jpg
Processed Img19.jpg

Average Evaluation Metrics:
IoU: 0.8048
Dice Coefficient: 0.8491
Accuracy: 0.9676
Precision: 0.8130
Recall: 0.8902


In [10]:
# Zip the dataset folder (replace 'tomato_dataset' with your folder name)
!zip -r tomato_dataset.zip /content/Tomato_dataset


  adding: content/Tomato_dataset/ (stored 0%)
  adding: content/Tomato_dataset/Mask2/ (stored 0%)
  adding: content/Tomato_dataset/Mask2/Img10.jpg (deflated 39%)
  adding: content/Tomato_dataset/Mask2/Img6.jpg (deflated 47%)
  adding: content/Tomato_dataset/Mask2/Img20.jpg (deflated 52%)
  adding: content/Tomato_dataset/Mask2/Img12.jpg (deflated 49%)
  adding: content/Tomato_dataset/Mask2/Img11.jpg (deflated 39%)
  adding: content/Tomato_dataset/Mask2/Img18.jpg (deflated 40%)
  adding: content/Tomato_dataset/Mask2/Img13.jpg (deflated 44%)
  adding: content/Tomato_dataset/Mask2/Img3.jpg (deflated 84%)
  adding: content/Tomato_dataset/Mask2/Img15.jpg (deflated 41%)
  adding: content/Tomato_dataset/Mask2/Img9.jpg (deflated 46%)
  adding: content/Tomato_dataset/Mask2/Img4.jpg (deflated 31%)
  adding: content/Tomato_dataset/Mask2/Img1.jpg (deflated 44%)
  adding: content/Tomato_dataset/Mask2/Img16.jpg (deflated 44%)
  adding: content/Tomato_dataset/Mask2/Img14.jpg (deflated 84%)
  adding: c

In [11]:
from google.colab import files
files.download('tomato_dataset.zip')  # Download outputs

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>