In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from torchvision.datasets import VOCSegmentation
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [2]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

In [3]:
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=21, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down path of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up path of UNET
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.maxpool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

In [4]:
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet = UNET(in_channels=3, out_channels=21).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(unet.parameters(), lr=0.001)

In [5]:
# Transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

target_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

In [7]:
# Download Pascal VOC dataset
train_dataset = VOCSegmentation(root='.', year='2012', image_set='train', download=True, transform=transform, target_transform=target_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)


Using downloaded and verified file: .\VOCtrainval_11-May-2012.tar
Extracting .\VOCtrainval_11-May-2012.tar to .


In [None]:
num_epochs = 5
losses = []

for epoch in range(num_epochs):
    for batch in train_loader:
        inputs, targets = batch
        targets = targets.squeeze(1).long()  # Convert targets to long tensor

        optimizer.zero_grad()

        outputs = unet(inputs.to(device))

        loss = criterion(outputs, targets.to(device))

        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        print(f"Epoch: {epoch}, Loss: {loss.item()}")

    # Save the model
    torch.save(unet.state_dict(), f'unet_epoch_{epoch}.pth')

Epoch: 0, Loss: 3.230180501937866
Epoch: 0, Loss: 3.10801100730896
Epoch: 0, Loss: 2.9428341388702393
Epoch: 0, Loss: 2.7719733715057373
Epoch: 0, Loss: 2.620206594467163
Epoch: 0, Loss: 2.5215301513671875
Epoch: 0, Loss: 2.404702663421631
Epoch: 0, Loss: 2.3266263008117676
Epoch: 0, Loss: 2.2532594203948975
Epoch: 0, Loss: 2.187397003173828
Epoch: 0, Loss: 2.1543076038360596
Epoch: 0, Loss: 2.096327066421509
Epoch: 0, Loss: 2.098078966140747
Epoch: 0, Loss: 2.066683769226074
Epoch: 0, Loss: 2.0024545192718506
Epoch: 0, Loss: 1.9914770126342773
Epoch: 0, Loss: 1.9308099746704102
Epoch: 0, Loss: 1.890919804573059
Epoch: 0, Loss: 1.835479497909546
Epoch: 0, Loss: 1.8456300497055054
Epoch: 0, Loss: 1.7869576215744019
Epoch: 0, Loss: 1.7544703483581543


In [None]:
# Plot the loss graph
plt.plot(losses)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

In [None]:
# Function to load the model and continue training
def load_and_continue_training(model_path, start_epoch, num_epochs):
    unet.load_state_dict(torch.load(model_path))
    unet.train()

    for epoch in range(start_epoch, start_epoch + num_epochs):
        for batch in train_loader:
            inputs, targets = batch
            targets = targets.squeeze(1).long()  # Convert targets to long tensor

            optimizer.zero_grad()

            outputs = unet(inputs.to(device))

            loss = criterion(outputs, targets.to(device))

            loss.backward()
            optimizer.step()

            losses.append(loss.item())
            print(f"Epoch: {epoch}, Loss: {loss.item()}")

        # Save the model
        torch.save(unet.state_dict(), f'unet_epoch_{epoch}.pth')

    # Plot the updated loss graph
    plt.plot(losses)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.show()

# Example usage to continue training
# load_and_continue_training('unet_epoch_4.pth', 5, 5)

### Evaluation and Inference

In [None]:
import numpy as np

def pixel_accuracy(output, target):
    with torch.no_grad():
        output = torch.argmax(output, dim=1)
        correct = (output == target).float()
        acc = correct.sum() / correct.numel()
    return acc

def intersection_and_union(output, target, num_classes):
    with torch.no_grad():
        output = torch.argmax(output, dim=1)
        intersection = output & target
        union = output | target

        intersection_hist = torch.histc(intersection.float(), bins=num_classes, min=0, max=num_classes-1)
        union_hist = torch.histc(union.float(), bins=num_classes, min=0, max=num_classes-1)

    return intersection_hist, union_hist

def mean_iou(intersection_hist, union_hist):
    iou = intersection_hist / union_hist
    miou = torch.mean(iou)
    return miou


In [None]:
import matplotlib.pyplot as plt

def evaluate_model(model, dataloader, device):
    model.eval()
    total_acc = 0
    total_intersection = torch.zeros(21).to(device)
    total_union = torch.zeros(21).to(device)

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs = inputs.to(device)
            targets = targets.squeeze(1).long().to(device)

            outputs = model(inputs)

            acc = pixel_accuracy(outputs, targets)
            total_acc += acc.item()

            intersection_hist, union_hist = intersection_and_union(outputs, targets, 21)
            total_intersection += intersection_hist
            total_union += union_hist

    mean_acc = total_acc / len(dataloader)
    miou = mean_iou(total_intersection, total_union)

    return mean_acc, miou

def visualize_predictions(model, dataloader, device, num_images=5):
    model.eval()

    with torch.no_grad():
        for i, (inputs, targets) in enumerate(dataloader):
            if i >= num_images:
                break
            inputs = inputs.to(device)
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1).cpu()

            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            axes[0].imshow(inputs[0].cpu().permute(1, 2, 0))
            axes[0].set_title("Input Image")
            axes[1].imshow(targets[0].cpu().squeeze(), cmap='gray')
            axes[1].set_title("Ground Truth")
            axes[2].imshow(preds[0].squeeze(), cmap='gray')
            axes[2].set_title("Prediction")
            plt.show()

# Evaluate the model
mean_acc, miou = evaluate_model(unet, train_loader, device)
print(f"Mean Accuracy: {mean_acc:.4f}, Mean IoU: {miou:.4f}")

# Visualize some predictions
visualize_predictions(unet, train_loader, device)
