# Prior4WeatherDetection Training Test

This notebook tests the Prior-aware Adversarial Domain Adaptation for Object Detection under Adverse Weather.

In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from models.detection_network import DomainAdaptiveFasterRCNN
from utils.data.cityscapes_clean_dataset import Cityscapes_Clean_Dataset, cityscapes_clean_dataset_collate_fn
from utils.data.cityscapes_foggy_dataset import Cityscapes_Foggy_Dataset, cityscapes_foggy_dataset_collate_fn

In [2]:
# Initialize model
model = DomainAdaptiveFasterRCNN(num_classes=10, backbone_name='vgg16')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)



DomainAdaptiveFasterRCNN(
  (backbone): CustomVGGBackbone(
    (backbone_c4): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      

In [3]:
# Optimizer (SGD as in paper)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)

In [None]:
# Datasets
source_dataset = Cityscapes_Clean_Dataset("/teamspace/studios/this_studio/Prior4WeatherDetection/datataset/cityscapes")  # Clean images + labels
target_dataset = Cityscapes_Foggy_Dataset("/teamspace/studios/this_studio/Prior4WeatherDetection/datataset/cityscapes")  # Hazy/Rainy images + priors

Found 2975 images in train split.
Found 5676 foggy images in train split for beta levels [0.01, 0.02, 0.05].


In [5]:
# Use dataset-specific collate functions
target_loader = DataLoader(
    target_dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=cityscapes_foggy_dataset_collate_fn
)

source_loader = DataLoader(
    source_dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=cityscapes_clean_dataset_collate_fn
)

In [None]:
# Testing a single batch pass
model.train()

for source_batch, target_batch in zip(source_loader, target_loader):
    # Unpack source batch
    source_images, source_prior_images, source_targets = source_batch
    source_images = [img.to(device) for img in source_images]
    source_prior_images = [img.to(device) for img in source_prior_images]
    source_targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in source_targets]
    
    # Convert lists to tensors for batch processing
    source_images_tensor = torch.stack(source_images)
    source_prior_images_tensor = torch.stack(source_prior_images)

    # Unpack target batch
    target_images, target_prior_images, target_targets = target_batch
    target_images = [img.to(device) for img in target_images]
    target_prior_images = [img.to(device) for img in target_prior_images]
    target_targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in target_targets]
    
    # Convert lists to tensors for batch processing
    target_images_tensor = torch.stack(target_images)
    target_prior_images_tensor = torch.stack(target_prior_images)

    # Print batch sizes
    print("Source batch size:", len(source_images))
    print("Target batch size:", len(target_images))
    
    # Forward pass - source domain
    try:
        source_losses = model(source_images_tensor, source_prior_images_tensor, source_targets)
        print("Source domain forward pass successful!")
        print("Source losses:", source_losses)
    except Exception as e:
        print(f"Error in source domain forward pass: {e}")
    
    # Forward pass - target domain
    try:
        target_losses = model(target_images_tensor, target_prior_images_tensor, target_targets)
        print("Target domain forward pass successful!")
        print("Target losses:", target_losses)
    except Exception as e:
        print(f"Error in target domain forward pass: {e}")
    
    break

Source batch size: 2
Target batch size: 2


## Complete Training Loop

In [None]:
# Training loop
def train_epoch(model, source_loader, target_loader, optimizer, device, epoch):
    model.train()
    source_epoch_loss = 0.0
    target_epoch_loss = 0.0
    num_batches = 0
    
    # Alternate between source and target batches
    for source_batch, target_batch in zip(source_loader, target_loader):
        # ----- Source Domain Training -----
        # Unpack source batch
        source_images, source_prior_images, source_targets = source_batch
        source_images = [img.to(device) for img in source_images]
        source_prior_images = [img.to(device) for img in source_prior_images]
        source_targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in source_targets]
        
        # Convert lists to tensors
        source_images_tensor = torch.stack(source_images)
        source_prior_images_tensor = torch.stack(source_prior_images)
        
        # Forward pass for source domain
        source_losses = model(source_images_tensor, source_prior_images_tensor, source_targets)
        
        # Calculate total source loss
        source_loss = source_losses['loss_classifier'] + source_losses['loss_box_reg'] + \
                    source_losses['loss_objectness'] + source_losses['loss_rpn_box_reg']
        
        # Backward pass
        optimizer.zero_grad()
        source_loss.backward()
        optimizer.step()
        
        source_epoch_loss += source_loss.item()
        
        # ----- Target Domain Training -----
        # Unpack target batch
        target_images, target_prior_images, target_targets = target_batch
        target_images = [img.to(device) for img in target_images]
        target_prior_images = [img.to(device) for img in target_prior_images]
        target_targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in target_targets]
        
        # Convert lists to tensors
        target_images_tensor = torch.stack(target_images)
        target_prior_images_tensor = torch.stack(target_prior_images)
        
        # Forward pass for target domain
        target_losses = model(target_images_tensor, target_prior_images_tensor, target_targets)
        
        # Calculate total target loss
        target_loss = 0.0
        if 'loss_pal' in target_losses:
            target_loss += target_losses['loss_pal']
        if 'loss_reg' in target_losses:
            target_loss += target_losses['loss_reg']
        
        # Backward pass
        optimizer.zero_grad()
        target_loss.backward()
        optimizer.step()
        
        target_epoch_loss += target_loss.item()
        num_batches += 1
    
    avg_source_loss = source_epoch_loss / num_batches if num_batches > 0 else 0
    avg_target_loss = target_epoch_loss / num_batches if num_batches > 0 else 0
    
    return avg_source_loss, avg_target_loss

In [None]:
# Run training for a small number of epochs to test
num_epochs = 2  # Set a small number for testing

for epoch in range(num_epochs):
    # Train for one epoch
    source_loss, target_loss = train_epoch(
        model, source_loader, target_loader, optimizer, device, epoch
    )
    
    # Print epoch results
    print(f"Epoch [{epoch+1}/{num_epochs}] Source Loss: {source_loss:.4f}, Target Loss: {target_loss:.4f}")
    
    # Adjust learning rate (after 50K iterations in the original paper)
    if epoch == num_epochs // 2:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.1
            print(f"Adjusted learning rate to {param_group['lr']}")