# 

In [None]:
import numpy as np
from utils.timer import Timer
import os
from data import CreateSrcDataLoader
from data import CreateTrgDataLoader
from model import CreateModel
from tensorboardX import SummaryWriter
import torch.backends.cudnn as cudnn
import torch
from torch.autograd import Variable
from utils import FDA_source_to_target
import scipy.io as sio
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [2]:
# trains only source images, test on target data
# no pseudo labeling
# only cross entropy loss with source data
# source images are in source domain
IMG_MEAN = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
IMG_MEAN = torch.reshape(torch.from_numpy(IMG_MEAN), (1, 3, 1, 1))

CS_weights = np.array((1.0, 1.0, 1.0, 1.0, 100.0, 1.0, 100.0, 100.0, 1.0, 1.0, 1.0,
                       100.0, 100.0, 1.0, 1.0, 100.0, 500.0, 500.0, 1000.0), dtype=np.float32)
weight_normalizer = CS_weights.max()/2
CS_weights = torch.from_numpy(CS_weights/weight_normalizer)
# weighting fence, light, sign, person, rider, bus, train, motorcycle, bicycle

In [3]:
import argparse
import os.path as osp

class TrainOptions():
    def initialize(self):
        parser = argparse.ArgumentParser( description="training script for FDA" )
        parser.add_argument("--model", type=str, default='DeepLab', help="available options : DeepLab and VGG")
        parser.add_argument("--LB", type=float, default=0.1, help="beta for FDA")
        parser.add_argument("--GPU", type=str, default='0', help="which GPU to use")
        parser.add_argument("--entW", type=float, default=0.005, help="weight for entropy")
        parser.add_argument("--ita", type=float, default=2.0, help="ita for robust entropy")
        parser.add_argument("--temperature", type=float, default=0.07, help="temperature for contrastive loss")
        parser.add_argument("--switch2entropy", type=int, default=0, help="switch to entropy after this many steps")
        parser.add_argument("--switch2contrast", type=int, default=0, help="switch to contrastive learning  after this many steps")
        parser.add_argument('--threshold', default=0.95, type=float, help='pseudo label threshold')
        parser.add_argument("--source", type=str, default='gta5', help="source dataset : gta5 or synthia")
        parser.add_argument("--target", type=str, default='cityscapes', help="target dataset : cityscapes")
        parser.add_argument("--snapshot-dir", type=str, default='../checkpoints/UDA_ENet_val/new_arch_cont', help="Where to save snapshots of the model.")
        parser.add_argument("--data-dir", type=str, default='../data/GTA5', help="Path to the directory containing the source dataset.")
        parser.add_argument("--data-list", type=str, default='./dataset/gta5_list/train_all.txt', help="Path to the listing of images in the source dataset.")
        parser.add_argument("--data-dir-target", type=str, default='../data/cityscapes', help="Path to the directory containing the target dataset.")
        parser.add_argument("--data-list-target", type=str, default='./dataset/cityscapes_list/train.txt', help="list of images in the target dataset.")
        parser.add_argument("--set", type=str, default='train', help="choose adaptation set.")
        parser.add_argument("--label-folder", type=str, default=None, help="Path to the directory containing the pseudo labels.")

        parser.add_argument("--batch-size", type=int, default=1, help="input batch size.")
        parser.add_argument("--num-steps", type=int, default=100000, help="Number of training steps.")
        parser.add_argument("--num-steps-stop", type=int, default=100000, help="Number of training steps for early stopping.")
        parser.add_argument("--num-workers", type=int, default=4, help="number of threads.")
        parser.add_argument("--learning-rate", type=float, default=2.5e-4, help="initial learning rate for the segmentation network.")
        parser.add_argument("--momentum", type=float, default=0.9, help="Momentum component of the optimiser.")
        parser.add_argument("--weight-decay", type=float, default=0.0005, help="Regularisation parameter for L2-loss.")
        parser.add_argument("--power", type=float, default=0.9, help="Decay parameter to compute the learning rate (only for deeplab).")

        parser.add_argument("--num-classes", type=int, default=19, help="Number of classes for cityscapes.")
        #parser.add_argument("--init-weights", type=str, default=None, help="initial model.")
        parser.add_argument("--init-weights", type=str, default='../checkpoints/DeepLab_init.pth', help="initial model.")
        parser.add_argument("--restore-from", type=str, default=None, help="Where restore model parameters from.")
        # parser.add_argument("--restore-from", type=str, default='./latest', help="Where restore model parameters from.")
        parser.add_argument("--save-pred-every", type=int, default=1000, help="Save summaries and checkpoint every often.")
        parser.add_argument("--print-freq", type=int, default=100, help="print loss and time fequency.")
        parser.add_argument("--matname", type=str, default='loss_log.mat', help="mat name to save loss")
        parser.add_argument("--tempdata", type=str, default='tempdata.mat', help="mat name to save data")


        return parser.parse_args(args=[])
    
    def print_options(self, args):
        message = ''
        message += '----------------- Options ---------------\n'
        for k, v in sorted(vars(args).items()):
            comment = ''
            message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
        message += '----------------- End -------------------'
        print(message)
    
        # save to the disk
        file_name = osp.join(args.snapshot_dir, 'opt.txt')
        with open(file_name, 'wt') as args_file:
            args_file.write(message)
            args_file.write('\n')

In [4]:
opt = TrainOptions()
args = opt.initialize()
os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
_t = {'iter time': Timer()}

model_name = args.source + '_to_' + args.target
if not os.path.exists(args.snapshot_dir):
    os.makedirs(args.snapshot_dir)
    os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
opt.print_options(args)

sourceloader, targetloader = CreateSrcDataLoader(args), CreateTrgDataLoader(args)
sourceloader_iter, targetloader_iter = iter(sourceloader), iter(targetloader)
print("model is created")
model, optimizer = CreateModel(args)

----------------- Options ---------------
                      GPU: 0                             
                       LB: 0.1                           
               batch_size: 1                             
                 data_dir: ../data/GTA5                  
          data_dir_target: ../data/cityscapes            
                data_list: ./dataset/gta5_list/train_all.txt
         data_list_target: ./dataset/cityscapes_list/train.txt
                     entW: 0.005                         
             init_weights: ../checkpoints/DeepLab_init.pth
                      ita: 2.0                           
             label_folder: None                          
            learning_rate: 0.00025                       
                  matname: loss_log.mat                  
                    model: DeepLab                       
                 momentum: 0.9                           
              num_classes: 19                            
                num_s

  super(SGD, self).__init__(params, defaults)


In [5]:
def accuracy(predict, target):
    n, c, h, w = predict.size()
    target_mask = (target >= 0) * (target != 255)
    target = target[target_mask]
    predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
    predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
    _, labels = torch.max(predict, dim=1)
    return 100*((labels==target).sum()/len(labels))

In [6]:
start_iter = 0
#if args.restore_from is not None:
#    start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])

cudnn.enabled = True
cudnn.benchmark = True

model.train()
model.cuda()

# losses to log
loss_all = 0
loss_train = 0
loss_val = 0.0
loss_ent = 0

best_loss_trg = float('inf')
loss_val_list = []

mean_img = torch.zeros(1, 1)
class_weights = Variable(CS_weights).cuda()
_t['iter time'].tic()

In [7]:
for i in range(start_iter, 30000):
    model.adjust_learning_rate(args, optimizer, i)  # adjust learning rate
    optimizer.zero_grad()  # zero grad
    
    src_img, src_lbl, _, _ = sourceloader_iter.next()  # new batch source
    trg_img, trg_lbl, _, _ = targetloader_iter.next()  # new batch target
    
    if mean_img.shape[-1] < 2:
        B, C, H, W = trg_img.shape
        mean_img = IMG_MEAN.repeat(B, 1, H, W)



    # 1. perform spectral transfer on source to get image in target style
    trg_in_src = FDA_source_to_target(trg_img, src_img, L=args.LB)  # src_lbl
    trg_in_src = trg_in_src - mean_img  # trg_img 
    
    # 4. forward pass source image in trg #####
    src_img = src_img - mean_img
    src_img, src_lbl = Variable(src_img).cuda(), Variable(src_lbl.long()).cuda()  # to gpu
    src_seg_score = model(src_img, lbl=src_lbl, weight=class_weights, ita=args.ita)  # forward pass
    loss_seg_src = model.loss_seg  # get loss

    trg_in_src, trg_lbl = Variable(trg_in_src).cuda(), Variable(trg_lbl.long()).cuda()  # to gpu
    trg_seg_score = model(trg_in_src, lbl=trg_lbl, weight=class_weights, ita=args.ita)  # forward pass
    loss_seg_trg = model.loss_seg  # get loss
    loss_ent_trg = model.loss_ent
    
   
    loss_all = loss_seg_src + args.entW*loss_ent_trg
    
    loss_all.backward()
    optimizer.step()
    
    running_acc = accuracy(trg_seg_score, trg_lbl)
    
    acc_score = running_acc.detach().cpu().numpy()
    loss_train += loss_seg_src.detach().cpu().numpy()
    loss_val += loss_seg_trg.detach().cpu().numpy()
    loss_ent += loss_ent_trg.detach().cpu().numpy()
    
    
    
    if (i + 1) % args.save_pred_every == 0:
        print('taking snapshot ...')
        torch.save(model.state_dict(), os.path.join(args.snapshot_dir, '%s_' % (args.source) + str(i + 1) + '.pth'))
        

    if best_loss_trg > loss_seg_trg:
        best_loss_trg = loss_seg_trg
        torch.save(model.state_dict(), os.path.join(args.snapshot_dir, '%s_' % (args.source) +'best'+ '.pth'))
    
    if (i + 1) % args.print_freq == 0:
        _t['iter time'].toc(average=False)
        print('[it %d][src seg loss %.4f][trgseg loss %.4f][acc %.4f][lr %.4f][%.2fs]' % \
              (i + 1,loss_seg_src.data, loss_seg_trg.data, running_acc.data, optimizer.param_groups[0]['lr'] * 10000,
               _t['iter time'].diff))
    
        sio.savemat(args.tempdata, {'trg_img': trg_img.cpu().numpy()})
        
        loss_train /= args.print_freq
        loss_val /= args.print_freq
        acc_score /= args.print_freq
        loss_val_list.append(loss_val)
        
        
        writer.add_scalar('Training loss', loss_train, i)
        writer.add_scalar('Validation loss', loss_val, i )
        writer.add_scalar('Accuracy', acc_score, i )
        writer.add_scalar('Validation loss', loss_val, i )
        writer.add_scalar('Entropy loss', loss_ent, i)
        
        sio.savemat(args.matname, {'loss_val': loss_val_list})
        loss_train = 0.0
        loss_val = 0.0
        acc_score = 0.0
        loss_ent = 0.0
        if i + 1 > args.num_steps_stop:
            print('finish training')
            break
        _t['iter time'].tic()



[it 100][src seg loss 2.7609][trgseg loss 2.5345][acc 47.8223][lr 2.4978][127.21s]
[it 200][src seg loss 0.8415][trgseg loss 1.8088][acc 66.5073][lr 2.4955][126.28s]
[it 300][src seg loss 0.9198][trgseg loss 1.1002][acc 59.9584][lr 2.4933][123.14s]
[it 400][src seg loss 1.1816][trgseg loss 2.4951][acc 67.5541][lr 2.4910][125.41s]
[it 500][src seg loss 0.9077][trgseg loss 1.8028][acc 57.5408][lr 2.4888][122.84s]
[it 600][src seg loss 0.7244][trgseg loss 2.2087][acc 60.6657][lr 2.4865][123.86s]
[it 700][src seg loss 0.9114][trgseg loss 1.4534][acc 56.5001][lr 2.4843][122.86s]
[it 800][src seg loss 0.8377][trgseg loss 1.2562][acc 52.4885][lr 2.4820][122.83s]
[it 900][src seg loss 0.9193][trgseg loss 1.0868][acc 65.0925][lr 2.4798][122.85s]
taking snapshot ...
[it 1000][src seg loss 0.4552][trgseg loss 1.5768][acc 66.1088][lr 2.4775][123.83s]
[it 1100][src seg loss 1.7773][trgseg loss 0.7306][acc 67.6585][lr 2.4753][123.48s]
[it 1200][src seg loss 0.3796][trgseg loss 1.3697][acc 76.5866][l

[it 9700][src seg loss 0.3529][trgseg loss 0.8885][acc 81.4334][lr 2.2807][122.84s]
[it 9800][src seg loss 0.7443][trgseg loss 0.0978][acc 80.1922][lr 2.2784][122.67s]
[it 9900][src seg loss 0.2575][trgseg loss 0.5643][acc 76.7667][lr 2.2761][122.65s]
taking snapshot ...
[it 10000][src seg loss 0.1139][trgseg loss 0.3253][acc 74.2726][lr 2.2739][122.80s]
[it 10100][src seg loss 0.3095][trgseg loss 0.6775][acc 72.8204][lr 2.2716][122.86s]
[it 10200][src seg loss 0.3408][trgseg loss 0.6964][acc 79.8117][lr 2.2693][122.64s]
[it 10300][src seg loss 0.4356][trgseg loss 0.6911][acc 83.6226][lr 2.2670][122.61s]
[it 10400][src seg loss 0.3878][trgseg loss 0.9020][acc 71.8662][lr 2.2648][122.68s]
[it 10500][src seg loss 0.4215][trgseg loss 1.4908][acc 78.4951][lr 2.2625][122.42s]
[it 10600][src seg loss 0.6641][trgseg loss 0.4048][acc 82.5312][lr 2.2602][122.64s]
[it 10700][src seg loss 0.2244][trgseg loss 0.3392][acc 78.7932][lr 2.2579][122.83s]
[it 10800][src seg loss 0.3767][trgseg loss 0.20

[it 19200][src seg loss 0.3207][trgseg loss 0.4996][acc 76.2289][lr 2.0636][122.43s]
[it 19300][src seg loss 0.2899][trgseg loss 0.1845][acc 85.2444][lr 2.0613][122.84s]
[it 19400][src seg loss 0.4371][trgseg loss 0.4107][acc 72.5040][lr 2.0590][122.83s]
[it 19500][src seg loss 0.6645][trgseg loss 0.2667][acc 86.5839][lr 2.0567][122.85s]
[it 19600][src seg loss 0.1363][trgseg loss 0.2912][acc 88.9785][lr 2.0544][122.83s]
[it 19700][src seg loss 0.2895][trgseg loss 0.5153][acc 81.5186][lr 2.0521][122.84s]
[it 19800][src seg loss 0.1646][trgseg loss 1.2415][acc 80.7972][lr 2.0498][122.87s]
[it 19900][src seg loss 0.7355][trgseg loss 0.4577][acc 81.1294][lr 2.0475][122.83s]
taking snapshot ...
[it 20000][src seg loss 0.0757][trgseg loss 0.2400][acc 80.3912][lr 2.0452][122.84s]
[it 20100][src seg loss 0.2761][trgseg loss 0.5848][acc 85.8323][lr 2.0429][122.66s]
[it 20200][src seg loss 0.2541][trgseg loss 0.3325][acc 87.5779][lr 2.0406][122.83s]
[it 20300][src seg loss 0.2119][trgseg loss 0

[it 28700][src seg loss 0.3885][trgseg loss 0.4102][acc 83.5791][lr 1.8439][122.87s]
[it 28800][src seg loss 0.0734][trgseg loss 0.2416][acc 92.3395][lr 1.8415][122.84s]
[it 28900][src seg loss 0.4059][trgseg loss 0.5330][acc 83.5086][lr 1.8392][122.84s]
taking snapshot ...
[it 29000][src seg loss 0.1285][trgseg loss 0.3940][acc 88.4198][lr 1.8369][123.02s]
[it 29100][src seg loss 0.3061][trgseg loss 0.3541][acc 78.6684][lr 1.8345][122.86s]
[it 29200][src seg loss 0.1832][trgseg loss 0.9233][acc 77.9103][lr 1.8322][122.64s]
[it 29300][src seg loss 0.3142][trgseg loss 0.2886][acc 82.1202][lr 1.8299][122.84s]
[it 29400][src seg loss 0.2176][trgseg loss 0.3204][acc 88.2285][lr 1.8276][122.83s]
[it 29500][src seg loss 0.2153][trgseg loss 0.7373][acc 78.6698][lr 1.8252][122.67s]
[it 29600][src seg loss 0.1974][trgseg loss 0.3008][acc 86.4877][lr 1.8229][122.84s]
[it 29700][src seg loss 0.2523][trgseg loss 0.2510][acc 78.9207][lr 1.8206][122.85s]
[it 29800][src seg loss 0.3709][trgseg loss 0