In [None]:
from google.colab import drive

drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import os.path
import sys
from PIL import Image
import cv2
import glob
import torch.utils.data as data
import imageio as m
import torch.autograd as autograd
from imageio import imread as imread
import imageio

In [None]:
import random
import torchvision.transforms.functional as tf
from PIL import Image, ImageOps

class Compose(object) :
    def __init__(self, augmentations) :
        self.augmentations = augmentations
        self.PIL2Numpy = False

    def __call__(self, img, mask) :
        if isinstance(img, np.ndarray) :
            img = Image.fromarray(img, mode = "RGB")
            mask = Image.fromarray(mask, mode = "L")
            # print('in __call__' , mask)

            self.PIL2Numpy = True

            for a in self.augmentations :
                img, mask = a(img, mask)
                # print('in for : ', mask)

            if self.PIL2Numpy :
                img, mask = np.array(img, np.float64), np.array(mask, dtype = np.uint8)
                # print('after if : ', mask)
            
            return img, mask

class RandomHorizontalFlip(object) :
    def __init__(self, p) :
        self.p = p

    def __call__(self, img, mask) :
        if random.random() < self.p :
            return (img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT))

        return img, mask

class RandomRotate(object) :
    def __init__(self, degree) :
        self.degree = degree

    def __call__(self, img, mask) :
        rotate_degree = random.random()*2*self.degree - self.degree

        return (
            tf.affine(img, translate=(0,0), scale=1.0, angle=rotate_degree, resample=Image.BILINEAR, 
                     fillcolor=(0,0,0), shear=0.0 ,) ,
            tf.affine(mask, translate=(0,0), scale=1.0, angle=rotate_degree, resample=Image.NEAREST, 
                      fillcolor=11, shear=0.0, ),
        )

class Resize(object) :
    def __init__(self, size = (512, 1024)) :
        self.size = size
    
    def __call__(self, img, mask) :
        w, h = img.size
        th, tw = self.size

        if w == tw and h == th :
            return img, mask

        return (img.resize((tw,th),Image.BILINEAR), mask.resize((tw, th), Image.NEAREST))


class RandomCrop(object) :
    def __init__(self, size = (128, 128)) :
        self.size = size

    def __call__(self, img, mask) :
        w, h = img.size
        th, tw = self.size

        if w == tw and h == th :
            return img, mask

        x1 = random.randint(0, w-tw)
        y1 = random.randint(0, h-th)

        return (img.crop((x1, y1, x1+tw, y1+th)), mask.crop((x1, y1, x1+tw, y1+th)))

In [None]:
#from https://github.com/mcordts/cityscapesScripts/blob/878f1d05b1676c669d977a91831ea800482e36c4/cityscapesscripts/helpers/labels.py
from collections import namedtuple
cityscapesLabel = namedtuple( 'Label' , [ 'name'        ,  'id'          , 'trainId'     ,  'category'    ,'categoryId'  , 'hasInstances',  'ignoreInEval',  'color', ] )

cityscapes_labels = [
    #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
    cityscapesLabel(  'unlabeled'            ,  0 ,      19 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    cityscapesLabel(  'ego vehicle'          ,  1 ,      19 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),          #블랙박스가 있는 1인칭 시점의 차량
    cityscapesLabel(  'rectification border' ,  2 ,      19 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    cityscapesLabel(  'out of roi'           ,  3 ,      19 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    cityscapesLabel(  'static'               ,  4 ,      19 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    cityscapesLabel(  'dynamic'              ,  5 ,      19 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    cityscapesLabel(  'ground'               ,  6 ,      19 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    cityscapesLabel(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    cityscapesLabel(  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    cityscapesLabel(  'parking'              ,  9 ,      19 , 'flat'            , 1       , False        , True         , (250,170,160) ),
    cityscapesLabel(  'rail track'           , 10 ,      19 , 'flat'            , 1       , False        , True         , (230,150,140) ),
    cityscapesLabel(  'building'             , 11 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
    cityscapesLabel(  'wall'                 , 12 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    cityscapesLabel(  'fence'                , 13 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    cityscapesLabel(  'guard rail'           , 14 ,      19 , 'construction'    , 2       , False        , True         , (180,165,180) ),
    cityscapesLabel(  'bridge'               , 15 ,      19 , 'construction'    , 2       , False        , True         , (150,100,100) ),
    cityscapesLabel(  'tunnel'               , 16 ,      19 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    cityscapesLabel(  'pole'                 , 17 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),
    cityscapesLabel(  'polegroup'            , 18 ,      19 , 'object'          , 3       , False        , True         , (153,153,153) ),
    cityscapesLabel(  'traffic light'        , 19 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),
    cityscapesLabel(  'traffic sign'         , 20 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),
    cityscapesLabel(  'vegetation'           , 21 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    cityscapesLabel(  'terrain'              , 22 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    cityscapesLabel(  'sky'                  , 23 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    cityscapesLabel(  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    cityscapesLabel(  'rider'                , 25 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    cityscapesLabel(  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    cityscapesLabel(  'truck'                , 27 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    cityscapesLabel(  'bus'                  , 28 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    cityscapesLabel(  'caravan'              , 29 ,      19 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    cityscapesLabel(  'trailer'              , 30 ,      19 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    cityscapesLabel(  'train'                , 31 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    cityscapesLabel(  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    cityscapesLabel(  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    cityscapesLabel(  'license plate'        , -1 ,       19 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
]


In [None]:
synthiaLabel = namedtuple( 'Label' , [ 'name'        ,  'id'          , 'trainId'  ,] )

synhia_labels = [
    #               name                     id    trainId 
    synthiaLabel(  'void'                  ,  0 ,      19 ),
    synthiaLabel(  'sky'                   ,  1 ,       10 ),
    synthiaLabel(  'building'              ,  2 ,        2 ),
    synthiaLabel(  'road'                  ,  3 ,        0 ),
    synthiaLabel(  'sidewalk'              ,  4 ,        1 ),
    synthiaLabel(  'fence'                 ,  5 ,        4 ),
    synthiaLabel(  'vegetation'            ,  6 ,        8 ),
    synthiaLabel(  'pole'                  ,  7 ,        5 ),
    synthiaLabel(  'car'                   ,  8 ,       13 ),
    synthiaLabel(  'traffic sign'          ,  9 ,        7 ),
    synthiaLabel(  'pedestrian'            , 10 ,       11 ),
    synthiaLabel(  'bicycle'               , 11 ,       18 ),
    synthiaLabel(  'motorcycle'            , 12 ,       17 ),
    synthiaLabel(  'parking slot'          , 13 ,      19 ),
    synthiaLabel(  'road work'             , 14 ,      19 ),
    synthiaLabel(  'traffic light'         , 15 ,        6 ),
    synthiaLabel(  'terrain'               , 16 ,        9 ),
    synthiaLabel(  'rider'                 , 17 ,       12 ),
    synthiaLabel(  'truck'                 , 18 ,       14 ),
    synthiaLabel(  'bus'                   , 19 ,       15 ),
    synthiaLabel(  'train'                 , 20 ,       16 ),
    synthiaLabel(  'wall'                  , 21 ,        3 ),
    synthiaLabel(  'lanemarking'           , 22 ,      19 ),
]

In [None]:
CITYSCAPES = 'gdrive/My Drive/Colab Notebooks/dataset/cityscape'

class Cityscape(data.Dataset) :         #1024*2048
    def __init__(self, root, mode = 'train') :
        self.root = root
        self.images = []
        self.targets = []
        self.img_dir = 'train' if mode is 'train' else 'val'
        self.imgs_dir = os.path.join(self.root, 'leftImg8bit', self.img_dir)
        self.targets_dir = os.path.join(self.root, 'gtFine', self.img_dir )

        for city in os.listdir(self.imgs_dir) :
            img_dir = os.path.join(self.imgs_dir, city)
            target_dir = os.path.join(self.targets_dir, city)

            for img in os.listdir(img_dir) :
                self.images.append(os.path.join(img_dir, img))
                target_name = img.split('_leftImg8bit')[0]+'_gtFine_labelIds.png'
                self.targets.append(os.path.join(target_dir, target_name))
        # print(self.images)
        # print(self.targets)

        if mode is 'train' :
            self.augmentation = Compose([ Resize(), RandomRotate(15), RandomCrop((480, 768)), RandomHorizontalFlip(p=0.5)])

        else :
            self.augmentation = Compose([Resize(), RandomCrop((480, 768))])
        
    def __getitem__(self, index) :
        img_path = self.images[index]
        target_path = self.targets[index]

        img = imread(img_path)
        img = np.array(img, dtype=np.uint8)
        
        lbl = imread(target_path)
        lbl = np.array(lbl, dtype=np.int8)
        # print(lbl)
        lbl = self.mapId(lbl)
        # print(lbl)
        if self.augmentation is not None :
            img, lbl = self.augmentation(img, lbl)
        img, lbl = self.transform(img, lbl)
        target = {'segmentation' : lbl, 'classification' : 0}
        return img, target

    def mapId(self, lbl) :
        id2trainId = {label.id : label.trainId for label in cityscapes_labels}
        id2trainId_map = np.vectorize(id2trainId.get)

        label_img = id2trainId_map(lbl)
        label_img = label_img.astype(np.uint8)

        return label_img

    def transform(self, img, lbl) :
        img = img[:, :, ::-1]
        img = img.astype(np.float64)
        img = img.astype(float)/255.0

        img = img.transpose(2,0,1)

        img = torch.from_numpy(img).float()
        lbl = torch.from_numpy(lbl).long()

        return img, lbl
    
    def __len__(self) :
        return len(self.images)

cityscape_trainSet = Cityscape(root = CITYSCAPES, mode = 'train')
cityscape_trainLoader = data.DataLoader( cityscape_trainSet, batch_size = 1, shuffle = True, num_workers=0)

cityscape_testSet = Cityscape(root = CITYSCAPES, mode = 'test')
cityscape_testLoader = data.DataLoader( cityscape_testSet,  batch_size = 1, shuffle = False, num_workers=0)

In [None]:
SYNTHIA = 'gdrive/My Drive/Colab Notebooks/dataset/synthia'
imageio.plugins.freeimage.download()

class Synthia(data.Dataset) :           #760*1280
    def __init__(self, root, mode='train') :
        self.root = root
        self.images = []
        self.targets = []
        self.mode = mode
        self.img_dir = os.path.join(root, mode, 'image')
        self.target_dir = os.path.join(root, mode, 'target')
        for img in os.listdir(self.img_dir) :
            self.images.append(os.path.join(self.img_dir, img))
            self.targets.append(os.path.join(self.target_dir, img))
        # print(self.targets)
        if self.mode is 'train' :
            self.augmentation = Compose([ Resize(), RandomRotate(15), RandomCrop((480, 768)), RandomHorizontalFlip(p=0.5)])
        else :
            self.augmentation = self.augmentation = Compose([Resize(), RandomCrop((480, 768))])

    def __getitem__(self, index) :
        img_path = self.images[index]
        target_path = self.targets[index]

        img = imread(img_path)
        img = np.array(img, dtype=np.uint8)
        
        lbl = imread(target_path, format = 'PNG-FI')
        lbl = np.array(lbl, dtype=np.uint8)[:, :, 0]
        lbl = self.mapId(lbl)
        # print(lbl)

        if self.augmentation is not None :
            img, lbl = self.augmentation(img, lbl)
        img, lbl = self.transform(img, lbl)
        target = {'segmentation' : lbl, 'classification' : 1}
        return img, target

    def mapId(self, lbl) :
        id2trainId = {label.id : label.trainId for label in synhia_labels}
        id2trainId_map = np.vectorize(id2trainId.get)

        label_img = id2trainId_map(lbl)
        label_img = label_img.astype(np.uint8)

        return label_img

    def transform(self, img, lbl) :
        img = img[:, :, ::-1]
        img = img.astype(np.float64)
        img = img.astype(float)/255.0

        img = img.transpose(2,0,1)

        img = torch.from_numpy(img).float()
        lbl = torch.from_numpy(lbl).long()

        return img, lbl
    
    def __len__(self) :
        return len(self.images)

synthia_trainSet = Synthia(root = SYNTHIA, mode = 'train')
synthia_trainLoader = data.DataLoader( synthia_trainSet, batch_size = 1, shuffle = True, num_workers=0)

synthia_testSet = Synthia(root = SYNTHIA, mode = 'test')
synthia_testLoader = data.DataLoader( synthia_testSet,  batch_size = 1, shuffle = False, num_workers=0)

Imageio: 'libfreeimage-3.16.0-linux64.so' was not found on your computer; downloading it now.
Try 1. Download from https://github.com/imageio/imageio-binaries/raw/master/freeimage/libfreeimage-3.16.0-linux64.so (4.6 MB)
Downloading: 8192/4830080 bytes (0.2%)1802240/4830080 bytes (37.3%)4830080/4830080 bytes (100.0%)
  Done
File saved as /root/.imageio/freeimage/libfreeimage-3.16.0-linux64.so.


In [None]:
class AverageMeter(object) :
    """computes and stores the average and curent value"""
    def __init__(self) :
        self.reset()

    def reset(self) :
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1) :
        self.val = val
        self.sum += val*n
        self.count += n
        self.avg = self.sum//self.count

def intersectionAndUnion(output, target, K, ignore_idx = 19) :

    output = output.max(1, keepdim=False)[1]
    output[target == ignore_idx] = ignore_idx

    intersection = output[output == target]
    area_intersection, _ = np.histogram(intersection, bins = np.arange(K+1))
    area_output, _ = np.histogram(output, bins=np.arange(K+1))
    area_target, _ = np.histogram(target, bins=np.arange(K+1))
    area_union = area_output + area_target - area_intersection
    return area_intersection, area_union, area_target

In [None]:
class _ASPP(nn.Module) :
    def __init__(self, in_plane, out_plane, kernel_size, padding, dilation) :
        super(_ASPP, self).__init__()
        self.atrous_conv = nn.Conv2d(in_channels=in_plane, out_channels=out_plane, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
        self.bn = nn.BatchNorm2d(out_plane)
        self.relu = nn.ReLU()

    def forward(self, x) :
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)

class ASPP(nn.Module) :
    def __init__(self) :
        super(ASPP, self).__init__()
        dilation = [1, 6, 12, 18]
        in_plane=512
        out_plane=256

        self.aspp1 = _ASPP(in_plane=in_plane, out_plane=out_plane, kernel_size=1, padding=0, dilation=dilation[0])
        self.aspp2 = _ASPP(in_plane=in_plane, out_plane=out_plane, kernel_size=3, padding=dilation[1], dilation=dilation[1])
        self.aspp3 = _ASPP(in_plane=in_plane, out_plane=out_plane, kernel_size=3, padding=dilation[2], dilation=dilation[2])
        self.aspp4 = _ASPP(in_plane=in_plane, out_plane=out_plane, kernel_size=3, padding=dilation[3], dilation=dilation[3])
        self.gap = nn.Sequential( nn.AvgPool2d(kernel_size=4),
                                 nn.BatchNorm2d(in_plane),
                                 nn.ReLU() ,
                                 nn.Conv2d(in_channels=in_plane, out_channels=out_plane, kernel_size=1),
                                 nn.BatchNorm2d(out_plane),
                                 nn.ReLU() ,
                                 nn.UpsamplingBilinear2d( scale_factor=4) )
        
        self.conv = nn.Sequential ( nn.Conv2d(in_channels=1280, out_channels=256, kernel_size=1),
                                   nn.BatchNorm2d(out_plane),
                                   nn.ReLU())
        
    def forward(self, x) :
        # print('aspp')
        x1 = self.aspp1(x)
        # print(x1.size())
        x2 = self.aspp2(x)
        # print(x2.size())
        x3 = self.aspp3(x)
        # print(x3.size())
        x4 = self.aspp4(x)
        # print(x4.size())
        x5 = self.gap(x)
        # print(x5.size())
        # print('end')
        x = torch.cat( (x1, x2, x3, x4, x5), dim=1)
        x = self.conv(x)

        return x

In [None]:
origin = models.resnet18(pretrained=False, progress=True)

class EncoderModule(nn.Module) :
    def __init__(self) :
        super(EncoderModule, self).__init__()
        child_list = list(origin.children())
        self.layer0 = nn.Sequential(
            *child_list[:4]
        )
        self.layer1 = nn.Sequential(
            *child_list[4]
        )
        self.layer2 = nn.Sequential(
            *child_list[5]
        )

        self.layer3 = self._custom_layer(child_list[6], 128, 256, 3, [1,1], [1,2], [1,2])

        self.layer4 = self._custom_layer(child_list[7], 256, 512, 3, [1,1], [2,4], [2,4])
        self.aspp = ASPP()

    def _custom_layer(self, origin_layer, inplanes, outplanes, kernel_size, strides, dilation, padding, downsampling = None) :
        origin_layer[0].conv1 = nn.Conv2d(in_channels=inplanes, out_channels=outplanes, kernel_size=kernel_size, stride=strides[0],
                                    padding=padding[0], dilation=dilation[0], bias=False)
        origin_layer[0].conv2 = nn.Conv2d(in_channels=outplanes, out_channels=outplanes, kernel_size=kernel_size, stride=strides[0],
                                    padding=padding[0], dilation=dilation[0], bias=False)
        origin_layer[0].downsample = nn.Sequential(
            nn.Conv2d(in_channels=inplanes, out_channels=outplanes, kernel_size=1, bias=False),
            nn.BatchNorm2d(num_features=outplanes)
        )
        origin_layer[1].conv1 = nn.Conv2d(in_channels=outplanes, out_channels=outplanes, kernel_size=kernel_size, stride=strides[1],
                                    padding=padding[1], dilation=dilation[1], bias=False)
        origin_layer[1].conv2 = nn.Conv2d(in_channels=outplanes, out_channels=outplanes, kernel_size=kernel_size, stride=strides[1],
                                    padding=padding[1], dilation=dilation[1], bias=False)
        
        return nn.Sequential(*origin_layer)

    def forward(self, x) :
        # print(x.size())
        x = self.layer0(x)
        # print(x.size())
        x = self.layer1(x)
        identity = x
        # print(x.size())
        x = self.layer2(x)
        # print(x.size())
        x = self.layer3(x)
        # print(x.size())
        x = self.layer4(x)
        # print(x.size())
        x = self.aspp(x)
        # print(x.size())
        
        return x, identity

In [None]:
class DomainClassifier(nn.Module) :
    def __init__(self) :
        super(DomainClassifier, self).__init__()
        self.conv = nn.Conv2d(in_channels=256, out_channels = 64, kernel_size=1)
        self.fc = nn.Linear(64*60*96, 2)

    def forward(self, x) :
        x = self.conv(x)
        # print(x.size())
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

In [None]:
class DecoderModule(nn.Module) :
    def __init__(self) :
        super(DecoderModule, self).__init__()
        self.upsampling = nn.UpsamplingBilinear2d(scale_factor=2)
        
        self.beforeConcat = nn.Sequential( nn.Conv2d(in_channels=64, out_channels=48, kernel_size=1),
                                       nn.BatchNorm2d(48),
                                       nn.ReLU())
                                       
        self.decoder = nn.Sequential( nn.Conv2d(in_channels=304, out_channels=256, kernel_size=3, padding=1),
                                     nn.BatchNorm2d(256),
                                     nn.ReLU(),
                                     nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
                                     nn.BatchNorm2d(256),
                                     nn.ReLU(),
                                     nn.Conv2d(in_channels=256, out_channels=19, kernel_size=1),
                                     nn.UpsamplingBilinear2d(scale_factor=4))

    def forward(self, x, _x) :
        _x = self.beforeConcat(_x)
        x = self.upsampling(x)
        x = torch.cat( (_x, x), dim=1)
        x = self.decoder(x)
        return x

In [None]:
def FDA(src_img, tgt_img, beta=0.001) :
    src_img, tgt_img = src_img.numpy(), tgt_img.numpy()
    
    f_src, f_tgt = np.fft.fft2(src_img, axes=(1,2)), np.fft.fft2(tgt_img, axes=(1,2))
    
    src_amp, tgt_amp = np.abs(f_src), np.abs(f_tgt)
    src_phase, tgt_phase = np.angle(f_src), np.angle(f_tgt)

    src_amp /= np.max(src_amp)
    tgt_amp /= np.max(tgt_amp)

    src_amp, tgt_amp = np.fft.fftshift(src_amp, axes=(1,2)), np.fft.fftshift(tgt_amp, axes=(1,2))
    
    _, h, w = src_amp.shape
    ch, cw = np.floor(h/2.0).astype(int), np.floor(w/2.0).astype(int)
    
    Mbeta = np.zeros(src_amp.shape, dtype=np.int64)
    Mbeta[:, ch-int(beta*ch):ch+int(beta*ch), cw-int(beta*cw):cw+int(beta*cw)] = 1
    
    fda_amp = Mbeta*tgt_amp + (1-Mbeta)*(src_amp)
    fda_phase = src_phase
    
    fda_amp = np.fft.ifftshift(fda_amp, axes=(-2,-1))
    
    res = fda_amp*np.exp(1j*fda_phase)
    res = np.fft.ifft2(res, axes=(-2,-1)).real

    res_img = np.clip(res/np.max(res), 0, 1)
        
    return torch.from_numpy(res_img)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
# def trainDA(epoch) :
#     encoder = EncoderModule()
#     domain_cls = DomainClassifier()
#     decoder = DecoderModule()

#     optimizer_encoder = optim.Adam(encoder.parameters(), lr = 0.00001)

#     optimizer_seg = optim.SGD(list(encoder.parameters())+list(decoder.parameters()), lr = 0.1, weight_decay = 0.0001, momentum = 0.9)
#     criterion_seg = nn.CrossEntropyLoss(ignore_index=19, reduction='mean')

#     optimizer_cls = optim.Adam(domain_cls.parameters(), lr = 0.00001)

#     intersection_meter = AverageMeter()
#     union_meter = AverageMeter()
#     target_meter = AverageMeter()

#     domain_cls.train()
#     encoder.train()
#     decoder.train()

#     correct = 0

#     synthia_trainLoader_iter = iter(synthia_trainLoader)
#     cityscape_trainLoader_iter = iter(cityscape_trainLoader)

#     conf = 0.
#     cls = 0.
#     for i in range(75300, epoch):

#         if i > 0 :
#             encoder.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_encoder_fda.pth"))
#             domain_cls.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_domain_cls_fda.pth"))
#             # decoder.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_decoder.pth"))

#         try :
#             src_data, src_target = next(synthia_trainLoader_iter)
#         except StopIteration :
#             synthia_trainLoader_iter = iter(synthia_trainLoader)
#             src_data, src_target = next(synthia_trainLoader_iter)

#         try :
#             tgt_data, tgt_target = next(cityscape_trainLoader_iter)
#         except StopIteration :
#             cityscape_trainLoader_iter = iter(cityscape_trainLoader)
#             tgt_data, tgt_target = next(cityscape_trainLoader_iter)

#         src_cls_target = src_target['classification']
#         src_seg_target = src_target['segmentation']
#         tgt_cls_target = tgt_target['classification']

#         src_data = FDA(src_data[0], tgt_data[0], 0.09)
#         src_data = src_data.unsqueeze(0).float()

#         if torch.cuda.is_available() :
#             criterion_seg = criterion_seg.cuda()
#             encoder = encoder.cuda()
#             domain_cls = domain_cls.cuda()
#             decoder = decoder.cuda()

#         src_data, src_cls_target, src_seg_target = src_data.to(device), src_cls_target.to(device), src_seg_target.to(device)
#         tgt_data, tgt_cls_target = tgt_data.to(device), tgt_cls_target.to(device)

#         optimizer_cls.zero_grad()
#         optimizer_encoder.zero_grad()
#         optimizer_seg.zero_grad()

#         # segmentation : encoder & decoder

#         # src_feature, src_l1 = encoder(src_data)
#         # tgt_feature, _ = encoder(tgt_data)

#         # src_seg = decoder(src_feature, src_l1)

#         # src_seg_loss = criterion_seg(src_seg, src_seg_target)
#         # src_seg_loss.backward(retain_graph=True)
#         # optimizer_seg.step()

#         # intersection, union, segmentation_target = intersectionAndUnion(output=src_seg.cpu(), target=src_seg_target.cpu(), K=19, ignore_idx=19)
#         # intersection_meter.update(intersection), union_meter.update(union), target_meter.update(segmentation_target)
#         # accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)

#         # classification_loss : fc only
#         src_feature, _ = encoder(src_data)
#         tgt_feature, _ = encoder(tgt_data)
#         src_cls = domain_cls(src_feature.detach())
#         tgt_cls = domain_cls(tgt_feature.detach())
#         src_cls_loss = F.cross_entropy(src_cls, src_cls_target)
#         tgt_cls_loss = F.cross_entropy(tgt_cls, tgt_cls_target)
#         cls_loss = (src_cls_loss + tgt_cls_loss)/2.0
#         cls_loss.backward()
#         optimizer_cls.step()

#         # confusion loss : encoder

#         src_feature, _ = encoder(src_data)
#         tgt_feature, _ = encoder(tgt_data)

#         output_src_domain_cls = domain_cls(src_feature.detach())
#         output_src_domain_cls = F.softmax(output_src_domain_cls, dim=1)

#         output_tgt_domain_cls = domain_cls(tgt_feature.detach())
#         output_tgt_domain_cls = F.softmax(output_tgt_domain_cls, dim = 1)

#         uni_distribution = torch.FloatTensor(output_src_domain_cls.size()).uniform_(0, 1)
#         if torch.cuda.is_available() :
#             uni_distribution = uni_distribution.cuda()
#         conf_src_loss = -0.5 * ( torch.sum(uni_distribution * torch.log(output_src_domain_cls)))/float(output_src_domain_cls.size(0))
#         conf_tgt_loss = -0.5 * (torch.sum(uni_distribution * torch.log(output_tgt_domain_cls))) / float(output_tgt_domain_cls.size(0))
#         conf_loss = 0.5 * (conf_src_loss + conf_tgt_loss)
#         conf_loss.backward()
#         optimizer_encoder.step()

#         ####correct 값 추가해야함
#         conf += conf_loss.data
#         cls += cls_loss.data

#         if i % 100 == 0 :
#             # print('Train Epoch : {} [{}/{} ({:.0f}%)]\tConfusion Loss : {:.6f}\tClassification Loss : {:.6f}\tAccuracy : {:.6f}'.format(
#             #         epoch, batch_idx * len(src_data), len(synthia_trainLoader.dataset),
#             #         100.*batch_idx / len(synthia_trainLoader), conf_loss.data, cls_loss.data, accuracy))

#             print('Train DA Epoch[{}/{} ({:.0f}%)]\tConfusion Loss : {:.6f}\tClassification Loss : {:.6f}'.format(
#                     i, epoch, 100.*i/epoch, conf/100., cls/100.))
#             conf = 0.
#             cls = 0.
#         # iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
#         # accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
#         # mIOU = np.mean(iou_class)
#         # mAcc = np.mean(accuracy_class)
#         # allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

#         # print('Train Epoch : {} mIOU : {:.6f} | mACC : {:.6f}% | ACC : {:.6f}%' .format(epoch, mIOU, mAcc, allAcc))

#         torch.save(encoder.cpu().state_dict(), "gdrive/My Drive/Colab Notebooks/uda_encoder_fda.pth")
#         torch.save(domain_cls.cpu().state_dict(), "gdrive/My Drive/Colab Notebooks/uda_domain_cls_fda.pth")
#         # torch.save(decoder.cpu().state_dict(), "gdrive/My Drive/Colab Notebooks/uda_decoder.pth")

# trainDA(103400)

# ############################LOSS 저장#############################################



In [None]:
def trainSEG(start, max_epoch) :
    encoder = EncoderModule()
    decoder = DecoderModule()

    srcData_iter = iter(synthia_trainLoader)
    tgtData_iter = iter(cityscape_trainLoader)

    seg_optim = optim.SGD(list(encoder.parameters())+list(decoder.parameters()), lr=0.1, weight_decay=0.001, momentum=0.9)
    criterion_seg = nn.CrossEntropyLoss(ignore_index=19, reduction='mean')
    
    encoder.train()
    decoder.train()
    
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()
    
    
    # encoder.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_encoder.pth"))
    
    for epoch in range(start, max_epoch) :
        
        if epoch>0 :
            encoder.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_encoder_fda_seg.pth"))
            decoder.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_decoder_fda_seg.pth"))
        
        if torch.cuda.is_available() :
            encoder = encoder.to(device)
            decoder = decoder.to(device)

        seg_optim.zero_grad()
        
        #get data and label
        try :
            src_data, src_label = next(srcData_iter)
        except StopIteration :
            srcData_iter = iter(synthia_trainLoader)
            src_data, src_label = next(srcData_iter)

        try :
            tgt_data, tgt_label = next(tgtData_iter)
        except StopIteration :
            tgtData_iter = iter(cityscape_trainLoader)
            tgt_data, tgt_label = next(tgtData_iter)
            
        src_seg_label = src_label['segmentation']
        
        src_data = FDA(src_data[0], tgt_data[0], 0.09)
        src_data = src_data.unsqueeze(0).float()

        #to GPU
        if torch.cuda.is_available() :
            src_data, src_seg_label = src_data.cuda(), src_seg_label.cuda()

        #train segmentation
        src_feature, shortcut = encoder(src_data)
        result = decoder(src_feature, shortcut)
        
        loss_seg = criterion_seg(result, src_seg_label)
        loss_seg.backward(retain_graph=True)
        seg_optim.step()
        
        intersection, union, segmentation_target = intersectionAndUnion(output=result.cpu(), target=src_seg_label.cpu(), K=19, ignore_idx=19)
        intersection_meter.update(intersection), union_meter.update(union), target_meter.update(segmentation_target)
        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        
        #print and reset iou score
        if epoch % 100 == 0 :
            iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
            accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
            mIOU = np.mean(iou_class)
            mAcc = np.mean(accuracy_class)
            allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
            
            print('Train Epoch : {} mIOU : {:.6f} | mACC : {:.6f}% | ACC : {:.6f}%' .format(epoch, mIOU, mAcc, allAcc))
            
            intersection_meter.reset()
            union_meter.reset()
            target_meter.reset()

        torch.save(encoder.cpu().state_dict(), "gdrive/My Drive/Colab Notebooks/uda_encoder_fda_seg.pth")
        torch.save(decoder.cpu().state_dict(), "gdrive/My Drive/Colab Notebooks/uda_decoder_fda_seg.pth")        

trainSEG(128100, 150000)

Train Epoch : 128100 mIOU : 0.202136 | mACC : 0.286952% | ACC : 0.620633%
Train Epoch : 128200 mIOU : 0.215855 | mACC : 0.270555% | ACC : 0.643154%
Train Epoch : 128300 mIOU : 0.199514 | mACC : 0.252171% | ACC : 0.621723%
Train Epoch : 128400 mIOU : 0.210214 | mACC : 0.263120% | ACC : 0.643615%
Train Epoch : 128500 mIOU : 0.200234 | mACC : 0.253148% | ACC : 0.611460%
Train Epoch : 128600 mIOU : 0.201650 | mACC : 0.256468% | ACC : 0.613471%
Train Epoch : 128700 mIOU : 0.203854 | mACC : 0.258130% | ACC : 0.615303%
Train Epoch : 128800 mIOU : 0.196764 | mACC : 0.250727% | ACC : 0.611858%
Train Epoch : 128900 mIOU : 0.202617 | mACC : 0.255752% | ACC : 0.627975%
Train Epoch : 129000 mIOU : 0.204788 | mACC : 0.258164% | ACC : 0.629419%
Train Epoch : 129100 mIOU : 0.199570 | mACC : 0.254666% | ACC : 0.598335%
Train Epoch : 129200 mIOU : 0.209159 | mACC : 0.264804% | ACC : 0.634102%
Train Epoch : 129300 mIOU : 0.213801 | mACC : 0.266792% | ACC : 0.639733%
Train Epoch : 129400 mIOU : 0.214962 |

In [None]:
srcData_iter = iter(synthia_testLoader)

encoder = EncoderModule()
domain_cls = DomainClassifier()
decoder = DecoderModule()

encoder.eval()
domain_cls.eval()
decoder.eval()

correct = torch.FloatTensor([0])
wrong = torch.FloatTensor([0])
total = torch.FloatTensor([0])

for i in range(1343) :        #1343

    if torch.cuda.is_available() :
        encoder = encoder.cuda()
        domain_cls = domain_cls.cuda()
        decoder = decoder.cuda()

    encoder.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_encoder_fda_once.pth"))
    domain_cls.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_domain_cls_once.pth"))

    src_data, src_label = next(srcData_iter)
    src_cls_label = src_label['classification']

    if torch.cuda.is_available() :
        src_data, src_cls_label = src_data.cuda(), src_cls_label.cuda()

    src_feature, _ = encoder(src_data)
    output = domain_cls(src_feature.detach())

    values,idx = output.max(dim=1)
    correct += torch.sum(src_cls_label==idx).float().cpu().data
    wrong += torch.sum(src_cls_label!=idx).float().cpu().data
#     print('values : ', values.data, 'output : ', idx.data, 'idx : ', src_cls_label.data)
    
    total += src_cls_label.cpu().size(0)
    print(i, '/', str(1343), 'correct : ', correct.data, 'wrong : ', wrong.data, 'total : ', total.data)
    
print("Test Data Accuracy: {}%".format(100*(correct/total).numpy()))

0 / 1343 correct :  tensor([1.]) wrong :  tensor([0.]) total :  tensor([1.])
1 / 1343 correct :  tensor([2.]) wrong :  tensor([0.]) total :  tensor([2.])
2 / 1343 correct :  tensor([3.]) wrong :  tensor([0.]) total :  tensor([3.])
3 / 1343 correct :  tensor([4.]) wrong :  tensor([0.]) total :  tensor([4.])
4 / 1343 correct :  tensor([5.]) wrong :  tensor([0.]) total :  tensor([5.])
5 / 1343 correct :  tensor([6.]) wrong :  tensor([0.]) total :  tensor([6.])
6 / 1343 correct :  tensor([7.]) wrong :  tensor([0.]) total :  tensor([7.])
7 / 1343 correct :  tensor([8.]) wrong :  tensor([0.]) total :  tensor([8.])
8 / 1343 correct :  tensor([9.]) wrong :  tensor([0.]) total :  tensor([9.])
9 / 1343 correct :  tensor([10.]) wrong :  tensor([0.]) total :  tensor([10.])
10 / 1343 correct :  tensor([11.]) wrong :  tensor([0.]) total :  tensor([11.])
11 / 1343 correct :  tensor([12.]) wrong :  tensor([0.]) total :  tensor([12.])
12 / 1343 correct :  tensor([13.]) wrong :  tensor([0.]) total :  te

In [None]:
tgtData_iter = iter(cityscape_testLoader)

encoder = EncoderModule()
domain_cls = DomainClassifier()
decoder = DecoderModule()

encoder.eval()
domain_cls.eval()
decoder.eval()

correct = torch.FloatTensor([0])
wrong = torch.FloatTensor([0])
total = torch.FloatTensor([0])

for i in range(500) :        #500

    if torch.cuda.is_available() :
        encoder = encoder.cuda()
        domain_cls = domain_cls.cuda()
        decoder = decoder.cuda()

    encoder.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_encoder_fda_once.pth"))
    domain_cls.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_domain_cls_once.pth"))

    tgt_data, tgt_label = next(tgtData_iter)
    tgt_cls_label = tgt_label['classification']

    if torch.cuda.is_available() :
        tgt_data, tgt_cls_label = tgt_data.cuda(), tgt_cls_label.cuda()

    tgt_feature, _ = encoder(tgt_data)
    output = domain_cls(tgt_feature.detach())

    values,idx = output.max(dim=1)
    correct += torch.sum(tgt_cls_label==idx).float().cpu().data
    wrong += torch.sum(tgt_cls_label!=idx).float().cpu().data
#     print('values : ', values.data, 'output : ', idx.data, 'idx : ', tgt_cls_label.data)
    
    total += tgt_cls_label.cpu().size(0)
    print(i, '/', str(500), 'correct : ', correct.data, 'wrong : ', wrong.data, 'total : ', total.data)
    
print("Test Data Accuracy: {}%".format(100*(correct/total).numpy()))

0 / 500 correct :  tensor([1.]) wrong :  tensor([0.]) total :  tensor([1.])
1 / 500 correct :  tensor([2.]) wrong :  tensor([0.]) total :  tensor([2.])
2 / 500 correct :  tensor([2.]) wrong :  tensor([1.]) total :  tensor([3.])
3 / 500 correct :  tensor([3.]) wrong :  tensor([1.]) total :  tensor([4.])
4 / 500 correct :  tensor([4.]) wrong :  tensor([1.]) total :  tensor([5.])
5 / 500 correct :  tensor([5.]) wrong :  tensor([1.]) total :  tensor([6.])
6 / 500 correct :  tensor([6.]) wrong :  tensor([1.]) total :  tensor([7.])
7 / 500 correct :  tensor([7.]) wrong :  tensor([1.]) total :  tensor([8.])
8 / 500 correct :  tensor([8.]) wrong :  tensor([1.]) total :  tensor([9.])
9 / 500 correct :  tensor([9.]) wrong :  tensor([1.]) total :  tensor([10.])
10 / 500 correct :  tensor([10.]) wrong :  tensor([1.]) total :  tensor([11.])
11 / 500 correct :  tensor([11.]) wrong :  tensor([1.]) total :  tensor([12.])
12 / 500 correct :  tensor([12.]) wrong :  tensor([1.]) total :  tensor([13.])
13

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=19, reduction='mean')

def test_synthia() :
    encoder = EncoderModule()
    decoder = DecoderModule()

    encoder.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_encoder_fda_once.pth"))
    decoder.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_decoder_fda_once.pth"))

    encoder.eval()
    decoder.eval()

    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    if torch.cuda.is_available() :
        encoder = encoder.cuda()
        decoder = decoder.cuda()

    src_seg_loss = 0

    for batch_idx, (src_data, src_target) in enumerate(synthia_testLoader) :
        print(batch_idx)
        src_seg_target = src_target['segmentation']

        src_data, src_seg_target = src_data.to(device),  src_seg_target.to(device)

        # segmentation : encoder & decoder
        feature, sc = encoder(src_data)
        result = decoder(feature, sc)

        src_seg_loss += criterion(result, src_seg_target).item()

        intersection, union, segmentation_target = intersectionAndUnion(output=result.cpu(), target=src_seg_target.cpu(), K=19, ignore_idx=19)
        intersection_meter.update(intersection), union_meter.update(union), target_meter.update(segmentation_target)
        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)

    src_seg_loss /= len(synthia_testLoader.dataset)
         
    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIOU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

    print('Test : mIOU {:.6f} | mACC {:.6f}% | ACC {:.6f}% ' .format( mIOU, mAcc, allAcc))

test_synthia()

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=19, reduction='mean')

def test_cityscape() :
    encoder = EncoderModule()
    decoder = DecoderModule()

    encoder.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_encoder_fda_once.pth"))
    decoder.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_decoder_fda_once.pth"))

    encoder.eval()
    decoder.eval()

    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    if torch.cuda.is_available() :
        encoder = encoder.cuda()
        decoder = decoder.cuda()

    seg_loss = 0

    for batch_idx, (tgt_data, tgt_target) in enumerate(cityscape_testLoader) :
        print(batch_idx)
        tgt_seg_target = tgt_target['segmentation']

        tgt_data, tgt_seg_target = tgt_data.to(device),  tgt_seg_target.to(device)

        # segmentation : encoder & decoder
        feature, sc = encoder(tgt_data)
        result = decoder(feature, sc)

        seg_loss += criterion(result, tgt_seg_target).item()

        intersection, union, segmentation_target = intersectionAndUnion(output=result.cpu(), target=tgt_seg_target.cpu(), K=19, ignore_idx=19)
        intersection_meter.update(intersection), union_meter.update(union), target_meter.update(segmentation_target)
        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)

    seg_loss /= len(cityscape_testLoader.dataset)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIOU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

    print('Test : mIOU {:.6f} | mACC {:.6f}% | ACC {:.6f}% ' .format( mIOU, mAcc, allAcc))

test_cityscape()

In [None]:
def trainOnce(start, max_epoch) :
    
    srcData_iter = iter(synthia_trainLoader)
    tgtData_iter = iter(cityscape_trainLoader)
    
    encoder = EncoderModule()
    decoder = DecoderModule()
    domain_cls = DomainClassifier()

    conf_optim = optim.Adam(encoder.parameters(), lr=0.00001)
    seg_optim = optim.SGD(list(encoder.parameters())+list(decoder.parameters()), lr=0.1, weight_decay=0.001, momentum=0.9)
    criterion_seg = nn.CrossEntropyLoss(ignore_index=19, reduction='mean')
    cls_optim = optim.Adam(domain_cls.parameters(), lr=0.00001)
    
    encoder.train()
    decoder.train()
    domain_cls.train()
    
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()
    
    # if torch.cuda.is_available() :
    #     encoder = encoder.cuda()
    #     decoder = decoder.cuda()
    #     domain_cls = domain_cls.cuda()
    
    conf = 0.
    cls = 0.
        
    for epoch in range(start, max_epoch) :
        
        if epoch>0 :
            encoder.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_encoder_fda_once_2.pth"))
            decoder.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_decoder_fda_once_2.pth"))
            domain_cls.load_state_dict(torch.load("gdrive/My Drive/Colab Notebooks/uda_domain_cls_once_2.pth"))

        if torch.cuda.is_available() :
            encoder = encoder.cuda()
            decoder = decoder.cuda()
            domain_cls = domain_cls.cuda()

        seg_optim.zero_grad()
        conf_optim.zero_grad()
        cls_optim.zero_grad()
        
        #get data and label
        try :
            src_data, src_label = next(srcData_iter)
        except StopIteration :
            srcData_iter = iter(synthia_trainLoader)
            src_data, src_label = next(srcData_iter)
            
        try :
            tgt_data, tgt_label = next(tgtData_iter)
        except StopIteration :
            tgtData_iter = iter(cityscape_trainLoader)
            tgt_data, tgt_label = next(tgtData_iter)
            
        src_cls_label, src_seg_label = src_label['classification'], src_label['segmentation']
        tgt_cls_label = tgt_label['classification']
        
        #FDA
        src_data = FDA(src_data[0], tgt_data[0], 0.09)
        src_data = src_data.unsqueeze(0).float()
        
        #data to GPU
        if torch.cuda.is_available() :
            src_data, src_cls_label, src_seg_label = src_data.cuda(), src_cls_label.cuda(), src_seg_label.cuda()
            tgt_data, tgt_cls_label = tgt_data.cuda(), tgt_cls_label.cuda()
        
        #train segmentation
        src_feature, shortcut = encoder(src_data)
        result = decoder(src_feature, shortcut)
        
        loss_seg = criterion_seg(result, src_seg_label)
        loss_seg.backward(retain_graph=True)
        seg_optim.step()
        
        intersection, union, segmentation_target = intersectionAndUnion(output=result.cpu(), target=src_seg_label.cpu(), K=19, ignore_idx=19)
        intersection_meter.update(intersection), union_meter.update(union), target_meter.update(segmentation_target)
        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        
        #train classification
        src_feature, _ = encoder(src_data)
        tgt_feature, _ = encoder(tgt_data)
        
        cls_src_loss = F.cross_entropy(domain_cls(src_feature.detach()), src_cls_label)
        cls_tgt_loss = F.cross_entropy(domain_cls(tgt_feature.detach()), tgt_cls_label)
        cls_loss = 0.5 * (cls_src_loss + cls_tgt_loss)
        cls_loss.backward()
        cls_optim.step()
        
        cls += cls_loss.data
        
        #train confusion
        src_feature, _ = encoder(src_data)
        tgt_feature, _ = encoder(tgt_data)
        
        output_src_domain_cls = F.softmax(domain_cls(src_feature.detach()), dim=1)
        output_tgt_domain_cls = F.softmax(domain_cls(tgt_feature.detach()), dim=1)
        
        uni_distribution = torch.FloatTensor(output_src_domain_cls.size()).uniform_(0,1)
        if torch.cuda.is_available() :
            uni_distribution = uni_distribution.cuda()
        
        conf_src_loss = -0.5 * (torch.sum(uni_distribution*torch.log(output_src_domain_cls)))/float(output_src_domain_cls.size(0))
        conf_tgt_loss = -0.5 * (torch.sum(uni_distribution*torch.log(output_tgt_domain_cls)))/float(output_tgt_domain_cls.size(0))
        conf_loss = 0.5 * (conf_src_loss + conf_tgt_loss)
        
        conf_loss.backward()
        conf_optim.step()
        
        conf += conf_loss.data
        
        #print and reset iou score
        if epoch % 100 == 0 :
            iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
            accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
            mIOU = np.mean(iou_class)
            mAcc = np.mean(accuracy_class)
            allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
            
            print('Train Epoch[{}/{} ({:.0f}%)]\tmIOU : {:.6f}\tConfusion Loss : {:.6f}\tClassification Loss : {:.6f}'.format(
                    epoch*len(src_data), max_epoch, 100.*epoch/max_epoch, mIOU, conf/100., cls/100.))
            conf = 0.
            cls = 0.
            
            intersection_meter.reset()
            union_meter.reset()
            target_meter.reset()

        torch.save(encoder.cpu().state_dict(), "gdrive/My Drive/Colab Notebooks/uda_encoder_fda_once_2.pth")
        torch.save(decoder.cpu().state_dict(), "gdrive/My Drive/Colab Notebooks/uda_decoder_fda_once_2.pth")        
        torch.save(domain_cls.cpu().state_dict(), "gdrive/My Drive/Colab Notebooks/uda_domain_cls_once_2.pth")

trainOnce(0, 150000)   

