# Imports 

In [1]:
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
from timm.models.layers import trunc_normal_
from util.pos_embed import interpolate_pos_embed
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import timm
assert timm.__version__ == "0.3.2"  # version check
import timm.optim.optim_factory as optim_factory
import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
import model_s_c_g_loss
from engine_pretrain import train_one_epoch
import os
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import prepare_data
from functools import partial
import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed, Block
from util.pos_embed import get_2d_sincos_pos_embed
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from tqdm import tqdm
import matplotlib
from matplotlib import pyplot as plt
import torchvision.transforms as transforms


ModuleNotFoundError: No module named 'torch._six'

# Input to the model


In [None]:
def get_args_parser():

    parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
    parser.add_argument('--device', default='cuda:9',help='device to use for training / testing')
    parser.add_argument('--batch_size', default=32, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--src_img_path', default=r'/media/RTCIN7TBDriveB/home/ihg6kor/BDD_data/clear_daytime', type=str,
                        help='source domain dataset path')
    parser.add_argument('--target_img_path', default=r'/media/RTCIN7TBDriveB/home/ihg6kor/BDD_data/clear_night', type=str,
                        help='target domain dataset path')
    parser.add_argument('--finetune', default=r'/media/RTCIN7TBDriveB/home/ihg6kor/Domain_adaptation/mae_adaptation_idea2_0/preload_model/mae_visualize_vit_large_ganloss.pth',help='finetune from checkpoint')

    
    parser.add_argument('--input_size', default=224, type=int,
                        help='images input size')

    parser.add_argument('--mask_ratio', default=0.01, type=float,
                        help='Masking ratio (percentage of removed patches).')
    parser.add_argument('--lr', type=float, default=None, metavar='LR',
                        help='learning rate (absolute lr)')
    
    parser.add_argument('--output_dir', default='./output_style_contentv2',
                        help='path where to save, empty for no saving')
    
    parser.add_argument('--log_dir', default='./output_dir',
                        help='path where to tensorboard log')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='',help='resume from checkpoint')

    
    parser.add_argument('--norm_pix_loss', action='store_true',
                        help='Use (per-patch) normalized pixels as targets for computing loss')

    return parser
a = get_args_parser()
args,unknown = a.parse_known_args()

# Loading the datasets


In [None]:

device = torch.device(args.device)
    # simple augmentation
transform_train = transforms.Compose([
        transforms.CenterCrop(720),
        transforms.Resize(args.input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

        ])

## loading the dataset

dataset_src_domain = datasets.ImageFolder(args.src_img_path, transform=transform_train)
dataset_target_domain = datasets.ImageFolder(args.target_img_path, transform=transform_train)

## Dataloaders

data_loader_src = torch.utils.data.DataLoader(
    dataset_src_domain,
    batch_size=args.batch_size,
    drop_last=True,
    shuffle = True
)
data_loader_target = torch.utils.data.DataLoader(
    dataset_target_domain,
    batch_size=args.batch_size,
    drop_last=True,
    shuffle = True
)

In [None]:
len(data_loader_src)*args.batch_size


In [None]:
# # Test for decoder
# buf = model_GANLOSS.Discriminator()
# out = buf(next(iter(data_loader_src))[0])
# out.shape

# Define the model


In [None]:
gen_model = model_s_c_g_loss.MaskedAutoencoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6),norm_pix_loss=args.norm_pix_loss)

disc_model = model_s_c_g_loss.Discriminator()

def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
    # build model
    model = gen_model
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model
chkpt_dir = args.finetune
gen_model = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
print('Model loaded.')

## verify weights inside the model
for name, param in gen_model.named_parameters():
    print(name, param.shape,torch.mean(param))

# Freeze the right layer 

In [None]:


for params in gen_model.parameters():
    params.requires_grad  = False

# for params in gen_model.transform_block.parameters():
#     params.requires_grad  = True
# for params in gen_model.transform_norm.parameters():
#     params.requires_grad  = True

for params in gen_model.decoder_embed.parameters():
    params.requires_grad  = True
for params in gen_model.decoder_blocks.parameters():
    params.requires_grad  = True
for params in gen_model.decoder_norm.parameters():
    params.requires_grad  = True
for params in gen_model.decoder_pred.parameters():
    params.requires_grad  = True

gen_model.to(device)
disc_model.to(device)
n_parameters = sum(p.numel() for p in gen_model.parameters() if p.requires_grad)
print('number of params (M): %.2f' % (n_parameters / 1.e6))

# Utility functions

In [None]:
loaded_model = model_s_c_g_loss.VGG().to(device).eval()

def calc_content_loss(gen_image,src_image,loaded_model):
    #calculating the content loss of each layer by calculating the MSE between the content and generated features and adding it to content loss
    gen_features = loaded_model(gen_image)
    src_features = loaded_model(src_image)
    content_l = 0
    # for i in range(len(gen_features)): 
    content_l=torch.mean((gen_features[-1]-src_features[-1])**2)
    return content_l

def calc_style_loss(gen_img,style_img,loaded_model):
    #Calculating the gram matrix for the style and the generated image
    style_all = []
    gen_feat = loaded_model(gen_img)
    style_feat = loaded_model(style_img)
    
    for feat in range(len(gen_feat)):
        gen = gen_feat[feat]
        style = style_feat[feat]
        batch_size,channel,height,width=gen.shape
        gen_calc = gen.view(batch_size,channel,height*width)
        style_calc = style.view(batch_size,channel,height*width)
        
        G=gen_calc.bmm(gen_calc.transpose(1,2))
        G /= channel * height * width
        A=style_calc.bmm(style_calc.transpose(1,2))
        A /= channel * height * width
        style_all.append(torch.mean((G-A)**2))     
    return torch.mean(torch.tensor(style_all))
    
def content_style_loss(src_image,gen_image,trg_image):
    ## transform the input images 
    # src_image = torch,intransforms.Resize(size= (224,512))(src_image)
    # gen_image = transforms.Resize(size= (512,512))(gen_image)
    # trg_image = transforms.Resize(size= (512,512))(trg_image)
    style_loss=content_loss=0
    content_loss = calc_content_loss(gen_image,src_image,loaded_model)
    style_loss = calc_style_loss(gen_image,trg_image,loaded_model)
    return content_loss,style_loss


In [None]:
##Testing the content and style loss

# src_image = next(iter(data_loader_src))[0].to(device)
# gen_image = next(iter(data_loader_src))[0].to(device)
# trg_image = next(iter(data_loader_target))[0].to(device)
# val = content_style_loss(1,10000,src_image , gen_image,trg_image)
# print("total",val)

In [None]:
def show_mean_model(model):
    layer_mean = []
    for param in model.parameters(): 
        if param.requires_grad:
            layer_mean.append(torch.mean(param))
    return torch.mean(torch.tensor(layer_mean))

def save_on_master(*args, **kwargs):
    torch.save(*args, **kwargs)
        
def save_model(d_model,g_model,epoch,d_optim,g_optim):
    output_dir = Path(args.output_dir)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    epoch_name = str(epoch)
    checkpoint_d = [output_dir / ('disc_checkpoint-%s.pth' % epoch_name)]

    to_save_d = {
        'model': d_model.state_dict(),
        'optimizer': d_optim.state_dict(),
        'epoch': epoch,
        'args': args
    }

    torch.save(to_save_d, checkpoint_d[0])
    checkpoint_g = [output_dir / ('gen_checkpoint-%s.pth' % epoch_name)]
    to_save_g = {
    'model': g_model.state_dict(),
    'optimizer': g_optim.state_dict(),
    'epoch': epoch,
    'args': args
    }
    torch.save(to_save_g, checkpoint_g[0])
    print("saved models for epoch = " , epoch)
    return 


def show_images(img,label,epoch,args,iter_step):
    epoch_name = str(epoch)
    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)
    output_dir = Path(args.output_dir)/("images-epoch%s"%epoch_name)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
        
    img = img.detach().cpu()
    img = img[0]
    imagenet_mean = np.array([0.485, 0.456, 0.406])
    imagenet_std = np.array([0.229, 0.224, 0.225])

    plt.figure(figsize = (10,2))
    plt.imshow(torch.clip((img * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    loc = str(output_dir)+"//"+str(label)+"_iter_"+str(iter_step)+'.png'
    plt.savefig(loc)    
    matplotlib.pyplot.close()
    return output_dir


# Training Loop 


In [None]:
## Losses defined
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(disc_model.parameters(), lr=1e-3)
g_optimizer = torch.optim.Adam(gen_model.parameters(), lr=1e-3)


In [None]:
# d_loss,d_fake,d_real = discriminator_train_step(args.batch_size,disc_model,gen_model,d_optimizer,criterion,target[:int(args.batch_size/2)],sample[:int(args.batch_size/2)])
# 
def discriminator_train_step(batch_size, discriminator, generator, d_optimizer, criterion, real_images,src_images):
    for params in discriminator.parameters():
        params.requires_grad  = True
    for params in generator.parameters():
        params.requires_grad  = False
    d_optimizer.zero_grad()
    a = show_mean_model(generator)
    b = show_mean_model(discriminator)
    classify_real_image  = discriminator(real_images).view(-1).to(device)
    # print(classify_real_image[1])
    disc_real_loss = criterion(classify_real_image, torch.ones(classify_real_image.shape).to(device))
    pred_img_latent,mask  = generator(src_images,real_images, mask_ratio=args.mask_ratio)
    ## convert pred image to be same dimension as target image
    pred =  generator.unpatchify(pred_img_latent)
    classify_fake_image = discriminator(pred.detach()).view(-1).to(device)
    disc_fake_loss = criterion(classify_fake_image, torch.zeros(classify_fake_image.shape).to(device))
    d_loss = (disc_real_loss + disc_fake_loss)/2
    d_loss.backward()
    d_optimizer.step()
    assert show_mean_model(generator) != a , "generator paramters changing in Discriminator"
    # assert show_mean_model(discriminator) == b , "discriminator paramters not in Discriminator"
    


    return d_loss,disc_fake_loss,disc_real_loss

In [None]:
def generator_train_step(loss_params,src_images,target_images, discriminator, generator, g_optimizer, criterion):
    g_optimizer.zero_grad()
    for params in discriminator.parameters():
        params.requires_grad  = False
    for params in gen_model.decoder_embed.parameters():
        params.requires_grad  = True
    for params in gen_model.decoder_blocks.parameters():
        params.requires_grad  = True
    for params in gen_model.decoder_norm.parameters():
        params.requires_grad  = True
    for params in gen_model.decoder_pred.parameters():
        params.requires_grad  = True
    a = show_mean_model(generator)
    b = show_mean_model(discriminator)
    pred_latent,mask = generator(src_images,target_images, mask_ratio=args.mask_ratio)
    pred =  generator.unpatchify(pred_latent)
    classify_fake_image = discriminator(pred).view(-1).to(device)
    g_loss = criterion(classify_fake_image,torch.ones(classify_fake_image.shape).to(device))
    content_loss,style_loss = content_style_loss(src_images,pred,target_images)

    final_loss = loss_params[0]* g_loss + loss_params[1]*content_loss + loss_params[2] * style_loss
    # print("final gen loss",final_loss.item(),"GAN_GEN_LOSS",g_loss.item(), "Content",content_loss.item(), "Style", style_loss.item())
    final_loss.backward()
    g_optimizer.step()
    # assert show_mean_model(generator) == a , "generator paramters NOT changing in generator"
    assert show_mean_model(discriminator) != b , "discriminator paramters CHANGING in generator"
    
    return final_loss,pred,content_loss,style_loss,g_loss


In [None]:
##Start the training process

alpha = 1
beta = 1
gamma = 100000


print(f"Start training for {args.epochs} epochs")
final_gloss = []
final_dloss = []

plt.ion()  
# preparing the data
x = [-3,-2,-1]
d_fake_list= [0.1,0.6,0.7]
d_real_list = [0.3,0.5,0.8]
final_loss_list = [0.3,0.5,0.8]
g_loss_list = [0.2,0.4,0.9]
c_loss_list = [0.2,0.4,0.9]
s_loss_list = [0.2,0.4,0.9]
graph = plt.plot(x,d_fake_list,label = "disc_fake")[0]
graph = plt.plot(x,d_real_list,label = "disc_real")[0]
graph = plt.plot(x,final_loss_list,label = "final_loss")[0]
graph = plt.plot(x,g_loss_list,label = "gen_loss")[0]
graph = plt.plot(x,c_loss_list,label = "content")[0]
graph = plt.plot(x,s_loss_list,label = "style")[0]
plt.ylim(0,5)
plt.legend()
plt.pause(0.25)

for epoch in range(0, args.epochs):    
    gen_model.train(True)
    iter_data_loader_target = iter(data_loader_target)
    for iter_step,samples in enumerate(tqdm(data_loader_src)):
        t,_= next(iter_data_loader_target)
        target= t.to(device)
        sample =samples[0].to(device)
        final_loss,pred_img_unpatch,c_loss,s_loss,g_loss = generator_train_step([alpha,beta,gamma],sample[int(args.batch_size/2):],target[int(args.batch_size/2):],disc_model,gen_model,g_optimizer,criterion)
        d_loss,d_fake,d_real = discriminator_train_step(args.batch_size,disc_model,gen_model,d_optimizer,criterion,target[:int(args.batch_size/2)],sample[:int(args.batch_size/2)])

        if iter_step%20==0:
            d_fake_list.append(d_fake.detach().cpu())
            d_real_list.append(d_real.detach().cpu())
            final_loss_list.append(final_loss.detach().cpu())
            g_loss_list.append((g_loss*alpha).detach().cpu())
            c_loss_list.append((c_loss*beta).detach().cpu())
            s_loss_list.append((s_loss*gamma).detach().cpu())
            x.append(epoch*len(data_loader_src)+iter_step)
            
            # print("epoch===",epoch,"discriminator loss=",torch.mean(torch.tensor(dis_losses))," :: generator loss=",torch.mean(torch.tensor(gen_losses)))
            sample_print = torch.einsum('nchw->nhwc',sample[int(args.batch_size/2):])
            target_print = torch.einsum('nchw->nhwc',target[int(args.batch_size/2):])
            pred_img_unpatch = torch.einsum('nchw->nhwc',pred_img_unpatch)
            out_loc = show_images(sample_print,"source",epoch,args,iter_step)
            out_loc = show_images(target_print,"target",epoch,args,iter_step)
            out_loc = show_images(pred_img_unpatch.detach().cpu(),"generated",epoch,args,iter_step)
            loc = str(out_loc)+"//losses_iter_"+str(iter_step)+'.png'
            ## Plotting the graphs
            graph.remove()
            graph = plt.plot(x,d_fake_list,color = 'g',label = "disc_fake")[0]
            graph = plt.plot(x,d_real_list,color = 'r',label = "disc_real")[0]
            graph = plt.plot(x,final_loss_list,color = 'k',label = "final_loss")[0]
            graph = plt.plot(x,g_loss_list,color = 'b',label = "gen_loss")[0]
            graph = plt.plot(x,c_loss_list,color = 'c',label = "content")[0]
            graph = plt.plot(x,s_loss_list,color = 'm',label = "style")[0]
            plt.xlim(x[0], x[-1])
            # calling pause function for 0.25 seconds
            plt.legend()
            plt.savefig(loc)  
            plt.pause(0.25)
            

In [None]:
print(final_gloss)
print(final_dloss)

In [None]:
save_model(disc_model,gen_model,100,d_optimizer,g_optimizer):
