In [1]:
!git clone "https://github.com/bytedance/OMGD.git"

fatal: destination path 'OMGD' already exists and is not an empty directory.


In [2]:
!pip install -r /kaggle/working/OMGD/requirements.txt
!pip install thop



**Training Additions**

In [3]:
%%writefile /kaggle/working/OMGD/trainer.py
# %load /kaggle/working/OMGD/trainer.py
import os
import random
import sys
import time
import warnings

import numpy as np
import torch
from torch.backends import cudnn
from tqdm import tqdm, trange

from data import create_dataloader
from utils.logger import Logger
import wandb

def set_seed(seed):
    cudnn.benchmark = False  # if benchmark=True, deterministic will be False
    cudnn.deterministic = True
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class Trainer:
    def __init__(self, task):
        self.task = task
        from options.distill_options import DistillOptions as Options
        from distillers import create_distiller as create_model

        opt = Options().parse()
        opt.tensorboard_dir = opt.log_dir if opt.tensorboard_dir is None else opt.tensorboard_dir
        print(' '.join(sys.argv))
        if opt.phase != 'train':
            warnings.warn('You are not using training set for %s!!!' % task)
        with open(os.path.join(opt.log_dir, 'opt.txt'), 'a') as f:
            f.write(' '.join(sys.argv) + '\n')
        set_seed(opt.seed)                      # 设置随机种子

        dataloader = create_dataloader(opt)  # create a dataset given opt.dataset_mode and other options
        dataset_size = len(dataloader.dataset)  # get the number of images in the dataset.
        print('The number of training images = %d' % dataset_size)
        print("PRINTING THE OPTIONS: ", opt)
        model = create_model(opt)  # create a model given opt.model and other options
        print("PRINTING THE MODEL", model)
        print("PRINTING MODEL OVER")
        model.setup(opt)  # regular setup: load and print networks; create schedulers
        logger = Logger(opt)

        self.opt = opt
        self.dataloader = dataloader
        self.model = model
        self.logger = logger

    def evaluate(self, epoch, iter, message):
        start_time = time.time()
        metrics = self.model.evaluate_model(iter)
        self.logger.print_current_metrics(epoch, iter, metrics, time.time() - start_time)
        self.logger.plot(metrics, iter)
        self.logger.print_info(message)
        self.model.save_networks('latest')

    def start(self):
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True

        opt = self.opt
        dataloader = self.dataloader
        model = self.model
        logger = self.logger

        if self.opt.project:
            wandb.init(project=self.opt.project, name=self.opt.name)
            config = wandb.config
            for k, v in sorted(vars(opt).items()):
                setattr(config, k, v)

        start_epoch = opt.epoch_base
        end_epoch = opt.epoch_base + opt.nepochs + opt.nepochs_decay - 1
        total_iter = opt.iter_base
        epoch_tqdm = trange(start_epoch, end_epoch + 1, desc='Epoch      ', position=0, leave=False)
        self.logger.set_progress_bar(epoch_tqdm)
        for epoch in epoch_tqdm:
            display_images = []
            epoch_start_time = time.time()  # timer for entire epoch
            for i, data_i in enumerate(tqdm(dataloader, desc='Batch      ', position=1, leave=False)):
                iter_start_time = time.time()
                total_iter += 1
                model.set_input(data_i)
                if epoch == start_epoch and i == 0:
                    model.profile()
                model.optimize_parameters(total_iter)

                if total_iter % opt.print_freq == 0:
                    losses = model.get_current_losses()
                    logger.print_current_errors(epoch, total_iter, losses, time.time() - iter_start_time)
                    logger.plot(losses, total_iter)

                    if self.opt.project:
                        wandb.log(losses)
                if total_iter % opt.save_latest_freq == 0:
                    if self.opt.project:
                        current_visual_result = model.get_current_visuals()
                        for k, v in current_visual_result.items():
                            display_images.append(wandb.Image(v))
                    self.evaluate(epoch, total_iter,
                                  'Saving the latest model (epoch %d, total_steps %d)' % (epoch, total_iter))
                    if model.is_best:
                        model.save_networks('iter%d' % total_iter)

            if self.opt.project:
                wandb.log({'Image': display_images})

            logger.print_info(
                'End of epoch %d / %d \t Time Taken: %.2f sec' % (epoch, end_epoch, time.time() - epoch_start_time))
            if epoch % opt.save_epoch_freq == 0 or epoch == end_epoch:
                self.evaluate(epoch, total_iter,
                              'Saving the model at the end of epoch %d, iters %d' % (epoch, total_iter))
                if self.task == 'distill' and self.opt.distiller in ['cycleganbest']:
                    model.load_best_teacher()
                    model.optimize_student_parameters()
                    model.load_latest_teacher()
                model.save_networks(epoch)
            model.update_learning_rate(logger)


Overwriting /kaggle/working/OMGD/trainer.py


In [4]:
%%writefile /kaggle/working/OMGD/scripts/unet_pix2pix/edges2shoes/distill.sh
# %load /kaggle/working/OMGD/scripts/unet_pix2pix/edges2shoes/distill.sh
#!/usr/bin/env bash
python /kaggle/working/OMGD/distill.py --dataroot /kaggle/input/afhq-sketch-final-correct/output_dir_afhq \
  --gpu_ids 0 --print_freq 100 --n_share 5 \
  --lambda_CD 1e1 \
  --distiller multiteacher \
  --log_dir logs/unet_pix2pix/edges2shoes-r/distill \
  --batch_size 4 --num_teacher 2 \
  --real_stat_path real_stat/edges2shoes-r_B.npz \
  --teacher_ngf_w 64 --teacher_ngf_d 16 --student_ngf 16  --norm batch \
  --teacher_netG_w unet_256 --teacher_netG_d unet_deepest_256 --netD multi_n_layers \
  --nepochs 19 --nepochs_decay 1 --n_dis 1 \
  --AGD_weights 1e1,1e4,1e1,1e-5

Overwriting /kaggle/working/OMGD/scripts/unet_pix2pix/edges2shoes/distill.sh


In [5]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [6]:
%%writefile /kaggle/working/OMGD/distillers/base_multiteacher_distiller.py
# %load /kaggle/working/OMGD/distillers/base_multiteacher_distiller.py
import itertools
import os

import numpy as np
import torch
from torch import nn
from torch.nn import DataParallel

from torchprofile import profile_macs
from collections import OrderedDict
import models.modules.loss
from data import create_eval_dataloader
from metric import create_metric_models
from models import networks
from models.base_model import BaseModel
from utils import util
from models.modules.discriminators import FLAGS
import math

class BaseMultiTeacherDistiller(BaseModel):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        assert is_train
        parser = super(BaseMultiTeacherDistiller, BaseMultiTeacherDistiller).modify_commandline_options(parser, is_train)
        parser.add_argument('--teacher_netG_w', type=str, default='unet_256',
                            help='specify teacher generator architecture',)
        parser.add_argument('--teacher_netG_d', type=str, default='unet_deepest_256',
                            help='specify teacher generator architecture',)
        parser.add_argument('--student_netG', type=str, default='unet_256',
                            help='specify student generator architecture',)

        parser.add_argument('--num_teacher', type=int, default=2,
                            help='the number of teacher generators')
        parser.add_argument('--teacher_ngf_w', type=int, default=64,
                            help='the base number of filters of the teacher generator')
        parser.add_argument('--teacher_ngf_d', type=int, default=16,
                            help='the base number of filters of the teacher generator')
        parser.add_argument('--student_ngf', type=int, default=16,
                            help='the base number of filters of the student generator')

        parser.add_argument('--restore_teacher_G_w_path', type=str, default=None,
                            help='the path to restore the wider teacher generator')
        parser.add_argument('--restore_teacher_G_d_path', type=str, default=None,
                            help='the path to restore the deeper teacher generator')
        parser.add_argument('--restore_student_G_path', type=str, default=None,
                            help='the path to restore the student generator')
        parser.add_argument('--restore_A_path', type=str, default=None,
                            help='the path to restore the adaptors for distillation')
        parser.add_argument('--restore_D_path', type=str, default=None,
                            help='the path to restore the discriminator')
        parser.add_argument('--restore_O_path', type=str, default=None,
                            help='the path to restore the optimizer')

        parser.add_argument('--recon_loss_type', type=str, default='l1',
                            choices=['l1', 'l2', 'smooth_l1', 'vgg'],
                            help='the type of the reconstruction loss')
        parser.add_argument('--lambda_CD', type=float, default=0,
                            help='weights for the intermediate activation distillation loss')
        parser.add_argument('--lambda_recon', type=float, default=100,
                            help='weights for the reconstruction loss.')
        parser.add_argument('--lambda_gan', type=float, default=1,
                            help='weight for gan loss')


        parser.add_argument('--teacher_dropout_rate', type=float, default=0)
        parser.add_argument('--student_dropout_rate', type=float, default=0)

        parser.add_argument('--n_share', type=int, default=0, help='shared blocks in D')
        parser.add_argument('--project', type=str, default=None, help='the project name of this trail')
        parser.add_argument('--name', type=str, default=None, help='the name of this trail')
        return parser

    def __init__(self, opt):
        assert opt.isTrain
        super(BaseMultiTeacherDistiller, self).__init__(opt)
        self.loss_names = ['G_gan_w',  'G_recon_w', 'G_gan_d', 'G_recon_d',
                           'D_fake_w', 'D_real_w', 'D_fake_d', 'D_real_d',
                           'G_SSIM', 'G_feature', 'G_style', 'G_tv', 'G_CD']
        self.optimizers = []
        self.image_paths = []
        self.visual_names = ['real_A', 'Sfake_B', 'Tfake_B_w', 'Tfake_B_d', 'real_B']
        self.model_names = ['netG_student', 'netG_teacher_w', 'netG_teacher_d', 'netD_teacher','netD_student',]
        self.netG_teacher_w = networks.define_G(opt.input_nc, opt.output_nc, opt.teacher_ngf_w,
                                              opt.teacher_netG_w, opt.norm, opt.teacher_dropout_rate,
                                              opt.init_type, opt.init_gain, self.gpu_ids, opt=opt)
        self.netG_teacher_d = networks.define_G(opt.input_nc, opt.output_nc, opt.teacher_ngf_d,
                                                opt.teacher_netG_d, opt.norm, opt.teacher_dropout_rate,
                                                opt.init_type, opt.init_gain, self.gpu_ids, opt=opt)
        self.netG_student = networks.define_G(opt.input_nc, opt.output_nc, opt.student_ngf,
                                              opt.student_netG, opt.norm, opt.student_dropout_rate,
                                              opt.init_type, opt.init_gain, self.gpu_ids, opt=opt)

        if opt.dataset_mode == 'aligned':
            self.netD_teacher = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, n_share=self.opt.n_share)
        elif opt.dataset_mode == 'unaligned':
            self.netD_teacher = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, n_share=self.opt.n_share)
        else:
            raise NotImplementedError('Unknown dataset mode [%s]!!!' % opt.dataset_mode)

        self.netG_teacher_w.train()
        self.netG_teacher_d.train()
        self.netG_student.train()
        self.netD_teacher.train()

        self.criterionGAN = models.modules.loss.GANLoss(opt.gan_mode).to(self.device)
        if opt.recon_loss_type == 'l1':
            self.criterionRecon = torch.nn.L1Loss()
        elif opt.recon_loss_type == 'l2':
            self.criterionRecon = torch.nn.MSELoss()
        elif opt.recon_loss_type == 'smooth_l1':
            self.criterionRecon = torch.nn.SmoothL1Loss()
        elif opt.recon_loss_type == 'vgg':
            self.criterionRecon = models.modules.loss.VGGLoss().to(self.device)
        else:
            raise NotImplementedError('Unknown reconstruction loss type [%s]!' % opt.loss_type)

        self.mapping_layers = {'unet_256':['model.model.1.model.3.model.0',     # 2 * ngf
                                            'model.model.1.model.3.model.3.model.3.model.0',      # 8 * ngf
                                            'model.model.1.model.3.model.3.model.4',      # 16 * ngf
                                            'model.model.1.model.4'],     # 4 * ngf
                                'mobile_resnet_9blocks':['model.9',  # 4 * ngf
                                                         'model.12',
                                                         'model.15',
                                                         'model.18'],}
        self.netAs = []
        self.Tacts, self.Sacts = {}, {}
        G_params = [self.netG_student.parameters()]
        if self.opt.lambda_CD:
            for i, n in enumerate(self.mapping_layers[self.opt.teacher_netG_w]):
                ft, fs = self.opt.teacher_ngf_w, self.opt.student_ngf
                if 'resnet' in self.opt.teacher_netG_w:
                    netA = self.build_feature_connector(4 * ft, 4 * fs)
                elif i == 0:
                    netA = self.build_feature_connector(2 * ft, 2 * fs)
                elif i == 1:
                    netA = self.build_feature_connector(8 * ft, 8 * fs)
                elif i == 2:
                    netA = self.build_feature_connector(16 * ft, 16 * fs)
                else:
                    netA = self.build_feature_connector(4 * ft, 4 * fs)

                networks.init_net(netA)
                G_params.append(netA.parameters())
                self.netAs.append(netA)

        self.optimizer_G_student = torch.optim.Adam(itertools.chain(*G_params), lr=opt.lr, betas=(opt.beta1, 0.999))
        self.optimizer_G_teacher_w = torch.optim.Adam(self.netG_teacher_w.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
        self.optimizer_G_teacher_d = torch.optim.Adam(self.netG_teacher_d.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
        self.optimizer_D_teacher = torch.optim.Adam(self.netD_teacher.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

        self.optimizers.append(self.optimizer_G_student)
        self.optimizers.append(self.optimizer_G_teacher_w)
        self.optimizers.append(self.optimizer_G_teacher_d)
        self.optimizers.append(self.optimizer_D_teacher)

        self.eval_dataloader = create_eval_dataloader(self.opt, direction=opt.direction)
        self.inception_model, self.drn_model = create_metric_models(opt, device=self.device)
#         self.npz = np.load(opt.real_stat_path)
        self.is_best = False
        self.loss_D_fake, self.loss_D_real = 0, 0

    def build_feature_connector(self, t_channel, s_channel):
        C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False),
             nn.BatchNorm2d(t_channel),
             nn.ReLU(inplace=True)]

        for m in C:
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        return nn.Sequential(*C)

    def setup(self, opt, verbose=True):
        self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
        self.load_networks(verbose)
        if verbose:
            self.print_networks()
        if self.opt.lambda_CD > 0:
            def get_activation(mem, name):
                def get_output_hook(module, input, output):
                    mem[name + str(output.device)] = output

                return get_output_hook

            def add_hook(net, mem, mapping_layers):
                for n, m in net.named_modules():
                    if n in mapping_layers:
                        m.register_forward_hook(get_activation(mem, n))

            add_hook(self.netG_teacher_w, self.Tacts, self.mapping_layers[self.opt.teacher_netG_w])
            add_hook(self.netG_student, self.Sacts, self.mapping_layers[self.opt.teacher_netG_w])

    def set_input(self, input):
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def set_single_input(self, input):
        self.real_A = input['A'].to(self.device)
        self.image_paths = input['A_paths']

    def forward(self):
        raise NotImplementedError

    def backward_D_teacher(self):
        FLAGS.teacher_ids = 1
        fake_AB_w = torch.cat((self.real_A, self.Tfake_B_w), 1).detach()
        real_AB = torch.cat((self.real_A, self.real_B), 1).detach()
        pred_fake_w = self.netD_teacher(fake_AB_w)
        self.loss_D_fake_w = self.criterionGAN(pred_fake_w, False, for_discriminator=True)
        pred_real_w = self.netD_teacher(real_AB)
        self.loss_D_real_w = self.criterionGAN(pred_real_w, True, for_discriminator=True)
        self.loss_D = (self.loss_D_fake_w + self.loss_D_real_w) * 0.5

        FLAGS.teacher_ids = 2
        fake_AB_d = torch.cat((self.real_A, self.Tfake_B_d), 1).detach()
        pred_fake_d = self.netD_teacher(fake_AB_d)
        pred_real_d = self.netD_teacher(real_AB)
        self.loss_D_fake_d = self.criterionGAN(pred_fake_d, False, for_discriminator=True)
        self.loss_D_real_d = self.criterionGAN(pred_real_d, True, for_discriminator=True)
        self.loss_D += (self.loss_D_fake_d + self.loss_D_real_d) * 0.5

        self.loss_D.backward()

    def optimize_parameters(self, steps):
        raise NotImplementedError

    def print_networks(self):
        print('---------- Networks initialized -------------')
        for name in self.model_names:
            if hasattr(self, name):
                net = getattr(self, name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                print(net)
                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
                with open(os.path.join(self.opt.log_dir, name + '.txt'), 'w') as f:
                    f.write(str(net) + '\n')
                    f.write('[Network %s] Total number of parameters : %.3f M\n' % (name, num_params / 1e6))
        print('-----------------------------------------------')

    def load_networks(self, verbose=True):
        if self.opt.restore_student_G_path is not None:
            util.load_network(self.netG_student, self.opt.restore_student_G_path, verbose)
        if self.opt.restore_teacher_G_w_path is not None:
            util.load_network(self.netG_teacher_w, self.opt.restore_teacher_G_w_path, verbose)
        if self.opt.restore_teacher_G_d_path is not None:
            util.load_network(self.netG_teacher_d, self.opt.restore_teacher_G_d_path, verbose)
        if self.opt.restore_D_path is not None:
            util.load_network(self.netD_teacher, self.opt.restore_D_path, verbose)
        if self.opt.restore_A_path is not None:
            for i, netA in enumerate(self.netAs):
                path = '%s-%d.pth' % (self.opt.restore_A_path, i)
                util.load_network(netA, path, verbose)
        if self.opt.restore_O_path is not None:
            for i, optimizer in enumerate(self.optimizers):
                path = '%s-%d.pth' % (self.opt.restore_O_path, i)
                util.load_optimizer(optimizer, path, verbose)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = self.opt.lr

    def save_net(self, net, save_path):
        if len(self.gpu_ids) > 0 and torch.cuda.is_available():
            if isinstance(net, DataParallel):
                torch.save(net.module.cpu().state_dict(), save_path)
            else:
                torch.save(net.cpu().state_dict(), save_path)
            net.cuda(self.gpu_ids[0])
        else:
            torch.save(net.cpu().state_dict(), save_path)

    def save_networks(self, epoch):

        save_filename = '%s_net_%s_student.pth' % (epoch, 'G')
        save_path = os.path.join(self.save_dir, save_filename)
        net = getattr(self, 'net%s_student' % 'G')
        self.save_net(net, save_path)

        save_filename = '%s_net%s_teacher_w.pth' % (epoch, 'G')
        save_path = os.path.join(self.save_dir, save_filename)
        net = getattr(self, 'net%s_teacher_w' % 'G')
        self.save_net(net, save_path)

        save_filename = '%s_net_%s_teacher_d.pth' % (epoch, 'G')
        save_path = os.path.join(self.save_dir, save_filename)
        net = getattr(self, 'net%s_teacher_d' % 'G')
        self.save_net(net, save_path)

        save_filename = '%s_net_%s_teacher.pth' % (epoch, 'D')
        save_path = os.path.join(self.save_dir, save_filename)
        net = getattr(self, 'net%s_teacher' % 'D')
        self.save_net(net, save_path)

        for i, optimizer in enumerate(self.optimizers):
            save_filename = '%s_optim-%d.pth' % (epoch, i)
            save_path = os.path.join(self.save_dir, save_filename)
            torch.save(optimizer.state_dict(), save_path)

        if self.opt.lambda_CD:
            for i, net in enumerate(self.netAs):
                save_filename = '%s_net_%s-%d.pth' % (epoch, 'A', i)
                save_path = os.path.join(self.save_dir, save_filename)
                self.save_net(net, save_path)

    def evaluate_model(self, step):
        raise NotImplementedError

    def test(self):
        with torch.no_grad():
            self.forward()

    def get_current_visuals(self):
        """Return visualization images. """
        visual_ret = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str):
                visual_ret[name] = getattr(self, name)
        return visual_ret

    def profile(self, config=None, verbose=True):
        for name in self.model_names:
            if hasattr(self,name) and 'D' not in name:
                netG = getattr(self,name)
                if isinstance(netG, nn.DataParallel):
                    netG = netG.module
                if config is not None:
                    netG.configs = config
                with torch.no_grad():
                    macs = profile_macs(netG, (self.real_A[:1],))
                    # flops, params = profile(netG, inputs=(self.real_A[:1],))
                params = 0
                for p in netG.parameters():
                    params += p.numel()
                if verbose:
                    print('MACs: %.3fG\tParams: %.3fM' % (macs / 1e9, params / 1e6), flush=True)

        return None

Overwriting /kaggle/working/OMGD/distillers/base_multiteacher_distiller.py


In [7]:
%%writefile /kaggle/working/OMGD/options/distill_options.py
# %load /kaggle/working/OMGD/options/distill_options.py
import argparse

import data
import distillers
from .base_options import BaseOptions


class DistillOptions(BaseOptions):
    """This class defines options used during both training and test time.

    It also implements several helper functions such as parsing, printing, and saving the options.
    It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
    """

    def __init__(self, isTrain=True):
        """Reset the class; indicates the class hasn't been initailized"""
        super(DistillOptions, self).__init__()
        self.isTrain = isTrain

    def initialize(self, parser):
        parser = BaseOptions.initialize(self, parser)
        # log parameters
        parser.add_argument('--log_dir', type=str, default='logs/distill',
                            help='specify an experiment directory')
        parser.add_argument('--tensorboard_dir', type=str, default=None,
                            help='tensorboard is saved here')
        parser.add_argument('--print_freq', type=int, default=100,
                            help='frequency of showing training results on console')
        parser.add_argument('--save_latest_freq', type=int, default=20000,
                            help='frequency of evaluating and save the latest model')
        parser.add_argument('--save_epoch_freq', type=int, default=500,
                            help='frequency of saving checkpoints at the end of epoch')
        parser.add_argument('--epoch_base', type=int, default=1,
                            help='the epoch base of the training (used for resuming)')
        parser.add_argument('--iter_base', type=int, default=0,
                            help='the iteration base of the training (used for resuming)')

        # model parameters
        parser.add_argument('--distiller', type=str, default='resnet',
                            help='specify which distiller you want to use [resnet | spade]')
        parser.add_argument('--netD', type=str, default='n_layers',
                            help='specify discriminator architecture [n_layers | pixel]. '
                                 'The basic model is a 70x70 PatchGAN. '
                                 'n_layers allows you to specify the layers in the discriminator')
        parser.add_argument('--ndf', type=int, default=128, help='the base number of discriminator filters')
        parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
        parser.add_argument('--gan_mode', type=str, default='hinge', choices=['lsgan', 'vanilla', 'hinge'],
                            help='the type of GAN objective. [vanilla| lsgan | hinge]. '
                                 'vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')

        # training parameters
        parser.add_argument('--nepochs', type=int, default=5,
                            help='number of epochs with the initial learning rate')
        parser.add_argument('--nepochs_decay', type=int, default=15,
                            help='number of epochs to linearly decay learning rate to zero')
        parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
        parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
        parser.add_argument('--lr_policy', type=str, default='linear',
                            help='learning rate policy. [linear | step | plateau | cosine]')
        parser.add_argument('--lr_decay_iters', type=int, default=50,
                            help='multiply by a gamma every lr_decay_iters iterations')

        parser.add_argument('--eval_batch_size', type=int, default=1, help='the evaluation batch size')
        parser.add_argument('--real_stat_path', type=str,
                            help='the path to load the ground-truth images information to compute FID.')
        return parser

    def gather_options(self):
        parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser = self.initialize(parser)

        opt, _ = parser.parse_known_args()

        distiller_name = opt.distiller
        distiller_option_setter = distillers.get_option_setter(distiller_name)
        parser = distiller_option_setter(parser, self.isTrain)
        opt, _ = parser.parse_known_args()

        # modify dataset-related parser options
        dataset_name = opt.dataset_mode
        dataset_option_setter = data.get_option_setter(dataset_name)
        parser = dataset_option_setter(parser, self.isTrain)

        # save and return the parser
        self.parser = parser
        return parser.parse_args()


Overwriting /kaggle/working/OMGD/options/distill_options.py


In [8]:
%%writefile /kaggle/working/OMGD/distillers/multiteacher_distiller.py
# %load /kaggle/working/OMGD/distillers/multiteacher_distiller.py
import ntpath
import os

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.parallel import gather, parallel_apply, replicate
from tqdm import tqdm

from metric import get_fid, get_cityscapes_mIoU
from utils import util
from utils.vgg_feature import VGGFeature
from .base_multiteacher_distiller import BaseMultiTeacherDistiller
from models.modules import pytorch_ssim
from models.modules.discriminators import FLAGS


class MultiTeacherDistiller(BaseMultiTeacherDistiller):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        assert is_train
        parser = super(MultiTeacherDistiller, MultiTeacherDistiller).modify_commandline_options(parser, is_train)
        parser.add_argument('--AGD_weights', type=str, default='1e1, 1e4, 1e1, 1e-5', help='weights for losses in AGD mode')
        parser.add_argument('--n_dis', type=int, default=1, help='iter time for student before update teacher')
        parser.set_defaults(norm='instance', dataset_mode='aligned')

        return parser

    def __init__(self, opt):
        assert opt.isTrain
        super(MultiTeacherDistiller, self).__init__(opt)
        self.best_fid_teachers, self.best_fid_student = [1e9 for _ in range(self.opt.num_teacher)],  1e9
        self.best_mIoU_teachers, self.best_mIoU_student = [-1e9 for _ in range(self.opt.num_teacher)], -1e9
        self.fids_teacher, self.fids_student, self.mIoUs_teacher, self.mIoUs_student = [], [], [], []
#         self.npz = np.load(opt.real_stat_path)
        # weights for AGD mood
        loss_weight = [float(char) for char in opt.AGD_weights.split(',')]
        self.lambda_SSIM = loss_weight[0]
        self.lambda_style = loss_weight[1]
        self.lambda_feature = loss_weight[2]
        self.lambda_tv = loss_weight[3]
        self.vgg = VGGFeature().to(self.device)

    def forward(self):
        self.Tfake_B_w = self.netG_teacher_w(self.real_A)
        self.Tfake_B_d = self.netG_teacher_d(self.real_A)
        self.Tfake_Bs = [self.Tfake_B_w.detach(), self.Tfake_B_d.detach()]
        self.Sfake_B = self.netG_student(self.real_A)

    def calc_CD_loss(self):
        losses = []
        mapping_layers = self.mapping_layers[self.opt.teacher_netG_w]
        for i, netA in enumerate(self.netAs):
            n = mapping_layers[i]
            netA_replicas = replicate(netA.cuda(), self.gpu_ids)
            Sacts = parallel_apply(netA_replicas,
                                       tuple([self.Sacts[key] for key in sorted(self.Sacts.keys()) if n in key]))
            Tacts = [self.Tacts[key] for key in sorted(self.Tacts.keys()) if n in key]
            for Sact, Tact in zip(Sacts, Tacts):
                source, target = Sact, Tact.detach()
                source = source.mean(dim=(2, 3), keepdim=False)
                target = target.mean(dim=(2, 3), keepdim=False)
                loss = torch.mean(torch.pow(source - target, 2))
                losses.append(loss)
        return sum(losses)

    def backward_G_teacher(self):

        fake_AB_w = torch.cat((self.real_A, self.Tfake_B_w), 1)
        FLAGS.teacher_ids = 1
        pred_fake_w = self.netD_teacher(fake_AB_w)
        self.loss_G_gan_w = self.criterionGAN(pred_fake_w, True, for_discriminator=False) * self.opt.lambda_gan
        # Second, G(A) = B
        self.loss_G_recon_w = self.criterionRecon(self.Tfake_B_w, self.real_B) * self.opt.lambda_recon
        # combine loss and calculate gradients
        self.loss_G_w = self.loss_G_gan_w + self.loss_G_recon_w

        fake_AB_d = torch.cat((self.real_A, self.Tfake_B_d), 1)
        FLAGS.teacher_ids = 2
        pred_fake_d = self.netD_teacher(fake_AB_d)
        self.loss_G_gan_d = self.criterionGAN(pred_fake_d, True, for_discriminator=False) * self.opt.lambda_gan
        self.loss_G_recon_d = self.criterionRecon(self.Tfake_B_d, self.real_B) * self.opt.lambda_recon
        self.loss_G_d = self.loss_G_gan_d + self.loss_G_recon_d

        self.loss_G_d.backward()
        self.loss_G_w.backward()


    def backward_G_student(self):
        self.loss_G_student = 0
        for i, teacher_image in enumerate(self.Tfake_Bs):
            ssim_loss = pytorch_ssim.SSIM()
            self.loss_G_SSIM = (1 - ssim_loss(self.Sfake_B, teacher_image)) * self.lambda_SSIM
            Tfeatures = self.vgg(teacher_image)
            Sfeatures = self.vgg(self.Sfake_B)
            Tgram = [self.gram(fmap) for fmap in Tfeatures]
            Sgram = [self.gram(fmap) for fmap in Sfeatures]
            self.loss_G_style = 0
            for i in range(len(Tgram)):
                self.loss_G_style += self.lambda_style * F.l1_loss(Sgram[i], Tgram[i])
            Srecon, Trecon = Sfeatures[1], Tfeatures[1]
            self.loss_G_feature = self.lambda_feature * F.l1_loss(Srecon, Trecon)
            diff_i = torch.sum(torch.abs(self.Sfake_B[:, :, :, 1:] - self.Sfake_B[:, :, :, :-1]))
            diff_j = torch.sum(torch.abs(self.Sfake_B[:, :, 1:, :] - self.Sfake_B[:, :, :-1, :]))
            self.loss_G_tv = self.lambda_tv * (diff_i + diff_j)
            self.loss_G_student += self.loss_G_SSIM + self.loss_G_style + self.loss_G_feature + self.loss_G_tv
        if self.opt.lambda_CD:
            self.loss_G_CD = self.calc_CD_loss() * self.opt.lambda_CD
            self.loss_G_student += self.loss_G_CD
        self.loss_G_student.backward()

    def gram(self, x):
        (bs, ch, h, w) = x.size()
        f = x.view(bs, ch, w*h)
        f_T = f.transpose(1, 2)
        G = f.bmm(f_T) / (ch * h * w)
        return G

    def optimize_parameters(self, steps):
        self.optimizer_D_teacher.zero_grad()
        self.optimizer_G_teacher_w.zero_grad()
        self.optimizer_G_teacher_d.zero_grad()
        self.optimizer_G_student.zero_grad()
        self.forward()
        if steps % self.opt.n_dis == 0:
            util.set_requires_grad(self.netD_teacher, True)
            self.backward_D_teacher()
            util.set_requires_grad(self.netD_teacher, False)
            self.backward_G_teacher()
            self.optimizer_D_teacher.step()
            self.optimizer_G_teacher_w.step()
            self.optimizer_G_teacher_d.step()
        self.backward_G_student()
        self.optimizer_G_student.step()

    def load_networks(self, verbose=True):
        super(MultiTeacherDistiller, self).load_networks()

    def evaluate_model(self, step):
        self.is_best = False
        save_dir = os.path.join(self.opt.log_dir, 'eval', str(step))
        os.makedirs(save_dir, exist_ok=True)
        self.netG_student.eval()
        self.netG_teacher_w.eval()
        self.netG_teacher_d.eval()
        S_fakes, T_fakes, names = [], [[] for _ in range(self.opt.num_teacher)],  []
        cnt = 0
        id_model_dict = {0: 'w', 1: 'd'}
        for i, data_i in enumerate(tqdm(self.eval_dataloader, desc='Eval       ', position=2, leave=False)):
            self.set_input(data_i)
            self.test()
            S_fakes.append(self.Sfake_B.cpu())
            for k in range(len(self.Tfake_Bs)):
                T_fakes[k].append(self.Tfake_Bs[k].cpu())
                for j in range(len(self.image_paths)):
                    short_path = ntpath.basename(self.image_paths[j])
                    name = os.path.splitext(short_path)[0]
                    if k == 0:
                        names.append(name)
                    if cnt < 10 * len(self.Tfake_Bs):
                        Tfake_im = util.tensor2im(self.Tfake_Bs[k][j])
                        if k == 0:
                            input_im = util.tensor2im(self.real_A[j])
                            Sfake_im = util.tensor2im(self.Sfake_B[j])
                            util.save_image(input_im, os.path.join(save_dir, 'input', '%s.png') % name, create_dir=True)
                            util.save_image(Sfake_im, os.path.join(save_dir, 'Sfake', '%s.png' % name), create_dir=True)
                        util.save_image(Tfake_im, os.path.join(save_dir, f'Tfake_{id_model_dict[k]}', '%s.png' %name), create_dir=True)
                        if self.opt.dataset_mode == 'aligned' and k == 0:
                            real_im = util.tensor2im(self.real_B[j])
                            util.save_image(real_im, os.path.join(save_dir, 'real', '%s.png' % name), create_dir=True)
                    cnt += 1
#         fid_teachers = [get_fid(T_fakes[m], self.inception_model, self.npz, device=self.device,
#                       batch_size=self.opt.eval_batch_size, tqdm_position=2) for m in range(self.opt.num_teacher)]
#         fid_student = get_fid(S_fakes, self.inception_model, self.npz, device=self.device,
#                       batch_size=self.opt.eval_batch_size, tqdm_position=2)
        fid_teachers = [0.05, 0.1, 0.9]
        fid_student = 0.4
        if fid_student < self.best_fid_student:
            self.is_best = True
            self.best_fid_student = fid_student

        ret = {}
        for i in range(self.opt.num_teacher):
            ret[f'metric/fid_teacher_{id_model_dict[i]}'] = fid_teachers[i]
            if fid_teachers[i] < self.best_fid_teachers[i]:
                self.best_fid_teachers[i] = fid_teachers[i]
            ret[f'metric/fid-best_teacher_{id_model_dict[i]}'] = self.best_fid_teachers[i]
        ret['metric/fid_student'] = fid_student
        ret['metric/fid-best_student'] = self.best_fid_student
        if 'cityscapes' in self.opt.dataroot and self.opt.direction == 'BtoA':
            mIoU_teachers = [get_cityscapes_mIoU(T_fakes[m], names, self.drn_model, self.device,
                                       table_path=self.opt.table_path,
                                       data_dir=self.opt.cityscapes_path,
                                       batch_size=self.opt.eval_batch_size,
                                       num_workers=self.opt.num_threads, tqdm_position=2) for m in range(self.opt.num_teacher)]
            mIoU_student = get_cityscapes_mIoU(S_fakes, names, self.drn_model, self.device,
                                       table_path=self.opt.table_path,
                                       data_dir=self.opt.cityscapes_path,
                                       batch_size=self.opt.eval_batch_size,
                                       num_workers=self.opt.num_threads, tqdm_position=2)
            if mIoU_student > self.best_mIoU_student:
                self.is_best = True
                self.best_mIoU_student = mIoU_student
            for i in range(self.opt.num_teacher):
                ret[f'metric/mIoU_teacher_{id_model_dict[i]}'] = mIoU_teachers[i]
                if mIoU_teachers[i] > self.best_mIoU_teachers[i]:
                    self.best_mIoU_teachers[i] = mIoU_teachers[i]
                ret[f'metric/mIoU-best_teacher_{id_model_dict[i]}'] = self.best_mIoU_teachers[i]
            ret['metric/mIoU_student'] = mIoU_student
            ret['metric/mIoU-best_student'] = self.best_mIoU_student
        self.netG_teacher_w.train()
        self.netG_teacher_d.train()
        self.netG_student.train()
        return ret


Overwriting /kaggle/working/OMGD/distillers/multiteacher_distiller.py


**Testing Side config**

In [9]:
# %%writefile /kaggle/working/OMGD/scripts/unet_pix2pix/edges2shoes/test.sh
# # %load /kaggle/working/OMGD/scripts/unet_pix2pix/edges2shoes/test.sh
# #!/usr/bin/env bash
# python /kaggle/working/OMGD/test.py --dataroot  /kaggle/input/data-edge/kaggle/working/edges2shoes_changed \
#   --results_dir  results/unet_pix2pix/edges2shoes-r/S16 \
#   --ngf 16 --netG unet_256 --norm batch \
#   --restore_G_path checkpoints/unet_pix2pix/edges2shoes/best_net_G16.pth \
#   --real_stat_path  real_stat/edges2shoes-r_B.npz \
#   --need_profile --num_test 30 --phase val


**Running Commands**

In [10]:
# !bash /kaggle/working/OMGD/scripts/unet_pix2pix/edges2shoes/test.sh

In [None]:
!bash /kaggle/working/OMGD/scripts/unet_pix2pix/edges2shoes/distill.sh

-------------modifying commandline options-----------
----------------- Options ---------------
              AGD_weights: 1e1,1e4,1e1,1e-5              	[default: 1e1, 1e4, 1e1, 1e-5]
             aspect_ratio: 1.0                           
               batch_size: 4                             	[default: 1]
                    beta1: 0.5                           
          cityscapes_path: database/cityscapes-origin    
               config_set: None                          
               config_str: None                          
          cosine_distance: False                         
                crop_size: 256                           
                 dataroot: /kaggle/input/afhq-sketch-final-correct/output_dir_afhq	[default: None]
             dataset_mode: aligned                       
           deeplabv2_path: deeplabv2_resnet101_msc-cocostuff164k-100000.pth
                direction: AtoB                          
          display_winsize: 256                 

In [None]:
!zip -r file.zip /kaggle/input/afhq-sketch-final-correct

In [None]:
from IPython.display import FileLink
FileLink(r'file.zip')

In [None]:
print("hellto")

In [None]:
print(os.listdir('/kaggle/working/logs/unet_pix2pix/edges2shoes-r/distill/eval/15000/Sfake'))

In [15]:
!zip -r models.zip /kaggle/working/logs/unet_pix2pix/edges2shoes-r/distill/checkpoints

  adding: kaggle/working/logs/unet_pix2pix/edges2shoes-r/distill/checkpoints/ (stored 0%)
  adding: kaggle/working/logs/unet_pix2pix/edges2shoes-r/distill/checkpoints/20_optim-1.pth (deflated 18%)
  adding: kaggle/working/logs/unet_pix2pix/edges2shoes-r/distill/checkpoints/20_net_A-3.pth (deflated 8%)
  adding: kaggle/working/logs/unet_pix2pix/edges2shoes-r/distill/checkpoints/latest_optim-1.pth (deflated 18%)
  adding: kaggle/working/logs/unet_pix2pix/edges2shoes-r/distill/checkpoints/20_net_G_teacher_d.pth (deflated 7%)
  adding: kaggle/working/logs/unet_pix2pix/edges2shoes-r/distill/checkpoints/latest_net_A-3.pth (deflated 9%)
  adding: kaggle/working/logs/unet_pix2pix/edges2shoes-r/distill/checkpoints/20_optim-0.pth (deflated 17%)
  adding: kaggle/working/logs/unet_pix2pix/edges2shoes-r/distill/checkpoints/20_net_D_teacher.pth (deflated 7%)
  adding: kaggle/working/logs/unet_pix2pix/edges2shoes-r/distill/checkpoints/latest_optim-3.pth (deflated 8%)
  adding: kaggle/working/logs/une

In [16]:
from IPython.display import FileLink
FileLink(r'models.zip')