In [4]:
import os
import cv2
import imageio
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torch
from random import randint
# from utils.loss_utils import l1_loss, ssim
# from gaussian_renderer import render, network_gui
import sys
# from scene import Scene, GaussianModel
# from utils.general_utils import safe_state
import uuid
from tqdm import tqdm
# from utils.image_utils import psnr
from argparse import ArgumentParser, Namespace
# from arguments import ModelParams, PipelineParams, OptimizationParams

from Gaussian2D import Gaussian2DModel

try:
    from torch.utils.tensorboard import SummaryWriter
    TENSORBOARD_FOUND = True
except ImportError:
    TENSORBOARD_FOUND = False

import time



In [6]:
imgPathList = [
    "./Pics/Pikachu.png",      # 皮卡丘 600*600
    "./Pics/Charmander.png",   # 小火龙 600*600
    "./Pics/Squirtle.png",     # 杰尼龟 600*600
    "./Pics/Bulbasaur.png",    # 妙蛙种子 600*600
    "./Pics/Lapras.png",       # 乘龙 600*600
    "./Pics/Meowth.png",       # 喵喵 600*600
    "./Pics/Mewtwo.png",       # 梦幻  600*600
    "./Pics/Rayquaza.png",     # 裂空座 1000*1000
    "./Pics/Groudon.png"      # 固拉多  1000*1000
]


In [29]:
def createGIF(dirPath):
    imageNameList = [f for f in os.listdir(dirPath) if f.endswith('.png')]
    frames = []
    for imageName in imageNameList:
        imgPath = os.path.join(dirPath, imageName)
        imgBGR = cv2.imread(imgPath, cv2.IMREAD_COLOR)
        imgRGB = cv2.cvtColor(imgBGR, cv2.COLOR_BGR2RGB)
        frames.append(imgRGB)
    
    gifName = f'output_{len(frames)}_frames.gif'
    gifPath = os.path.join(dirPath, gifName)
    imageio.mimsave(gifPath, frames, duration=0.25)  # duration 是每帧的显示时间 (s)
    print(f'GIF saved at: {gifPath}')


In [25]:
def train(imgPath, Epoch, numGaussian=800, saveRenderInterval=20, densifyInterval=20):

    startTime = time.time()
    startTimeInt = int(startTime)

    imgBGR = cv2.imread(imgPath, flags=cv2.IMREAD_COLOR)
    imgRGB = cv2.cvtColor(imgBGR, cv2.COLOR_BGR2RGB)
    imgNormalizedRGB = imgRGB/255.0
    imgNormalizedBGR = imgBGR/255.0


    # cv2.imshow('imgNormalized', imgNormalized)
    # cv2.waitKey(100)

    # fig =plt.figure(figsize=(10, 5))  # 设置图像大小
    # plt.suptitle('Gaussian 2D Model Visualize')  # 设置窗口标题
    # fig.canvas.manager.set_window_title('Gaussian 2D Model Visualize')  # 设置窗口标题

    imgShape = imgNormalizedRGB.shape
    print(imgShape)

    gaussian = Gaussian2DModel(numGaussian, imgNormalizedBGR)



    numEpochs = Epoch
    for epoch in range(numEpochs):
        loss = gaussian.lossAndBackwardAndStep()
        if epoch % saveRenderInterval == 0:
            print(f'Epoch {epoch}, Loss: {loss}')
            # imgTorch = gaussian.render()
            imgRenderBGR = (gaussian.renderImgTorch).detach().cpu().numpy()
            imgRenderBGR255 = (imgRenderBGR*255).astype(np.int32)
            
            avatarName = os.path.splitext(os.path.basename(imgPath))[0]
            fileName = rf"Epoch_{epoch:04}_Loss_{loss:08.5f}.png"
            fileDirPath = rf"./output/{startTimeInt}_{Epoch}_Epoch_{numGaussian}_Gaussians_{avatarName}"
            
            if not os.path.exists(fileDirPath):
                os.mkdir(fileDirPath)
            filePath = os.path.join(fileDirPath, fileName)
            
            absolutePath = os.path.abspath(filePath)
            
            print("absolutePath : ", absolutePath)
            
            cv2.imwrite(filePath, imgRenderBGR255)
            
            # imgRenderRGB = cv2.cvtColor(imgRenderBGR, cv2.COLOR_BGR2RGB)
            
            # plt.clf()
            
            # plt.subplot(1, 2, 1)  # 1行2列，当前为第1个
            # plt.imshow(imgNormalizedRGB)
            # plt.axis('off')  # 关闭坐标轴
            # plt.title('Original Image')
            
            # plt.subplot(1, 2, 2)  # 1行2列，当前为第1个
            # # plt.gcf().canvas.setWindowTitle(f'Epoch: {epoch}, Loss: {loss:.4f}')  # 设置窗口标题
            # plt.imshow(imgRenderRGB)
            # plt.axis('off')  # 关闭坐标轴
            # plt.title(f'Epoch: {epoch}, Loss: {loss:.4f}')
            
            # plt.show()
            # plt.pause(0.1)  # 暂停以更新图像
            
    return fileDirPath

    # cv2.waitKey(0)
    # cv2.destroyAllWindows()


In [30]:
fileDirPath = train(imgPath=imgPathList[0], Epoch=300, numGaussian=100, saveRenderInterval=20)
# fileDirPath = r"output/1729555929_300_Epoch_100_Gaussians_Pikachu"
createGIF(fileDirPath)


GIF saved at: output/1729555929_300_Epoch_100_Gaussians_Pikachu/output_15_frames.gif
