**NEED GGCOLAB PRO --> MORE THAN 15GB GPU RAM**

In [26]:
!git clone https://github.com/p-kTmm/Unet_Segmentation.git

Cloning into 'Unet_Segmentation'...
remote: Enumerating objects: 8, done.[K
remote: Counting objects: 100% (8/8), done.[K
remote: Compressing objects: 100% (6/6), done.[K
Receiving objects: 100% (8/8), done.
remote: Total 8 (delta 0), reused 0 (delta 0), pack-reused 0[K


In [2]:
%cd Unet_Segmentation

/content/Unet_Segmentation


In [3]:
import torch
from torch import optim, nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

from unet import UNet
from carvana_dataset import CarvanaDataset

## Load dataset

In [4]:
import os

# Path to the base directory
base_dir = '/content/Unet_Segmentation'

# Paths for the 'data' and 'models' directories
data_dir = os.path.join(base_dir, 'data')
models_dir = os.path.join(base_dir, 'models')

# Create the directories if they don't exist
os.makedirs(data_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)

Use datasets on kaggle: https://www.kaggle.com/c/carvana-image-masking-challenge/data

Firstly, you need to upload the **kaggle.json** file in **/content/Unet_Segmentation** which is required for accessing Kaggle's API. You can download it on kaggle API account

In [8]:
!pip install kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle competitions download -c carvana-image-masking-challenge

Downloading carvana-image-masking-challenge.zip to /content/Unet_Segmentation
100% 24.4G/24.4G [04:28<00:00, 91.4MB/s]
100% 24.4G/24.4G [04:28<00:00, 97.5MB/s]


In [9]:
!unzip -q /content/Unet_Segmentation/carvana-image-masking-challenge.zip -d /content/Unet_Segmentation/data

In [10]:
!unzip -q /content/Unet_Segmentation/data/train.zip -d /content/Unet_Segmentation/data
!unzip -q /content/Unet_Segmentation/data/train_masks.zip -d /content/Unet_Segmentation/data

## Training

In [11]:
LEARNING_RATE = 3e-4
BATCH_SIZE = 32
EPOCHS = 2
DATA_PATH = "/content/Unet_Segmentation/data"
MODEL_SAVE_PATH = "/content/Unet_Segmentation/models/unet.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

train_dataset = CarvanaDataset(DATA_PATH)

generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(train_dataset, [0.8, 0.2], generator=generator)

train_dataloader = DataLoader(dataset=train_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True)

cuda


In [14]:
model = UNet(in_channels=3, num_classes=1).to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

for epoch in tqdm(range(EPOCHS)):
    model.train()
    train_running_loss = 0
    for idx, img_mask in enumerate(tqdm(train_dataloader)):
        img = img_mask[0].float().to(device)
        mask = img_mask[1].float().to(device)

        y_pred = model(img)
        optimizer.zero_grad()

        loss = criterion(y_pred, mask)
        train_running_loss += loss.item()

        loss.backward()
        optimizer.step()

    train_loss = train_running_loss / (idx + 1)

    model.eval()
    val_running_loss = 0
    with torch.no_grad():
        for idx, img_mask in enumerate(tqdm(val_dataloader)):
            img = img_mask[0].float().to(device)
            mask = img_mask[1].float().to(device)

            y_pred = model(img)
            loss = criterion(y_pred, mask)

            val_running_loss += loss.item()

        val_loss = val_running_loss / (idx + 1)

    print("-"*30)
    print(f"Train Loss EPOCH {epoch+1}: {train_loss:.4f}")
    print(f"Valid Loss EPOCH {epoch+1}: {val_loss:.4f}")
    print("-"*30)

torch.save(model.state_dict(), MODEL_SAVE_PATH)

## Inference

In [20]:
SINGLE_IMG_PATH = "/content/Unet_Segmentation/data/29bb3ece3180_11.jpg"
DATA_PATH = "/content/Unet_Segmentation/data"
MODEL_PATH = "./models/unet.pth"  #change this if u want to use your model

In [21]:
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image

from carvana_dataset import CarvanaDataset
from unet import UNet

def pred_show_image_grid(data_path, model_pth, device):
    model = UNet(in_channels=3, num_classes=1).to(device)
    model.load_state_dict(torch.load(model_pth, map_location=torch.device(device)))
    image_dataset = CarvanaDataset(data_path, test=True)
    images = []
    orig_masks = []
    pred_masks = []

    for img, orig_mask in image_dataset:
        img = img.float().to(device)
        img = img.unsqueeze(0)

        pred_mask = model(img)

        img = img.squeeze(0).cpu().detach()
        img = img.permute(1, 2, 0)

        pred_mask = pred_mask.squeeze(0).cpu().detach()
        pred_mask = pred_mask.permute(1, 2, 0)
        pred_mask[pred_mask < 0]=0
        pred_mask[pred_mask > 0]=1

        orig_mask = orig_mask.cpu().detach()
        orig_mask = orig_mask.permute(1, 2, 0)

        images.append(img)
        orig_masks.append(orig_mask)
        pred_masks.append(pred_mask)

    images.extend(orig_masks)
    images.extend(pred_masks)
    fig = plt.figure()
    for i in range(1, 3*len(image_dataset)+1):
       fig.add_subplot(3, len(image_dataset), i)
       plt.imshow(images[i-1], cmap="gray")
    plt.show()


def single_image_inference(image_pth, model_pth, device):
    model = UNet(in_channels=3, num_classes=1).to(device)
    model.load_state_dict(torch.load(model_pth, map_location=torch.device(device)))

    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor()])

    img = transform(Image.open(image_pth)).float().to(device)
    img = img.unsqueeze(0)

    pred_mask = model(img)

    img = img.squeeze(0).cpu().detach()
    img = img.permute(1, 2, 0)

    pred_mask = pred_mask.squeeze(0).cpu().detach()
    pred_mask = pred_mask.permute(1, 2, 0)
    pred_mask[pred_mask < 0]=0
    pred_mask[pred_mask > 0]=1

    fig = plt.figure()
    for i in range(1, 3):
        fig.add_subplot(1, 2, i)
        if i == 1:
            plt.imshow(img, cmap="gray")
        else:
            plt.imshow(pred_mask, cmap="gray")
    plt.show()


If you want to run prediction on multiple images, you must use pred_show_image_grid() function by giving your data path, model path and device as arguments.
Make 2 folder and put your images in those:
- /content/Unet_Segmentation/data/manual_test
- /content/Unet_Segmentation/data/manual_test_masks

In [25]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# pred_show_image_grid(DATA_PATH, MODEL_PATH, device)
single_image_inference(SINGLE_IMG_PATH, MODEL_PATH, device)