Towards Lifelong Self-Supervision For Unpaired Image-to-Image Translation
This repo contains the Pytorch code for the paper Towards Lifelong Self-Supervision For Unpaired Image-to-Image Translation
It is not very user-friendly (yet?) and some experimental tricks / comments will be lying around.
Its codebase comes from CycleGAN's Pytorch repo.
Main differences include the use of comet.ml
to log experiments and LiSS
-specific command-line arguments.
Getting Started
These instructions will help you get the project up and running on your local machine using your own data.
Prerequisites
Install library requirements from requirements.txt file
pip3 install -r requirements.txt
Data
The data should be structured as CycleGANs. Additionally, if you want to use depth
as an auxiliary task, you need to have sub-folders data/trainA/depths
(same for all other data folders) with inferences from MegaDepth.
Here is a version of the Horse<>Zebra dataset with depth.
Generated Horse to Zebra
Comparison of models on the horse -> zebra dataset, with rows corresponding the image to translate (row 0) then translations from: CycleGAN (row 1), LiSS CycleGAN (row 2), Parallel Schedule (row 3), Sequential schedule(row 4). Note: the slight discrepancy in cropping is due to data-loading randomness.
Code
If you want to understand the code in liss_model.py
, you should follow the following schedule:
- Training and validation procedures rely on the
AuxiliaryTasks
class to index, sort and parametrize tasks (see--auxiliary_tasks
andmodels/task.py
) - According to
train.py
, the typical update step calls are:model.set_input(data)
model.optimize_parameters()
which itself calls in order:model.forward()
model.backward_G()
model.backward_D()
- We use
self.should_compute(task_key)
for the model to know which head/task should be trained / back-propped through at any time - Schedules are updated according to validation metrics with
model.update_task_schedule(metrics)
which is set bymodel.init_schedule()
- The reference encoder is updated in
model.update_ref_encoder()
- The reference encoder is updated in
- Each generator
netG_A
ornetG_B
has anencoder
attribute and one head per task.decoder
is the translation head so thatdecoder(encoder(x))
is CycleGAN's default generator.- Be careful with Pytorch's
nn.DataParallel
interface: to accessnetG_A
'sencoder
for instance you should donetG_A.encoder
oncpu
and singlegpu
butnetG_A.module.encoder
on ann.DataParallel
device (same goes for all other attributes and tensors).
- Be careful with Pytorch's
LiSS Arguments
Compared to CycleGAN, LiSS
uses RAdam to optimize parameters and adds the following arguments (in liss_mode.py
):
parser.add_argument(
"--lambda_CA",
type=float,
default=10.0,
help="weight for cycle loss (A -> B -> A)",
)
parser.add_argument(
"--lambda_CB",
type=float,
default=10.0,
help="weight for cycle loss (B -> A -> B)",
)
parser.add_argument(
"--lambda_DA",
type=float,
default=1.0,
help="weight for Discriminator loss (A -> B -> A)",
)
parser.add_argument(
"--lambda_DB",
type=float,
default=1.0,
help="weight for Discriminator loss (B -> A -> B)",
)
parser.add_argument(
"--lambda_R", type=float, default=1.0, help="weight for rotation"
)
parser.add_argument(
"--lambda_J", type=float, default=1.0, help="weight for jigsaw"
)
parser.add_argument(
"--lambda_D", type=float, default=1.0, help="weight for depth"
)
parser.add_argument(
"--lambda_G", type=float, default=1.0, help="weight for gray"
)
parser.add_argument(
"--lambda_DR",
type=float,
default=1.0,
help="weight for rotation loss in discriminator (see SSGAN minimax)",
)
parser.add_argument(
"--lambda_distillation",
type=float,
default=5.0,
help="weight for distillation loss (when repr_mode == 'distillation')",
)
parser.add_argument(
"--lambda_I",
type=float,
default=0.5,
help="use identity mapping. Setting lambda_I other than 0 has an\
effect of scaling the weight of the identity mapping loss. For \
example, if the weight of the identity loss should be 10 times \
smaller than the weight of the reconstruction loss, please set \
lambda_I = 0.1",
)
parser.add_argument(
"--task_schedule", type=str, default="parallel", help="Tasks schedule"
)
# sequential : <rotation> then <depth> then <translation>
# without the possibility to come back
#
# parallel : always <depth, rotation, translation>
#
# additional : <rotation> then <depth, rotation> then
# <depth, rotation, translation>
#
# liss : sequential with distillation
#
# representational : <rotation, depth, identity> then <translation>
parser.add_argument(
"--rotation_acc_threshold",
type=float,
default=0.2,
help="minimal rotation classification loss to switch task",
)
parser.add_argument(
"--jigsaw_acc_threshold",
type=float,
default=0.2,
help="minimal jigsaw classification loss to switch task",
)
parser.add_argument(
"--depth_loss_threshold",
type=float,
default=0.5,
help="minimal depth estimation loss to switch task",
)
parser.add_argument(
"--gray_loss_threshold",
type=float,
default=0.5,
help="minimal gray loss to switch task",
)
parser.add_argument(
"--i_loss_threshold",
type=float,
default=0.5,
help="minimal identity loss to switch task (representational only)",
)
parser.add_argument(
"--lr_rotation",
type=float,
default=0,
help="minimal identity loss to switch task (representational only)",
)
parser.add_argument(
"--lr_depth",
type=float,
default=0,
help="minimal identity loss to switch task (representational only)",
)
parser.add_argument("--lr_gray", type=float, default=0)
parser.add_argument("--lr_jigsaw", type=float, default=0)
parser.add_argument(
"--encoder_merge_ratio",
type=float,
default=1.0,
help="Exp. moving average coefficient: ref = a * new + (1 - a) * old",
)
parser.add_argument(
"--auxiliary_tasks", type=str, default="rotation, gray, depth, jigsaw"
)
parser.add_argument(
"--D_rotation",
action="store_true",
default=False,
help="use rotation self-supervision in discriminators when translating",
)
parser.add_argument(
"--repr_mode",
type=str,
default="freeze",
help="freeze | flow | distillation: When switching from representation "
+ "to traduction: either freeze the encoder or let gradients flow. "
+ "Set continual for flow + distillation loss",
)
License
This project is licensed under the GNU License - see the LICENSE.md file for details