In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
import logging
from datetime import datetime

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# Paths
SYNTHETIC_DATA_DIR = Path(r"C:\IMMAGE ANALYSIS PROJECT DATASET\weed\weeds\output\images")
MASK_DIR = Path(r"C:\IMMAGE ANALYSIS PROJECT DATASET\weed\weeds\output\masks")
VISUALIZATION_DIR = Path(r"C:\IMMAGE ANALYSIS PROJECT DATASET\weed\weeds\output\visualization")
MODELS_DIR = Path("saved_models")
MODELS_DIR.mkdir(exist_ok=True)

class LightweightUNet(nn.Module):
    def __init__(self, n_classes=3):
        super(LightweightUNet, self).__init__()
        
        # Encoder (downsampling)
        self.enc1 = self._conv_block(3, 32)
        self.enc2 = self._conv_block(32, 64)
        self.enc3 = self._conv_block(64, 128)
        self.enc4 = self._conv_block(128, 256)
        
        # Decoder (upsampling)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = self._conv_block(256, 128)
        
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = self._conv_block(128, 64)
        
        self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = self._conv_block(64, 32)
        
        self.final = nn.Conv2d(32, n_classes, kernel_size=1)
        self.max_pool = nn.MaxPool2d(2)
        
    def _conv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        # Encoder path with skip connections
        e1 = self.enc1(x)
        p1 = self.max_pool(e1)
        
        e2 = self.enc2(p1)
        p2 = self.max_pool(e2)
        
        e3 = self.enc3(p2)
        p3 = self.max_pool(e3)
        
        # Bridge
        e4 = self.enc4(p3)
        
        # Decoder path
        d3 = self.up3(e4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        return self.final(d1)

class PlantSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = Path(image_dir)
        self.mask_dir = Path(mask_dir)
        self.transform = transform
        self.images = [img for img in sorted(os.listdir(image_dir)) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = self.image_dir / img_name
        mask_path = self.mask_dir / img_name.replace('synthetic', 'mask')
        
        try:
            # Load and preprocess image
            image = cv2.imread(str(img_path))
            if image is None:
                raise ValueError(f"Image not found or unable to read at {img_path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            # Load mask
            mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
            if mask is None:
                raise ValueError(f"Mask not found or unable to read at {mask_path}")
            
            # One-hot encode mask
            mask_one_hot = np.zeros((3, mask.shape[0], mask.shape[1]), dtype=np.float32)
            for i in range(3):
                mask_one_hot[i, :, :] = (mask == i).astype(np.float32)

            # Apply transformations
            if self.transform:
                augmented = self.transform(image=image, mask=mask_one_hot.transpose(1, 2, 0))
                image = augmented['image']
                mask_one_hot = augmented['mask'].permute(2, 0, 1)

            return image, mask_one_hot
        
        except Exception as e:
            logging.error(f"Error loading file {img_name}: {e}")
            raise

def get_transforms(train=True):
    if train:
        return A.Compose([
            A.RandomRotate90(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.GaussNoise(p=0.2),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

def save_model(model, optimizer, epoch, train_loss, save_dir, filename):
    """Save model checkpoint with additional training information"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
    }
    
    save_path = save_dir / filename
    torch.save(checkpoint, save_path)
    logging.info(f"Model saved to {save_path}")

def load_checkpoint(model, optimizer, checkpoint_path):
    """Load model checkpoint"""
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['train_loss']

def plot_training_history(train_losses, save_dir):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training History')
    plt.legend()
    
    save_path = save_dir / f'training_history_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png'
    plt.savefig(save_path)
    plt.close()
    logging.info(f"Training history plot saved to {save_path}")

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    with tqdm(dataloader, desc='Training') as pbar:
        for images, masks in pbar:
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            
            optimizer.zero_grad(set_to_none=True)  # More efficient than zero_grad()
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}")
    
    return total_loss / len(dataloader)

def main():
    # Training configuration
    config = {
        'BATCH_SIZE': 4,
        'NUM_EPOCHS': 50,
        'LEARNING_RATE': 0.001,
        'SAVE_FREQUENCY': 5,  # Save model every N epochs
        'NUM_WORKERS': 4 if torch.cuda.is_available() else 0,  # Use multiple workers if GPU available
    }
    
    # Set up device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f"Using device: {device}")
    
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True  # Optimize CUDA performance
    
    # Create datasets and loaders
    train_dataset = PlantSegmentationDataset(
        SYNTHETIC_DATA_DIR,
        MASK_DIR,
        transform=get_transforms(train=True)
    )
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config['BATCH_SIZE'],
        shuffle=True,
        num_workers=config['NUM_WORKERS'],
        pin_memory=True,
        persistent_workers=True if config['NUM_WORKERS'] > 0 else False
    )
    
    # Initialize model, criterion, optimizer
    model = LightweightUNet().to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        logging.info(f"Using {torch.cuda.device_count()} GPUs")
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['LEARNING_RATE'])
    
    # Optional: Load checkpoint if exists
    checkpoint_path = MODELS_DIR / 'latest_checkpoint.pth'
    start_epoch = 0
    train_losses = []
    
    if checkpoint_path.exists():
        start_epoch, prev_losses = load_checkpoint(model, optimizer, checkpoint_path)
        train_losses = prev_losses
        logging.info(f"Resumed training from epoch {start_epoch}")
    
    # Training loop
    try:
        for epoch in range(start_epoch, config['NUM_EPOCHS']):
            logging.info(f"\nEpoch {epoch+1}/{config['NUM_EPOCHS']}")
            
            # Train
            train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
            train_losses.append(train_loss)
            
            logging.info(f"Train Loss: {train_loss:.4f}")
            
            # Save checkpoint periodically
            if (epoch + 1) % config['SAVE_FREQUENCY'] == 0:
                save_model(
                    model if not isinstance(model, nn.DataParallel) else model.module,
                    optimizer,
                    epoch + 1,
                    train_losses,
                    MODELS_DIR,
                    f'checkpoint_epoch_{epoch+1}.pth'
                )
        
        # Save final model
        save_model(
            model if not isinstance(model, nn.DataParallel) else model.module,
            optimizer,
            config['NUM_EPOCHS'],
            train_losses,
            MODELS_DIR,
            'final_model.pth'
        )
        
        # Plot and save training history
        plot_training_history(train_losses, MODELS_DIR)
        logging.info("Training completed successfully!")
        
    except KeyboardInterrupt:
        logging.info("Training interrupted by user")
        # Save checkpoint on interruption
        save_model(
            model if not isinstance(model, nn.DataParallel) else model.module,
            optimizer,
            epoch + 1,
            train_losses,
            MODELS_DIR,
            'interrupted_checkpoint.pth'
        )
        
    except Exception as e:
        logging.error(f"Training failed with error: {e}")
        raise

if __name__ == "__main__":
    main()

  check_for_updates()
2024-11-02 09:50:38,863 - INFO - Using device: cuda
2024-11-02 09:50:39,277 - INFO - 
Epoch 1/50
Training:   0%|                                                                                | 0/250 [00:05<?, ?it/s]
2024-11-02 09:50:44,847 - ERROR - Training failed with error: DataLoader worker (pid(s) 24700, 9744, 8316, 9100) exited unexpectedly


RuntimeError: DataLoader worker (pid(s) 24700, 9744, 8316, 9100) exited unexpectedly

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
import logging
from datetime import datetime

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# Paths
SYNTHETIC_DATA_DIR = Path(r"C:\IMMAGE ANALYSIS PROJECT DATASET\weed\weeds\output\images")
MASK_DIR = Path(r"C:\IMMAGE ANALYSIS PROJECT DATASET\weed\weeds\output\masks")
VISUALIZATION_DIR = Path(r"C:\IMMAGE ANALYSIS PROJECT DATASET\weed\weeds\output\visualization")
MODELS_DIR = Path("saved_models")
MODELS_DIR.mkdir(exist_ok=True)

class LightweightUNet(nn.Module):
    def __init__(self, n_classes=3):
        super(LightweightUNet, self).__init__()
        
        # Encoder (downsampling)
        self.enc1 = self._conv_block(3, 32)
        self.enc2 = self._conv_block(32, 64)
        self.enc3 = self._conv_block(64, 128)
        self.enc4 = self._conv_block(128, 256)
        
        # Decoder (upsampling)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = self._conv_block(256, 128)
        
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = self._conv_block(128, 64)
        
        self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = self._conv_block(64, 32)
        
        self.final = nn.Conv2d(32, n_classes, kernel_size=1)
        self.max_pool = nn.MaxPool2d(2)
        
    def _conv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        # Encoder path with skip connections
        e1 = self.enc1(x)
        p1 = self.max_pool(e1)
        
        e2 = self.enc2(p1)
        p2 = self.max_pool(e2)
        
        e3 = self.enc3(p2)
        p3 = self.max_pool(e3)
        
        # Bridge
        e4 = self.enc4(p3)
        
        # Decoder path
        d3 = self.up3(e4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        return self.final(d1)

class PlantSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = Path(image_dir)
        self.mask_dir = Path(mask_dir)
        self.transform = transform
        
        # Validate directories exist
        if not self.image_dir.exists():
            raise ValueError(f"Image directory not found: {image_dir}")
        if not self.mask_dir.exists():
            raise ValueError(f"Mask directory not found: {mask_dir}")
            
        # Get valid image files
        self.images = []
        for img in sorted(os.listdir(image_dir)):
            if img.lower().endswith(('.png', '.jpg', '.jpeg')):
                # Check if corresponding mask exists
                mask_name = img.replace('synthetic', 'mask')
                if (self.mask_dir / mask_name).exists():
                    self.images.append(img)
                else:
                    logging.warning(f"Skipping {img} - no corresponding mask found")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        if idx >= len(self.images):
            raise IndexError(f"Index {idx} out of bounds for dataset of size {len(self.images)}")
            
        img_name = self.images[idx]
        img_path = self.image_dir / img_name
        mask_path = self.mask_dir / img_name.replace('synthetic', 'mask')
        
        try:
            # Load and preprocess image
            image = cv2.imread(str(img_path))
            if image is None:
                raise ValueError(f"Failed to read image at {img_path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            # Load mask
            mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
            if mask is None:
                raise ValueError(f"Failed to read mask at {mask_path}")
            
            # Validate image and mask dimensions match
            if image.shape[:2] != mask.shape[:2]:
                raise ValueError(f"Image and mask dimensions don't match for {img_name}")
            
            # One-hot encode mask
            mask_one_hot = np.zeros((3, mask.shape[0], mask.shape[1]), dtype=np.float32)
            for i in range(3):
                mask_one_hot[i, :, :] = (mask == i).astype(np.float32)

            # Apply transformations
            if self.transform:
                try:
                    augmented = self.transform(image=image, mask=mask_one_hot.transpose(1, 2, 0))
                    image = augmented['image']
                    mask_one_hot = augmented['mask'].permute(2, 0, 1)
                except Exception as e:
                    raise ValueError(f"Transform failed for {img_name}: {str(e)}")

            return image, mask_one_hot
        
        except Exception as e:
            logging.error(f"Error processing {img_name}: {str(e)}")
            raise

def get_transforms(train=True):
    if train:
        return A.Compose([
            A.RandomRotate90(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.GaussNoise(p=0.2),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

def save_model(model, optimizer, epoch, train_loss, save_dir, filename):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
    }
    
    save_path = save_dir / filename
    torch.save(checkpoint, save_path)
    logging.info(f"Model saved to {save_path}")

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    with tqdm(dataloader, desc='Training') as pbar:
        for images, masks in pbar:
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            
            optimizer.zero_grad(set_to_none=True)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}")
    
    return total_loss / len(dataloader)

def main():
    # Training configuration
    config = {
        'BATCH_SIZE': 4,
        'NUM_EPOCHS': 50,
        'LEARNING_RATE': 0.001,
        'SAVE_FREQUENCY': 5,
        'NUM_WORKERS': 0,
    }
    
    # Set up device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f"Using device: {device}")
    
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
    
    try:
        train_dataset = PlantSegmentationDataset(
            SYNTHETIC_DATA_DIR,
            MASK_DIR,
            transform=get_transforms(train=True)
        )
        
        train_loader = DataLoader(
            train_dataset, 
            batch_size=config['BATCH_SIZE'],
            shuffle=True,
            num_workers=0,
            pin_memory=True if torch.cuda.is_available() else False,
            persistent_workers=False,
            prefetch_factor=2,
            timeout=60,
        )
        
        model = LightweightUNet(n_classes=3).to(device)
        optimizer = optim.Adam(model.parameters(), lr=config['LEARNING_RATE'])
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(config['NUM_EPOCHS']):
            train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
            logging.info(f"Epoch [{epoch + 1}/{config['NUM_EPOCHS']}], Loss: {train_loss:.4f}")
            
            # Save model checkpoint
            if (epoch + 1) % config['SAVE_FREQUENCY'] == 0:
                save_model(model, optimizer, epoch + 1, train_loss, MODELS_DIR, f"model_epoch_{epoch + 1}.pth")
    
    except Exception as e:
        logging.error(f"Training failed: {str(e)}")

if __name__ == "__main__":
    main()


2024-11-02 09:55:16,738 - INFO - Using device: cuda
2024-11-02 09:55:17,009 - ERROR - Training failed: prefetch_factor option could only be specified in multiprocessing.let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
import logging
from datetime import datetime

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# Paths
SYNTHETIC_DATA_DIR = Path(r"C:\IMMAGE ANALYSIS PROJECT DATASET\weed\weeds\output\images")
MASK_DIR = Path(r"C:\IMMAGE ANALYSIS PROJECT DATASET\weed\weeds\output\masks")
VISUALIZATION_DIR = Path(r"C:\IMMAGE ANALYSIS PROJECT DATASET\weed\weeds\output\visualization")
MODELS_DIR = Path("saved_models")
MODELS_DIR.mkdir(exist_ok=True)

class LightweightUNet(nn.Module):
    def __init__(self, n_classes=3):
        super(LightweightUNet, self).__init__()
        
        # Encoder (downsampling)
        self.enc1 = self._conv_block(3, 32)
        self.enc2 = self._conv_block(32, 64)
        self.enc3 = self._conv_block(64, 128)
        self.enc4 = self._conv_block(128, 256)
        
        # Decoder (upsampling)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = self._conv_block(256, 128)
        
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = self._conv_block(128, 64)
        
        self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = self._conv_block(64, 32)
        
        self.final = nn.Conv2d(32, n_classes, kernel_size=1)
        self.max_pool = nn.MaxPool2d(2)
        
    def _conv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        # Encoder path with skip connections
        e1 = self.enc1(x)
        p1 = self.max_pool(e1)
        
        e2 = self.enc2(p1)
        p2 = self.max_pool(e2)
        
        e3 = self.enc3(p2)
        p3 = self.max_pool(e3)
        
        # Bridge
        e4 = self.enc4(p3)
        
        # Decoder path
        d3 = self.up3(e4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        return self.final(d1)

class PlantSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = Path(image_dir)
        self.mask_dir = Path(mask_dir)
        self.transform = transform
        
        # Validate directories exist
        if not self.image_dir.exists():
            raise ValueError(f"Image directory not found: {image_dir}")
        if not self.mask_dir.exists():
            raise ValueError(f"Mask directory not found: {mask_dir}")
            
        # Get valid image files
        self.images = []
        for img in sorted(os.listdir(image_dir)):
            if img.lower().endswith(('.png', '.jpg', '.jpeg')):
                # Check if corresponding mask exists
                mask_name = img.replace('synthetic', 'mask')
                if (self.mask_dir / mask_name).exists():
                    self.images.append(img)
                else:
                    logging.warning(f"Skipping {img} - no corresponding mask found")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        if idx >= len(self.images):
            raise IndexError(f"Index {idx} out of bounds for dataset of size {len(self.images)}")
            
        img_name = self.images[idx]
        img_path = self.image_dir / img_name
        mask_path = self.mask_dir / img_name.replace('synthetic', 'mask')
        
        try:
            # Load and preprocess image
            image = cv2.imread(str(img_path))
            if image is None:
                raise ValueError(f"Failed to read image at {img_path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            # Load mask
            mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
            if mask is None:
                raise ValueError(f"Failed to read mask at {mask_path}")
            
            # Validate image and mask dimensions match
            if image.shape[:2] != mask.shape[:2]:
                raise ValueError(f"Image and mask dimensions don't match for {img_name}")
            
            # One-hot encode mask
            mask_one_hot = np.zeros((3, mask.shape[0], mask.shape[1]), dtype=np.float32)
            for i in range(3):
                mask_one_hot[i, :, :] = (mask == i).astype(np.float32)

            # Apply transformations
            if self.transform:
                try:
                    augmented = self.transform(image=image, mask=mask_one_hot.transpose(1, 2, 0))
                    image = augmented['image']
                    mask_one_hot = augmented['mask'].permute(2, 0, 1)
                except Exception as e:
                    raise ValueError(f"Transform failed for {img_name}: {str(e)}")

            return image, mask_one_hot
        
        except Exception as e:
            logging.error(f"Error processing {img_name}: {str(e)}")
            raise

def get_transforms(train=True):
    if train:
        return A.Compose([
            A.RandomRotate90(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.GaussNoise(p=0.2),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

def save_model(model, optimizer, epoch, train_loss, save_dir, filename):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
    }
    
    save_path = save_dir / filename
    torch.save(checkpoint, save_path)
    logging.info(f"Model saved to {save_path}")

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    with tqdm(dataloader, desc='Training') as pbar:
        for images, masks in pbar:
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            
            optimizer.zero_grad(set_to_none=True)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}")
    
    return total_loss / len(dataloader)

def main():
    # Training configuration
    config = {
        'BATCH_SIZE': 4,
        'NUM_EPOCHS': 50,
        'LEARNING_RATE': 0.001,
        'SAVE_FREQUENCY': 5,
        'NUM_WORKERS': 0,
    }
    
    # Set up device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f"Using device: {device}")
    
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
    
    try:
        train_dataset = PlantSegmentationDataset(
            SYNTHETIC_DATA_DIR,
            MASK_DIR,
            transform=get_transforms(train=True)
        )
        
        train_loader = DataLoader(
            train_dataset, 
            batch_size=config['BATCH_SIZE'],
            shuffle=True,
            num_workers=config['NUM_WORKERS'],
            pin_memory=torch.cuda.is_available(),
            persistent_workers=config['NUM_WORKERS'] > 0,
            prefetch_factor=(2 if config['NUM_WORKERS'] > 0 else None),
            timeout=0,
            drop_last=True  # Ensure batches are complete
        )
        
        model = LightweightUNet(n_classes=3).to(device)
        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(model.parameters(), lr=config['LEARNING_RATE'])
        
        # Training loop
        for epoch in range(config['NUM_EPOCHS']):
            logging.info(f"Epoch {epoch + 1}/{config['NUM_EPOCHS']}")
            train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)

            logging.info(f"Epoch {epoch + 1}/{config['NUM_EPOCHS']} - Loss: {train_loss:.4f}")
            
            if (epoch + 1) % config['SAVE_FREQUENCY'] == 0:
                save_model(model, optimizer, epoch, train_loss, MODELS_DIR, f"model_epoch_{epoch + 1}.pt")

    except Exception as e:
        logging.error(f"Training failed: {str(e)}")

if __name__ == "__main__":
    main()


2024-11-02 10:00:32,369 - INFO - Using device: cuda
2024-11-02 10:00:32,524 - INFO - Epoch 1/50
Training: 100%|█████████████████████████████████████████████████████████| 250/250 [00:58<00:00,  4.27it/s, loss=0.1432]
2024-11-02 10:01:31,137 - INFO - Epoch 1/50 - Loss: 0.2285
2024-11-02 10:01:31,137 - INFO - Epoch 2/50
Training: 100%|█████████████████████████████████████████████████████████| 250/250 [00:48<00:00,  5.14it/s, loss=0.0795]
2024-11-02 10:02:19,754 - INFO - Epoch 2/50 - Loss: 0.1227
2024-11-02 10:02:19,755 - INFO - Epoch 3/50
Training: 100%|█████████████████████████████████████████████████████████| 250/250 [00:48<00:00,  5.10it/s, loss=0.1581]
2024-11-02 10:03:08,755 - INFO - Epoch 3/50 - Loss: 0.1094
2024-11-02 10:03:08,755 - INFO - Epoch 4/50
Training: 100%|█████████████████████████████████████████████████████████| 250/250 [00:48<00:00,  5.15it/s, loss=0.0697]
2024-11-02 10:03:57,273 - INFO - Epoch 4/50 - Loss: 0.1061
2024-11-02 10:03:57,274 - INFO - Epoch 5/50
Training: 10