# Image Segmentation using PyTorch

In this notebook, we will perform image segmentation on images stored in the `data/cam1` directory using a pre-trained U-Net model.

## Step 1: Setup and Imports

Before we import the libraries, we need to install the libraries.

In [None]:
pip install torch torchvision numpy matplotlib opencv-python

Now we can move onto import the necessary libraries.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models.segmentation import fcn_resnet50

## Step 2: Define the Dataset

We need to define a custom dataset class to handle the images in the `data/cam1` directory.

In [None]:
# Define a simple dataset
class ImageDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.transform = transform
        self.image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_folder, self.image_files[idx])
        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)
        return image

# Define the transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Load the dataset
dataset = ImageDataset(image_folder='../data/cam1', transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

## Step 3: Load a Pre-trained Model

We will use a pre-trained Fully Convolutional Network (FCN) for segmentation.

In [None]:
# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the pre-trained FCN model
model = fcn_resnet50(pretrained=True)
model = model.to(device)
model.eval()

## Step 4: Perform Segmentation and Visualize Results

We will iterate over the images, perform segmentation, and visualize the results.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Load and display the image
img = mpimg.imread('../data/plot_tool.png')
plt.imshow(img)
plt.axis('off')
plt.show()

In [None]:
import matplotlib.pyplot as plt

path = '../data/'
image_files = [path + 'cam1_plot_coverage.png', path + 'cam2_plot_coverage.png', path + 'cam3_plot_coverage.png',
               path + 'cam4_plot_coverage.png', path + 'cam5_plot_coverage.png', path + 'cam6_plot_coverage.png',
               path + 'cam7_plot_coverage.png', path + 'cam8_plot_coverage.png']

# Assuming 'image_files' is the list of image file names
for image_file in image_files:
    # Load the image
    img = plt.imread(image_file)

    # Display the image with the title as the image file name
    plt.imshow(img)
    image_file = image_file.replace('../data/', '')
    plt.title(image_file)
    plt.show()


In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import matplotlib.pyplot as plt

# Define the transformations
transform = transforms.Compose([
    transforms.Resize((928, 1280)),
    transforms.ToTensor(),
])

class ImageDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.transform = transform
        self.image_files = [f for f in os.listdir(image_folder) if os.path.isfile(os.path.join(image_folder, f))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_folder, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')  # Convert RGBA to RGB
        if self.transform:
            image = self.transform(image)
        return image

def split_image(image):
    left_image = image[:, :, :640]
    right_image = image[:, :, 640:]
    return left_image, right_image

# Load the dataset
dataset_folder = '../data/cam1'
output_folder = '../data/cam1_split'
os.makedirs(output_folder, exist_ok=True)
dataset = ImageDataset(image_folder=dataset_folder, transform=transform)

# Iterate through the dataset, split images, and display them
for idx in range(len(dataset)):
    image = dataset[idx]  # Get image as a tensor
    left_image, right_image = split_image(image)

    # Convert tensors back to PIL Images
    left_image_pil = transforms.ToPILImage()(left_image)
    right_image_pil = transforms.ToPILImage()(right_image)

    # Save the split images
    base_filename = os.path.splitext(dataset.image_files[idx])[0]
    left_image_pil.save(os.path.join(output_folder, f"{base_filename}_left.png"))
    right_image_pil.save(os.path.join(output_folder, f"{base_filename}_right.png"))

    # Display the images
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(left_image_pil)
    plt.title(f"{base_filename}_left.png")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(right_image_pil)
    plt.title(f"{base_filename}_right.png")
    plt.axis('off')

    plt.show()