In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
#from dataloader import SegmentationDataset
from model import UNet
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import csv
import time
from datetime import datetime
from PIL import Image

In [2]:
SEGMENTATION_COLOURS = {0:[0,0,0],1:[255,0,0],2:[0,253,0],3:[0,0,250], 4:[253,255,0]}

In [3]:
def calculate_iou(pred_mask, target_mask, num_classes=5):
    ious = []
    
    if torch.is_tensor(pred_mask):
        pred_mask = pred_mask.cpu().numpy()
    if torch.is_tensor(target_mask):
        target_mask = target_mask.cpu().numpy()
    
    for cls in range(num_classes):
        pred_inds = pred_mask == cls
        target_inds = target_mask == cls
        
        intersection = np.logical_and(pred_inds, target_inds).sum()
        union = np.logical_or(pred_inds, target_inds).sum()
        
        iou = intersection / (union + 1e-6)  
        ious.append(iou)
    
    return np.mean(ious)

In [4]:
class SegmentationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        
        self.root_dir = root_dir
        self.transform = transform

        self.image_dir = os.path.join(root_dir, 'Images')
        self.mask_dir = os.path.join(root_dir, 'Labels')
        
        self.image_files = sorted([f for f in os.listdir(self.image_dir) 
                                 if f.endswith(('.png', '.jpg', '.jpeg'))])
        
        for img in self.image_files:
            base_name = os.path.splitext(img)[0]
            mask_name = f"{base_name}_mask.png"
            mask_path = os.path.join(self.mask_dir, mask_name)


    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        base_name = os.path.splitext(self.image_files[idx])[0]
        mask_name = os.path.join(self.mask_dir, f"{base_name}_mask.png")
        
        # Load image and mask
        image = Image.open(img_name).convert('RGB')
        mask = Image.open(mask_name).convert('RGB')  # Keep as RGB for color mapping
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform_mask(mask)
        
        # Convert RGB mask to class indices
        mask_np = np.array(mask)
        class_mask = np.zeros((mask_np.shape[0], mask_np.shape[1]), dtype=np.uint8)
        
        # Map colors to class indices
        for class_idx, color in SEGMENTATION_COLOURS.items():
            # Find all pixels that match this color
            matches = np.all(mask_np == np.array(color), axis=-1)
            class_mask[matches] = class_idx
        
        return image, torch.from_numpy(class_mask).long()
    
    def transform_mask(self, mask):
        """Apply only the spatial transforms to the mask (resize)"""
        if self.transform is None:
            return np.array(mask)
        
        # Get the resize transform if it exists
        for t in self.transform.transforms:
            if isinstance(t, transforms.Resize):
                mask = t(mask)
                break
        
        return np.array(mask)

In [5]:
def print_latest_metrics_table():
    metrics_dir = 'metrics'
    try:
        if not os.path.isdir(metrics_dir):
            print(f"Metrics directory '{metrics_dir}' not found. Run training first.")
            return

        all_files = [
            os.path.join(metrics_dir, f)
            for f in os.listdir(metrics_dir)
            if os.path.isfile(os.path.join(metrics_dir, f)) and
               f.startswith('training_metrics_') and f.endswith('.csv')
        ]
        if not all_files:
            print(f"No 'training_metrics_*.csv' files found in the '{metrics_dir}' directory.")
            return

        latest_file = max(all_files, key=os.path.getmtime)
        print(f"\n--- Training Metrics Summary from: {os.path.basename(latest_file)} ---")

        with open(latest_file, 'r', newline='') as f:
            reader = csv.reader(f)
            try:
                header = next(reader) # Read header row
                # Expected header: ['Epoch', 'Train Loss', 'Val Loss', 'Mean IoU', 'Time (s)']
                print(f"{'Epoch':<7} | {'Train Loss':<12} | {'Val Loss':<10} | {'Mean IoU':<10} | {'Time (s)':<10}")
                print("-" * 60) # Separator line

                for i, row in enumerate(reader):
                    if len(row) == 5:
                        try:
                            epoch = row[0]
                            train_loss = float(row[1])
                            val_loss = float(row[2])
                            mean_iou = float(row[3])
                            time_s = float(row[4])
                            print(f"{epoch:<7} | {train_loss:<12.4f} | {val_loss:<10.4f} | {mean_iou:<10.4f} | {time_s:<10.2f}")
                        except ValueError:
                            print(f"Warning: Could not parse row {i+1} in {os.path.basename(latest_file)}: {row}")
                    else:
                        print(f"Warning: Skipping malformed row {i+1} in {os.path.basename(latest_file)} (expected 5 columns): {row}")
            except StopIteration:
                print(f"Warning: The metrics file {os.path.basename(latest_file)} is empty or has no data rows.")

    except FileNotFoundError:
        print(f"Error: The directory '{metrics_dir}' or a specific metrics file was not found. Ensure training has run.")
    except Exception as e:
        print(f"An error occurred while trying to print metrics: {e}")

In [6]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    best_val_loss = float('inf')
    
    metrics_dir = 'metrics'
    os.makedirs(metrics_dir, exist_ok=True)
    
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    metrics_file = os.path.join(metrics_dir, f'training_metrics_{timestamp}.csv')
    
    with open(metrics_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Epoch', 'Train Loss', 'Val Loss', 'Mean IoU', 'Time (s)'])
    
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        model.train()
        train_loss = 0
        train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for images, masks in train_bar:
            images = images.to(device)
            masks = masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_bar.set_postfix({'loss': train_loss / len(train_loader)})
        
        avg_train_loss = train_loss / len(train_loader)
        
        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                
                predictions = torch.argmax(outputs, dim=1)
                all_preds.append(predictions.cpu().numpy())
                all_targets.append(masks.cpu().numpy())
        
        avg_val_loss = val_loss / len(val_loader)
        
        all_preds = np.concatenate(all_preds, axis=0)
        all_targets = np.concatenate(all_targets, axis=0)
        mean_iou = calculate_iou(all_preds, all_targets)
        
        epoch_time = time.time() - epoch_start_time        
        with open(metrics_file, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([epoch+1, avg_train_loss, avg_val_loss, mean_iou, epoch_time])
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {avg_train_loss:.4f}')
        print(f'  Val Loss: {avg_val_loss:.4f}')
        print(f'  Mean IoU: {mean_iou:.4f}')
        print(f'  Time: {epoch_time:.2f}s')
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print('Model saved!')


In [7]:
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = SegmentationDataset('dataset/train', transform=transform)
    val_dataset = SegmentationDataset('dataset/val', transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)

    model = UNet(n_classes=5).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=50, device=device)
    print_latest_metrics_table()

if __name__ == '__main__':
    main() 
    

Using device: cuda


Epoch 1/50: 100%|██████████| 14/14 [00:05<00:00,  2.79it/s, loss=1.26]


Epoch 1/50:
  Train Loss: 1.2630
  Val Loss: 0.7279
  Mean IoU: 0.1710
  Time: 5.55s
Model saved!


Epoch 2/50: 100%|██████████| 14/14 [00:04<00:00,  3.14it/s, loss=0.737]


Epoch 2/50:
  Train Loss: 0.7370
  Val Loss: 0.6514
  Mean IoU: 0.1755
  Time: 4.91s
Model saved!


Epoch 3/50: 100%|██████████| 14/14 [00:04<00:00,  3.19it/s, loss=0.521]


Epoch 3/50:
  Train Loss: 0.5207
  Val Loss: 0.8100
  Mean IoU: 0.2182
  Time: 4.86s


Epoch 4/50: 100%|██████████| 14/14 [00:04<00:00,  3.23it/s, loss=0.401]


Epoch 4/50:
  Train Loss: 0.4013
  Val Loss: 0.4358
  Mean IoU: 0.2527
  Time: 4.77s
Model saved!


Epoch 5/50: 100%|██████████| 14/14 [00:04<00:00,  3.20it/s, loss=0.326]


Epoch 5/50:
  Train Loss: 0.3264
  Val Loss: 0.3859
  Mean IoU: 0.2499
  Time: 4.83s
Model saved!


Epoch 6/50: 100%|██████████| 14/14 [00:04<00:00,  3.25it/s, loss=0.288]


Epoch 6/50:
  Train Loss: 0.2883
  Val Loss: 0.3329
  Mean IoU: 0.2536
  Time: 4.76s
Model saved!


Epoch 7/50: 100%|██████████| 14/14 [00:04<00:00,  3.24it/s, loss=0.259]


Epoch 7/50:
  Train Loss: 0.2593
  Val Loss: 0.3487
  Mean IoU: 0.2110
  Time: 4.82s


Epoch 8/50: 100%|██████████| 14/14 [00:04<00:00,  3.19it/s, loss=0.235]


Epoch 8/50:
  Train Loss: 0.2349
  Val Loss: 0.3341
  Mean IoU: 0.2257
  Time: 4.84s


Epoch 9/50: 100%|██████████| 14/14 [00:04<00:00,  3.13it/s, loss=0.211]


Epoch 9/50:
  Train Loss: 0.2114
  Val Loss: 0.4452
  Mean IoU: 0.1777
  Time: 4.96s


Epoch 10/50: 100%|██████████| 14/14 [00:04<00:00,  3.21it/s, loss=0.207]


Epoch 10/50:
  Train Loss: 0.2066
  Val Loss: 0.2430
  Mean IoU: 0.2757
  Time: 4.84s
Model saved!


Epoch 11/50: 100%|██████████| 14/14 [00:04<00:00,  3.25it/s, loss=0.194]


Epoch 11/50:
  Train Loss: 0.1939
  Val Loss: 0.2322
  Mean IoU: 0.2856
  Time: 4.78s
Model saved!


Epoch 12/50: 100%|██████████| 14/14 [00:04<00:00,  3.27it/s, loss=0.182]


Epoch 12/50:
  Train Loss: 0.1816
  Val Loss: 0.3456
  Mean IoU: 0.1981
  Time: 4.73s


Epoch 13/50: 100%|██████████| 14/14 [00:04<00:00,  3.12it/s, loss=0.171]


Epoch 13/50:
  Train Loss: 0.1706
  Val Loss: 0.4085
  Mean IoU: 0.1805
  Time: 4.98s


Epoch 14/50: 100%|██████████| 14/14 [00:04<00:00,  3.19it/s, loss=0.175]


Epoch 14/50:
  Train Loss: 0.1747
  Val Loss: 0.2984
  Mean IoU: 0.2402
  Time: 4.85s


Epoch 15/50: 100%|██████████| 14/14 [00:04<00:00,  3.20it/s, loss=0.168] 


Epoch 15/50:
  Train Loss: 0.1678
  Val Loss: 0.2509
  Mean IoU: 0.2598
  Time: 4.83s


Epoch 16/50: 100%|██████████| 14/14 [00:04<00:00,  3.20it/s, loss=0.159]


Epoch 16/50:
  Train Loss: 0.1587
  Val Loss: 0.4512
  Mean IoU: 0.1866
  Time: 4.99s


Epoch 17/50: 100%|██████████| 14/14 [00:04<00:00,  3.10it/s, loss=0.152]


Epoch 17/50:
  Train Loss: 0.1521
  Val Loss: 0.4152
  Mean IoU: 0.1902
  Time: 5.05s


Epoch 18/50: 100%|██████████| 14/14 [00:04<00:00,  3.06it/s, loss=0.152]


Epoch 18/50:
  Train Loss: 0.1521
  Val Loss: 0.2266
  Mean IoU: 0.2649
  Time: 5.04s
Model saved!


Epoch 19/50: 100%|██████████| 14/14 [00:04<00:00,  3.15it/s, loss=0.153] 


Epoch 19/50:
  Train Loss: 0.1529
  Val Loss: 0.3239
  Mean IoU: 0.2357
  Time: 4.86s


Epoch 20/50: 100%|██████████| 14/14 [00:04<00:00,  3.19it/s, loss=0.157]


Epoch 20/50:
  Train Loss: 0.1575
  Val Loss: 0.6118
  Mean IoU: 0.1757
  Time: 4.90s


Epoch 21/50: 100%|██████████| 14/14 [00:04<00:00,  3.15it/s, loss=0.148]


Epoch 21/50:
  Train Loss: 0.1476
  Val Loss: 0.2606
  Mean IoU: 0.2365
  Time: 4.92s


Epoch 22/50: 100%|██████████| 14/14 [00:04<00:00,  3.08it/s, loss=0.14]  


Epoch 22/50:
  Train Loss: 0.1403
  Val Loss: 0.3518
  Mean IoU: 0.2540
  Time: 5.11s


Epoch 23/50: 100%|██████████| 14/14 [00:04<00:00,  3.08it/s, loss=0.143]


Epoch 23/50:
  Train Loss: 0.1431
  Val Loss: 0.4480
  Mean IoU: 0.1826
  Time: 5.02s


Epoch 24/50: 100%|██████████| 14/14 [00:04<00:00,  2.99it/s, loss=0.134] 


Epoch 24/50:
  Train Loss: 0.1337
  Val Loss: 0.4373
  Mean IoU: 0.2025
  Time: 5.17s


Epoch 25/50: 100%|██████████| 14/14 [00:04<00:00,  3.03it/s, loss=0.134] 


Epoch 25/50:
  Train Loss: 0.1344
  Val Loss: 0.3269
  Mean IoU: 0.2568
  Time: 5.08s


Epoch 26/50: 100%|██████████| 14/14 [00:04<00:00,  3.07it/s, loss=0.13]  


Epoch 26/50:
  Train Loss: 0.1302
  Val Loss: 0.2564
  Mean IoU: 0.2440
  Time: 5.08s


Epoch 27/50: 100%|██████████| 14/14 [00:04<00:00,  3.04it/s, loss=0.13]  


Epoch 27/50:
  Train Loss: 0.1301
  Val Loss: 0.4206
  Mean IoU: 0.2139
  Time: 5.09s


Epoch 28/50: 100%|██████████| 14/14 [00:04<00:00,  3.08it/s, loss=0.133] 


Epoch 28/50:
  Train Loss: 0.1327
  Val Loss: 0.6177
  Mean IoU: 0.1891
  Time: 5.02s


Epoch 29/50: 100%|██████████| 14/14 [00:04<00:00,  3.09it/s, loss=0.128] 


Epoch 29/50:
  Train Loss: 0.1283
  Val Loss: 0.2570
  Mean IoU: 0.2776
  Time: 5.03s


Epoch 30/50: 100%|██████████| 14/14 [00:04<00:00,  3.09it/s, loss=0.129] 


Epoch 30/50:
  Train Loss: 0.1292
  Val Loss: 0.6196
  Mean IoU: 0.1847
  Time: 5.09s


Epoch 31/50: 100%|██████████| 14/14 [00:04<00:00,  3.08it/s, loss=0.118] 


Epoch 31/50:
  Train Loss: 0.1181
  Val Loss: 0.2066
  Mean IoU: 0.2917
  Time: 5.05s
Model saved!


Epoch 32/50: 100%|██████████| 14/14 [00:04<00:00,  3.12it/s, loss=0.123] 


Epoch 32/50:
  Train Loss: 0.1231
  Val Loss: 0.5070
  Mean IoU: 0.1855
  Time: 5.02s


Epoch 33/50: 100%|██████████| 14/14 [00:04<00:00,  3.07it/s, loss=0.117] 


Epoch 33/50:
  Train Loss: 0.1174
  Val Loss: 0.4753
  Mean IoU: 0.2133
  Time: 5.04s


Epoch 34/50: 100%|██████████| 14/14 [00:04<00:00,  3.18it/s, loss=0.115] 


Epoch 34/50:
  Train Loss: 0.1148
  Val Loss: 0.3474
  Mean IoU: 0.2473
  Time: 4.91s


Epoch 35/50: 100%|██████████| 14/14 [00:04<00:00,  3.08it/s, loss=0.114] 


Epoch 35/50:
  Train Loss: 0.1135
  Val Loss: 0.2925
  Mean IoU: 0.2482
  Time: 4.99s


Epoch 36/50: 100%|██████████| 14/14 [00:04<00:00,  3.07it/s, loss=0.111] 


Epoch 36/50:
  Train Loss: 0.1109
  Val Loss: 0.3336
  Mean IoU: 0.2519
  Time: 5.10s


Epoch 37/50: 100%|██████████| 14/14 [00:04<00:00,  3.15it/s, loss=0.108] 


Epoch 37/50:
  Train Loss: 0.1081
  Val Loss: 0.4002
  Mean IoU: 0.2403
  Time: 4.94s


Epoch 38/50: 100%|██████████| 14/14 [00:04<00:00,  3.07it/s, loss=0.106] 


Epoch 38/50:
  Train Loss: 0.1061
  Val Loss: 0.7646
  Mean IoU: 0.1641
  Time: 5.10s


Epoch 39/50: 100%|██████████| 14/14 [00:04<00:00,  3.09it/s, loss=0.103] 


Epoch 39/50:
  Train Loss: 0.1025
  Val Loss: 0.2841
  Mean IoU: 0.2675
  Time: 5.03s


Epoch 40/50: 100%|██████████| 14/14 [00:04<00:00,  3.06it/s, loss=0.104] 


Epoch 40/50:
  Train Loss: 0.1036
  Val Loss: 0.5589
  Mean IoU: 0.1833
  Time: 5.06s


Epoch 41/50: 100%|██████████| 14/14 [00:04<00:00,  3.08it/s, loss=0.0998]


Epoch 41/50:
  Train Loss: 0.0998
  Val Loss: 0.3703
  Mean IoU: 0.2186
  Time: 5.01s


Epoch 42/50: 100%|██████████| 14/14 [00:04<00:00,  3.16it/s, loss=0.0968]


Epoch 42/50:
  Train Loss: 0.0968
  Val Loss: 0.2827
  Mean IoU: 0.2600
  Time: 4.90s


Epoch 43/50: 100%|██████████| 14/14 [00:04<00:00,  3.19it/s, loss=0.0952]


Epoch 43/50:
  Train Loss: 0.0952
  Val Loss: 0.4705
  Mean IoU: 0.2224
  Time: 4.86s


Epoch 44/50: 100%|██████████| 14/14 [00:04<00:00,  3.11it/s, loss=0.0939]


Epoch 44/50:
  Train Loss: 0.0939
  Val Loss: 0.4828
  Mean IoU: 0.2032
  Time: 4.97s


Epoch 45/50: 100%|██████████| 14/14 [00:04<00:00,  3.06it/s, loss=0.0922]


Epoch 45/50:
  Train Loss: 0.0922
  Val Loss: 0.4971
  Mean IoU: 0.1909
  Time: 5.08s


Epoch 46/50: 100%|██████████| 14/14 [00:04<00:00,  3.14it/s, loss=0.0872]


Epoch 46/50:
  Train Loss: 0.0872
  Val Loss: 0.2884
  Mean IoU: 0.2624
  Time: 4.95s


Epoch 47/50: 100%|██████████| 14/14 [00:04<00:00,  3.12it/s, loss=0.0893]


Epoch 47/50:
  Train Loss: 0.0893
  Val Loss: 0.2414
  Mean IoU: 0.2951
  Time: 4.95s


Epoch 48/50: 100%|██████████| 14/14 [00:04<00:00,  3.10it/s, loss=0.0863]


Epoch 48/50:
  Train Loss: 0.0863
  Val Loss: 0.3488
  Mean IoU: 0.2366
  Time: 4.98s


Epoch 49/50: 100%|██████████| 14/14 [00:04<00:00,  3.17it/s, loss=0.083] 


Epoch 49/50:
  Train Loss: 0.0830
  Val Loss: 0.2293
  Mean IoU: 0.2944
  Time: 4.89s


Epoch 50/50: 100%|██████████| 14/14 [00:04<00:00,  3.17it/s, loss=0.0794]


Epoch 50/50:
  Train Loss: 0.0794
  Val Loss: 0.5721
  Mean IoU: 0.2177
  Time: 4.88s

--- Training Metrics Summary from: training_metrics_20250506_230956.csv ---
Epoch   | Train Loss   | Val Loss   | Mean IoU   | Time (s)  
------------------------------------------------------------
1       | 1.2630       | 0.7279     | 0.1710     | 5.55      
2       | 0.7370       | 0.6514     | 0.1755     | 4.91      
3       | 0.5207       | 0.8100     | 0.2182     | 4.86      
4       | 0.4013       | 0.4358     | 0.2527     | 4.77      
5       | 0.3264       | 0.3859     | 0.2499     | 4.83      
6       | 0.2883       | 0.3329     | 0.2536     | 4.76      
7       | 0.2593       | 0.3487     | 0.2110     | 4.82      
8       | 0.2349       | 0.3341     | 0.2257     | 4.84      
9       | 0.2114       | 0.4452     | 0.1777     | 4.96      
10      | 0.2066       | 0.2430     | 0.2757     | 4.84      
11      | 0.1939       | 0.2322     | 0.2856     | 4.78      
12      | 0.1816       | 0.3456 