# Training the second model: Step 1
After obtaining the pseudo labels from the first model, we train the second model. This time, we do not need grayscale images.
This model will be trained with contrastive loss on target data using pseudo labels in addition to cross entropy on source images in target style and entropy minimization on target data.

In [1]:
import numpy as np
from utils.timer import Timer
import os
from data import CreateSrcDataLoader
from data import CreateTrgDataLoaderPseudo
from data import CreateTrgDataLoader
from model import CreateModel
import torch.backends.cudnn as cudnn
import torch
from torch.autograd import Variable
from utils import FDA_source_to_target
import scipy.io as sio
from torch.utils.tensorboard import SummaryWriter
from contrastiveloss import *
writer = SummaryWriter()

In [2]:
IMG_MEAN = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
IMG_MEAN = torch.reshape(torch.from_numpy(IMG_MEAN), (1, 3, 1, 1))

CS_weights = np.array((1.0, 1.0, 1.0, 1.0, 100.0, 1.0, 100.0, 100.0, 1.0, 1.0, 1.0,
                       100.0, 100.0, 1.0, 1.0, 100.0, 500.0, 500.0, 1000.0), dtype=np.float32)
weight_normalizer = CS_weights.max()/2
CS_weights = torch.from_numpy(CS_weights/weight_normalizer)
# weighting fence, light, sign, person, rider, bus, train, motorcycle, bicycle

## Train options
Parameters for the training process such as beta, ita, temperature as well as the directories to load/save model weights can be adjusted here. The directory that the weights of this model are saved will be used in the second step for further training.

In [3]:
import argparse
import os.path as osp

class TrainOptions():
    def initialize(self):
        parser = argparse.ArgumentParser( description="training script for FDA" )
        parser.add_argument("--model", type=str, default='DeepLab', help="available options : DeepLab and VGG")
        parser.add_argument("--LB", type=float, default=0.09, help="beta for FDA")
        parser.add_argument("--GPU", type=str, default='0', help="which GPU to use")
        parser.add_argument("--entW", type=float, default=0.005, help="weight for entropy")
        parser.add_argument("--ita", type=float, default=2.0, help="ita for robust entropy")
        parser.add_argument("--temperature", type=float, default=0.07, help="temperature for contrastive loss")
        parser.add_argument("--switch2entropy", type=int, default=0, help="switch to entropy after this many steps")
        parser.add_argument("--switch2contrast", type=int, default=50000, help="switch to contrastive learning  after this many steps")
        parser.add_argument('--threshold', default=0.95, type=float, help='pseudo label threshold')
        parser.add_argument("--source", type=str, default='gta5', help="source dataset : gta5 or synthia")
        parser.add_argument("--target", type=str, default='cityscapes', help="target dataset : cityscapes")
        parser.add_argument("--snapshot-dir", type=str, default='../checkpoints/UDA/step1', help="Where to save snapshots of the model.")
        parser.add_argument("--data-dir", type=str, default='../data/GTA5', help="Path to the directory containing the source dataset.")
        parser.add_argument("--data-list", type=str, default='./dataset/gta5_list/train_all.txt', help="Path to the listing of images in the source dataset.")
        parser.add_argument("--data-dir-target", type=str, default='../data/cityscapes', help="Path to the directory containing the target dataset.")
        parser.add_argument("--data-list-target", type=str, default='./dataset/cityscapes_list/train.txt', help="list of images in the target dataset.")
        parser.add_argument("--data-list-val", type=str, default='./dataset/cityscapes_list/val.txt', help="list of val images in the target dataset.")
        parser.add_argument("--set", type=str, default='train', help="choose adaptation set.")
        parser.add_argument("--label-folder", type=str, default=None, help="Path to the directory containing the pseudo labels.")

        parser.add_argument("--batch-size", type=int, default=1, help="input batch size.")
        parser.add_argument("--num-steps", type=int, default=60000, help="Number of training steps.")
        parser.add_argument("--num-steps-stop", type=int, default=100000, help="Number of training steps for early stopping.")
        parser.add_argument("--num-workers", type=int, default=4, help="number of threads.")
        parser.add_argument("--learning-rate", type=float, default=2.5e-4, help="initial learning rate for the segmentation network.")
        parser.add_argument("--momentum", type=float, default=0.9, help="Momentum component of the optimiser.")
        parser.add_argument("--weight-decay", type=float, default=0.0005, help="Regularisation parameter for L2-loss.")
        parser.add_argument("--power", type=float, default=0.9, help="Decay parameter to compute the learning rate (only for deeplab).")

        parser.add_argument("--num-classes", type=int, default=19, help="Number of classes for cityscapes.")
        parser.add_argument("--init-weights", type=str, default='../checkpoints/DeepLab_init.pth', help="initial model.")
        parser.add_argument("--restore-from", type=str, default=None, help="Where restore model parameters from.")
        parser.add_argument("--save-pred-every", type=int, default=1000, help="Save summaries and checkpoint every often.")
        parser.add_argument("--print-freq", type=int, default=100, help="print loss and time fequency.")
        parser.add_argument("--matname", type=str, default='loss_log.mat', help="mat name to save loss")
        parser.add_argument("--tempdata", type=str, default='tempdata.mat', help="mat name to save data")


        return parser.parse_args(args=[])
    
    def print_options(self, args):
        message = ''
        message += '----------------- Options ---------------\n'
        for k, v in sorted(vars(args).items()):
            comment = ''
            message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
        message += '----------------- End -------------------'
        print(message)
    
        # save to the disk
        file_name = osp.join(args.snapshot_dir, 'opt.txt')
        with open(file_name, 'wt') as args_file:
            args_file.write(message)
            args_file.write('\n')

In [4]:
opt = TrainOptions()
args = opt.initialize()
os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
_t = {'iter time': Timer()}

model_name = args.source + '_to_' + args.target
if not os.path.exists(args.snapshot_dir):
    os.makedirs(args.snapshot_dir)
    os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
opt.print_options(args)


sourceloader, targetloader, valloader = CreateSrcDataLoader(args), CreateTrgDataLoaderPseudo(args), CreateValDataLoader(args, True)

sourceloader_iter, targetloader_iter, valloader_iter = iter(sourceloader), iter(targetloader), iter(valloader)
print("model is created")
model, optimizer = CreateModel(args)

----------------- Options ---------------
                      GPU: 0                             
                       LB: 0.01                          
               batch_size: 1                             
                 data_dir: ../data/GTA5                  
          data_dir_target: ../data/cityscapes            
                data_list: ./dataset/gta5_list/train_all.txt
         data_list_target: ./dataset/cityscapes_list/train.txt
            data_list_val: ./dataset/cityscapes_list/train.txt
                     entW: 0.005                         
             init_weights: ../checkpoints/DeepLab_init.pth
                      ita: 2.0                           
             label_folder: None                          
            learning_rate: 0.00025                       
                  matname: loss_log.mat                  
                    model: DeepLab                       
                 momentum: 0.9                           
              nu

  super(SGD, self).__init__(params, defaults)


In [5]:
start_iter = 0
if args.restore_from is not None:
    start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])

cudnn.enabled = True
cudnn.benchmark = True

model.train()
model.cuda()

# losses to log
loss_all = 0
loss_train = 0
loss_val = 0.0
loss_contrastive_trg = 0.0
loss_ent = 0



best_loss_trg = float('inf')
loss_val_list = []

mean_img = torch.zeros(1, 1)
class_weights = Variable(CS_weights).cuda()
_t['iter time'].tic()

## Training the model

In [None]:
for i in range(start_iter, args.num_steps):
    model.adjust_learning_rate(args, optimizer, i)  # adjust learning rate
    optimizer.zero_grad()  # zero grad
    
    src_img, src_lbl, _, _ = sourceloader_iter.next()  # new batch source
    trg_img, trg_lbl, trg_lbl_pseudo, _, _ = targetloader_iter.next()  # new batch target
    
    if mean_img.shape[-1] < 2:
        B, C, H, W = trg_img.shape
        mean_img = IMG_MEAN.repeat(B, 1, H, W)



    # 1. perform spectral transfer on source to get image in target style
    src_in_trg = FDA_source_to_target(src_img, trg_img, L=args.LB)  # src_lbl
    src_in_trg = src_in_trg - mean_img  # trg_img 
    
    
    # 2. forward pass source image in trg 
    src_in_trg, src_lbl = Variable(src_in_trg).cuda(), Variable(src_lbl.long()).cuda()  # to gpu
    src_seg_score = model(src_in_trg, lbl=src_lbl, weight=class_weights, ita=args.ita)  # forward pass
    loss_seg_src = model.loss_seg  # get loss

    # 3. perform spectral transfer on target to get image in source style
    trg_in_src = FDA_source_to_target(trg_img, src_img, L=args.LB)  # src_lbl
    trg_in_src = trg_in_src - mean_img  # trg_img 
    
    # 4. forward pass target image in source style
    trg_in_src, trg_lbl_pseudo = Variable(trg_in_src).cuda(), Variable(trg_lbl_pseudo.long()).cuda()  # to gpu
    trg_seg_score = model(trg_in_src, lbl=trg_lbl, weight=class_weights, ita=args.ita)  # forward pass
    trg_in_src_cont = model.cont # get the embeddings to compute contrastive loss
    
    # 5. forward pass target image
    trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(trg_lbl.long()).cuda()  # to gpu
    trg_seg_score = model(trg_img, lbl=trg_lbl, weight=class_weights, ita=args.ita)  # forward pass
    loss_seg_trg = model.loss_seg  # get loss
    loss_ent_trg = model.loss_ent
    trg_cont = model.cont # get the embeddings to compute contrastive loss
    
    # 6. sum all the losses and backpropagate 
    loss_all = loss_seg_src + int(i>=args.switch2contrast)*loss_cont_trg + int(i>=args.switch2contrast)*args.entW*loss_ent_trg
    loss_all.backward()
    optimizer.step()
    
    
    # save the model weights for each save_pred_every steps
    if (i + 1) % args.save_pred_every == 0:
        print('taking snapshot ...')
        torch.save(model.state_dict(), os.path.join(args.snapshot_dir, '%s_' % (args.source) + str(i + 1) + '.pth'))
       
    if (i + 1) % args.print_freq == 0:
        _t['iter time'].toc(average=False)
        print('[it %d][src seg loss %.4f][lr %.4f][%.2fs]' % \
              (i + 1,loss_seg_src.data, optimizer.param_groups[0]['lr'] * 10000,
               _t['iter time'].diff))
    
        sio.savemat(args.tempdata, {'trg_img': trg_img.cpu().numpy()})
        
        loss_train /= args.print_freq
        
        
        
        writer.add_scalar('Training loss', loss_train, i)
        
        
        
        loss_train = 0.0
    
        
        if i + 1 > args.num_steps_stop:
            print('finish training')
            break
        _t['iter time'].tic()



[it 100][src seg loss 4.2514][trgseg loss 2.5642][contrastive 7.1698][lr 2.4978][195.30s]
[it 200][src seg loss 2.2253][trgseg loss 2.0283][contrastive 7.4524][lr 2.4955][182.07s]
[it 300][src seg loss 2.6076][trgseg loss 2.1357][contrastive 7.0374][lr 2.4933][181.58s]
[it 400][src seg loss 2.4821][trgseg loss 2.3677][contrastive 6.8540][lr 2.4910][179.73s]
[it 500][src seg loss 1.4140][trgseg loss 2.0239][contrastive 6.6706][lr 2.4888][182.08s]
[it 600][src seg loss 1.4141][trgseg loss 1.6458][contrastive 7.1323][lr 2.4865][180.42s]
[it 700][src seg loss 1.1480][trgseg loss 1.1847][contrastive 7.1742][lr 2.4843][180.41s]
[it 800][src seg loss 1.5631][trgseg loss 0.9847][contrastive 7.2798][lr 2.4820][181.38s]
[it 900][src seg loss 1.5201][trgseg loss 1.1829][contrastive 6.4742][lr 2.4798][179.78s]
taking snapshot ...
[it 1000][src seg loss 3.4957][trgseg loss 4.7001][contrastive 7.8671][lr 2.4775][180.91s]
[it 1100][src seg loss 0.7406][trgseg loss 0.9231][contrastive 8.1684][lr 2.475