<a href="https://colab.research.google.com/github/ssudhanshu488/DANN_Implementation/blob/main/DANN_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
import numpy as np

In [2]:
# Gradient Reversal Layer
class GradientReversal(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambda_ * grad_output, None

class GradientReversalLayer(nn.Module):
    def __init__(self, lambda_=1.0):
        super(GradientReversalLayer, self).__init__()
        self.lambda_ = lambda_

    def forward(self, x):
        return GradientReversal.apply(x, self.lambda_)

In [3]:
# DANN Model
class DANN(nn.Module):
    def __init__(self, num_classes=65):
        super(DANN, self).__init__()
        # Feature extractor (ResNet-50 without the final layer)
        self.feature_extractor = torchvision.models.resnet50(weights='IMAGENET1K_V2')
        self.feature_dim = self.feature_extractor.fc.in_features
        self.feature_extractor.fc = nn.Identity()  # Remove the final FC layer

        # Task classifier
        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

        # Domain discriminator
        self.domain_discriminator = nn.Sequential(
            GradientReversalLayer(),
            nn.Linear(self.feature_dim, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        class_output = self.classifier(features)
        domain_output = self.domain_discriminator(features)
        return class_output, domain_output

In [4]:
# Custom Dataset for Office-Home
class OfficeHomeDataset(Dataset):
    def __init__(self, root_dir, domain, transform=None):
        self.root_dir = os.path.join(root_dir, domain)
        self.transform = transform
        self.classes = sorted(os.listdir(self.root_dir))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.images = []
        self.labels = []

        for cls_name in self.classes:
            cls_dir = os.path.join(self.root_dir, cls_name)
            for img_name in os.listdir(cls_dir):
                self.images.append(os.path.join(cls_dir, img_name))
                self.labels.append(self.class_to_idx[cls_name])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

In [5]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("karntiwari/home-office-dataset")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/karntiwari/home-office-dataset?dataset_version_number=1...


100%|██████████| 982M/982M [00:48<00:00, 21.4MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/karntiwari/home-office-dataset/versions/1


In [6]:
ls "/root/.cache/kagglehub/datasets/karntiwari/home-office-dataset/versions/1/OfficeHomeDataset_10072016"

 [0m[01;34mArt[0m/   [01;34mClipart[0m/   ImageInfo.csv   imagelist.txt  [01;34m'Real World'[0m/


In [7]:
# Data loading
def get_dataloaders(source_domain, target_domain, data_dir, batch_size=32):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    source_dataset = OfficeHomeDataset(data_dir, source_domain, transform=transform)
    target_dataset = OfficeHomeDataset(data_dir, target_domain, transform=transform)

    source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    target_loader = DataLoader(target_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    return source_loader, target_loader

In [8]:
# Training function
def train_dann(model, source_loader, target_loader, optimizer, num_epochs, device):
    class_criterion = nn.CrossEntropyLoss()
    domain_criterion = nn.BCELoss()

    for epoch in range(num_epochs):
        model.train()
        len_dataloader = min(len(source_loader), len(target_loader))
        data_zip = zip(source_loader, target_loader)

        total_class_loss = 0
        total_domain_loss = 0
        total_correct = 0
        total_samples = 0

        # Dynamic lambda for gradient reversal
        p = float(epoch) / num_epochs
        lambda_ = 2. / (1. + np.exp(-10 * p)) - 1

        for i, ((source_data, source_labels), (target_data, _)) in enumerate(data_zip):
            source_data, source_labels = source_data.to(device), source_labels.to(device)
            target_data = target_data.to(device)

            # Set domain labels (0 for source, 1 for target)
            source_domain_labels = torch.zeros(source_data.size(0), 1).to(device)
            target_domain_labels = torch.ones(target_data.size(0), 1).to(device)

            # Update gradient reversal lambda
            model.domain_discriminator[0].lambda_ = lambda_

            # Forward pass
            optimizer.zero_grad()

            # Source data
            class_output, domain_output = model(source_data)
            class_loss = class_criterion(class_output, source_labels)
            domain_loss_source = domain_criterion(domain_output, source_domain_labels)

            # Target data
            _, domain_output = model(target_data)
            domain_loss_target = domain_criterion(domain_output, target_domain_labels)

            # Total loss
            total_loss = class_loss + domain_loss_source + domain_loss_target

            # Backward and optimize
            total_loss.backward()
            optimizer.step()

            # Statistics
            total_class_loss += class_loss.item()
            total_domain_loss += (domain_loss_source.item() + domain_loss_target.item())

            _, predicted = torch.max(class_output, 1)
            total_correct += (predicted == source_labels).sum().item()
            total_samples += source_labels.size(0)

        # Print epoch statistics
        avg_class_loss = total_class_loss / len_dataloader
        avg_domain_loss = total_domain_loss / (2 * len_dataloader)
        accuracy = 100. * total_correct / total_samples

        print(f'Epoch [{epoch+1}/{num_epochs}]')
        print(f'Class Loss: {avg_class_loss:.4f}, Domain Loss: {avg_domain_loss:.4f}, Source Accuracy: {accuracy:.2f}%')

In [9]:
# Evaluation function
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, labels in loader:
            data, labels = data.to(device), labels.to(device)
            class_output, _ = model(data)
            _, predicted = torch.max(class_output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100. * correct / total
    return accuracy

In [10]:
# Main execution
def main():
    # Configuration
    data_dir = '/root/.cache/kagglehub/datasets/karntiwari/home-office-dataset/versions/1/OfficeHomeDataset_10072016'  # Update with actual path
    source_domain = 'Clipart'
    target_domain = 'Real World'
    batch_size = 32
    num_epochs = 50
    lr = 0.001
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Data loaders
    source_loader, target_loader = get_dataloaders(source_domain, target_domain, data_dir, batch_size)

    # Model, optimizer
    model = DANN(num_classes=65).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Train
    print(f'Training DANN: {source_domain} -> {target_domain}')
    train_dann(model, source_loader, target_loader, optimizer, num_epochs, device)

    # Evaluate
    print('\nEvaluating on source domain...')
    source_acc = evaluate(model, source_loader, device)
    print(f'Source ({source_domain}) Accuracy: {source_acc:.2f}%')

    print('Evaluating on target domain...')
    target_acc = evaluate(model, target_loader, device)
    print(f'Target ({target_domain}) Accuracy: {target_acc:.2f}%')

if __name__ == '__main__':
    main()

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 188MB/s]


Training DANN: Clipart -> Real World
Epoch [1/50]
Class Loss: 2.9525, Domain Loss: 0.6589, Source Accuracy: 28.22%
Epoch [2/50]
Class Loss: 1.7517, Domain Loss: 0.6577, Source Accuracy: 54.50%
Epoch [3/50]
Class Loss: 1.1531, Domain Loss: 0.6158, Source Accuracy: 67.93%
Epoch [4/50]
Class Loss: 0.8442, Domain Loss: 0.6181, Source Accuracy: 76.43%
Epoch [5/50]
Class Loss: 0.6722, Domain Loss: 0.6164, Source Accuracy: 80.48%
Epoch [6/50]
Class Loss: 0.5402, Domain Loss: 0.6274, Source Accuracy: 84.22%
Epoch [7/50]
Class Loss: 0.4616, Domain Loss: 0.6537, Source Accuracy: 86.48%
Epoch [8/50]
Class Loss: 0.4893, Domain Loss: 0.6753, Source Accuracy: 85.54%
Epoch [9/50]
Class Loss: 0.3679, Domain Loss: 0.6874, Source Accuracy: 89.12%
Epoch [10/50]
Class Loss: 0.3408, Domain Loss: 0.6679, Source Accuracy: 90.03%
Epoch [11/50]
Class Loss: 0.3823, Domain Loss: 0.6798, Source Accuracy: 89.58%
Epoch [12/50]
Class Loss: 0.2962, Domain Loss: 0.6841, Source Accuracy: 92.12%
Epoch [13/50]
Class Loss