# Domain Adversarial Networks (DAN / DANN)

## 1. Definition

A **Domain Adversarial Network (DANN)** is a neural network designed to learn **domain-invariant features** so that a model trained on a **source domain** generalizes to a **target domain** without labels.

Given:

- Source data: $(x_s, y_s) \sim p_s(x, y)$  
- Target data: $x_t \sim p_t(x)$ (no labels)  
- Feature extractor: $F(x)$  
- Label predictor: $C(F(x))$  
- Domain discriminator: $D(F(x))$

The idea is:

- $C$ is trained to predict labels on the source domain.  
- $D$ is trained to distinguish whether features come from source or target.  
- $F$ is trained **adversarially** to fool $D$ so that source and target features become indistinguishable.

Thus $F$ becomes **domain-invariant**.

---

## 2. Objective Function

Two goals:

1. **Label classification loss** on source domain  
   $ \mathcal{L}_{cls} = \mathbb{E}_{(x_s,y_s)}\left[ \ell( C(F(x_s)),\, y_s ) \right] $

2. **Domain adversarial loss** using binary classification  
   $ \mathcal{L}_{dom} = \mathbb{E}_{x_s}[\log D(F(x_s))] + \mathbb{E}_{x_t}[\log (1 - D(F(x_t)))] $

Final DANN objective:

$ \min_{F,C} \max_D \; \mathcal{L}_{cls} - \lambda \mathcal{L}_{dom} $

Here:

- $D$ maximizes $\mathcal{L}_{dom}$  
- $F$ minimizes $\mathcal{L}_{cls}$ and **minimizes** $(-\mathcal{L}_{dom})$, i.e., it **fools $D$**  
- $\lambda$ controls adaptation strength  

---

## 3. Derivation: Why Adversarial Training Leads to Domain Invariance

### Step 1: Domain classifier objective

The discriminator tries to classify domain labels:

- Domain label $d = 1$ for source  
- Domain label $d = 0$ for target  

Optimal discriminator is:

$ D^\*(F(x)) = \frac{p_s(F(x))}{p_s(F(x)) + p_t(F(x))} $

---

### Step 2: Substitute $D^\*$ into adversarial loss

Adversarial loss becomes the binary cross entropy:

$ \mathcal{L}_{dom}(F) = -\log 4 + 2 \cdot \text{JS}(p_s(F(x)) \;\|\; p_t(F(x))) $

Thus, minimizing the domain loss w.r.t. $F$ gives:

$ F^\* = \arg\min \text{JS}(p_s(F(x)), p_t(F(x))) $

This means the feature extractor learns:

$ p_s(F(x)) \approx p_t(F(x)) $

‚áí **source and target feature distributions align**

---

### Step 3: Final objective

Plug back into the joint objective:

$ \min_{F,C} \max_D \left( \mathcal{L}_{cls} - \lambda \mathcal{L}_{dom} \right) $

This yields a saddle point:

- Minimization w.r.t. $F$ reduces classification loss  
- Minimization w.r.t. $F$ also forces domain feature alignment  
- Maximization w.r.t. $D$ improves domain discrimination  

---

## 4. Gradient Reversal Layer (GRL)

To implement adversarial training efficiently, DANN uses a **gradient reversal layer**:

- Forward pass: identity  
  $ \text{GRL}(h) = h $

- Backward pass: multiplies gradient by $- \lambda$  
  $ \frac{\partial \text{GRL}}{\partial h} = -\lambda I $

This makes:

- $D$ receive normal gradients  
- $F$ receive reversed gradients, forcing domain confusion  

---

## 5. Summary

- DANN learns **domain-invariant features** using adversarial training.  
- Objective:  
  $ \min_{F,C} \max_D \left( \mathcal{L}_{cls} - \lambda \mathcal{L}_{dom} \right) $  
- Domain alignment minimizes the Jensen‚ÄìShannon divergence between feature distributions.  
- Implemented using the Gradient Reversal Layer.  
- Widely used in **unsupervised domain adaptation**.



In [None]:
import torch
from torchvision import datasets
from torchvision import transforms
import os


def get_loader(args):
    if args.dset == 's2m':
        svhn_tr = transforms.Compose([transforms.Resize([32, 32]),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.5], [0.5])])
        s_train = datasets.SVHN(os.path.join(args.data_path, 'svhn'), split='train', download=True, transform=svhn_tr)
        s_test = datasets.SVHN(os.path.join(args.data_path, 'svhn'), split='test', download=True, transform=svhn_tr)

        mnist_tr = transforms.Compose([transforms.Resize([32, 32]),
                                       transforms.Grayscale(3),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.5], [0.5])])
        t_train = datasets.MNIST(os.path.join(args.data_path, 'mnist'), train=True, download=True, transform=mnist_tr)
        t_test = datasets.MNIST(os.path.join(args.data_path, 'mnist'), train=False, download=True, transform=mnist_tr)

    elif args.dset == 'u2m':
        tr = transforms.Compose([transforms.Resize([32, 32]),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.5], [0.5])])
        s_train = datasets.USPS(os.path.join(args.data_path, 'usps'), train=True, download=True, transform=tr)
        s_test = datasets.USPS(os.path.join(args.data_path, 'usps'), train=False, download=True, transform=tr)

        t_train = datasets.MNIST(os.path.join(args.data_path, 'mnist'), train=True, download=True, transform=tr)
        t_test = datasets.MNIST(os.path.join(args.data_path, 'mnist'), train=False, download=True, transform=tr)

    elif args.dset == 'm2u':
        tr = transforms.Compose([transforms.Resize([32, 32]),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.5], [0.5])])
        s_train = datasets.MNIST(os.path.join(args.data_path, 'mnist'), train=True, download=True, transform=tr)
        s_test = datasets.MNIST(os.path.join(args.data_path, 'mnist'), train=False, download=True, transform=tr)

        t_train = datasets.USPS(os.path.join(args.data_path, 'usps'), train=True, download=True, transform=tr)
        t_test = datasets.USPS(os.path.join(args.data_path, 'usps'), train=False, download=True, transform=tr)

    elif args.dset == 'm2mm':
        mnist_tr = transforms.Compose([transforms.Resize([32, 32]),
                                       transforms.Grayscale(3),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.5], [0.5])])

        s_train = datasets.MNIST(os.path.join(args.data_path, 'mnist'), train=True, download=True, transform=mnist_tr)
        s_test = datasets.MNIST(os.path.join(args.data_path, 'mnist'), train=False, download=True, transform=mnist_tr)

        mnistm_tr = transforms.Compose([transforms.Resize([32, 32]),
                                        transforms.ToTensor(),
                                        transforms.Normalize([0.5], [0.5])])

        t_train = datasets.ImageFolder(root=os.path.join(args.data_path, 'mnistm', 'trainset'), transform=mnistm_tr)
        t_test = datasets.ImageFolder(root=os.path.join(args.data_path, 'mnistm', 'testset'), transform=mnistm_tr)

    elif args.dset == 'sd2sv':
        tr = transforms.Compose([transforms.Resize([32, 32]),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.5], [0.5])])

        s_train = datasets.ImageFolder(root=os.path.join(args.data_path, 'sydigits', 'trainset'), transform=tr)
        s_test = datasets.ImageFolder(root=os.path.join(args.data_path, 'sydigits', 'trainset'), transform=tr)  # Does not have a testset

        t_train = datasets.SVHN(os.path.join(args.data_path, 'svhn'), split='train', download=True, transform=tr)
        t_test = datasets.SVHN(os.path.join(args.data_path, 'svhn'), split='test', download=True, transform=tr)

    elif args.dset == 'signs':
        tr = transforms.Compose([transforms.Resize([32, 32]),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.5], [0.5])])

        s_train = datasets.ImageFolder(root=os.path.join(args.data_path, 'sysigns', 'trainset'), transform=tr)
        s_test = datasets.ImageFolder(root=os.path.join(args.data_path, 'sysigns', 'trainset'), transform=tr)  # Does not have a testset

        t_train = datasets.ImageFolder(root=os.path.join(args.data_path, 'gtsrb', 'trainset'), transform=tr)
        t_test = datasets.ImageFolder(root=os.path.join(args.data_path, 'gtsrb', 'testset'), transform=tr)

    s_train_loader = torch.utils.data.DataLoader(dataset=s_train,
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 num_workers=args.num_workers,
                                                 drop_last=True)

    s_test_loader = torch.utils.data.DataLoader(dataset=s_test,
                                                batch_size=args.batch_size * 2,
                                                shuffle=False,
                                                num_workers=args.num_workers,
                                                drop_last=False)

    t_train_loader = torch.utils.data.DataLoader(dataset=t_train,
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 num_workers=args.num_workers,
                                                 drop_last=True)

    t_test_loader = torch.utils.data.DataLoader(dataset=t_test,
                                                batch_size=args.batch_size * 2,
                                                shuffle=False,
                                                num_workers=args.num_workers,
                                                drop_last=False)

    return s_train_loader, s_test_loader, t_train_loader, t_test_loader

In [None]:
import math
from torch.autograd import Function


class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, lamda):
        ctx.lamda = lamda
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = (grad_output.neg() * ctx.lamda)
        return output, None


def adjust_alpha(i, epoch, min_len, nepochs):
    p = float(i + epoch * min_len) / nepochs / min_len
    o = -10
    alpha = 2. / (1. + math.exp(o * p)) - 1
    return alpha

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch



class encoder(nn.Module):
    def __init__(self, args):
        super(encoder, self).__init__()
        conv_dim = 32

        self.conv1 = nn.Conv2d(args.channels, conv_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(conv_dim, conv_dim, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(2, stride=2)

        self.conv4 = nn.Conv2d(conv_dim, conv_dim * 2, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(conv_dim * 2, conv_dim * 2, kernel_size=3, padding=1)
        self.pool6 = nn.MaxPool2d(2, stride=2)

        self.conv7 = nn.Conv2d(conv_dim * 2, conv_dim * 4, kernel_size=3, padding=1)
        self.conv8 = nn.Conv2d(conv_dim * 4, conv_dim * 4, kernel_size=3, padding=1)
        self.pool9 = nn.MaxPool2d(2, stride=2)

        self.flat_dim = 4 * 4 * conv_dim * 4
        self.fc1 = nn.Linear(self.flat_dim, 128)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool3(x)

        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = self.pool6(x)

        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        x = self.pool9(x)

        x = x.view(-1, self.flat_dim)
        x = F.relu(self.fc1(x))
        return x


class classifier(nn.Module):
    def __init__(self, args):
        super(classifier, self).__init__()
        self.fc1 = nn.Linear(128, args.num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)

    def forward(self, x):
        x = self.fc1(x)
        return x


class discriminator(nn.Module):
    def __init__(self, args):
        super(discriminator, self).__init__()
        self.args = args

        self.l1 = nn.Linear(128, 500)
        self.l2 = nn.Linear(500, 500)
        self.l3 = nn.Linear(500, 1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0.0, 0.02)

    def forward(self, x, alpha=-1):
        if self.args.method.lower() == 'dann':
            x = GradReverse.apply(x, alpha)
        x = F.leaky_relu(self.l1(x), 0.2)
        x = F.leaky_relu(self.l2(x), 0.2)
        x = torch.sigmoid(self.l3(x))
        return x

In [None]:
import torch
import torch.nn as nn
import os
from torch import optim

from sklearn.metrics import confusion_matrix, accuracy_score


import copy
import torch.nn.functional as F
from torchsummary import summary

class Solver(object):
    def __init__(self, args):
        self.args = args

        self.s_train_loader, self.s_test_loader, self.t_train_loader, self.t_test_loader = get_loader(args)

        self.ce = nn.CrossEntropyLoss()
        self.bce = nn.BCELoss()

        self.best_acc = 0
        self.time_taken = None

        self.enc = encoder(self.args).cuda()
        self.clf = classifier(self.args).cuda()
        self.fd = discriminator(self.args).cuda()

        print('--------Network--------')
        print(self.enc)
        print(self.clf)

        print('--------Feature Disc--------')
        print(self.fd)

        self.fake_label = torch.FloatTensor(self.args.batch_size, 1).fill_(0).cuda()
        self.real_label = torch.FloatTensor(self.args.batch_size, 1).fill_(1).cuda()

        if not args.method == 'src':
            if os.path.exists(os.path.join(self.args.model_path, 'src_enc.pt')):
                print("Loading Source model...")
                self.enc.load_state_dict(torch.load(os.path.join(self.args.model_path, 'src_enc.pt')))
                self.clf.load_state_dict(torch.load(os.path.join(self.args.model_path, 'src_clf.pt')))
            else:
                print("Training Source model...")
                self.src()
                self.test()

    def test_dataset(self, db='t_test'):
        self.enc.eval()
        self.clf.eval()

        actual = []
        pred = []

        if db.lower() == 's_train':
            loader = self.s_train_loader
        elif db.lower() == 's_test':
            loader = self.s_test_loader
        elif db.lower() == 't_train':
            loader = self.t_train_loader
        else:
            loader = self.t_test_loader

        for data in loader:
            img, label = data

            img = img.cuda()

            with torch.no_grad():
                class_out = self.clf(self.enc(img))
            _, predicted = torch.max(class_out.data, 1)

            actual += label.tolist()
            pred += predicted.tolist()

        acc = accuracy_score(y_true=actual, y_pred=pred) * 100
        cm = confusion_matrix(y_true=actual, y_pred=pred, labels=range(self.args.num_classes))

        return acc, cm

    def test(self):
        s_train_acc, cm = self.test_dataset('s_train')
        print("Source Tr Acc: %.2f" % (s_train_acc))
        if self.args.cm:
            print(cm)

        s_test_acc, cm = self.test_dataset('s_test')
        print("Source Te Acc: %.2f" % (s_test_acc))
        if self.args.cm:
            print(cm)

        t_train_acc, cm = self.test_dataset('t_train')
        print("Target Tr Acc: %.2f" % (t_train_acc))
        if self.args.cm:
            print(cm)

        t_test_acc, cm = self.test_dataset('t_test')
        print("Target Te Acc: %.2f" % (t_test_acc))
        if self.args.cm:
            print(cm)

        return s_train_acc, s_test_acc, t_train_acc, t_test_acc

    def src(self):
        total_iters = 0
        self.best_acc = 0
        s_iter_per_epoch = len(iter(self.s_train_loader))
        self.args.src_test_epoch = max(self.args.src_epochs // 10, 1)

        self.optimizer = optim.Adam(list(self.enc.parameters()) + list(self.clf.parameters()), self.args.lr,
                                    betas=[0.5, 0.999], weight_decay=self.args.weight_decay)

        for epoch in range(self.args.src_epochs):

            self.clf.train()
            self.enc.train()

            for i, (source, s_labels) in enumerate(self.s_train_loader):
                total_iters += 1

                source, s_labels = source.cuda(), s_labels.cuda()

                s_logits = self.clf(self.enc(source))
                s_clf_loss = self.ce(s_logits, s_labels)
                loss = s_clf_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                if i % 50 == 0 or i == (s_iter_per_epoch - 1):
                    print('Ep: %d/%d, iter: %d/%d, total_iters: %d, s_err: %.4f'
                          % (epoch + 1, self.args.src_epochs, i + 1, s_iter_per_epoch, total_iters, s_clf_loss))

            if (epoch + 1) % self.args.src_test_epoch == 0:
                s_test_acc, cm = self.test_dataset('s_test')
                print("Source test acc: %0.2f" % (s_test_acc))
                if self.args.cm:
                    print(cm)

                if s_test_acc > self.best_acc:
                    self.best_acc = s_test_acc
                    best_enc = copy.deepcopy(self.enc.state_dict())
                    best_clf = copy.deepcopy(self.clf.state_dict())

        torch.save(best_enc, os.path.join(self.args.model_path, 'src_enc.pt'))
        torch.save(best_clf, os.path.join(self.args.model_path, 'src_clf.pt'))

        self.enc.load_state_dict(best_enc)
        self.clf.load_state_dict(best_clf)

    def dann(self):

        s_iter_per_epoch = len(self.s_train_loader)
        t_iter_per_epoch = len(self.t_train_loader)
        min_len = min(s_iter_per_epoch, t_iter_per_epoch)
        total_iters = 0

        print("Source iters per epoch: %d" % (s_iter_per_epoch))
        print("Target iters per epoch: %d" % (t_iter_per_epoch))
        print("iters per epoch: %d" % (min(s_iter_per_epoch, t_iter_per_epoch)))

        self.optimizer = optim.Adam(list(self.enc.parameters()) + list(self.clf.parameters()) + list(self.fd.parameters()), self.args.lr,
                                      betas=[0.5, 0.999], weight_decay=self.args.weight_decay)

        for epoch in range(self.args.adapt_epochs):
            self.clf.train()
            self.enc.train()
            self.fd.train()

            for i, (source_data, target_data) in enumerate(zip(self.s_train_loader, self.t_train_loader)):
                total_iters += 1
                alpha = adjust_alpha(i, epoch, min_len, self.args.adapt_epochs)

                source, s_labels = source_data
                source, s_labels = source.cuda(), s_labels.cuda()

                target, t_labels = target_data
                target, t_labels = target.cuda(), t_labels.cuda()

                s_deep = self.enc(source)
                s_out = self.clf(s_deep)

                t_deep = self.enc(target)
                t_out = self.clf(t_deep)

                s_fd_out = self.fd(s_deep, alpha=alpha)
                t_fd_out = self.fd(t_deep, alpha=alpha)

                s_domain_err = self.bce(s_fd_out, self.real_label)
                t_domain_err = self.bce(t_fd_out, self.fake_label)
                disc_loss = s_domain_err + t_domain_err

                s_clf_loss = self.ce(s_out, s_labels)

                loss = s_clf_loss + disc_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                if i % 50 == 0 or i == (min_len - 1):
                    print('Ep: %d/%d, iter: %d/%d, total_iters: %d, s_err: %.4f, d_err: %.4f, alpha: %.4f'
                          % (epoch + 1, self.args.adapt_epochs, i + 1, min_len, total_iters, s_clf_loss, disc_loss, alpha))

            if (epoch + 1) % self.args.adapt_test_epoch == 0:
                t_test_acc, cm = self.test_dataset('t_test')
                print("Target test acc: %0.2f" % (t_test_acc))
                if self.args.cm:
                    print(cm)

        torch.save(self.enc.state_dict(), os.path.join(self.args.model_path, 'dann_enc.pt'))
        torch.save(self.clf.state_dict(), os.path.join(self.args.model_path, 'dann_clf.pt'))
        torch.save(self.fd.state_dict(), os.path.join(self.args.model_path, 'dann_disc.pt'))

In [None]:
import os
import torch
import random
import numpy as np
import datetime


# ------------------------------
# ‚ùó Your Solver should be defined elsewhere or imported.
# from solver import Solver
# ------------------------------


class AttrDict(dict):
    """Make args.x attribute-style accessible"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__


def update_args(args):
    args.adapt_epochs = 200
    args.channels = 3
    args.num_classes = 10
    args.cm = True

    if args.dset == 's2m':
        args.source = 'svhn'
        args.target = 'mnist'

    elif args.dset == 'u2m':
        args.source = 'usps'
        args.target = 'mnist'
        args.channels = 1
        args.adapt_epochs = 1000

    elif args.dset == 'm2u':
        args.source = 'mnist'
        args.target = 'usps'
        args.channels = 1
        args.adapt_epochs = 1000

    elif args.dset == 'm2mm':
        args.source = 'mnist'
        args.target = 'mnistm'

    elif args.dset == 'sd2sv':
        args.source = 'sydigits'
        args.target = 'svhn'

    elif args.dset == 'signs':
        args.source = 'sysigns'
        args.target = 'gtsrb'
        args.num_classes = 43
        args.cm = False

    else:
        raise ValueError("Incorrect dataset combination")

    args.model_path = os.path.join(args.model_path, args.dset)
    args.adapt_test_epoch = args.adapt_epochs // 10

    return args


def print_args(args):
    for k, v in dict(sorted(args.items())).items():
        print(f"{k}: {v}")
    print()


def main(args):
    os.makedirs(args.model_path, exist_ok=True)

    solver = Solver(args)

    if args.method == 'src':
        solver.src()
    elif args.method == 'dann':
        solver.dann()

    solver.test()


# ---------------------------------------
# üî• COLAB CONFIG ‚Äî Modify as needed
# ---------------------------------------

args = AttrDict({
    'p_thresh': 0.9,
    'method': 'src',       # or "dann"
    'src_epochs': 100,
    'batch_size': 128,
    'num_workers': 2,
    'lr': 1e-4,
    'weight_decay': 1e-5,
    'log_step': 50,

    'dset': 's2m',  # choose among: s2m, u2m, m2u, m2mm, sd2sv, signs
    'data_path': './data/',
    'model_path': './model',
    'seed': 100,
})

args = update_args(args)

# ---------------------------------------
# üîß SEED SETUP (Colab Safe)
# ---------------------------------------

manual_seed = args.seed
random.seed(manual_seed)
torch.manual_seed(manual_seed)
np.random.seed(manual_seed)
os.environ['PYTHONHASHSEED'] = str(manual_seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed(manual_seed)
    torch.cuda.manual_seed_all(manual_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ---------------------------------------
# üìå Print config in Colab
# ---------------------------------------

print("Training configuration:")
print_args(args)

start_time = datetime.datetime.now()
print("Started at:", start_time.strftime('%Y-%m-%d %H:%M:%S'))

# ---------------------------------------
# üöÄ RUN TRAINING
# ---------------------------------------

main(args)

end_time = datetime.datetime.now()
duration = end_time - start_time

print("Ended at:", end_time.strftime('%Y-%m-%d %H:%M:%S'))
print("Duration:", duration)


Training configuration:
adapt_epochs: 200
adapt_test_epoch: 20
batch_size: 128
channels: 3
cm: True
data_path: ./data/
dset: s2m
log_step: 50
lr: 0.0001
method: src
model_path: ./model/s2m
num_classes: 10
num_workers: 2
p_thresh: 0.9
seed: 100
source: svhn
src_epochs: 100
target: mnist
weight_decay: 1e-05

Started at: 2025-12-04 19:36:29


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 182M/182M [00:37<00:00, 4.83MB/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 64.3M/64.3M [00:14<00:00, 4.33MB/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9.91M/9.91M [00:04<00:00, 2.07MB/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 28.9k/28.9k [00:00<00:00, 131kB/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1.65M/1.65M [00:01<00:00, 1.24MB/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4.54k/4.54k [00:00<00:00, 8.10MB/s]


--------Network--------
encoder(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=2048, out_features=128, bias=True)
)
classifier(
  (fc1): Linear(in_features=128, out_features=10, bias=True)
)
--------Feature Disc--------
discriminator(
  (l1): Linear(in_features=128, out_features=500, bias=True)


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)
  print('Ep: %d/%d, iter: %d/%d, total_iters: %d, s_err: %.4f'


Ep: 1/100, iter: 1/572, total_iters: 1, s_err: 3.0055
Ep: 1/100, iter: 51/572, total_iters: 51, s_err: 2.2779
Ep: 1/100, iter: 101/572, total_iters: 101, s_err: 2.1377
Ep: 1/100, iter: 151/572, total_iters: 151, s_err: 1.9433
Ep: 1/100, iter: 201/572, total_iters: 201, s_err: 1.7359
Ep: 1/100, iter: 251/572, total_iters: 251, s_err: 1.3909
Ep: 1/100, iter: 301/572, total_iters: 301, s_err: 1.2097
Ep: 1/100, iter: 351/572, total_iters: 351, s_err: 0.9503
Ep: 1/100, iter: 401/572, total_iters: 401, s_err: 1.1261
Ep: 1/100, iter: 451/572, total_iters: 451, s_err: 0.8767
Ep: 1/100, iter: 501/572, total_iters: 501, s_err: 0.6626
Ep: 1/100, iter: 551/572, total_iters: 551, s_err: 0.6652
Ep: 1/100, iter: 572/572, total_iters: 572, s_err: 0.5190
Ep: 2/100, iter: 1/572, total_iters: 573, s_err: 0.6120
Ep: 2/100, iter: 51/572, total_iters: 623, s_err: 0.5747
Ep: 2/100, iter: 101/572, total_iters: 673, s_err: 0.4228
Ep: 2/100, iter: 151/572, total_iters: 723, s_err: 0.5365
Ep: 2/100, iter: 201/57

In [None]:
# ----------------------------------------
# ‚ú® AFTER main(args)
# ----------------------------------------

from torchvision import datasets, transforms

# Load MNIST test images
transform = transforms.Compose([
    transforms.ToTensor(),
])

mnist_test = datasets.MNIST('./data/', train=False, transform=transform, download=True)
loader = torch.utils.data.DataLoader(mnist_test, batch_size=16, shuffle=True)

# Fetch a batch
imgs, labels = next(iter(loader))
imgs = imgs.to('cuda' if torch.cuda.is_available() else 'cpu')

# Run prediction
pred = Solver.predict(imgs)

print("Ground truth labels:", labels[:10])
print("Predictions:", pred[:10])


AttributeError: type object 'Solver' has no attribute 'predict'

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8,2))
for i in range(8):
    plt.subplot(1,8,i+1)
    plt.imshow(imgs[i].cpu().squeeze(), cmap="gray")
    plt.title(f"P={pred[i].item()}")
    plt.axis("off")
plt.show()
