### Libraries

In [2]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
import copy
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

In [12]:
INPUT_IMAGE_HEIGHT = 1024
INPUT_IMAGE_WIDTH = 672

class Compose:
    def __init__(self, transforms):
        self.transforms = transforms
        
    def __call__(self, image, target):
        for t in self.transforms:
            image = t(image)
            target = t(target)
        target = torch.tensor(np.array(target), dtype=torch.int64)
        image = transforms.ToTensor()(image)
        return image, target

class SegmentationDataset(Dataset):
    def __init__(self, root_dir, phase='Train', transform=None):
        self.root_dir = root_dir
        self.phase = phase
        self.transform = transform

        self.image_paths = []
        self.mask_paths = []

        sub_dirs = ['Brown_Field', 'Main_Trail', 'Power_Line', 'mixed']

        for sub_dir in sub_dirs:
            subdir_path = os.path.join(root_dir, sub_dir)
            if os.path.isdir(subdir_path):
                img_dir = os.path.join(subdir_path, phase, 'imgs')
                mask_dir = os.path.join(subdir_path, phase, 'annos', 'int_maps')

                if os.path.exists(img_dir) and os.path.exists(mask_dir):
                    img_files = sorted([f for f in os.listdir(img_dir) if f.endswith('.jpg') or f.endswith('.png')])
                    mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith('.jpg') or f.endswith('.png')])

                    if len(img_files) == 0 or len(mask_files) == 0:
                        print(f"No images or masks found in {img_dir} or {mask_dir}")
                    else:
                        for img_file, mask_file in zip(img_files, mask_files):
                            self.image_paths.append(os.path.join(img_dir, img_file))
                            self.mask_paths.append(os.path.join(mask_dir, mask_file))
                else:
                    print(f"Image directory {img_dir} or mask directory {mask_dir} does not exist")

        if len(self.image_paths) == 0:
            print("No images or masks found in any subdirectory")
        else:
            print(f"Found {len(self.image_paths)} images and {len(self.mask_paths)} masks")
        
    def _collect_image_mask_pairs(self):
        image_mask_pairs = []
        for root_dir in self.root_dirs:
            image_dir = os.path.join(root_dir, self.split, 'imgs')
            mask_dir = os.path.join(root_dir, self.split, 'annos', 'int_maps')
            
            images = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswidth('.png')])
            masks = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswidth('.png')])
            
            mask_dict = {os.path.basename(mask).split('_')[1].replace('.png',''): mask for mask in masks}
            for img in images:
                key = os.path.basename(img).split('_')[1].replace('.png', '')
                if key in mask_dict:
                    image_mask_pairs.append((img, mask_dict,[key]))
                else:
                    print(f"No matching mask for image: {img}")
                    
        return image_mask_pairs
        
    def __len__(self):
        return len(self.image_masks_pairs)
        
    def __getitem__(self, idx):
        img_path, mask_path = self.image_mask_pairs[idx]
        image = Image.open(img_path).convert("RGB")
        mask - Image.open(mask_path).convert("L")
        if self.transforms:
            image, mask = self.transforms(image, mask)
        return image, mask
    
#Transformation definition
transform = Compose([transforms.Resize((INPUT_IMAGE_HEIGHT, INPUT_IMAGE_WIDTH), interpolation=Image.NEAREST)])

#List of root directories
root_dirs = ['CAT/Brown_Field', 'CAT/Main_Trail', 'CAT/Power_Line', 'CAT/mixed']

#Data Transforms
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

mask_transforms = transforms.Compose([
    
])

#Data loader
root_dir = '/home/rchen2/CAT'

train_dataset = SegmentationDataset(root_dir=root_dir, phase='Train', transform=data_transforms)
test_dataset = SegmentationDataset(root_dir=root_dir, phase='Test', transform=data_transforms)

train_loader = DataLoader(train_dataset, batch_size=16, num_workers=4, shuffle=True,)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False,)

#Debugging
for images, masks in train_loader:
    print(f"Batch of images shape: {images.shape}")
    print(f"Batch of masks shape: {masks.shape}")
    break

#Instantiate the model
model = smp.Unet(encoder_name = "resnet34", encoder_weights="imagenet", in_channels=3, classes=4).cuda()

        

Found 2536 images and 2536 masks
Found 1088 images and 1088 masks


AttributeError: 'SegmentationDataset' object has no attribute 'image_masks_pairs'

In [13]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm

#define criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

NUM_CLASSES = 4

class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, num_classes, device='cuda'):
        self.model = model.to(device)
        self.train_loader = train.loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.num_classes = num_classes
        self.device = device
        self.train_iou_per_class = []
        
    def calculate_iou(self, pred, target):
        iou_per_class = []
        pred = torch.argmax(pred, dim=1)
        
        for cls in range(self.num_classes):
            pred_inds = (pred == cls)
            target_inds = (target == cls)
            
            intersection = (pred_inds & target_inds).sum().item()
            union = (pred_inds | target_inds).sum().item()
            
            if union == 0:
                iou_per_class.append(float('nan'))
            else:
                iou_per_class.append(intersection / union)
                
        return iou_per_class

    def mean_iou(self, iou_list):
        valid_iou = [iou for iou in iou_list if not np.isnan(iou)]
        if len(valid_iou) == 0:
            return float('nan')
        return np.mean(valid_iou)
    
    def train_epoch(self, epoch):
        self.model.train()
        running_loss = 0.0
        epoch_iou_scores = np.zeros((len(self.train_loader), self.num_classes))
        
        for batch_idx, (images, masks) in enumerate(tqdm(self.train_loader, desc=f'Training Epoch {epoch}')):
            images = images.to(self.device)
            masks = masks.to(self.device)
            
            outputs = self.model(images)
            loss = self.criterion(outputs, masks)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            running_loss += loss.item()
            
            #Calculate IoU
            iou = self.calculate_iou(outputs, masks)
            epoch_iou_scores[batch_idx] = iou
            
        avg_loss = running_loss / len(self.train_loader)
        avg_iou_scores = np.nanmean(epoch_iou_scores, axis=0)
        self.train_iou_per_class.append(avg_iou_scores)
        
        print(f"Epoch [{epoch}] Training Loss: {avg_loss}")
        print(f"Epoch [{epoch}] IoU per class: {avg_iou_scores}")
        return avg_loss
    
    def validate_epoch(self):
        self.model.eval()
        running_loss = 0.0
        iou_scores = []
        
        with torch.no_grad():
            for images, masks in tqdm(self.val_loader, desc='Validating'):
                images = images.to(self.device)
                masks = masks.to(self.device)
                
                outputs = self.model(images)
                loss = self.criterion(outputs, masks)
                
                running_loss += loss.item()
                
                #Calculate IoU
                iou = self.calculate_iou(outputs, masks)
                epoch_iou_scores[batch_idx] = iou
                
            avg_loss = running_loss / len(self.train_loader)
            
            #Compute mean IoU
            iou_scores = np.array(iou_scores)
            miou_per_class = np.nanmean(iou_scores, axis=0)
            miou = self.mean_iou(miou_per_class)
            
            print(f"Validation Loss: {avg_loss}")
            print(f"Mean IoU: {miou}")
            print(f"IoU per class: {miou_per_class}")
            return avg_loss, miou, miou_per_class
        
        def train(self, num_epochs):
            for epoch in range(1, num_epochs + 1):
                self.train_epoch(epoch)
                self.validate_epoch
        


NameError: name 'model' is not defined

In [None]:
#Define class weights based on the dataset
class_weights = torch.tensor([1.0, 2.0, 2.0, 2.0]).cuda()

#Use weighted loss function
criterion = nn.CrossEntropyLoss(weight=class_weights)

#Instantiate the Trainer class with weighted loss
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=test_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_classes=NUM_CLASSES,
    device='cuda'
)

#Train the model
num_epochs=4
trainer.train(num_epochs)

In [None]:
def print_class_distribution(loader, num_classes):
    class_counts = [0] * num_classes
    
    for _, masks in loader:
        for mask in masks:
            class_counts[cls] += (mask == cls).sum().item()
            
    total_pixels = sum(class_counts)
    class_distribution = [count / total_pixels for count in class_counts]
    
    print(f"Class distribution: {class_distribution}")
    
print_class_distribution(train_loader, NUM_CLASSES)
print_class_distribution(test_loader, NUM_CLASSES)

In [None]:
import matplotlib.pyplot as plt

#Plotting IoU per class
def plot_iou_per_class(trainer, num_epochs, num_classes):
    epochs = np.arrange(1, num_epochs + 1)
    iou_per_class = np.array(trainer.train_iou_per_class).T
    
    plt.figure(figsize=(12, 8))
    
    classNames = ['Background', 'Sedan', 'Pickup', 'Off-Road']
    for cls in range(num_classes):
        plt.plot(epochs, iou_per_class[cls], label=f'{classNames[cls]} IoU')
        
    plt.xlabel('Epochs')
    plt.ylabel('IoU')
    plt.title('ResNet-34')
    plt.legend()
    plt.grid(True)
    plt.show()
    
plot_iou_per_class(trainer, num_epochs, NUM_CLASSES)