This repository contains the full implementation of Reusing the Task-Specific Classifier as a Discriminator: Discriminator-free Adversarial Domain Adaptation using TensorFlow.
The implementation includes experiments conducted on the MNIST and USPS datasets, demonstrated in the following notebooks:
DALNCustom.ipynb
: Custom experiments and analysis.DALNMNISTtoUSPS.ipynb
: Experiment specifically transferring from MNIST to USPS.
The model is trained on MNIST and evaluated on USPS dataset with and without Domain Adaptation techniques using DALN (Discriminator-free Adversarial Domain Adaptation).
- DALNModel.py: Implementation of the DALN model.
- grl.py and nwd.py: Implementations of the Gradient Reversal Layer and Nuclear Wasserstein Discrepancy, adapted from the official implementation repository originally written for PyTorch.
- DALNtrain.py: Utilities for training the DALN model.
- DisplayLogs.py: Utilities for logging and calculating accuracies during training.
- DALNCustom.ipynb and DALNMNISTtoUSPS.ipynb: Jupyter notebooks for conducting experiments and showcasing results.
The experimental results demonstrate a significant improvement in accuracy when using DALN compared to training without domain adaptation:
Method | Source Accuracy | Target Accuracy |
---|---|---|
Source only | 0.9998 | 0.417 |
DALN | 0.9341 | 0.838 |
-
Without DALN: Features from MNIST and USPS datasets appear clearly separated.
-
With DALN: Features from both datasets show similar distributions, indicating successful domain adaptation.
DALN also enhances determinacy (confidence of predictions) and diversity (performance across different classes), as observed in the provided visualizations.
To train the model:
-
Prepare Datasets: Import and resize MNIST and USPS datasets to 32x32x3 to match model requirements.
-
Initialize Model: Import the model from
DALNModel.py
and initialize it. -
Training Setup: Import
train
fromDALNtrain.py
and initialize a training object (trainer
) with parameters such asX_source
,y_source
,model
,batch_size
,X_target
,y_target
,epochs
, andsource_only
boolean. -
Run Training: Set
source_only=True
for training without domain adaptation, orsource_only=False
for training with DALN. -
Predictions: Use the
predict_label
method of the model object to predict labels.
- During training, both source and target accuracies are displayed.
- Alternatively, use the
accuracy_score
function or importdisplay_logs
fromDisplayLogs.py
for automated accuracy calculation and logging.
- Codes in
grl.py
andnwd.py
were adapted from the DALN repository originally written in PyTorch. - The model architecture has been optimized for computational efficiency while maintaining effectiveness.
- MNIST: Handwritten digits dataset imported from
tensorflow.keras.datasets.mnist
. - USPS: Handwritten digits dataset imported from
extra_keras.usps
.
Both datasets are small-sized with sufficient samples, making them suitable for domain adaptation experiments due to their notable differences.
Feel free to explore and contribute to this repository to enhance domain adaptation techniques using DALN.