This repository contains the implementation of Domain-Adversarial Training of Neural Networks (DANN) using TensorFlow. The DANN.ipynb
notebook demonstrates experiments on the MNIST and USPS datasets. The model is trained on MNIST and its accuracy is evaluated on the USPS dataset both with and without domain adaptation techniques.
- DANNModel.py: Contains the implementation of the DANN model.
- grl.py: Implements the Gradient Reversal Layer, adapted from the DALN repository originally written for PyTorch.
- DALNtrain.py: Provides training utilities for the DANN model.
- DisplayLogs.py: Contains utilities for displaying and calculating accuracies during training.
- DANN.ipynb: Jupyter notebook for conducting experiments and showcasing results.
The experimental results show a significant increase in accuracy when using DANN compared to training without domain adaptation:
Method | Source Accuracy | Target Accuracy |
---|---|---|
Source only | 0.9998 | 0.417 |
DANN | 0.9341 | 0.713 |
To train the model:
-
Import Required Modules: Import TensorFlow and load the MNIST dataset, resizing images to 32x32x3 for compatibility.
-
Initialize Model: Import the model from
DANNModel.py
and initialize it. -
Training Setup: Import
train
fromDALNtrain.py
and initialize a training object (trainer
) with parameters includingX_source
,y_source
,model
,batch_size
,X_target
,y_target
,epochs
, andsource_only
boolean flag. -
Run Training: Set
source_only=True
to train without domain adaptation, orsource_only=False
to train with domain adaptation. -
Prediction: Use the
predict_label
method of the model object to predict labels.
- During training, source and target accuracies are displayed.
- Alternatively, use the
accuracy_score
function from scikit-learn or importdisplay_logs
fromDisplayLogs.py
for automated accuracy calculation and logging.
- The code in
grl.py
was adapted from the DALN repository originally written in PyTorch. - The architecture of the model has been simplified for computational efficiency while retaining effectiveness.
- MNIST: Handwritten digits dataset imported from
tensorflow.keras.datasets.mnist
. - USPS: Handwritten digits dataset imported from
extra_keras.usps
.
Both datasets are small-sized with sufficient samples, making them suitable for this domain adaptation experiment due to their notable differences.
Feel free to explore and contribute to this repository to enhance domain adaptation techniques using DANN.