In [10]:
import sys
sys.path.append('../../../')

import torch
from torch import nn
from train import val        
import warnings
import numpy as np
import torch.nn.functional as F
warnings.filterwarnings(action='ignore')

from static_definer import *

In [2]:
source_sample_image, source_sample_label  = next(iter(train_dataloader))
target_sample_image, _  = next(iter(train_dataloaderGTA5))

In [3]:
from models.domain_shift.adversarial.functions import DomainDiscriminator
from models.bisenet.build_bisenet import BiSeNet

In [4]:

# # defining a CrossEntropyLoss for the segmentation and a BCEWithLogitsLoss for the domain classification
generator_loss = nn.CrossEntropyLoss(ignore_index=19)
discriminator_loss = nn.BCEWithLogitsLoss()

# defining the models
generator = BiSeNet(num_classes=num_classes,context_path='resnet18',with_interpolation=True)
# defining the Discriminator
discriminator = DomainDiscriminator(num_classes=num_classes,with_grl=False)

# defining the optimizer
generator_optimizer = torch.optim.SGD(generator.parameters(), lr=1e-2, momentum=0.9, weight_decay=5e-4)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-3)

In [5]:
generator_optimizer.zero_grad()
discriminator_optimizer.zero_grad()

source_input_size = source_sample_image.size()[2:]
target_input_size = target_sample_image.size()[2:]
# These interpolation are defined to resizing the output of the discriminator
source_interp = nn.Upsample(size=(source_input_size[1], source_input_size[0]), mode='bilinear')
target_interp = nn.Upsample(size=(target_input_size[1], target_input_size[0]), mode='bilinear')

In [7]:
def generator_loss_calculator(criteria, model_output, segment_label):
    '''
    This function calculates the loss of the generator model
    :param criteria: the loss function to be used
    :param model_output: the output of the model
    :param segment_label: the ground truth label

    :return: the loss
    '''

    if isinstance(model_output,tuple):
        model_output, ax1, ax2 = model_output
        loss = criteria(model_output, segment_label)
        loss += criteria(ax1, segment_label)
        loss += criteria(ax2, segment_label)
    else:
        loss = criteria(model_output, segment_label)

    return loss

In [9]:
def train(epoch : int,lambda_=0.1):

    try:
        from IPython import get_ipython
        if get_ipython():
            from tqdm.notebook import tqdm
    except:
        from tqdm import tqdm


    generator.train()
    discriminator.train()
    for i, (source_data, target_data) in tqdm(enumerate(zip(train_dataloaderGTA5, train_dataloader)), total=len(train_dataloaderGTA5) , desc=f'Epoch {epoch}'):
        source_image, source_label = source_data
        target_image, _ = target_data

        source_image, source_label = source_image.to(device), source_label.to(device)
        target_image = target_image.to(device)

        # ! Training the generator
        generator_optimizer.zero_grad()
        discriminator_optimizer.zero_grad()

        # Forward pass Generator
        # * The source features in here are same as the segmentation output as the low-dimenssion segmentation has been used as input of discriminator 
        source_features = generator(source_image)
        target_feature = generator(target_image)

        # loss on generator of source domain 
        # * We only perform the loss on the source domain as the target domain is not labeled
        gen_source_loss = generator_loss_calculator(generator_loss, source_features, source_label)

        if isinstance(source_features,tuple):
            source_features, _ , _ = source_features
        if isinstance(target_feature,tuple):
            target_feature, _ , _ = target_feature
        
        # ! Forward pass Discriminator
        # * Here we feed the Discriminator with the output of the generator (features) or in this case the (low-dimenssion segmentation)
        source_discriminator_output = source_interp(discriminator(F.softmax(source_features)))
        target_discriminator_output = target_interp(discriminator(F.softmax(target_feature)))
        # * defining the Target label as 0 and the Source label as 1
        source_label = torch.ones_like(source_discriminator_output)
        target_label = torch.zeros_like(target_discriminator_output)

        # loss on discriminator
        disc_loss = discriminator_loss(source_discriminator_output, source_label) + discriminator_loss(target_discriminator_output, target_label)
        
        # ! Adversarial Training
        target_feature, _, _ = generator(target_image)
        target_discriminator_output = target_interp(discriminator(F.softmax(target_feature)))
        # * To fool the discriminator
        adver_loss = discriminator_loss(target_discriminator_output, source_label)
        # total loss
        total_loss = gen_source_loss + lambda_ * ( disc_loss + adver_loss )
        total_loss.backward()
        # Update the weights
        generator_optimizer.step()
        discriminator_optimizer.step()

        if i % 100 == 0 and i != 0:
            print(f'Iteration {i}, Generator Loss: {gen_source_loss.item()}, Discriminator Loss: {disc_loss.item()} , Adversarial Loss: {adver_loss.item()} , Total Loss: {total_loss.item()}')

        

In [None]:
for epoch in range(1, 10):

    train(epoch,lambda_=0.1)
    
    if epoch % 5 == 0:
        print('-'*50)
        val(epoch, generator, generator_loss, num_classes, device, val_dataloader, 'GTA5')
        print('-'*50)

torch.save(generator.state_dict(), f'generator_{epoch}.pth')
torch.save(discriminator.state_dict(), f'discriminator_{epoch}.pth')
    