# U-Net Image Segmentation From Scratch (PyTorch)

Custom U-Net architecture implementation in PyTorch for image segmentation with circle detection.

## Features

- U-Net architecture
- Encoder-decoder
- Skip connections
- PyTorch implementation
- Custom dummy data generation

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim  # training model
from torch.utils.data import DataLoader  # batching our data
from torchvision import transforms  # images to tensor
from PIL import Image  # image creation and processing
import numpy as np
import matplotlib.pyplot as plt

In [None]:
#create dummy data (creating 50 images and size of the images is 128*128)

def make_dummy_data(num=50, size=128):
    # image and the mask (location of the circle)
    data = []
    tf = transforms.ToTensor()  # numpy image into tensors with pixel values from 0 to 1
    
    for _ in range(num):
        # creating images and mask
        # dtype range from 0-255, 3 is rgb here
        image = np.zeros((size, size, 3), dtype=np.uint8)
        # grayscale mask either 0 or 1
        mask = np.zeros((size, size), np.float32)
        
        # drawing random circle now
        # randomly choosing a center point value are between 32 and 95,
        # keeping them away from the edge
        cx, cy = np.random.randint(32, 96, size=2)
        # random radius between 10 and 29 pixels
        r = np.random.randint(10, 30)
        
        # creating 2d arrays, rr is the row vector of row indices,
        # cc is the column vector of column indices
        rr, cc = np.ogrid[:size, :size]
        # equation of the circle returns a boolean mask, draw a white circle
        circle = (rr - cx) ** 2 + (cc - cy) ** 2 <= r ** 2

In [None]:
# visualization
images, masks = next(iter(loader))  # grabbing the batch from the loader
# pytorch requires both models on one device
images, masks = images.to(device), masks.to(device)

with torch.no_grad():
    e1 = encoder1(images)  # extracts low level features
    e2 = encoder2(pool(e1))  # downsamples the images (eg 128 to 64)
    e3 = encoder3(pool(e2))  # deep level features
    
    d1 = up1(e3)  # upsampling (eg 32 to 64*64)
    d1 = torch.cat([d1, e2], dim=1)  # skip connection
    d1 = decoder1(d1)  # refines features after skip connection
    
    d2 = up2(d1)
    d2 = torch.cat([d2, e1], dim=1)
    d2 = decoder2(d2)
    
    preds = sigmoid(out_conv(d2))

for i in range(4):
    plt.figure(figsize=(8, 2))
    
    plt.subplot(1, 3, 1)
    # permute rearranges tensors to height width and channel,
    # numpy converts tensors to numpy array to work with matplotlib
    plt.imshow(images[i].cpu().permute(1, 2, 0).numpy())
    
    plt.subplot(1, 3, 2)
    plt.imshow(masks[i][0].cpu().numpy(), cmap='gray')
    
    plt.subplot(1, 3, 3)
    plt.imshow(preds[i][0].cpu().detach().numpy(), cmap='gray')