In [None]:
def dic2arg(config:dict):
  arg = ''
  for value in config:
    arg += f'{value if str(config[value]) != "False" else ""} {"" if type(config[value]) == bool else config[value]} '
  return arg

def data_config(flip, color_aug, shuffle_caption):
  return {
    "general": {
        "flip_aug": flip,
        "color_aug": color_aug,
        "keep_tokens_separator": "|||",
        "shuffle_caption": shuffle_caption,
        "caption_tag_dropout_rate": 0,
        "caption_extension": ".txt"
    },
    "datasets": []
  }

op = {
    '--mixed_precision': 'bf16',
    '--num_cpu_threads_per_process': 1,
}

def config(TrainType, model_train, flux_path, clip_l_path, t5xxl_fp16_path, vae_path, dim, alpha, train_textencode, optimizer_type, lr, lr_scheduler, epochs, save_every_n_epochs, save_every_n_steps, dataset, output_dir, output_name):
  return {
    '--pretrained_model_name_or_path': flux_path if TrainType == 'Flux' else model_train,
    '--clip_l': clip_l_path if TrainType == 'Flux' else False,
    '--t5xxl': t5xxl_fp16_path if TrainType == 'Flux' else False,
    '--ae': vae_path if TrainType == 'Flux' else False,
    '--cache_latents_to_disk': True,
    '--save_model_as': 'safetensors',
    '--sdpa': True,
    '--persistent_data_loader_workers': True,
    '--max_data_loader_n_workers': 2,
    '--seed': 42,
    '--gradient_checkpointing': True,
    '--mixed_precision': 'bf16',
    '--save_precision': 'bf16',
    '--network_module': '"networks.lora_flux"' if TrainType == 'Flux' else '"networks.lora"',
    '--network_dim': dim,
    '--network_alpha': alpha,
    '--network_train_unet_only': False if train_textencode else True,
    '--network_args': "train_t5xxl=True" if train_textencode and TrainType == 'Flux' else False,
    '--optimizer_type': optimizer_type,
    '--optimizer_args': '"relative_step=False" "scale_parameter=False" "warmup_init=False"' if optimizer_type == 'adafactor' else False,
    '--learning_rate': lr,
    '--lr_scheduler': 'constant_with_warmup' if optimizer_type == 'adafactor' else lr_scheduler,
    '--unet_lr': False, 
    '--text_encoder_lr': False, 
    '--cache_text_encoder_outputs': False,
    '--cache_text_encoder_outputs_to_disk': False,
    '--fp8_base': True if TrainType == 'Flux' else False,
    '--highvram': True,
    '--max_train_epochs': epochs,
    '--save_every_n_epochs': False if save_every_n_epochs <= 0 else save_every_n_epochs,
    '--save_every_n_steps': False if save_every_n_steps <= 0 else save_every_n_steps,
    '--dataset_config': f'"{dataset}"',
    '--output_dir': f'"{output_dir}"',
    '--output_name': f'"{output_name}"',
    '--timestep_sampling': 'shift' if TrainType == 'Flux' else False,
    '--discrete_flow_shift': 3.1582 if TrainType == 'Flux' else False,
    '--model_prediction_type': 'raw' if TrainType == 'Flux' else False,
    '--guidance_scale': 1.0 if TrainType == 'Flux' else False,
    '--max_bucket_reso': 2048,
    '--min_bucket_reso': 256 if TrainType == 'SD15' else 512,
  }

def config_cp(TrainType, model_train, flux_path, clip_l_path, t5xxl_fp16_path, vae_path, optimizer_type, lr, lr_scheduler, epochs, save_every_n_epochs, save_every_n_steps, dataset, output_dir, output_name):
  return {
    '--pretrained_model_name_or_path': flux_path if TrainType == 'Flux' else f'"{model_train}"',
    '--clip_l': clip_l_path if TrainType == 'Flux' else False,
    '--t5xxl': t5xxl_fp16_path if TrainType == 'Flux' else False,
    '--ae': vae_path if TrainType == 'Flux' else False,
    '--cache_latents_to_disk': True,
    '--save_model_as': 'safetensors',
    '--sdpa': True,
    '--persistent_data_loader_workers': True,
    '--max_data_loader_n_workers': 2,
    '--seed': 42,
    '--gradient_checkpointing': True,
    '--mixed_precision': 'bf16',
    '--save_precision': 'fp16',
    '--optimizer_type': optimizer_type,
    '--optimizer_args': '"relative_step=False" "scale_parameter=False" "warmup_init=False"' if optimizer_type == 'adafactor' else False,
    '--learning_rate': lr,
    '--lr_scheduler': 'constant_with_warmup' if optimizer_type == 'adafactor' else lr_scheduler,
    '--unet_lr': False, 
    '--text_encoder_lr': False, 
    '--cache_text_encoder_outputs': True if TrainType == 'Flux' else False,
    '--cache_text_encoder_outputs_to_disk': True if TrainType == 'Flux' else False,
    '--fp8_base': True if TrainType == 'Flux' else False,
    '--highvram': True,
    '--max_train_epochs': epochs,
    '--save_every_n_epochs': False if save_every_n_epochs <= 0 else save_every_n_epochs,
    '--save_every_n_steps': False if save_every_n_steps <= 0 else save_every_n_steps,
    '--dataset_config': f'"{dataset}"',
    '--output_dir': f'"{output_dir}"',
    '--output_name': f'"{output_name}"',
    '--timestep_sampling': 'shift' if TrainType == 'Flux' else False,
    '--discrete_flow_shift': 3.1582 if TrainType == 'Flux' else False,
    '--model_prediction_type': 'raw' if TrainType == 'Flux' else False,
    '--guidance_scale': 1.0 if TrainType == 'Flux' else False,
    '--max_bucket_reso': 2048,
    '--min_bucket_reso': 256 if TrainType == 'SD15' else 512,
  }

def extra(TrainType, output_name, author,save_state, resume, wandbapi, sample_steps, runcode):
  if runcode:
    wandbapi = ''
  return {
      '--metadata_title': f'"{output_name}"',
      '--metadata_author': f'"{author}"',
      '--metadata_description': '"training by trainlora.vn"',
      '--metadata_license': False, # "MIT"
      '--metadata_tags': False, # "sdvn.me"
      '--no_metadata': False,
      '--training_comment': f'"training by trainlora.vn"',
      '--save_n_epoch_ratio': False,
      '--save_last_n_epochs': False,
      '--save_last_n_epochs_state': False,
      '--save_last_n_steps': False,
      '--save_last_n_steps_state': False,
      '--save_state': save_state,
      '--save_state_on_train_end': True,
      '--resume': f'"{resume}"' if resume != '' else False, #Path
      '--log_with': 'wandb' if wandbapi!='' else False, # tensorboard,wandb,all
      '--wandb_run_name': f'"{TrainType}lora_{output_name}"' if wandbapi!='' else False,
      '--wandb_api_key': wandbapi if wandbapi!='' else False,
      '--sample_every_n_steps': False if sample_steps <= 0 else sample_steps,
      '--sample_at_first': False if sample_steps <= 0 else True,
      '--sample_every_n_epochs': False,
      '--sample_prompts': '/content/prompt.txt',
      '--sample_sampler': 'euler',
  }

def extra_cp(TrainType, output_name, author,save_state, resume, wandbapi, sample_steps, runcode):
  if runcode:
    wandbapi = ''
  return {
      '--metadata_title': f'"{output_name}"',
      '--metadata_author': f'"{author}"',
      '--metadata_description': '"training by trainlora.vn"',
      '--metadata_license': False, # "MIT"
      '--metadata_tags': False, # "sdvn.me"
      '--no_metadata': False,
      '--save_n_epoch_ratio': False,
      '--save_last_n_epochs': False,
      '--save_last_n_epochs_state': False,
      '--save_last_n_steps': False,
      '--save_last_n_steps_state': False,
      '--save_state': save_state,
      '--save_state_on_train_end': True,
      '--resume': f'"{resume}"' if resume != '' else False, #Path
      '--log_with': 'wandb' if wandbapi!='' else False, # tensorboard,wandb,all
      '--wandb_run_name': f'"{TrainType}lora_{output_name}"' if wandbapi!='' else False,
      '--wandb_api_key': wandbapi if wandbapi!='' else False,
      '--sample_every_n_steps': False if sample_steps <= 0 else sample_steps,
      '--sample_at_first': False if sample_steps <= 0 else True,
      '--sample_every_n_epochs': False,
      '--sample_prompts': '/content/prompt.txt',
      '--sample_sampler': 'euler',
  }

def prompt(sample_prompt, sample_size, TrainFolder):
  prompt = f"""#prompt 1
{sample_prompt if sample_prompt != "" else random_sample(TrainFolder)} --w {sample_size.split(",")[0]} --h {sample_size.split(",")[1]}
#prompt 2
{random_sample(TrainFolder)} --w {sample_size.split(",")[0]} --h {sample_size.split(",")[1]}
#prompt 3
{random_sample(TrainFolder)} --w {sample_size.split(",")[0]} --h {sample_size.split(",")[1]}"""
  write_file('/content/prompt.txt',prompt)

def dataset_file(resolution, batch_size, image_dir, num_repeats, data_config, dataset):
  resolutions = resolution.split(',')
  index = 0
  for resolution in resolutions:
    if '-' in resolution:
      b = int(resolution.split('-')[1])
      r = int(resolution.split('-')[0])
    else:
      b = batch_size
      r = int(resolution)
    data_config["datasets"].append({
              "batch_size": b,
              "enable_bucket": True,
              "resolution": [r, r],
              "subsets": []
              })
    for dir in check_sub_dir(image_dir):
      data_config["datasets"][index]["subsets"] += [{
          "image_dir": dir,
          "num_repeats": repeat_dir(dir,num_repeats)
      }]
    index += 1

  with open(dataset, "w") as file:
      toml.dump(data_config, file)