In [None]:
!nvidia-smi
!git clone --recursive https://github.com/taqu/threestudio.git

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!suto apt install -y libaio-dev
!pip install ninja
!pip install lightning==2.0.0 omegaconf==2.3.0 jaxtyping typeguard diffusers transformers accelerate opencv-python tensorboard matplotlib imageio imageio[ffmpeg] trimesh bitsandbytes sentencepiece safetensors huggingface_hub libigl xatlas datasets wandb tomesd
!pip install open3d plotly # mesh visualization
#!pip install cloud-tpu-client==0.10 torch==2.0.0 torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl
!pip install git+https://github.com/KAIR-BAIR/nerfacc.git@v0.5.2
!pip install git+https://github.com/NVlabs/nvdiffrast.git
!pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
!pip install deepspeed


In [None]:
import torch
import datasets
import sys
import os
import secrets
import asyncio
import shutil
import datetime
import glob

def load_count(filename):
    count = 0
    try:
        with open(filename, mode='r') as f:
            for l in f.readlines():
                try:
                    count = int(l)
                    break
                except:
                    continue
    except:
        pass
    finally:
        return count

def save_count(filename, count):
    try:
        with open(filename, mode='w') as f:
            f.write(str(count))
    except:
        pass
    finally:
        return

async def run_command(command: str, args: list[str]) -> None:
    process = await asyncio.create_subprocess_exec(
        command,
        *args,
        env=os.environ,
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.PIPE)
    while True:
        if process.stdout.at_eof() and process.stderr.at_eof():
            break;
        stdout = (await process.stdout.readline()).decode()
        if stdout:
            print(stdout, end='', flush=True)
        stderr = (await process.stderr.readline()).decode()
        if stderr:
            print(stderr, end='', flush=True, file=sys.stderr)
        await asyncio.sleep(0.25)
    await process.communicate()

def find_last_checkpoint(path:str) -> str:
    if False == os.path.exists(path):
        return None
    files = os.listdir(path)
    dirs = [f for f in files if os.path.isdir(os.path.join(path,f))]
    dirs.sort()
    if len(dirs)<=0:
        return None
    check_point = os.path.join(path, dirs[-1], 'ckpts', 'last.ckpt')
    return check_point if os.path.exists(check_point) else None

def save_last_checkpoint(path:str) -> None:
    check_point = find_last_checkpoint('/content/threestudio/outputs/sheep')
    if None == check_point:
        return
    dst_dir = '/content/drive/MyDrive/Sheep/ckpts/'
    date = datetime.datetime.utcnow()
    dst_path = os.path.join(dst_dir, 'ckpt'+date.strftime('%Y%m%d')+'.ckpt')
    shutil.copy(check_point, dst_path)

def restore_last_checkpoint(path:str)->None:
    try:
        file_path = os.path.join(path, 'ckpt*.ckpt')
        files = glob.glob(file_path)
        if None == files:
            return
        files.sort()
        file_path = os.path.join(path, files[-1])
        print(file_path)
        shutil.copy(file_path, '/content/last.ckpt')
    except Exception as e:
        print(e)

def save_last_train_count(path:str)->None:
    src_path = '/content/train_count.txt'
    if False == os.path.exists(src_path):
        return
    date = datetime.datetime.utcnow()
    dst_path = os.path.join(path, 'train_count'+date.strftime('%Y%m%d')+'.txt')
    shutil.copy(src_path, dst_path)

def restore_last_train_count(path:str)->None:
    try:
        file_path = os.path.join(path, 'train_count*.txt')
        files = glob.glob(file_path)
        if None == files:
            return
        files.sort()
        file_path = os.path.join(path, files[-1])
        print(file_path)
        shutil.copy(file_path, '/content/train_count.txt')
    except Exception as e:
        print(e)

prompts = []
dataset = datasets.load_dataset("FredZhang7/stable-diffusion-prompts-2.47M")
prompts = prompts + dataset['train']['text']

dataset = datasets.load_dataset('Gustavosta/Stable-Diffusion-Prompts')
prompts = prompts + dataset['train']['Prompt']

#dataset = datasets.load_dataset("FredZhang7/anime-prompts-180K")
#prompts = prompts + dataset['train']['safebooru_clean']
#prompts = prompts + dataset['train']['danbooru_clean']
#prompts = prompts + dataset['train']['danbooru_raw']                         

dataset = None
restore_last_checkpoint('/content/drive/MyDrive/Sheep/ckpts')
restore_last_train_count('/content/drive/MyDrive/Sheep/ckpts') 
count_path = '/content/train_count.txt'
train_count = load_count(count_path)
print(min(len(prompts), train_count))
max_train_count = max(len(prompts), train_count)
print(max_train_count)
conf_file = os.path.join(os.getcwd(), 'threestudio/configs/prolificdreamer.yaml')

os.environ['PYTORCH_CUDA_ALLOC_CONF']='128'
os.chdir('/content/threestudio')
torch.set_float32_matmul_precision('medium')
#base_config = ['launch.py','--config', conf_file, '--train', '--gpu', '0', 'name=\"sheep\"','tag=\"sheep\"', 'data.width=512', 'data.height=512']
base_config = ['launch.py','--config', conf_file, '--train', '--gpu', '0', 'name=\"sheep\"','tag=\"sheep\"','data.width=64', 'data.height=64']
extra_config = [
    'trainer.max_steps=10000'
    ,'system.guidance.token_merging=true'
    ,'trainer.strategy=\"deepspeed_stage_2_offload\"'
    ,'trainer.devices=1'
    #,'trainer.precision=\"16-mixed\"'
    ,'trainer.precision=\"32\"'
    ,'system.guidance.enable_attention_slicing=true'
    ,'system.guidance.use_deepspeed=true'
    ,'system.optimizer.name=\"DeepSpeedCPUAdam\"'
    ,'system.guidance.pretrained_model_name_or_path=\"stabilityai/stable-diffusion-2-1-base\"'
    ,'system.guidance.pretrained_model_name_or_path_lora=\"stabilityai/stable-diffusion-2-1\"'
]
#extra_config = ['trainer.max_steps=1', 'trainer.max_epochs=1', 'trainer.fast_dev_run=True', 'data.batch_size=1', 'system.guidance.token_merging=true', 'system.guidance.enable_attention_slicing=true']
base_config = base_config + extra_config
train_count = 0
for i in range(train_count, 1):
    seed = secrets.randbits(64)
    args = base_config + ['seed='+str(seed)] + ['system.prompt_processor.prompt=' + prompts[i]]
    try:
        print(args)
        await run_command('python3', args)
    except Exception as e:
        print(e)
        break
    train_count += 1
    save_count(count_path, train_count)
    save_last_checkpoint('/content/drive/MyDrive/Sheep/ckpts')
    save_last_train_count('/content/drive/MyDrive/Sheep/ckpts')
os.chdir('/content')
