## 1. Import Libraries and Setup

## 1.1 Distributed Training Setup

In [None]:
def setup_distributed(rank, world_size):
    """Initialize distributed training environment"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Initialize the process group
    dist.init_process_group(
        backend='nccl',  # Use NCCL for GPU training
        init_method='env://',
        world_size=world_size,
        rank=rank
    )
    torch.cuda.set_device(rank)
    print(f"Rank {rank} initialized on GPU {rank}")

def cleanup_distributed():
    """Clean up distributed training"""
    if dist.is_initialized():
        dist.destroy_process_group()

print("✓ Distributed helper functions defined")

In [1]:
# Import required libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os
import warnings

warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check device availability
gpu_ids = "1,2"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
PyTorch version: 2.1.0+cu118


## 2. Configuration

In [2]:
# Training configuration
image_size = 224
batch_size = 512
num_workers = 16
data_dir = './data'
num_epochs = 20
learning_rate = 0.001
weight_decay = 1e-4

# Selected attributes
selected_attributes = ['Heavy_Makeup', 'Wearing_Lipstick', 'Attractive', 'High_Cheekbones', 'Rosy_Cheeks']
num_attributes = len(selected_attributes)

print(f"Configuration:")
print(f"  Image size: {image_size}x{image_size}")
print(f"  Batch size: {batch_size}")
print(f"  Epochs: {num_epochs}")
print(f"  Learning rate: {learning_rate}")
print(f"  Attributes: {selected_attributes}")

Configuration:
  Image size: 224x224
  Batch size: 512
  Epochs: 20
  Learning rate: 0.001
  Attributes: ['Heavy_Makeup', 'Wearing_Lipstick', 'Attractive', 'High_Cheekbones', 'Rosy_Cheeks']


## 3. Data Transforms and Loading

In [3]:
# Data transforms with augmentation for training
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(image_size),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Validation transform (no augmentation)
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

print("✓ Data transforms defined")

✓ Data transforms defined


In [4]:
# Load CelebA dataset
print("Loading CelebA dataset...")

try:
    train_dataset = datasets.CelebA(
        root=data_dir,
        split='train',
        transform=train_transform,
        download=False,
        target_type='attr'
    )
    
    val_dataset = datasets.CelebA(
        root=data_dir,
        split='valid',
        transform=val_transform,
        download=False,
        target_type='attr'
    )
    
    print(f"✓ Training samples: {len(train_dataset)}")
    print(f"✓ Validation samples: {len(val_dataset)}")
    
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Please ensure CelebA dataset is in the data directory")
    raise

# Get attribute names and indices
attribute_names = [name for name in train_dataset.attr_names if name.strip()]
attribute_indices = [attribute_names.index(attr) for attr in selected_attributes]
print(f"\nAttribute indices: {attribute_indices}")

Loading CelebA dataset...
✓ Training samples: 162770
✓ Validation samples: 19867

Attribute indices: [18, 36, 2, 19, 29]


In [5]:
# Replace the AttributeFilterDataset cell with this:
from dataset import AttributeFilterDataset

# Wrap datasets
train_dataset = AttributeFilterDataset(train_dataset, attribute_indices)
val_dataset = AttributeFilterDataset(val_dataset, attribute_indices)

print("✓ Dataset filtering applied")

✓ Dataset filtering applied


In [6]:
# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True  # Keeps workers alive between epochs
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True
)

print(f"✓ DataLoaders created")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")

✓ DataLoaders created
  Training batches: 318
  Validation batches: 39


## 4. Define ResNet18 Model

In [7]:
# Basic block for ResNet18
class BasicBlock(nn.Module):
    """Basic residual block with skip connection"""
    
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        
        # First conv layer
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                              stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        # Second conv layer
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                              stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.relu = nn.ReLU(inplace=True)
        
        # Skip connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1,
                         stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += self.shortcut(identity)
        out = self.relu(out)
        
        return out

print("✓ BasicBlock defined")

✓ BasicBlock defined


In [8]:
# ResNet18 for multi-label classification
class ResNet18MultiLabel(nn.Module):
    """ResNet18 architecture for multi-label classification"""
    
    def __init__(self, num_classes=5):
        super(ResNet18MultiLabel, self).__init__()
        
        self.in_channels = 64
        
        # Initial convolution
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # ResNet layers (2, 2, 2, 2 blocks)
        self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)
        self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
        self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)
        self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)
        
        # Classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
        
        # Initialize weights
        self._initialize_weights()

    def _make_layer(self, block, out_channels, blocks, stride=1):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))
        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x

print("✓ ResNet18MultiLabel defined")

✓ ResNet18MultiLabel defined


## 5. Create Model and Training Setup

In [9]:
# Create model
model = ResNet18MultiLabel(num_classes=num_attributes)
model = model.to(device)

# Loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=0.5,
    patience=2,
    verbose=True,
    min_lr=1e-7
)

print(f"✓ Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"✓ Loss: BCEWithLogitsLoss")
print(f"✓ Optimizer: Adam (lr={learning_rate}, weight_decay={weight_decay})")
print(f"✓ Scheduler: ReduceLROnPlateau")

✓ Model created with 11,179,077 parameters
✓ Loss: BCEWithLogitsLoss
✓ Optimizer: Adam (lr=0.001, weight_decay=0.0001)
✓ Scheduler: ReduceLROnPlateau


## 6. Training Functions

In [None]:
# Start distributed training
world_size = torch.cuda.device_count()  # Number of GPUs
if world_size < 2:
    print("Warning: Only 1 GPU detected. DDP works best with 2+ GPUs.")
    print("Continuing anyway...")
    world_size = max(1, world_size)

print(f"\nStarting distributed training on {world_size} GPUs...")
print("="*80)

# Spawn processes for distributed training
mp.spawn(
    main_worker,
    args=(world_size, train_dataset, val_dataset),
    nprocs=world_size,
    join=True
)

print("\n✓ All training processes completed!")

In [10]:
def train_epoch(model, train_loader, criterion, optimizer, device, epoch):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    all_predictions = []
    all_targets = []

    for batch_idx, (data, target) in enumerate(train_loader):
        # Move to device
        data = data.to(device)
        target = target.to(device).float()

        # Forward pass
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Track metrics
        running_loss += loss.item()
        predictions = torch.sigmoid(output) > 0.5
        all_predictions.append(predictions.cpu())
        all_targets.append(target.cpu())

        # Progress
        print(f'\rEpoch {epoch}: [{batch_idx+1}/{len(train_loader)}] Loss: {loss.item():.4f}', end='', flush=True)

    # Calculate metrics
    epoch_loss = running_loss / len(train_loader)
    all_predictions = torch.cat(all_predictions).numpy()
    all_targets = torch.cat(all_targets).numpy()
    epoch_acc = (all_predictions == all_targets).mean()

    return epoch_loss, epoch_acc

print("✓ train_epoch function defined")

✓ train_epoch function defined


In [11]:
def validate_epoch(model, val_loader, criterion, device):
    """Validate the model"""
    model.eval()
    val_loss = 0.0
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for data, target in val_loader:
            data = data.to(device)
            target = target.to(device).float()

            output = model(data)
            loss = criterion(output, target)

            val_loss += loss.item()
            predictions = torch.sigmoid(output) > 0.5
            all_predictions.append(predictions.cpu())
            all_targets.append(target.cpu())

    # Calculate metrics
    val_loss /= len(val_loader)
    all_predictions = torch.cat(all_predictions).numpy()
    all_targets = torch.cat(all_targets).numpy()
    val_acc = (all_predictions == all_targets).mean()

    return val_loss, val_acc

print("✓ validate_epoch function defined")

✓ validate_epoch function defined


## 7. Train the Model

In [12]:
# Training loop
print("\nStarting training...")
print("="*80)

train_losses = []
train_accs = []
val_losses = []
val_accs = []
best_val_acc = 0.0

for epoch in range(1, num_epochs + 1):
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, epoch)
    
    # Validate
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step(val_acc)
    
    # Store metrics
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    # Print epoch summary
    print(f'\n{"-"*80}')
    print(f'Epoch {epoch}/{num_epochs}:')
    print(f'  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}')
    print(f'  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}')
    print(f'  LR: {optimizer.param_groups[0]["lr "]:.6f}')
    print(f'{"-"*80}')
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_celeba_resnet18.pth')
        print(f'  ✓ New best model saved with Val Acc: {best_val_acc:.4f}')
        
print("Training complete.")


Starting training...


OutOfMemoryError: CUDA out of memory. Tried to allocate 392.00 MiB. GPU 0 has a total capacty of 21.97 GiB of which 36.50 MiB is free. Process 3374839 has 17.17 GiB memory in use. Including non-PyTorch memory, this process has 4.76 GiB memory in use. Of the allocated memory 4.54 GiB is allocated by PyTorch, and 21.29 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF