```
Implementation of U-NET

Paper: https://arxiv.org/abs/1505.04597

Article: https://medium.com/analytics-vidhya/unet-implementation-in-pytorch-idiot-developer-da40d955f201

Helpful YT Video: https://www.youtube.com/watch?v=HS3Q_90hnDg

```

In [None]:
## UNET Functions ##

import torch
import torch.nn as nn

class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU()
        
    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))
        
    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c+out_c, out_c)
        
    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x

In [None]:
## Building UNET ##

class build_unet(nn.Module):
    def __init__(self):
        super().__init__()
        """ Encoder """
        self.e1 = encoder_block(3, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.e4 = encoder_block(256, 512)
        """ Bottleneck """
        self.b = conv_block(512, 1024)
        """ Decoder """
        self.d1 = decoder_block(1024, 512)
        self.d2 = decoder_block(512, 256)
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64)
        """ Classifier """
        self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)
        
    def forward(self, inputs):
        """ Encoder """
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)
        """ Bottleneck """
        b = self.b(p4)
        """ Decoder """
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)
        """ Classifier """
        outputs = self.outputs(d4)
        return outputs

In [None]:
## Custom Dataset ##
import os
from PIL import Image
import torchvision.transforms as transforms

# Define transformations to apply to the images
transform = transforms.Compose([
    transforms.Resize((416, 416)), # Resize images to a consistent size
    transforms.ToTensor(),         # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize pixel values
])

class ImageLabelDataset(torch.utils.data.Dataset):
    def __init__(self, image_folder, label_folder, device='cpu'):
        self.image_folder = image_folder
        self.label_folder = label_folder
        self.images = os.listdir(image_folder)
        self.labels = os.listdir(label_folder)
        self.device = device

    # Get Image and Label as Tensors
    def __getitem__(self, idx):
        # Image --> Tensor
        image_file = self.images[idx]
        image = Image.open((self.image_folder + "/" + image_file))
        image = transform(image)
        
        # Label as String (for now)
        label_file = self.labels[idx]
        f = open(self.label_folder + "/" + label_file)
        label = f.read()

        # Move to device
        image = image.to(self.device)

        # Return image label combo!
        return image, label

    def __len__(self):
        return len(self.images)

In [None]:
## HYPERPARAMETERS ##
BATCH_SIZE = 10
EPOCHS = 5

In [None]:
## Data Preprocessing ###

# Import training data
training_images_path = "data/car/train/images"
training_labels_path = "data/car/train/labels"

# Create training dataset
training_dataset = ImageLabelDataset(training_images_path, training_labels_path)

train_dataloader = torch.utils.data.DataLoader(
    training_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

: 

In [None]:
# Run Model on Training

model = build_unet()
model.train() # Built in PyTorch function, tells model we are training it

# Not sure how to run/train model. What's the deal with the labels? How do we validate data against training??
for img, label in train_dataloader:  
    y = model(img)
    print(y.shape)