In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir=None, transform=None, create_dummy_masks=False):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.create_dummy_masks = create_dummy_masks
        
        # Get all image files from class subdirectories (your original structure)
        self.image_paths = []
        self.labels = []
        self.class_names = {}
        
        for label, class_dir in enumerate(sorted(os.listdir(image_dir))):
            if os.path.isdir(os.path.join(image_dir, class_dir)):
                self.class_names[label] = class_dir
                class_path = os.path.join(image_dir, class_dir)
                
                for img_name in sorted(os.listdir(class_path)):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        self.image_paths.append(os.path.join(class_path, img_name))
                        self.labels.append(label)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        class_label = self.labels[idx]
        
        # Load image
        try:
            image = Image.open(image_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {image_path}: {e}")
            image = Image.new('RGB', (256, 256), (0, 0, 0))
        
        # Create or load mask
        if self.create_dummy_masks:
            # Create a dummy segmentation mask for demonstration
            # This creates simple geometric patterns based on class
            mask = self._create_dummy_mask(image.size, class_label)
        else:
            # Try to load real mask (if mask_dir is provided)
            img_name = os.path.basename(image_path)
            mask_name = os.path.splitext(img_name)[0] + '.png'
            mask_path = os.path.join(self.mask_dir, mask_name)
            try:
                mask = Image.open(mask_path).convert('L')
            except Exception as e:
                print(f"Error loading {mask_path}: {e}")
                mask = Image.new('L', image.size, 0)
        
        if self.transform:
            # Apply same transform to both image and mask
            seed = np.random.randint(2147483647)
            
            # Transform image
            np.random.seed(seed)
            torch.manual_seed(seed)
            image = self.transform['image'](image)
            
            # Transform mask
            np.random.seed(seed)
            torch.manual_seed(seed)
            mask = self.transform['mask'](mask)
            mask = mask.squeeze(0).long()
        
        return image, mask
    
    def _create_dummy_mask(self, image_size, class_label):
        """Create a dummy segmentation mask for demonstration purposes"""
        width, height = image_size
        mask = np.zeros((height, width), dtype=np.uint8)
        
        # Create different patterns based on class
        if class_label == 0:
            # Horizontal stripes
            mask[height//4:3*height//4, :] = 1
        elif class_label == 1:
            # Vertical stripes  
            mask[:, width//4:3*width//4] = 2
        elif class_label == 2:
            # Circle in center
            center_x, center_y = width//2, height//2
            radius = min(width, height) // 4
            y, x = np.ogrid[:height, :width]
            mask_circle = (x - center_x)**2 + (y - center_y)**2 <= radius**2
            mask[mask_circle] = 3
        else:
            # Random pattern for other classes
            mask[::2, ::2] = (class_label % 5) + 1
            
        return Image.fromarray(mask, mode='L')

# Custom CNN for Semantic Segmentation (UNet-like architecture)
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # Handle size mismatch
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                       diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class CustomSegmentationCNN(nn.Module):
    def __init__(self, n_channels=3, n_classes=21):  # 21 for PASCAL VOC, adjust as needed
        super(CustomSegmentationCNN, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        
        # Encoder (Downsampling path)
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        
        # Decoder (Upsampling path)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        
        # Output layer
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
    
    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # Decoder with skip connections
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        # Output
        logits = self.outc(x)
        return logits

# Data transforms for segmentation
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

mask_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=Image.NEAREST),  # Use nearest for masks
    transforms.ToTensor()
])

train_transform = {
    'image': image_transform,
    'mask': mask_transform
}

test_image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_mask_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=Image.NEAREST),
    transforms.ToTensor()
])

test_transform = {
    'image': test_image_transform,
    'mask': test_mask_transform
}

# Reusing original paths from your classification code
train_image_dir = '/Workspace/sid-v2/computervision1/Classification_dataset_v3/images/train'
test_image_dir = '/Workspace/sid-v2/computervision1/Classification_dataset_v3/images/test'

# Create datasets with dummy masks for demonstration
print("Creating datasets with dummy masks for demonstration...")
training_dataset = SegmentationDataset(image_dir=train_image_dir, transform=train_transform, create_dummy_masks=True)
test_dataset = SegmentationDataset(image_dir=test_image_dir, transform=test_transform, create_dummy_masks=True)

# Number of classes from your original classification dataset
num_classes = len(training_dataset.class_names) + 1  # +1 for background class
print(f"Number of classes: {num_classes}")
print("Class names:", training_dataset.class_names)
print("Note: Using dummy masks for demonstration. Class 0 = background, Classes 1+ = your original classes")

train_loader = DataLoader(dataset=training_dataset, batch_size=8, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=8, shuffle=False, num_workers=4)

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = CustomSegmentationCNN(n_channels=3, n_classes=num_classes)
model = model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=255)  # 255 is typically used for ignore class
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

# Metrics
def calculate_iou(pred, target, num_classes):
    """Calculate Intersection over Union (IoU) for each class"""
    ious = []
    pred = pred.view(-1)
    target = target.view(-1)
    
    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        
        if target_inds.sum().item() == 0:
            ious.append(float('nan'))
        else:
            intersection = (pred_inds & target_inds).sum().item()
            union = (pred_inds | target_inds).sum().item()
            if union == 0:
                ious.append(float('nan'))
            else:
                ious.append(intersection / union)
    
    return ious

def train_model(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    pixel_correct = 0
    pixel_total = 0
    
    for images, masks in tqdm(train_loader, desc="Training"):
        images, masks = images.to(device), masks.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # Calculate pixel accuracy
        _, predicted = torch.max(outputs.data, 1)
        mask_pixels = (masks != 255)  # Ignore index
        pixel_total += mask_pixels.sum().item()
        pixel_correct += ((predicted == masks) & mask_pixels).sum().item()
    
    train_loss = running_loss / len(train_loader)
    pixel_acc = 100 * pixel_correct / pixel_total if pixel_total > 0 else 0
    
    return train_loss, pixel_acc

def validate_model(model, test_loader, criterion, device, num_classes):
    model.eval()
    running_loss = 0.0
    pixel_correct = 0
    pixel_total = 0
    all_ious = []
    
    with torch.no_grad():
        for images, masks in tqdm(test_loader, desc="Validating"):
            images, masks = images.to(device), masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            running_loss += loss.item()
            
            # Calculate pixel accuracy
            _, predicted = torch.max(outputs.data, 1)
            mask_pixels = (masks != 255)
            pixel_total += mask_pixels.sum().item()
            pixel_correct += ((predicted == masks) & mask_pixels).sum().item()
            
            # Calculate IoU for this batch
            for i in range(predicted.size(0)):
                pred_mask = predicted[i][mask_pixels[i]]
                true_mask = masks[i][mask_pixels[i]]
                if len(pred_mask) > 0:
                    ious = calculate_iou(pred_mask, true_mask, num_classes)
                    all_ious.append(ious)
    
    val_loss = running_loss / len(test_loader)
    pixel_acc = 100 * pixel_correct / pixel_total if pixel_total > 0 else 0
    
    # Calculate mean IoU
    if all_ious:
        mean_ious = []
        for cls in range(num_classes):
            class_ious = [iou[cls] for iou in all_ious if not np.isnan(iou[cls])]
            if class_ious:
                mean_ious.append(np.mean(class_ious))
        mean_iou = np.mean(mean_ious) if mean_ious else 0
    else:
        mean_iou = 0
    
    return val_loss, pixel_acc, mean_iou

# Training configuration
num_epochs = 20
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
val_ious = []

print("Starting Custom CNN Segmentation training...")
print(f"Model: Custom UNet-like CNN")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 30)
    
    train_loss, train_acc = train_model(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, mean_iou = validate_model(model, test_loader, criterion, device, num_classes)
    
    scheduler.step()
    
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    val_ious.append(mean_iou)
    
    print(f"Train Loss: {train_loss:.4f}, Train Pixel Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Pixel Acc: {val_acc:.2f}%, Mean IoU: {mean_iou:.4f}")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

print("Training completed!")

# Plot training results
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss', linewidth=2)
plt.plot(val_losses, label='Validation Loss', linewidth=2)
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(train_accuracies, label='Train Pixel Accuracy', linewidth=2)
plt.plot(val_accuracies, label='Validation Pixel Accuracy', linewidth=2)
plt.title('Training and Validation Pixel Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
plt.plot(val_ious, label='Validation Mean IoU', linewidth=2, color='green')
plt.title('Validation Mean IoU')
plt.xlabel('Epoch')
plt.ylabel('Mean IoU')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# Save the trained model
torch.save(model.state_dict(), 'custom_segmentation_cnn.pth')
print("Model saved as 'custom_segmentation_cnn.pth'")

def predict_segmentation(model, image_path, transform, device, num_classes):
    """Make segmentation prediction on a single image"""
    model.eval()
    
    image = Image.open(image_path).convert('RGB')
    original_size = image.size
    
    image_tensor = transform['image'](image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = F.softmax(outputs, dim=1)
        predicted_mask = torch.argmax(probabilities, dim=1).squeeze(0)
    
    # Convert back to PIL Image
    predicted_mask = predicted_mask.cpu().numpy().astype(np.uint8)
    predicted_mask_pil = Image.fromarray(predicted_mask, mode='L')
    predicted_mask_pil = predicted_mask_pil.resize(original_size, Image.NEAREST)
    
    return predicted_mask_pil, probabilities.cpu()

def visualize_segmentation(image_path, predicted_mask, num_classes):
    """Visualize original image and predicted segmentation"""
    original_image = Image.open(image_path).convert('RGB')
    
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 2, 1)
    plt.imshow(original_image)
    plt.title('Original Image')
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    # Create a colormap for visualization
    predicted_mask_array = np.array(predicted_mask)
    colored_mask = plt.cm.tab20(predicted_mask_array / num_classes)
    plt.imshow(colored_mask)
    plt.title('Predicted Segmentation')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Example usage
print("\nModel ready for inference!")
print("Note: This demo uses dummy masks created from your classification dataset")
print("For real segmentation, you would need:")
print("1. Actual pixel-level annotated masks")
print("2. Images and masks in the proper segmentation dataset structure")
print("\nUse predict_segmentation() for single image prediction")
print("Use visualize_segmentation() to visualize results")

# Example of how to use:
# predicted_mask, probabilities = predict_segmentation(model, 'path/to/image.jpg', test_transform, device, num_classes)
# visualize_segmentation('path/to/image.jpg', predicted_mask, num_classes)

print(f"\nDataset loaded successfully:")
print(f"Training samples: {len(training_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Classes: {training_dataset.class_names}")