# Insulator Segmentation with U-Net + ResNet34

This notebook implements semantic segmentation for insulator detection in power line images.

**Architecture**: U-Net with pretrained ResNet34 encoder  
**Expected Performance**: 0.88-0.92 Dice coefficient  
**Training Time**: ~13 hours on 2Ã—Tesla T4 GPUs

## 1. Imports and Setup

In [None]:
import os
import gc
import cv2
import numpy as np
from pathlib import Path
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

import albumentations as A
from albumentations.pytorch import ToTensorV2

def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'GPU Count: {torch.cuda.device_count()}')

## 2. Configuration

In [None]:
class Config:
    DATA_DIR = Path('data/train')
    TEST_DIR = Path('data/test')
    WORK_DIR = Path('outputs')
    CHECKPOINT_DIR = WORK_DIR / 'checkpoints'
    PREDICTION_DIR = WORK_DIR / 'predictions'
    
    CHECKPOINT_DIR.mkdir(exist_ok=True, parents=True)
    PREDICTION_DIR.mkdir(exist_ok=True, parents=True)
    
    ENCODER = 'resnet34'
    IMG_SIZE = 512
    BATCH_SIZE = 16
    NUM_EPOCHS = 40
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    NUM_WORKERS = 2
    VAL_SPLIT = 0.15
    EARLY_STOPPING_PATIENCE = 7
    USE_TTA = True
    OPTIMIZE_THRESHOLD = True
    SAVE_EVERY_N_EPOCHS = 5

cfg = Config()
print(f'Configuration: {cfg.NUM_EPOCHS} epochs, batch size {cfg.BATCH_SIZE}, TTA={cfg.USE_TTA}')

## 3. Model Architecture

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

class UNetResNet(nn.Module):
    def __init__(self, encoder_name='resnet34', pretrained=True):
        super().__init__()
        
        import torchvision.models as models
        if encoder_name == 'resnet34':
            resnet = models.resnet34(pretrained=pretrained)
        else:
            resnet = models.resnet50(pretrained=pretrained)
        
        self.enc0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)
        self.enc1 = nn.Sequential(resnet.maxpool, resnet.layer1)
        self.enc2 = resnet.layer2
        self.enc3 = resnet.layer3
        self.enc4 = resnet.layer4
        
        self.up1 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv1 = ConvBlock(512, 256)
        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv2 = ConvBlock(256, 128)
        
        self.up3 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv3 = ConvBlock(128, 64)
        
        self.up4 = nn.ConvTranspose2d(64, 64, 2, stride=2)
        self.conv4 = ConvBlock(128, 64)
        
        self.up5 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.conv5 = ConvBlock(32, 32)
        
        self.final = nn.Conv2d(32, 1, 1)
    
    def forward(self, x):
        e0 = self.enc0(x)
        e1 = self.enc1(e0)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        
        x = self.up1(e4)
        x = torch.cat([x, e3], dim=1)
        x = self.conv1(x)
        
        x = self.up2(x)
        x = torch.cat([x, e2], dim=1)
        x = self.conv2(x)
        
        x = self.up3(x)
        x = torch.cat([x, e1], dim=1)
        x = self.conv3(x)
        
        x = self.up4(x)
        x = torch.cat([x, e0], dim=1)
        x = self.conv4(x)
        
        x = self.up5(x)
        x = self.conv5(x)
        
        return self.final(x)

print('Model architecture defined')

## 4. Data Loading

In [None]:
def get_file_paths(data_dir, test_dir):
    train_img_dir = data_dir / 'images'
    train_mask_dir = data_dir / 'masks'
    test_img_dir = test_dir / 'images'
    
    train_images = sorted(list(train_img_dir.glob('*.jpg')))
    train_masks = sorted(list(train_mask_dir.glob('*.png')))
    test_images = sorted(list(test_img_dir.glob('*.jpg')))
    
    print(f'Train: {len(train_images)} images, {len(train_masks)} masks')
    print(f'Test: {len(test_images)} images')
    
    assert len(train_images) == len(train_masks), 'Mismatch between images and masks'
    return train_images, train_masks, test_images

train_images, train_masks, test_images = get_file_paths(cfg.DATA_DIR, cfg.TEST_DIR)

train_imgs, val_imgs, train_msks, val_msks = train_test_split(
    train_images, train_masks, test_size=cfg.VAL_SPLIT, random_state=42
)

print(f'Split: {len(train_imgs)} train, {len(val_imgs)} val')

## 5. Data Augmentation

In [None]:
def get_train_transforms(img_size=512):
    return A.Compose([
        A.Resize(img_size, img_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.15, rotate_limit=45, p=0.6),
        A.OneOf([A.RandomBrightnessContrast(p=1), A.HueSaturationValue(p=1)], p=0.5),
        A.OneOf([A.GaussNoise(p=1), A.GaussianBlur(p=1)], p=0.3),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

def get_val_transforms(img_size=512):
    return A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

print('Augmentation pipelines defined')

## 6. Dataset and DataLoaders

In [None]:
class InsulatorDataset(Dataset):
    def __init__(self, image_paths, mask_paths=None, transforms=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.is_test = mask_paths is None
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.is_test:
            if self.transforms:
                image = self.transforms(image=image)['image']
            return image, str(img_path.name)
        else:
            mask_path = self.mask_paths[idx]
            mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
            mask = (mask > 127).astype(np.float32)
            
            if self.transforms:
                augmented = self.transforms(image=image, mask=mask)
                image = augmented['image']
                mask = augmented['mask']
            
            mask = mask.unsqueeze(0)
            return image, mask

train_dataset = InsulatorDataset(train_imgs, train_msks, get_train_transforms(cfg.IMG_SIZE))
val_dataset = InsulatorDataset(val_imgs, val_msks, get_val_transforms(cfg.IMG_SIZE))
test_dataset = InsulatorDataset(test_images, None, get_val_transforms(cfg.IMG_SIZE))

train_loader = DataLoader(train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True, 
                         num_workers=cfg.NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False, 
                       num_workers=cfg.NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False, 
                        num_workers=cfg.NUM_WORKERS, pin_memory=True)

print(f'DataLoaders created: {len(train_loader)} train batches, {len(val_loader)} val batches, {len(test_loader)} test batches')

## 7. Loss Function and Metrics

In [None]:
model = UNetResNet(encoder_name=cfg.ENCODER, pretrained=True)

if torch.cuda.device_count() > 1:
    print(f'Using {torch.cuda.device_count()} GPUs')
    model = nn.DataParallel(model)

model = model.to(device)

class DiceBCELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
    
    def forward(self, logits, targets):
        bce = self.bce(logits, targets)
        probs = torch.sigmoid(logits)
        intersection = (probs * targets).sum(dim=(2, 3))
        union = probs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        return 0.5 * bce + 0.5 * (1 - dice.mean())

def dice_coefficient(pred, target, threshold=0.5):
    pred = (pred > threshold).float()
    intersection = (pred * target).sum(dim=(2, 3))
    union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
    return ((2. * intersection + 1e-6) / (union + 1e-6)).mean()

criterion = DiceBCELoss()
print('Model and loss function initialized')

## 8. Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, scaler, device):
    model.train()
    total_loss, total_dice = 0, 0
    
    for images, masks in tqdm(loader, desc='Training'):
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        
        with autocast():
            logits = model(images)
            loss = criterion(logits, masks)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        with torch.no_grad():
            dice = dice_coefficient(torch.sigmoid(logits), masks)
        
        total_loss += loss.item()
        total_dice += dice.item()
    
    return total_loss / len(loader), total_dice / len(loader)

def validate_epoch(model, loader, criterion, device):
    model.eval()
    total_loss, total_dice = 0, 0
    
    with torch.no_grad():
        for images, masks in tqdm(loader, desc='Validation'):
            images, masks = images.to(device), masks.to(device)
            logits = model(images)
            loss = criterion(logits, masks)
            dice = dice_coefficient(torch.sigmoid(logits), masks)
            total_loss += loss.item()
            total_dice += dice.item()
    
    return total_loss / len(loader), total_dice / len(loader)

print('Training functions defined')

## 9. Checkpoint Manager

In [None]:
class CheckpointManager:
    def __init__(self, checkpoint_dir):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(exist_ok=True, parents=True)
        self.best_score = -np.inf
    
    def save(self, model, optimizer, scaler, epoch, metrics, is_best=False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'metrics': metrics,
        }
        torch.save(checkpoint, self.checkpoint_dir / 'latest.pth')
        if is_best:
            torch.save(checkpoint, self.checkpoint_dir / 'best.pth')
            self.best_score = metrics['val_dice']
            print(f"Saved best model: {metrics['val_dice']:.4f}")
    
    def load(self, model, optimizer=None, scaler=None, name='best.pth'):
        path = self.checkpoint_dir / name
        if not path.exists():
            return None
        checkpoint = torch.load(path, map_location=device)
        if hasattr(model, 'module'):
            model.module.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint['model_state_dict'])
        if optimizer:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scaler:
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
        print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
        return checkpoint

ckpt_mgr = CheckpointManager(cfg.CHECKPOINT_DIR)
print('Checkpoint manager initialized')

## 10. Training Loop

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=cfg.LEARNING_RATE, weight_decay=cfg.WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=15, T_mult=2, eta_min=1e-6)
scaler = GradScaler()

history = {'train_loss': [], 'train_dice': [], 'val_loss': [], 'val_dice': []}
best_val_dice, patience = -np.inf, 0

print('Starting training...')

for epoch in range(1, cfg.NUM_EPOCHS + 1):
    print(f'\nEpoch {epoch}/{cfg.NUM_EPOCHS}')
    
    train_loss, train_dice = train_epoch(model, train_loader, criterion, optimizer, scaler, device)
    val_loss, val_dice = validate_epoch(model, val_loader, criterion, device)
    scheduler.step()
    
    history['train_loss'].append(train_loss)
    history['train_dice'].append(train_dice)
    history['val_loss'].append(val_loss)
    history['val_dice'].append(val_dice)
    
    print(f'Train - Loss: {train_loss:.4f}, Dice: {train_dice:.4f}')
    print(f'Val   - Loss: {val_loss:.4f}, Dice: {val_dice:.4f}')
    
    is_best = val_dice > best_val_dice
    
    if epoch % cfg.SAVE_EVERY_N_EPOCHS == 0 or is_best:
        ckpt_mgr.save(model, optimizer, scaler, epoch, 
                     {'train_loss': train_loss, 'train_dice': train_dice,
                      'val_loss': val_loss, 'val_dice': val_dice}, is_best)
    
    if is_best:
        best_val_dice = val_dice
        patience = 0
    else:
        patience += 1
        if patience >= cfg.EARLY_STOPPING_PATIENCE:
            print(f'Early stopping triggered after {epoch} epochs')
            break
    
    gc.collect()
    torch.cuda.empty_cache()

print(f'\nTraining complete. Best validation Dice: {best_val_dice:.4f}')

## 11. Threshold Optimization

In [None]:
if cfg.OPTIMIZE_THRESHOLD:
    print('Optimizing threshold...')
    ckpt_mgr.load(model, name='best.pth')
    model.eval()
    
    best_threshold, best_score = 0.5, 0
    
    for thresh in np.arange(0.3, 0.7, 0.05):
        total = 0
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                probs = torch.sigmoid(model(images))
                total += dice_coefficient(probs, masks, thresh).item()
        
        avg = total / len(val_loader)
        print(f'Threshold {thresh:.2f}: Dice = {avg:.4f}')
        
        if avg > best_score:
            best_score = avg
            best_threshold = thresh
    
    print(f'\nBest threshold: {best_threshold:.2f} (Dice: {best_score:.4f})')
else:
    best_threshold = 0.5
    print(f'Using default threshold: {best_threshold}')

## 12. Generate Test Predictions with TTA

In [None]:
def predict_with_tta(model, image, device):
    model.eval()
    
    with torch.no_grad():
        pred1 = torch.sigmoid(model(image.to(device)))
        pred2 = torch.sigmoid(model(torch.flip(image, [3]).to(device)))
        pred2 = torch.flip(pred2, [3])
        pred3 = torch.sigmoid(model(torch.flip(image, [2]).to(device)))
        pred3 = torch.flip(pred3, [2])
        pred4 = torch.sigmoid(model(torch.flip(image, [2, 3]).to(device)))
        pred4 = torch.flip(pred4, [2, 3])
    
    return (pred1 + pred2 + pred3 + pred4) / 4

ckpt_mgr.load(model, name='best.pth')
model.eval()

print('Generating test predictions...')

for images, filenames in tqdm(test_loader, desc='Predicting'):
    if cfg.USE_TTA:
        preds = predict_with_tta(model, images, device)
    else:
        with torch.no_grad():
            preds = torch.sigmoid(model(images.to(device)))
    
    preds = (preds.cpu().numpy() > best_threshold).astype(np.uint8)
    
    for pred, filename in zip(preds, filenames):
        mask = pred[0]
        output_path = cfg.PREDICTION_DIR / filename.replace('.jpg', '.png')
        cv2.imwrite(str(output_path), mask * 255)

print('Predictions saved')

## 13. Create Submission Archive

In [None]:
import shutil

print('Creating submission archive...')
shutil.make_archive(str(cfg.WORK_DIR / 'submission'), 'zip', cfg.PREDICTION_DIR)

print('\nTraining Summary:')
print(f'Best Validation Dice: {best_val_dice:.4f}')
print(f'Optimal Threshold: {best_threshold:.2f}')
print(f'Submission file: {cfg.WORK_DIR / "submission.zip"}')

## 14. Visualization

In [None]:
ckpt_mgr.load(model, name='best.pth')
model.eval()

test_img = test_images[0]
image = cv2.cvtColor(cv2.imread(str(test_img)), cv2.COLOR_BGR2RGB)
aug = get_val_transforms(cfg.IMG_SIZE)(image=image)
img_t = aug['image'].unsqueeze(0).to(device)

with torch.no_grad():
    pred = torch.sigmoid(model(img_t)).cpu().numpy()[0, 0]

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(image)
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(pred, cmap='hot', vmin=0, vmax=1)
axes[1].set_title('Prediction Probabilities')
axes[1].axis('off')

axes[2].imshow(pred > best_threshold, cmap='gray')
axes[2].set_title(f'Binary Mask (threshold={best_threshold:.2f})')
axes[2].axis('off')

plt.tight_layout()
plt.savefig(cfg.WORK_DIR / 'prediction_example.png', dpi=150, bbox_inches='tight')
plt.show()