### Libraries

In [2]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
import copy
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

### Dataset

In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.cuda.amp import GradScaler, autocast
from PIL import Image
import numpy as np
import time
from tqdm import tqdm

class CustomDataset(Dataset):
    def __init__(self, root_dir, phase='Train', transform=None):
        self.root_dir = root_dir
        self.phase = phase
        self.transform = transform

        self.image_paths = []
        self.mask_paths = []

        sub_dirs = ['Brown_Field', 'Main_Trail', 'Power_Line', 'mixed']

        for sub_dir in sub_dirs:
            subdir_path = os.path.join(root_dir, sub_dir)
            if os.path.isdir(subdir_path):
                img_dir = os.path.join(subdir_path, phase, 'imgs')
                mask_dir = os.path.join(subdir_path, phase, 'annos', 'int_maps')

                if os.path.exists(img_dir) and os.path.exists(mask_dir):
                    img_files = sorted([f for f in os.listdir(img_dir) if f.endswith('.jpg') or f.endswith('.png')])
                    mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith('.jpg') or f.endswith('.png')])

                    if len(img_files) == 0 or len(mask_files) == 0:
                        print(f"No images or masks found in {img_dir} or {mask_dir}")
                    else:
                        for img_file, mask_file in zip(img_files, mask_files):
                            self.image_paths.append(os.path.join(img_dir, img_file))
                            self.mask_paths.append(os.path.join(mask_dir, mask_file))
                else:
                    print(f"Image directory {img_dir} or mask directory {mask_dir} does not exist")

        if len(self.image_paths) == 0:
            print("No images or masks found in any subdirectory")
        else:
            print(f"Found {len(self.image_paths)} images and {len(self.mask_paths)} masks")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert('L').resize((224,224), resample=Image.NEAREST)
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        # Convert mask to binary (0 and 1)
        mask = (mask > 0).float()
        
        transforms.Resize((224, 224)),
        mask = np.array(mask)
        mask = torch.tensor(mask, dtype=torch.int64)
        
        # squeeze the tensor
        mask = torch.squeeze(mask)
        # print the squeezed tensor

        return image, mask



# Accuracy calculation function
def calculate_accuracy(outputs, masks):
    preds = torch.sigmoid(outputs) > 0.5  # Threshold predictions at 0.5
    correct = (preds == masks).float()
    accuracy = correct.sum() / correct.numel()
    return accuracy


### Data Transforms

In [4]:
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

mask_transforms = transforms.Compose([
    
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

### DataLoader

In [5]:
root_dir = '/home/rchen2/CAT'

train_dataset = CustomDataset(root_dir=root_dir, phase='Train', transform=data_transforms)
test_dataset = CustomDataset(root_dir=root_dir, phase='Test', transform=data_transforms)
val_dataset = CustomDataset(root_dir=root_dir, phase='Train', transform=val_transforms)


train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,)
val_loader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=False)


# Accuracy calculation function
def calculate_accuracy(predictions, ground_truth):
    # Assuming predictions and ground_truth are binary masks with values 0 or 1
    correct = np.sum(predictions == ground_truth)
    total = predictions.size
    accuracy = correct / total
    return accuracy

Found 2536 images and 2536 masks
Found 1088 images and 1088 masks
Found 2536 images and 2536 masks


### Loss Function & Optimizer

In [6]:
import segmentation_models_pytorch as smp
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = smp.Unet(encoder_name = "resnet34", encoder_weights="imagenet", in_channels=3, classes=4,).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

### Training Loop

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast

class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, num_epochs, device, num_classes):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.num_epochs = num_epochs
        self.device = device
        self.num_classes = num_classes
        self.train_iou_per_class = []
        self.scaler = GradScaler()  # For mixed precision training

    def calculate_iou(self, pred, target):
        iou_per_class = []
        pred = torch.argmax(pred, dim=1)

        for cls in range(self.num_classes):
            pred_inds = (pred == cls)
            target_inds = (target == cls)

            intersection = (pred_inds & target_inds).sum().item()
            union = (pred_inds | target_inds).sum().item()

            if union == 0:
                iou_per_class.append(float('nan'))
            else:
                iou_per_class.append(intersection / union)

        return iou_per_class

    def mean_iou(self, iou_list):
        valid_iou = [iou for iou in iou_list if not np.isnan(iou)]
        if len(valid_iou) == 0:
            return float('nan')
        return np.mean(valid_iou)

    def train_epoch(self, epoch):
        self.model.train()
        running_loss = 0.0
        iou_scores = []

        for images, masks in tqdm(self.train_loader, desc=f'Training Epoch {epoch}'):
            images = images.to(self.device)
            masks = masks.to(self.device)

            self.optimizer.zero_grad()
            with autocast():  # Mixed precision context
                outputs = self.model(images)
                loss = self.criterion(outputs, masks)

            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            running_loss += loss.item()

            # Calculate IoU
            iou = self.calculate_iou(outputs, masks)
            iou_scores.append(iou)

        avg_loss = running_loss / len(self.train_loader)
        avg_iou_scores = np.nanmean(iou_scores, axis=0)
        self.train_iou_per_class.append(avg_iou_scores)

        print(f"Epoch [{epoch}] Training Loss: {avg_loss}")
        print(f"Epoch [{epoch}] IoU per class: {avg_iou_scores.tolist()}")  # Convert to list for safe printing
        return avg_loss

    def validate_epoch(self):
        self.model.eval()
        running_loss = 0.0
        iou_scores = []

        with torch.no_grad():
            for images, masks in tqdm(self.val_loader, desc='Validating'):
                images = images.to(self.device)
                masks = masks.to(self.device)

                outputs = self.model(images)
                loss = self.criterion(outputs, masks)

                running_loss += loss.item()

                # Calculate IoU
                iou = self.calculate_iou(outputs, masks)
                iou_scores.append(iou)

            avg_loss = running_loss / len(self.val_loader)

            # Compute mean IoU
            iou_scores = np.array(iou_scores)
            miou_per_class = np.nanmean(iou_scores, axis=0)
            miou = self.mean_iou(miou_per_class)

            print(f"Validation Loss: {avg_loss}")
            print(f"Mean IoU: {miou}")
            print(f"IoU per class: {miou_per_class.tolist()}")  # Convert to list for safe printing
            return avg_loss, miou, miou_per_class

    def train(self):
        for epoch in range(1, self.num_epochs + 1):
            train_loss = self.train_epoch(epoch)
            val_loss, val_miou, val_miou_per_class = self.validate_epoch()

# Example usage:
# Define the number of epochs
num_epochs = 4

# Define the number of classes
num_classes = 4

# Define class weights based on the dataset
class_weights = torch.tensor([1.0, 2.0, 2.0, 2.0]).to(device)

# Use weighted loss function
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Create a trainer instance
trainer = Trainer(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, num_classes)

# Start training
trainer.train()

Training Epoch 1: 100%|██████████| 80/80 [19:26<00:00, 14.58s/it]


Epoch [1] Training Loss: 0.6237573485821486
Epoch [1] IoU per class: [0.38296619405795673, 0.6628771161735837, 0.0, 0.0]


Validating: 100%|██████████| 80/80 [06:28<00:00,  4.85s/it]


Validation Loss: 0.5287993714213371
Mean IoU: 0.5537591658528216
IoU per class: [0.4088495663423434, 0.6986687653632997, nan, nan]


Training Epoch 2: 100%|██████████| 80/80 [18:45<00:00, 14.06s/it]


Epoch [2] Training Loss: 0.36221774891018865
Epoch [2] IoU per class: [0.40503623090097635, 0.7376561658433549, nan, nan]


Validating: 100%|██████████| 80/80 [06:27<00:00,  4.84s/it]


Validation Loss: 0.4738102236762643
Mean IoU: 0.5956204516944662
IoU per class: [0.48307979323398753, 0.7081611101549449, nan, nan]


Training Epoch 3: 100%|██████████| 80/80 [18:45<00:00, 14.07s/it]


Epoch [3] Training Loss: 0.33280202820897103
Epoch [3] IoU per class: [0.4285358151115994, 0.7481885097980725, nan, nan]


Validating: 100%|██████████| 80/80 [06:28<00:00,  4.86s/it]


Validation Loss: 0.3783735426142812
Mean IoU: 0.5865965558673774
IoU per class: [0.44439337578582094, 0.7287997359489337, nan, nan]


Training Epoch 4:  90%|█████████ | 72/80 [17:04<01:53, 14.16s/it]

In [None]:
def plot_iou_per_class(trainer, num_epochs):
    epochs = np.arrange(1, num_epochs+1)
    iou_per_class == np.array(trainer.train_iou_per_class).T
    
    plt.figure(figsize=(12, 8))
    
    classNames = ['Background', 'Sedan', 'Pickup', 'Off-Road']
    for cls in range(num_classes):
        plt.plot(epochs, iou_per_class[cls], label=f'{classNames[cls]} IoU')
    
    plt.xlabel('Epochs')
    plt.ylabel('IoU')
    plt.title('ResNet-34')
    plt.legend()
    plt.grid(True)
    plt.show()

plot_iou_per_class(trainer, num_epochs)