In [None]:
# %load lowlight_train.py
# 对其进行改进，使其能训练3D图像的模型

import torch
import torch.nn as nn
import torchvision
import torch.backends.cudnn as cudnn
import torch.optim
import os
import sys
import argparse
import time
# import dataloader
# import model
# import Myloss
import dataloader_3D
import model_3D
import Myloss_3D
import numpy as np
from torchvision import transforms

torch.set_default_dtype(torch.float64)


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


def train(config):

    os.environ['CUDA_VISIBLE_DEVICES']='0'

    DCE_net = model_3D.enhance_net_nopool().cuda()
    
    DCE_net.apply(weights_init)
#     if config.load_pretrain == True:
#         DCE_net.load_state_dict(torch.load(config.pretrain_dir))
        

    train_dataset = dataloader_3D.lowlight_loader(config.lowlight_images_path)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, 
                                               shuffle=True, num_workers=config.num_workers, 
                                               pin_memory=True)
#     L_color = Myloss.L_color()
    L_spa = Myloss_3D.L_spa()
    L_exp = Myloss_3D.L_exp(16,0.6)
    L_TV = Myloss_3D.L_TV()
    
    optimizer = torch.optim.Adam(DCE_net.parameters(), lr=config.lr, weight_decay=config.weight_decay)

    DCE_net.train()

    for epoch in range(config.num_epochs):
        for iteration, img_lowlight in enumerate(train_loader):
            img_lowlight = img_lowlight.cuda()

            enhanced_image_1,enhanced_image,A  = DCE_net(img_lowlight)

            Loss_TV = L_TV(A)
            loss_spa = torch.mean(L_spa(enhanced_image, img_lowlight))
#             loss_col = 5*torch.mean(L_color(enhanced_image))
            loss_exp = torch.mean(L_exp(enhanced_image))

            # best_loss
            loss =  Loss_TV + loss_spa + loss_exp
                        
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(DCE_net.parameters(),config.grad_clip_norm)
            optimizer.step()

            if ((iteration+1) % config.display_iter) == 0:
                print("Loss at iteration", iteration+1, ":", loss.item(),
                      '  Loss_TV:{} + loss_spa:{} + loss_exp:{}'.format(Loss_TV , loss_spa , loss_exp))
            if ((iteration+1) % config.snapshot_iter) == 0:
                torch.save(DCE_net.state_dict(), config.snapshots_folder 
                           + "Epoch" + str(epoch) + '.pth')




if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    # Input Parameters
    parser.add_argument('--lowlight_images_path', type=str, default="D:/zero_dce_data/img/")
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    parser.add_argument('--grad_clip_norm', type=float, default=0.1)
    parser.add_argument('--num_epochs', type=int, default=2)   # 200
    parser.add_argument('--train_batch_size', type=int, default=1)    # 8
    parser.add_argument('--val_batch_size', type=int, default=4)      # 4
    parser.add_argument('--num_workers', type=int, default=0)         # 4
    parser.add_argument('--display_iter', type=int, default=10)       # 10
    parser.add_argument('--snapshot_iter', type=int, default=10)      # 10
    parser.add_argument('--snapshots_folder', type=str, default="snapshots/")
    parser.add_argument('--load_pretrain', type=bool, default= False)    # False
    parser.add_argument('--pretrain_dir', type=str, default= "snapshots/Epoch99.pth")
    parser.add_argument('--h', type=str, default= "snapshots/Epoch99.pth")
    parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")

    config = parser.parse_args()

    if not os.path.exists(config.snapshots_folder):
        os.mkdir(config.snapshots_folder)


    train(config)




  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")


Total training examples: 1433




Loss at iteration 10 : 0.3599997227262327 Loss_TV0.00014452328546255943 + loss_spa7.862492690879738e-09 + loss_exp0.35985519157827744
Loss at iteration 20 : 0.3600454767350298 Loss_TV6.0754592283884015e-05 + loss_spa3.9884973546670875e-11 + loss_exp0.3599847221028609
Loss at iteration 30 : 0.360018249799478 Loss_TV3.124671116860944e-05 + loss_spa1.3423898517359974e-10 + loss_exp0.3599870029540704
Loss at iteration 40 : 0.3599378863103512 Loss_TV2.422330630562198e-05 + loss_spa3.097690355686184e-09 + loss_exp0.3599136599063552
Loss at iteration 50 : 0.36000639128602935 Loss_TV1.42973735520102e-05 + loss_spa9.828009465926428e-11 + loss_exp0.3599920938141972
Loss at iteration 60 : 0.35983823427480843 Loss_TV1.0399228209113683e-05 + loss_spa7.198851338277591e-08 + loss_exp0.35982776305808595
Loss at iteration 70 : 0.35997833339151053 Loss_TV8.279820198167105e-06 + loss_spa2.8806326285170336e-09 + loss_exp0.3599700506906797
Loss at iteration 80 : 0.35951820840616827 Loss_TV7.963829623893794