<a href="https://colab.research.google.com/github/sid0nair/Crack-propagation-predictor/blob/main/APL405.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision opencv-python matplotlib

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset

class CrackDataset(Dataset):
    def __init__(self, input_dir, output_dir, target_size=(256, 256)):
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.input_files = sorted(os.listdir(input_dir))
        self.output_files = sorted(os.listdir(output_dir))
        self.target_size = target_size

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        input_path = os.path.join(self.input_dir, self.input_files[idx])
        output_path = os.path.join(self.output_dir, self.output_files[idx])

        # Read images in color
        input_img = cv2.imread(input_path, cv2.IMREAD_COLOR)
        output_img = cv2.imread(output_path, cv2.IMREAD_COLOR)

        # Convert from BGR (default in OpenCV) to RGB
        input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)
        output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)

        # Resize images to the target dimensions
        input_img = cv2.resize(input_img, self.target_size)
        output_img = cv2.resize(output_img, self.target_size)

        # Normalize pixel values to [0, 1]
        input_img = input_img.astype('float32') / 255.0
        output_img = output_img.astype('float32') / 255.0

        # Convert from HWC to CHW format
        input_img = np.transpose(input_img, (2, 0, 1))
        output_img = np.transpose(output_img, (2, 0, 1))

        return {
            'input': torch.tensor(input_img, dtype=torch.float32),
            'output': torch.tensor(output_img, dtype=torch.float32)
        }

# Define directories and create the dataset
input_dir = '/content/data/input'
output_dir = '/content/data/output'
dataset = CrackDataset(input_dir, output_dir)

FileNotFoundError: [Errno 2] No such file or directory: '/content/data/input'

In [None]:
from torch.utils.data import DataLoader, random_split

dataset_size = len(dataset)
train_size = int(0.8 * dataset_size)
test_size = dataset_size - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Define DataLoaders
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet, self).__init__()
        # Encoder
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2)

        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2)

        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU()
        )

        # Decoder
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv5 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)

        # Decoder
        u2 = self.up2(c3)
        merge2 = torch.cat([u2, c2], dim=1)
        c4 = self.conv4(merge2)
        u1 = self.up1(c4)
        merge1 = torch.cat([u1, c1], dim=1)
        c5 = self.conv5(merge1)

        output = self.final(c5)
        output = torch.sigmoid(output)  # Ensure outputs are in [0,1]
        return output

# Instantiate the model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=3, out_channels=3).to(device)

# Use MSE loss since we are predicting full-color images
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

print(model)

In [None]:
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for batch in train_loader:
        inputs = batch['input'].to(device)   # shape: [batch_size, 3, H, W]
        targets = batch['output'].to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * inputs.size(0)

    epoch_loss /= len(train_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

    # Optionally, save model checkpoints at intervals:
    # if (epoch + 1) % 10 == 0:
    #     torch.save(model.state_dict(), f"unet_epoch_{epoch+1}.pth")

In [None]:
model.eval()
test_loss = 0.0
with torch.no_grad():
    for batch in test_loader:
        inputs = batch['input'].to(device)
        targets = batch['output'].to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        test_loss += loss.item() * inputs.size(0)
    test_loss /= len(test_loader.dataset)

print(f"Test Loss (MSE): {test_loss:.4f}")

In [None]:
import matplotlib.pyplot as plt

model.eval()
with torch.no_grad():
    # Retrieve one sample from the test DataLoader
    sample_batch = next(iter(test_loader))
    input_img = sample_batch['input'][0].cpu().numpy()   # Shape: [3, H, W]
    target_img = sample_batch['output'][0].cpu().numpy()
    pred_img = model(sample_batch['input'].to(device))[0].cpu().numpy()

    # Convert images from CHW to HWC for visualization
    input_img = np.transpose(input_img, (1, 2, 0))
    target_img = np.transpose(target_img, (1, 2, 0))
    pred_img = np.transpose(pred_img, (1, 2, 0))

    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(input_img)
    plt.title("Initial Configuration")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(target_img)
    plt.title("Ground Truth Final")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(pred_img)
    plt.title("Predicted Final")
    plt.axis("off")

    plt.show()


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(AttentionBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        attention = self.conv(x)
        attention = self.sigmoid(attention)
        return x * attention.expand_as(x)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class EnhancedUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(EnhancedUNet, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.attention_blocks = nn.ModuleList()

        # Encoder
        for feature in features:
            self.encoder.append(ConvBlock(in_channels, feature))
            in_channels = feature

        # Bottleneck
        self.bottleneck = ConvBlock(features[-1], features[-1]*2)
        bottleneck_size = features[-1]*2  # Store this for metadata integration

        # Metadata integration - adjusted to match bottleneck size
        self.metadata_fc = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(),
            nn.Linear(64, bottleneck_size)  # Match bottleneck feature dimension
        )

        # Decoder with skip connections
        for feature in reversed(features):
            self.decoder.append(
                nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)
            )
            self.decoder.append(ConvBlock(feature*2, feature))
            self.attention_blocks.append(AttentionBlock(feature))

        # Final output
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x, metadata):
        # x shape: [batch_size, channels, height, width]
        # metadata shape: [batch_size, 2] (notch_length, notch_position)

        skip_connections = []

        # Print shape info during first forward pass for debugging
        debug = False

        # Encoder path
        for i, enc in enumerate(self.encoder):
            x = enc(x)
            if debug:
                print(f"Encoder {i} output shape: {x.shape}")
            skip_connections.append(x)
            x = self.pool(x)

        # Bottleneck
        x = self.bottleneck(x)
        if debug:
            print(f"Bottleneck output shape: {x.shape}")

        # Integrate metadata
        batch_size, channels, h, w = x.shape
        metadata_features = self.metadata_fc(metadata)  # [batch_size, bottleneck_size]
        if debug:
            print(f"Metadata features shape: {metadata_features.shape}")

        # Reshape metadata features to match bottleneck spatial dimensions
        metadata_features = metadata_features.view(batch_size, channels, 1, 1).expand(-1, -1, h, w)
        if debug:
            print(f"Reshaped metadata features shape: {metadata_features.shape}")

        # Add metadata features to bottleneck features
        x = x + 0.1 * metadata_features  # Reduced influence with 0.1 scaling factor

        # Decoder path with skip connections
        skip_connections = skip_connections[::-1]  # Reverse for easier access

        for idx in range(0, len(self.decoder), 2):
            x = self.decoder[idx](x)  # Upsample
            if debug:
                print(f"Decoder upsampled {idx} shape: {x.shape}")

            skip = skip_connections[idx//2]
            if debug:
                print(f"Skip connection {idx//2} shape: {skip.shape}")

            # Apply attention to skip connection
            attended_skip = self.attention_blocks[idx//2](skip)

            # Handle different sizes
            if x.shape != attended_skip.shape:
                x = F.interpolate(x, size=attended_skip.shape[2:])
                if debug:
                    print(f"After interpolation shape: {x.shape}")

            concat_skip = torch.cat((attended_skip, x), dim=1)
            if debug:
                print(f"After concat shape: {concat_skip.shape}")

            x = self.decoder[idx+1](concat_skip)
            if debug:
                print(f"Decoder block {idx+1} output shape: {x.shape}")

        # Final output
        return torch.sigmoid(self.final_conv(x))

# Instantiate the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EnhancedUNet(in_channels=3, out_channels=1).to(device)

# Print model summary to check architecture
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model has {count_parameters(model):,} trainable parameters")

In [None]:
#Training

from torch.utils.data import DataLoader, random_split
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Mixed loss function for better crack path prediction
class CrackLoss(nn.Module):
    def __init__(self, dice_weight=0.5, bce_weight=0.5):
        super(CrackLoss, self).__init__()
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight
        self.bce_loss = nn.BCELoss()

    def dice_coef(self, y_pred, y_true, smooth=1.0):
        y_pred_flat = y_pred.view(-1)
        y_true_flat = y_true.view(-1)
        intersection = (y_pred_flat * y_true_flat).sum()
        return (2. * intersection + smooth) / (y_pred_flat.sum() + y_true_flat.sum() + smooth)

    def dice_loss(self, y_pred, y_true):
        return 1 - self.dice_coef(y_pred, y_true)

    def forward(self, y_pred, y_true):
        dice = self.dice_loss(y_pred, y_true)
        bce = self.bce_loss(y_pred, y_true)
        return self.bce_weight * bce + self.dice_weight * dice

# Setup dataset and dataloader
# Ensure these directories exist
input_dir = '/content/data/input'
output_dir = '/content/data/output'
dataset = CrackDataset(input_dir, output_dir, target_size=(256, 256))

# Print dataset size
print(f"Dataset size: {len(dataset)} samples")

# Try fetching a sample to make sure dataset works
try:
    sample = dataset[0]
    print("Successfully loaded a sample from dataset")
    print(f"Input shape: {sample['input'].shape}")
    print(f"Output shape: {sample['output'].shape}")
    print(f"Crack mask shape: {sample['crack_mask'].shape}")
    print(f"Metadata: notch_length={sample['notch_length']}, position={sample['notch_position']}")
except Exception as e:
    print(f"Error loading sample: {e}")

# Split into train and validation sets
dataset_size = len(dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

batch_size = 4  # Reduced from 8 to avoid potential memory issues
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Verify dataloader works
try:
    batch = next(iter(train_loader))
    print("Successfully loaded a batch from dataloader")
    print(f"Batch input shape: {batch['input'].shape}")
    print(f"Batch metadata shape: {torch.stack([batch['notch_length'], batch['notch_position']], dim=1).shape}")
except Exception as e:
    print(f"Error loading batch: {e}")

# Setup model, loss function, optimizer, and scheduler
model = EnhancedUNet(in_channels=3, out_channels=1).to(device)
criterion = CrackLoss(dice_weight=0.7, bce_weight=0.3)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5, verbose=True)

# Training loop with additional error handling
num_epochs = 50
best_val_loss = float('inf')

for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0.0

    for batch_idx, batch in enumerate(train_loader):
        try:
            inputs = batch['input'].to(device)
            crack_masks = batch['crack_mask'].to(device)
            metadata = torch.stack([batch['notch_length'], batch['notch_position']], dim=1).to(device)

            # Print shapes for first batch of first epoch for debugging
            if epoch == 0 and batch_idx == 0:
                print(f"Input shape: {inputs.shape}")
                print(f"Crack mask shape: {crack_masks.shape}")
                print(f"Metadata shape: {metadata.shape}")

            optimizer.zero_grad()
            outputs = model(inputs, metadata)

            if epoch == 0 and batch_idx == 0:
                print(f"Output shape: {outputs.shape}")

            loss = criterion(outputs, crack_masks)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)

            # Print batch progress every 10 batches
            if batch_idx % 10 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            continue

    train_loss /= len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            try:
                inputs = batch['input'].to(device)
                crack_masks = batch['crack_mask'].to(device)
                metadata = torch.stack([batch['notch_length'], batch['notch_position']], dim=1).to(device)

                outputs = model(inputs, metadata)
                loss = criterion(outputs, crack_masks)

                val_loss += loss.item() * inputs.size(0)
            except Exception as e:
                print(f"Error in validation: {e}")
                continue

        val_loss /= len(val_loader.dataset)

    # Update learning rate
    scheduler.step(val_loss)

    # Print status
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Save model if it's the best so far
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_val_loss,
        }, 'best_crack_model.pth')
        print(f"Model saved at epoch {epoch+1} with val_loss: {val_loss:.4f}")

    # Early stopping condition
    if epoch > 10 and optimizer.param_groups[0]['lr'] < 1e-5:
        print("Early stopping due to learning rate reduction")
        break

Dataset size: 49 samples
Successfully loaded a sample from dataset
Input shape: torch.Size([3, 256, 256])
Output shape: torch.Size([3, 256, 256])
Crack mask shape: torch.Size([1, 256, 256])
Metadata: notch_length=0.14000000059604645, position=0.0
Successfully loaded a batch from dataloader
Batch input shape: torch.Size([4, 3, 256, 256])
Batch metadata shape: torch.Size([4, 2])
Input shape: torch.Size([4, 3, 256, 256])
Crack mask shape: torch.Size([4, 1, 256, 256])
Metadata shape: torch.Size([4, 2])
Output shape: torch.Size([4, 1, 256, 256])
Epoch 1, Batch 0/10, Loss: 0.6885
Epoch [1/50], Train Loss: 0.3373, Val Loss: 0.6765
Model saved at epoch 1 with val_loss: 0.6765
Epoch 2, Batch 0/10, Loss: 0.2349
Epoch [2/50], Train Loss: 0.2099, Val Loss: 0.7198
Epoch 3, Batch 0/10, Loss: 0.1833
Epoch [3/50], Train Loss: 0.1650, Val Loss: 0.7776
Epoch 4, Batch 0/10, Loss: 0.1438
Epoch [4/50], Train Loss: 0.1305, Val Loss: 0.7895
Epoch 5, Batch 0/10, Loss: 0.1143
Epoch [5/50], Train Loss: 0.1036, 

In [None]:
# Visualization test
def test_dataset_visualization():
    print("Testing dataset visualization:")
    sample_idx = 0
    sample = dataset[sample_idx]

    import matplotlib.pyplot as plt

    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    input_img = sample['input'].numpy().transpose(1, 2, 0)
    plt.imshow(input_img)
    plt.title(f"Input (NL: {sample['notch_length'].item()*100:.1f}mm, NP: {sample['notch_position'].item()*100:.1f}mm)")
    plt.axis('off')

    plt.subplot(1, 3, 2)
    output_img = sample['output'].numpy().transpose(1, 2, 0)
    plt.imshow(output_img)
    plt.title("Ground Truth Output")
    plt.axis('off')

    plt.subplot(1, 3, 3)
    crack_mask = sample['crack_mask'].numpy()[0]
    plt.imshow(crack_mask, cmap='hot')
    plt.title("Crack Mask")
    plt.axis('off')

    plt.tight_layout()
    plt.show()

# Run visualization test
test_dataset_visualization()