In [36]:
import os, random, math
from pathlib import Path
from typing import Any, Dict
from datetime import timedelta
import argparse
import json
# ----------------------------------------------------
import torch
import matplotlib.pyplot as plt
from yaml import load, dump, Loader, Dumper
# ----------------------------------------------------
import diffusers
import transformers
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
    cast_training_params,
    compute_density_for_timestep_sampling,
    compute_loss_weighting_for_sd3,
)
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.utils import export_to_video, load_image, load_video
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from tqdm import tqdm

# ----------------------------------------------------
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import (
    DistributedDataParallelKwargs,
    InitProcessGroupKwargs,
    ProjectConfiguration,
    set_seed,
    gather_object,
)
# ----------------------------------------------------
from dataset import BucketSampler, PrecomputedDataset
from ltx_video_lora import *
# ----------------------------------------------------
from utils.file_utils import find_files, delete_files, string_to_filename
from utils.optimizer_utils import get_optimizer, gradient_norm
from utils.memory_utils import get_memory_statistics, free_memory, make_contiguous
from utils.torch_utils import unwrap_model, align_device_and_dtype
import logging

In [40]:
LOG_LEVEL = "DEBUG"
logger = get_logger("ltxtrainer")
logger.setLevel(LOG_LEVEL)

In [73]:
class State:
    # Training state
    seed: int = None
    model_name: str = None
    accelerator: Accelerator = None
    weight_dtype: torch.dtype = None
    train_epochs: int = None
    train_steps: int = None
    overwrote_max_train_steps: bool = False
    num_trainable_parameters: int = 0
    learning_rate: float = None
    train_batch_size: int = None
    generator: torch.Generator = None

    # Hub state
    repo_id: str = None
    # Artifacts state
    output_dir: str = None


class Trainer:
    def __init__(self, config_file) -> None:
        cd = load(open(config_file, "r"), Loader=Loader)
        cd.setdefault("train_steps", None)
        cd.setdefault("logging_dir", "logs")
        cd.setdefault("report_to", "none")
        cd.setdefault("dataset_file", None)
        cd.setdefault("pin_memory", True)
        cd.setdefault("allow_tf32", True)
        cd.setdefault("scale_lr", True)
        cd.setdefault("train_type", "lora") # or full
        cd.setdefault("optimizer_8bit", True)
        cd.setdefault("optimizer_torchao", False)
        cd.setdefault("caption_dropout_technique", "zero")
        # ----------------- optimizer params --------
        cd.setdefault("optimizer" "adamw")
        cd.setdefault("lr", float(1e-4))
        cd.setdefault("scale_lr",  False)
        cd.setdefault("lr_scheduler", "constant_with_warmup")
        cd.setdefault("lr_warmup_steps", 1000)
        cd.setdefault("lr_num_cycles", 1)
        cd.setdefault("lr_power",  1.0)
        cd.setdefault("beta1",  0.9)
        cd.setdefault("beta2",  0.95)
        cd.setdefault("beta3",  0.999)
        cd.setdefault("weight_decay",  0.0001)
        cd.setdefault("epsilon",  float(1e-8))
        cd.setdefault("max_grad_norm",  1.0)
        # ---------------- Diffusion arguments
        cd.setdefault("flow_resolution_shifting", False)
        cd.setdefault("flow_base_image_seq_len", 256)
        cd.setdefault("flow_max_image_seq_len", 4096)
        cd.setdefault("flow_base_shift", 0.5)
        cd.setdefault("flow_max_shift", 1.15)
        cd.setdefault("flow_shift", 1.0)
        cd.setdefault("flow_weighting_scheme", "none")
        cd.setdefault("flow_logit_mean", 0.0)
        cd.setdefault("flow_logit_std", 1.0)
        cd.setdefault("flow_mode_scale", 1.29)
        
        # cd.setdefault("enable_slicing", False)
        # cd.setdefault("enable_tiling", False)
        
        args = argparse.Namespace(**cd)
        args.lr = float(args.lr)
        args.epsilon = float(args.epsilon)
        args.weight_decay = float(args.weight_decay)
        args.target_modules = args.target_modules.split(" ")

        self.args = args
        self.state = State()

        # Tokenizers
        self.tokenizer = None
        # self.tokenizer_2 = None
        # self.tokenizer_3 = None
        # Text encoders
        self.text_encoder = None
        # self.text_encoder_2 = None
        # self.text_encoder_3 = None

        # Denoisers
        self.transformer = None
        self.unet = None

        # Autoencoders
        self.vae = None

        # Scheduler
        self.scheduler = None

        self._init_distributed()
        self._init_logging()
        self._init_directories_and_repositories()

        self.state.model_name = self.args.model_name

        self.model_config = LTX_VIDEO_T2V_LORA_CONFIG
        # self.model_config = get_config_from_model_name(self.args.model_name, self.args.training_type)
    
    def _init_distributed(self):
        logging_dir = Path(self.args.output_dir, self.args.logging_dir)
        project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
        init_process_group_kwargs = InitProcessGroupKwargs(
            backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
        )
        mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision
        report_to = None if self.args.report_to.lower() == "none" else self.args.report_to

        accelerator = Accelerator(
            project_config=project_config,
            gradient_accumulation_steps=self.args.gradient_accumulation_steps,
            mixed_precision=mixed_precision,
            log_with=report_to,
            kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
        )

        # Disable AMP for MPS.
        if torch.backends.mps.is_available():
            accelerator.native_amp = False

        self.state.accelerator = accelerator

        if self.args.seed is not None:
            self.state.seed = self.args.seed
            set_seed(self.args.seed)

        weight_dtype = torch.float32
        if self.state.accelerator.mixed_precision == "fp16":
            weight_dtype = torch.float16
        elif self.state.accelerator.mixed_precision == "bf16":
            weight_dtype = torch.bfloat16
            
        self.state.weight_dtype = weight_dtype
        
    def _init_logging(self):
        logging.basicConfig(
            format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
            datefmt="%m/%d/%Y %H:%M:%S",
            level=LOG_LEVEL,
        )
        if self.state.accelerator.is_local_main_process:
            transformers.utils.logging.set_verbosity_warning()
            diffusers.utils.logging.set_verbosity_info()
        else:
            transformers.utils.logging.set_verbosity_error()
            diffusers.utils.logging.set_verbosity_error()

        logger.info("Initialized Trainer")
        logger.info(self.state.accelerator.state, main_process_only=False)
        
    def _init_directories_and_repositories(self):
        if self.state.accelerator.is_main_process:
            self.args.output_dir = Path(self.args.output_dir)
            self.args.output_dir.mkdir(parents=True, exist_ok=True)
            self.state.output_dir = self.args.output_dir
    
    def prepare_dataset(self) -> None:
        logger.info("Initializing dataset and dataloader")

        self.dataset = PrecomputedDataset(
            data_dir=self.args.data_root,
            width=768,
            height=512,
            num_frames=49, 
        )
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=self.args.batch_size,
            # sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True),
            # collate_fn=self.model_config.get("collate_fn"),
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.pin_memory,
        )
    def prepare_models(self):
        logger.info("Initializing models")
        device = self.state.accelerator.device
        dtype = self.state.weight_dtype
        
        # >> we use precomputation so text encoder is not needed
        # cond_models = load_condition_models()
        # tokenizer, text_encoder = cond_models["tokenizer"], cond_models["text_encoder"]
        # self.text_encoder = text_encoder.to(device, dtype=dtype)
        
        # self.vae = load_latent_models()["vae"].to(device, dtype=dtype)

        # if self.vae is not None:
        #     if self.args.enable_slicing:
        #         self.vae.enable_slicing()
        #     if self.args.enable_tiling:
        #         self.vae.enable_tiling()
        diff_models = load_diffusion_models()
        self.transformer = diff_models["transformer"].to(device, dtype=dtype)
        self.scheduler = diff_models["scheduler"]
        self.transformer_config = self.transformer.config if self.transformer is not None else None

    def prepare_trainable_parameters(self):
        logger.info("Initializing trainable parameters")
        
        if self.args.train_type == "lora":
            components_to_disable_grads = [ self.transformer ] # self.vae 
        else:
            components_to_disable_grads = []
            
        for component in components_to_disable_grads:
            if component is not None:
                component.requires_grad_(False)

        if torch.backends.mps.is_available() and self.state.weight_dtype == torch.bfloat16:
            # due to pytorch#99272, MPS does not yet support bfloat16.
            raise ValueError(
                "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
            )

        if self.args.gradient_checkpointing:
            self.transformer.enable_gradient_checkpointing()

        if self.args.train_type == "lora":
            transformer_lora_config = LoraConfig(
                r=self.args.rank,
                lora_alpha=self.args.lora_alpha,
                init_lora_weights=True,
                target_modules=self.args.target_modules,
            )
            self.transformer.add_adapter(transformer_lora_config)

        # TODO: refactor
        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
        def save_model_hook(models, weights, output_dir):
            if self.state.accelerator.is_main_process:
                transformer_lora_layers_to_save = None

                for model in models:
                    if isinstance(
                        unwrap_model(self.state.accelerator, model),
                        type(unwrap_model(self.state.accelerator, self.transformer)),
                    ):
                        model = unwrap_model(self.state.accelerator, model)
                        transformer_lora_layers_to_save = get_peft_model_state_dict(model)
                    else:
                        raise ValueError(f"Unexpected save model: {model.__class__}")

                    # make sure to pop weight so that corresponding model is not saved again
                    if weights:
                        weights.pop()

                self.model_config["pipeline_cls"].save_lora_weights(
                    output_dir,
                    transformer_lora_layers=transformer_lora_layers_to_save,
                )

        def load_model_hook(models, input_dir):
            transformer_ = self.model_config["pipeline_cls"].from_pretrained(
                self.args.pretrained_model_name_or_path, subfolder="transformer"
            )
            transformer_.add_adapter(transformer_lora_config)

            lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(input_dir)

            transformer_state_dict = {
                f'{k.replace("transformer.", "")}': v
                for k, v in lora_state_dict.items()
                if k.startswith("transformer.")
            }
            incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
            if incompatible_keys is not None:
                # check only for unexpected keys
                unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
                if unexpected_keys:
                    logger.warning(
                        f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
                        f" {unexpected_keys}. "
                    )

            # Make sure the trainable params are in float32. This is again needed since the base models
            # are in `weight_dtype`. More details:
            # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
            if self.args.mixed_precision == "fp16":
                # only upcast trainable parameters (LoRA) into fp32
                cast_training_params([transformer_])

        self.state.accelerator.register_save_state_pre_hook(save_model_hook)
        self.state.accelerator.register_load_state_pre_hook(load_model_hook)

        # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
        if self.args.allow_tf32 and torch.cuda.is_available():
            torch.backends.cuda.matmul.allow_tf32 = True
            
    def prepare_optimizer(self):
        logger.info("Initializing optimizer and lr scheduler")

        self.state.train_epochs = self.args.train_epochs
        self.state.train_steps = self.args.train_steps

        # Make sure the trainable params are in float32
        if self.args.mixed_precision == "fp16":
            # only upcast trainable parameters (LoRA) into fp32
            cast_training_params([self.transformer], dtype=torch.float32)

        self.state.learning_rate = self.args.lr
        if self.args.scale_lr:
            self.state.learning_rate = (
                self.state.learning_rate
                * self.args.gradient_accumulation_steps
                * self.args.batch_size
                * self.state.accelerator.num_processes
            )

        transformer_lora_parameters = list(filter(lambda p: p.requires_grad, self.transformer.parameters()))
        transformer_parameters_with_lr = {
            "params": transformer_lora_parameters,
            "lr": self.state.learning_rate,
        }
        params_to_optimize = [transformer_parameters_with_lr]
        self.state.num_trainable_parameters = sum(p.numel() for p in transformer_lora_parameters)

        # TODO(aryan): add deepspeed support
        optimizer = get_optimizer(
            params_to_optimize=params_to_optimize,
            optimizer_name=self.args.optimizer,
            learning_rate=self.args.lr,
            beta1=self.args.beta1,
            beta2=self.args.beta2,
            beta3=self.args.beta3,
            epsilon=self.args.epsilon,
            weight_decay=self.args.weight_decay,
            use_8bit = self.args.optimizer_8bit,
            use_torchao = self.args.optimizer_torchao,
        )

        num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps)
        if self.state.train_steps is None:
            self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
            self.state.overwrote_max_train_steps = True

        lr_scheduler = get_scheduler(
            name=self.args.lr_scheduler,
            optimizer=optimizer,
            num_warmup_steps=self.args.lr_warmup_steps * self.state.accelerator.num_processes,
            num_training_steps=self.state.train_steps * self.state.accelerator.num_processes,
            num_cycles=self.args.lr_num_cycles,
            power=self.args.lr_power,
        )

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        
    def prepare_for_training(self):
        self.transformer, self.optimizer, self.dataloader, self.lr_scheduler = self.state.accelerator.prepare(
            self.transformer, self.optimizer, self.dataloader, self.lr_scheduler
        )

        # We need to recalculate our total training steps as the size of the training dataloader may have changed.
        # num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps)
        # if self.state.overwrote_max_train_steps:
        #     self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
        # Afterwards we recalculate our number of training epochs
        # self.state.train_epochs = math.ceil(self.state.train_steps / num_update_steps_per_epoch)
        
    def prepare_trackers(self):
        logger.info("Initializing trackers")

        tracker_name = self.args.tracker_name or "ltx_train"
        self.state.accelerator.init_trackers(tracker_name, config=self.args.__dict__)
        
    def train(self):
        logger.info("Starting training")

        memory_statistics = get_memory_statistics()
        logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")

        self.state.train_batch_size = (
            self.args.batch_size * self.state.accelerator.num_processes * self.args.gradient_accumulation_steps
        )
        info = {
            "trainable parameters": self.state.num_trainable_parameters,
            "total samples": len(self.dataset),
            "train epochs": self.state.train_epochs,
            "train steps": self.state.train_steps,
            "batches per device": self.args.batch_size,
            "total batches observed per epoch": len(self.dataloader),
            "train batch size": self.state.train_batch_size,
            "gradient accumulation steps": self.args.gradient_accumulation_steps,
        }
        logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
        
        global_step = 0
        first_epoch = 0
        initial_global_step = 0
        progress_bar = tqdm(
            range(0, self.state.train_steps),
            initial=initial_global_step,
            desc="Training steps",
            disable=not self.state.accelerator.is_local_main_process,
        )

        accelerator = self.state.accelerator
        weight_dtype = self.state.weight_dtype
        scheduler_sigmas = self.scheduler.sigmas.clone().to(device=accelerator.device, dtype=weight_dtype)
        generator = torch.Generator(device=accelerator.device)
        if self.args.seed is not None:
            generator = generator.manual_seed(self.args.seed)
        self.state.generator = generator

        for epoch in range(first_epoch, self.state.train_epochs):
            logger.debug(f"Starting epoch ({epoch + 1}/{self.state.train_epochs})")

            self.transformer.train()

            for step, batch in enumerate(self.dataloader):
                logger.debug(f"Starting step {step + 1}")
                logs = {}

                with accelerator.accumulate([ self.transformer ]):
                    latents, prompt_embeds, prompt_attention_mask, caption, meta_info = batch
                    
                    # latent_conditions = batch["latent_conditions"]
                    # text_conditions = batch["text_conditions"]
                    # latent_conditions["latents"] = DiagonalGaussianDistribution(
                    #     latent_conditions["latents"]
                    # ).sample(generator)
                    # print("--", latents.shape)
                    # latents = DiagonalGaussianDistribution(latents).sample(generator)
                    # print("Diagonal", latents.shape)
                    
                    # if "post_latent_preparation" in self.model_config.keys():
                    #     latent_conditions = self.model_config["post_latent_preparation"](**latent_conditions)
                        
                    # align_device_and_dtype(latents, accelerator.device, weight_dtype)
                    # align_device_and_dtype(text_conditions, accelerator.device, weight_dtype)
                    latents = latents.to(accelerator.device, dtype=weight_dtype).contiguous()
                    prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype).contiguous()
                    prompt_attention_mask = prompt_attention_mask.to(accelerator.device, dtype=weight_dtype)
                    batch_size = latents.shape[0]

                    # latent_conditions = make_contiguous(latent_conditions)
                    # text_conditions = make_contiguous(text_conditions)
                    

                    if self.args.caption_dropout_technique == "zero":
                        if random.random() < self.args.caption_dropout_p:
                            # text_conditions["prompt_embeds"].fill_(0)
                            prompt_embeds.fill_(0)
                            prompt_attention_mask.fill_(False)

                            # if "pooled_prompt_embeds" in text_conditions:
                            #     text_conditions["pooled_prompt_embeds"].fill_(0)

                    # These weighting schemes use a uniform timestep sampling and instead post-weight the loss
                    weights = compute_density_for_timestep_sampling(
                        weighting_scheme=self.args.flow_weighting_scheme,
                        batch_size=batch_size,
                        logit_mean=self.args.flow_logit_mean,
                        logit_std=self.args.flow_logit_std,
                        mode_scale=self.args.flow_mode_scale,
                    )
                    indices = (weights * self.scheduler.config.num_train_timesteps).long()
                    sigmas = scheduler_sigmas[indices]
                    timesteps = (sigmas * 1000.0).long()
                    print("sigmas", sigmas.shape, "weights", weights, "timesteps", timesteps.shape)
                    
                    noise = torch.randn(
                        latents.shape,
                        generator=generator,
                        device=accelerator.device,
                        dtype=weight_dtype,
                    )
                    ss= sigmas.reshape(-1, 1, 1).repeat(1, 1, latents.size(-1))
                    print("ss", ss.shape, latents.shape)
                    noisy_latents = (1.0 - ss) * latents + ss * noise

                    # These weighting schemes use a uniform timestep sampling and instead post-weight the loss
                    weights = compute_loss_weighting_for_sd3(
                        weighting_scheme=self.args.flow_weighting_scheme, sigmas=sigmas
                    ).reshape(-1, 1, 1).repeat(1, 1, latents.size(-1))
                    print("weights", weights)
                    pred = self.model_config["forward_pass"](
                        transformer=self.transformer, 
                        timesteps=timesteps, 
                        latents=latents,
                        noisy_latents=noisy_latents,
                        prompt_embeds=prompt_embeds, 
                        prompt_attention_mask=prompt_attention_mask,
                        num_frames=meta_info["num_frames"][0],
                        height=meta_info["height"][0],
                        width=meta_info["width"][0],
                    )
                    target = noise - latents

                    loss = weights.float() * (pred["latents"].float() - target.float()).pow(2)
                    # Average loss across channel dimension
                    loss = loss.mean(list(range(1, loss.ndim)))
                    # Average loss across batch dimension
                    loss = loss.mean()
                    accelerator.backward(loss)

                    if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED:
                        grad_norm = accelerator.clip_grad_norm_(self.transformer.parameters(), self.args.max_grad_norm)
                        logs["grad_norm"] = grad_norm

                    self.optimizer.step()
                    self.lr_scheduler.step()
                    self.optimizer.zero_grad()

                # Checks if the accelerator has performed an optimization step behind the scenes
                if accelerator.sync_gradients:
                    progress_bar.update(1)
                    global_step += 1

                    # Checkpointing
                    if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
                        if global_step % self.args.checkpointing_steps == 0:
                            # before saving state, check if this save would set us over the `checkpointing_limit`
                            if self.args.checkpointing_limit is not None:
                                checkpoints = find_files(self.args.output_dir, prefix="checkpoint")

                                # before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints
                                if len(checkpoints) >= self.args.checkpointing_limit:
                                    num_to_remove = len(checkpoints) - self.args.checkpointing_limit + 1
                                    checkpoints_to_remove = checkpoints[0:num_to_remove]
                                    delete_files(checkpoints_to_remove)

                            logger.info(f"Checkpointing at step {global_step}")
                            save_path = os.path.join(self.args.output_dir, f"checkpoint-{global_step}")
                            accelerator.save_state(save_path)
                            logger.info(f"Saved state to {save_path}")

                # Maybe run validation
                should_run_validation = (
                    self.args.validation_steps is not None
                    and global_step % self.args.validation_steps == 0
                )
                if should_run_validation:
                    self.validate(global_step)

                logs = {"loss": loss.detach().item(), "lr": self.lr_scheduler.get_last_lr()[0]}
                progress_bar.set_postfix(logs)
                accelerator.log(logs, step=global_step)

                if global_step >= self.state.train_steps:
                    break

            memory_statistics = get_memory_statistics()
            logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")

            # Maybe run validation
            # should_run_validation = (
            #     self.args.validation_every_n_epochs is not None
            #     and (epoch + 1) % self.args.validation_every_n_epochs == 0
            # )
            # if should_run_validation:
            #     self.validate(global_step)

        accelerator.wait_for_everyone()
        if accelerator.is_main_process:
            self.transformer = unwrap_model(accelerator, self.transformer)
            dtype = (
                torch.float16
                if self.args.mixed_precision == "fp16"
                else torch.bfloat16
                if self.args.mixed_precision == "bf16"
                else torch.float32
            )
            self.transformer = self.transformer.to(dtype)
            transformer_lora_layers = get_peft_model_state_dict(self.transformer)

            self.model_config["pipeline_cls"].save_lora_weights(
                save_directory=self.args.output_dir,
                transformer_lora_layers=transformer_lora_layers,
            )

        del self.transformer, self.scheduler
        free_memory()
        memory_statistics = get_memory_statistics()
        logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")

        accelerator.end_training()
    
    def evaluate(self):
        logger.info("Starting validation")
        pass
        
# trainer = Trainer("ltx_training/configs/ltx.yaml")

In [None]:
trainer = Trainer("configs/ltx.yaml")

trainer.prepare_dataset()
trainer.prepare_models()
trainer.prepare_trainable_parameters()
trainer.prepare_optimizer()
trainer.prepare_for_training()
trainer.prepare_trackers()
trainer.train()
trainer.evaluate()

In [27]:
aa = torch.sigmoid(torch.randn(10, 1, 1))
bb = torch.rand(10, 1, 1)
print(aa.mean(), aa.std())
print(bb.mean(), bb.std())

tensor(0.4946) tensor(0.1420)
tensor(0.6169) tensor(0.2430)


In [5]:
print(logging.__file__)

/home/eisneim/.conda/envs/_learn/lib/python3.10/logging/__init__.py


In [54]:
dd = next(iter(trainer.dataloader))

In [68]:
trainer.state.learning_rate

6e-05

In [69]:
trainer.args.lr

3e-05

In [29]:
ab.c

0.001

In [82]:
aa = "343"
isinstance(aa, str)

True

In [76]:
aa.setdefault("c", 1e-3)

AttributeError: 'Tensor' object has no attribute 'setdefault'

In [27]:
ab = argparse.Namespace(**aa)
ab

Namespace(a=1, b=1e-05, c=0.001)

In [14]:
file = "/home/eisneim/www/ml/video_gen/ltx_training/data/ltxv_disney_49x768x512-normed/VID_20230606_101912.pt"
data = torch.load(file)
data

  data = torch.load(file)


{'latents': tensor([[[ 0.4414, -0.3730, -0.3945,  ..., -0.5742,  0.0530, -0.2520],
          [ 0.4941, -0.3555,  0.2021,  ..., -0.8242,  0.3711, -0.5742],
          [ 0.4316, -0.4570, -0.0957,  ..., -1.3047,  0.0991, -0.1445],
          ...,
          [-0.7148, -0.4668, -0.1191,  ...,  0.2891, -2.0312, -0.0093],
          [-0.0302, -0.6680, -0.7344,  ..., -0.0815, -1.8281, -0.2090],
          [ 1.0469, -0.8047, -0.5273,  ..., -0.6992, -1.9766, -0.1592]]],
        dtype=torch.bfloat16),
 'embedds': tensor([[[ 0.1602, -0.0073, -0.1445,  ..., -0.1543, -0.1162, -0.0337],
          [ 0.1260,  0.0156,  0.1270,  ..., -0.1709, -0.0042, -0.1816],
          [-0.0280,  0.0708,  0.1074,  ...,  0.0030,  0.0339, -0.0942],
          ...,
          [ 0.0771, -0.0098, -0.0084,  ..., -0.1357, -0.0674, -0.0043],
          [ 0.0388, -0.0135,  0.0410,  ..., -0.1011, -0.2236, -0.0259],
          [ 0.0101,  0.1157, -0.1177,  ..., -0.0334, -0.0569,  0.0130]]],
        dtype=torch.bfloat16),
 'masks': tensor([

In [47]:
aa = torch.tensor([0.1, 0.2])
aa = aa.unsqueeze(dim=1).repeat(1, 4)
print(aa, aa.shape)

tensor([[0.1000, 0.1000, 0.1000, 0.1000],
        [0.2000, 0.2000, 0.2000, 0.2000]]) torch.Size([2, 4])


In [1]:
576/32

18.0

In [2]:
1024 / 32

32.0