In [None]:
#%%

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
import numpy as np
import os
import gc

# Enable memory efficient optimizations
torch.backends.cudnn.benchmark = True
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

# Memory cleanup function
def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()

# Dataset class remains the same
class FaultDetectionDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None):
        self.dataframe = dataframe
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = self.dataframe.iloc[idx, 0]
        label = self.dataframe.iloc[idx, 1]
        img_path = os.path.join(self.img_dir, img_name)
        
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.float32)

# Memory optimized ResUNet
class ResUNet(nn.Module):
    def __init__(self, base_channels=32):  # Reduced base channels
        super(ResUNet, self).__init__()
        
        # Encoder with reduced channels
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, base_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True)
        )
        
        self.enc2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels*2, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels*2),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels*2, base_channels*2, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels*2),
            nn.ReLU(inplace=True)
        )
        
        self.enc3 = nn.Sequential(
            nn.Conv2d(base_channels*2, base_channels*4, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels*4),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels*4, base_channels*4, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels*4),
            nn.ReLU(inplace=True)
        )
        
        # Decoder
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.dec3 = nn.Sequential(
            nn.Conv2d(base_channels*6, base_channels*2, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels*2),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels*2, base_channels*2, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels*2),
            nn.ReLU(inplace=True)
        )
        
        self.dec2 = nn.Sequential(
            nn.Conv2d(base_channels*3, base_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True)
        )
        
        # Final layers
        self.final = nn.Sequential(
            nn.Conv2d(base_channels, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 1, kernel_size=1)
        )
        
        self.adaptive_pool = nn.AdaptiveAvgPool2d(1)
        
    def forward(self, x):
        # Encoder path with memory clearing
        e1 = self.enc1(x)
        e2 = self.enc2(F.max_pool2d(e1, 2))
        e3 = self.enc3(F.max_pool2d(e2, 2))
        
        # Decoder path
        d3 = self.upsample(e3)
        d3 = torch.cat([d3, e2], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.upsample(d3)
        d2 = torch.cat([d2, e1], dim=1)
        d2 = self.dec2(d2)
        
        out = self.final(d2)
        out = self.adaptive_pool(out)
        return out.view(-1)

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Clear memory every few batches
            if batch_idx % 10 == 0:
                clear_memory()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                
                # Clear memory
                clear_memory()
        
        val_loss = val_loss / len(val_loader)
        val_acc = correct / total
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        
        print(f'Epoch [{epoch+1}/{num_epochs}]')
        print(f'Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
        
        if (epoch + 1) % 5 == 0:
            report = classification_report(all_labels, all_preds, 
                                        target_names=['No Defect', 'Defect'])
            print(f'\nClassification Report:\n{report}\n')
            
        # Clear memory at the end of each epoch
        clear_memory()
    
    return train_losses, val_losses, train_accuracies, val_accuracies

In [None]:
#%%

def main():
    # Set smaller image size
    transform = transforms.Compose([
        transforms.Resize((128, 128)),  # Reduced image size
        transforms.ToTensor(),
    ])
    
    # Load and split data
    df = pd.read_csv('defect_and_no_defect.csv')
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
    
    # Create datasets and dataloaders with smaller batch size
    train_dataset = FaultDetectionDataset(train_df, "train_images", transform)
    val_dataset = FaultDetectionDataset(val_df, "train_images", transform)
    
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)  # Reduced batch size
    val_loader = DataLoader(val_dataset, batch_size=8)  # Reduced batch size
    
    # Model setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ResUNet(base_channels=32).to(device)  # Reduced base channels
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Clear memory before training
    clear_memory()
    
    # Train model
    metrics = train_model(model, train_loader, val_loader, criterion, 
                         optimizer, num_epochs=20, device=device)
    
    # Plot results
    plot_metrics(*metrics)
    
    # Save model
    torch.save(model.state_dict(), 'resunet_model.pth')

if __name__ == "__main__":
    main()

Epoch [1/20]
Train Loss: 0.5870, Train Acc: 0.6922
Val Loss: 0.5228, Val Acc: 0.7292
Epoch [2/20]
Train Loss: 0.5324, Train Acc: 0.7289
Val Loss: 0.4799, Val Acc: 0.7673
Epoch [3/20]
Train Loss: 0.4859, Train Acc: 0.7588
Val Loss: 0.4751, Val Acc: 0.7496
Epoch [4/20]
Train Loss: 0.4640, Train Acc: 0.7709
Val Loss: 0.3988, Val Acc: 0.8135
Epoch [5/20]
Train Loss: 0.4372, Train Acc: 0.7896
Val Loss: 0.4592, Val Acc: 0.7554

Classification Report:
              precision    recall  f1-score   support

   No Defect       0.90      0.52      0.66      1178
      Defect       0.71      0.95      0.81      1422

    accuracy                           0.76      2600
   macro avg       0.80      0.74      0.73      2600
weighted avg       0.79      0.76      0.74      2600


Epoch [6/20]
Train Loss: 0.4120, Train Acc: 0.8028
Val Loss: 0.4112, Val Acc: 0.8004
