### Data

In [1]:
import sys
import os

path = os.path.abspath(os.path.join('..'))
if path not in sys.path:
    sys.path.append(path)

In [2]:
from diffusion.data_loaders.backflip_dataset import BackflipMotionDataset
dataset = BackflipMotionDataset("/home/kenji/Fyp/DeepMimic_mujoco/diffusion/data/motions/humanoid3d_backflip.txt")
len(dataset), dataset[0].shape

(29, torch.Size([29, 69]))

### Model

In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_buffer('pe', pe)

    def forward(self, x):
        # not used in the final model
        x = x + self.pe[:x.shape[0], :]
        return self.dropout(x)


class TimestepEmbedder(nn.Module):
    def __init__(self, latent_dim, sequence_pos_encoder):
        super().__init__()
        self.latent_dim = latent_dim
        self.sequence_pos_encoder = sequence_pos_encoder

        time_embed_dim = self.latent_dim
        self.time_embed = nn.Sequential(
            nn.Linear(self.latent_dim, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        )

    def forward(self, timesteps):
        return self.time_embed(self.sequence_pos_encoder.pe[timesteps])
    
class MotionTransformer(nn.Module):
    def __init__(self, nfeats, latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1, activation="gelu"):
        super(MotionTransformer, self).__init__()
        
        self.nfeats = nfeats
        self.latent_dim = latent_dim
        self.ff_size = ff_size  
        self.dropout = dropout

        self.inputEmbedding = nn.Linear(self.nfeats, self.latent_dim)
        self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
        self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)

        # Transformer Encoder
        encoder_layers = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=num_heads, 
                                                    dim_feedforward=ff_size, dropout=dropout, activation=activation, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)

        # Output Linear Layer
        self.outputEmbedding = nn.Linear(self.latent_dim, self.nfeats)

    def forward(self, x: torch.Tensor, timesteps, y=None):
        """
        x: [batch_size, max_frames, n_feats], denoted x_t in the paper
        timesteps: [batch_size] (int)
        """
        # x: [batch_size, seq_len, nfeats]
        emb = self.embed_timestep(timesteps)  # [bs, seq_len, time_embed_dim]
        # print("Emb", emb.shape)

        # Input process
        x = x.float()
        # print(x.dtype, x.shape)
        x = self.inputEmbedding(x)  # [bs, seq_len, d]
        # print("Input Embedding", x.shape)

        # Transformer Encoder
        # adding the timestep embed
        xseq = torch.cat((emb, x), axis=1)  # [bs, n_frames+1, d]
        # print("Concat x and zkx", xseq.shape)

        xseq = self.sequence_pos_encoder(xseq)  # [bs, n_frames+1, d]
        # print("Sequence Pos Encoder", xseq.shape)
        
        output = self.transformer_encoder(xseq)[:, 1:, :]  # , src_key_padding_mask=~maskseq)  # [bs, n_frames, d]
        # print("Transformer Encoder", output.shape)

        # Output Linear
        output = self.outputEmbedding(output)  # [bs, n_frames, n_feats]
        # print("Output Embedding", output.shape)

        return output

In [23]:
len(dataset), dataset[0].shape

(29, torch.Size([29, 69]))

In [29]:
import torch.optim as optim
from torch.utils.data import DataLoader

batch_size = 1
dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=True,
        num_workers=8, drop_last=True)

nfeats = dataset[0].shape[1]

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MotionTransformer(nfeats=nfeats, latent_dim=32, ff_size=128, num_layers=8, num_heads=4, dropout=0.1, activation="gelu").to(device)

In [30]:
for it, batch in enumerate(dataloader):
    batch = batch.to(device)
    print(batch.shape)


torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])
torch.Size([1, 29, 69])


In [31]:
from diffusion.diffusion import gaussian_diffusion as gd
from diffusion.diffusion.respace import SpacedDiffusion, space_timesteps

def create_gaussian_diffusion(
        diffusion_steps, # number eg 1000
        noise_schedule, # can be 'linear', 'cosine'
        sigma_small, # default True
        lambda_vel, lambda_rcxyz, lambda_fc # for geometric loss, we don't have fc, default 1 for rest
        ):
    # default params
    predict_xstart = True  # we always predict x_start (a.k.a. x0), that's our deal!
    steps = diffusion_steps
    scale_beta = 1.  # no scaling
    timestep_respacing = ''  # can be used for ddim sampling, we don't use it.
    learn_sigma = False
    rescale_timesteps = False

    betas = gd.get_named_beta_schedule(noise_schedule, steps, scale_beta)
    loss_type = gd.LossType.MSE

    if not timestep_respacing:
        timestep_respacing = [steps]

    return SpacedDiffusion(
        use_timesteps=space_timesteps(steps, timestep_respacing),
        betas=betas,
        model_mean_type=(
            gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
        ),
        model_var_type=(
            (
                gd.ModelVarType.FIXED_LARGE
                if not sigma_small
                else gd.ModelVarType.FIXED_SMALL
            )
            if not learn_sigma
            else gd.ModelVarType.LEARNED_RANGE
        ),
        loss_type=loss_type,
        rescale_timesteps=rescale_timesteps,
        lambda_vel=lambda_vel,
        lambda_rcxyz=lambda_rcxyz,
        lambda_fc=lambda_fc,
    )

In [32]:
model = MotionTransformer(nfeats=nfeats, latent_dim=32, ff_size=128, num_layers=8, num_heads=4, dropout=0.1, activation="gelu").to(device)
diffusion = create_gaussian_diffusion(diffusion_steps=1000, noise_schedule="cosine", sigma_small=True, lambda_vel=1, lambda_rcxyz=1, lambda_fc=1)

In [43]:
class DefaultArgs:
    def __init__(self, save_dir, model_path, eval_model_path):
        # Base options
        self.cuda = True
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.seed = 10
        self.batch_size = 1

        # Diffusion options
        self.noise_schedule = 'cosine'
        self.diffusion_steps = 1000
        self.sigma_small = True

        # Model options
        self.arch = 'trans_enc'
        self.emb_trans_dec = False
        self.layers = 8
        self.latent_dim = 512
        self.cond_mask_prob = 0.1
        self.lambda_rcxyz = 0.0
        self.lambda_vel = 0.0
        self.lambda_fc = 0.0
        self.unconstrained = False  # This is inferred from the 'action' parameter

        # Data options
        self.dataset = 'humanml'
        self.data_dir = ""

        # Training options
        self.save_dir = save_dir
        self.overwrite = False
        self.train_platform_type = 'NoPlatform'
        self.lr = 1e-4
        self.weight_decay = 0.0
        self.lr_anneal_steps = 0
        self.eval_batch_size = 16
        self.eval_split = 'test'
        self.eval_during_training = False
        self.eval_rep_times = 3
        self.eval_num_samples = 1000
        self.log_interval = 250
        self.save_interval = 250
        self.num_steps = 2500
        # self.num_frames = 29
        self.resume_checkpoint = ""

        # Sampling options
        self.model_path = model_path
        self.output_dir = ''
        self.num_samples = 10
        self.num_repetitions = 3
        self.guidance_param = 2.5

        # Generate options
        self.motion_length = 6.0
        self.input_text = ''
        self.action_file = ''
        self.text_prompt = ''
        self.action_name = ''

        # Edit options
        self.edit_mode = 'in_between'
        self.text_condition = ''
        self.prefix_end = 0.25
        self.suffix_start = 0.75

        # Evaluation options
        self.eval_model_path = eval_model_path
        self.eval_mode = 'wo_mm'
        self.eval_guidance_param = 2.5


In [44]:
args = DefaultArgs(save_dir="/home/kenji/Fyp/DeepMimic_mujoco/diffusion/logs/", model_path="/home/kenji/Fyp/DeepMimic_mujoco/diffusion/logs/model.pt", eval_model_path="/home/kenji/Fyp/DeepMimic_mujoco/diffusion/logs/model.pt")
args.device

'cuda:0'

In [45]:
len(dataloader)

29

In [46]:
from train.training_loop import TrainLoop
TrainLoop(args, None , model, diffusion, dataloader).run_loop()


Starting epoch 0/87


  0%|          | 0/29 [00:00<?, ?it/s]

-------------------------
| grad_norm  | 3e+03    |
| loss       | 3.68e+03 |
| loss_q0    | 3.3e+03  |
| loss_q1    | 3.32e+03 |
| loss_q2    | 3.44e+03 |
| loss_q3    | 4.61e+03 |
| param_norm | 39.1     |
| samples    | 1        |
| step       | 0        |
-------------------------
step[0]: loss[3684.26591]
saving model...


 10%|█         | 3/29 [00:00<00:03,  7.83it/s]

Skipping evaluation for now.


100%|██████████| 29/29 [00:01<00:00, 17.12it/s]


Starting epoch 1/87


100%|██████████| 29/29 [00:01<00:00, 22.17it/s]


Starting epoch 2/87


100%|██████████| 29/29 [00:01<00:00, 21.38it/s]


Starting epoch 3/87


100%|██████████| 29/29 [00:01<00:00, 20.23it/s]


Starting epoch 4/87


100%|██████████| 29/29 [00:01<00:00, 20.52it/s]


Starting epoch 5/87


100%|██████████| 29/29 [00:01<00:00, 21.64it/s]


Starting epoch 6/87


100%|██████████| 29/29 [00:01<00:00, 23.49it/s]


Starting epoch 7/87


100%|██████████| 29/29 [00:01<00:00, 23.74it/s]


Starting epoch 8/87


 55%|█████▌    | 16/29 [00:00<00:00, 23.92it/s]

-------------------------
| grad_norm  | 3.87e+03 |
| loss       | 4e+03    |
| loss_q0    | 3.08e+03 |
| loss_q1    | 3.12e+03 |
| loss_q2    | 3.4e+03  |
| loss_q3    | 6.19e+03 |
| param_norm | 39.1     |
| samples    | 251      |
| step       | 250      |
-------------------------
step[250]: loss[3997.26521]
saving model...


 76%|███████▌  | 22/29 [00:01<00:00, 23.16it/s]

Skipping evaluation for now.


100%|██████████| 29/29 [00:01<00:00, 21.49it/s]


Starting epoch 9/87


100%|██████████| 29/29 [00:01<00:00, 22.42it/s]


Starting epoch 10/87


100%|██████████| 29/29 [00:01<00:00, 19.90it/s]


Starting epoch 11/87


100%|██████████| 29/29 [00:01<00:00, 21.84it/s]


Starting epoch 12/87


100%|██████████| 29/29 [00:01<00:00, 20.82it/s]


Starting epoch 13/87


100%|██████████| 29/29 [00:01<00:00, 20.79it/s]


Starting epoch 14/87


100%|██████████| 29/29 [00:01<00:00, 22.11it/s]


Starting epoch 15/87


100%|██████████| 29/29 [00:01<00:00, 20.22it/s]


Starting epoch 16/87


100%|██████████| 29/29 [00:01<00:00, 20.98it/s]


Starting epoch 17/87


 21%|██        | 6/29 [00:00<00:01, 15.93it/s]

-------------------------
| grad_norm  | 3.95e+03 |
| loss       | 3.68e+03 |
| loss_q0    | 2.7e+03  |
| loss_q1    | 2.75e+03 |
| loss_q2    | 3.01e+03 |
| loss_q3    | 5.93e+03 |
| param_norm | 39.2     |
| samples    | 501      |
| step       | 500      |
-------------------------
step[500]: loss[3683.80894]
saving model...


 34%|███▍      | 10/29 [00:00<00:01, 16.17it/s]

Skipping evaluation for now.


100%|██████████| 29/29 [00:01<00:00, 18.41it/s]


Starting epoch 18/87


100%|██████████| 29/29 [00:01<00:00, 20.39it/s]


Starting epoch 19/87


100%|██████████| 29/29 [00:01<00:00, 22.22it/s]


Starting epoch 20/87


100%|██████████| 29/29 [00:01<00:00, 21.30it/s]


Starting epoch 21/87


100%|██████████| 29/29 [00:01<00:00, 21.83it/s]


Starting epoch 22/87


100%|██████████| 29/29 [00:01<00:00, 21.03it/s]


Starting epoch 23/87


100%|██████████| 29/29 [00:01<00:00, 21.08it/s]


Starting epoch 24/87


100%|██████████| 29/29 [00:01<00:00, 22.33it/s]


Starting epoch 25/87


 86%|████████▌ | 25/29 [00:01<00:00, 23.92it/s]

-------------------------
| grad_norm  | 3.89e+03 |
| loss       | 3.31e+03 |
| loss_q0    | 2.4e+03  |
| loss_q1    | 2.45e+03 |
| loss_q2    | 2.68e+03 |
| loss_q3    | 5.57e+03 |
| param_norm | 39.2     |
| samples    | 751      |
| step       | 750      |
-------------------------
step[750]: loss[3314.02574]
saving model...


100%|██████████| 29/29 [00:01<00:00, 20.24it/s]


Skipping evaluation for now.
Starting epoch 26/87


100%|██████████| 29/29 [00:01<00:00, 21.18it/s]


Starting epoch 27/87


100%|██████████| 29/29 [00:01<00:00, 19.87it/s]


Starting epoch 28/87


100%|██████████| 29/29 [00:01<00:00, 21.53it/s]


Starting epoch 29/87


100%|██████████| 29/29 [00:01<00:00, 22.07it/s]


Starting epoch 30/87


100%|██████████| 29/29 [00:01<00:00, 20.91it/s]


Starting epoch 31/87


100%|██████████| 29/29 [00:01<00:00, 21.61it/s]


Starting epoch 32/87


100%|██████████| 29/29 [00:01<00:00, 20.33it/s]


Starting epoch 33/87


100%|██████████| 29/29 [00:01<00:00, 21.13it/s]


Starting epoch 34/87


 45%|████▍     | 13/29 [00:00<00:00, 22.21it/s]

-------------------------
| grad_norm  | 3.56e+03 |
| loss       | 2.82e+03 |
| loss_q0    | 2.13e+03 |
| loss_q1    | 2.16e+03 |
| loss_q2    | 2.4e+03  |
| loss_q3    | 4.85e+03 |
| param_norm | 39.3     |
| samples    | 1e+03    |
| step       | 1e+03    |
-------------------------
step[1000]: loss[2824.55610]
saving model...


 66%|██████▌   | 19/29 [00:00<00:00, 22.43it/s]

Skipping evaluation for now.


100%|██████████| 29/29 [00:01<00:00, 20.86it/s]


Starting epoch 35/87


100%|██████████| 29/29 [00:01<00:00, 20.35it/s]


Starting epoch 36/87


100%|██████████| 29/29 [00:01<00:00, 20.88it/s]


Starting epoch 37/87


100%|██████████| 29/29 [00:01<00:00, 21.74it/s]


Starting epoch 38/87


100%|██████████| 29/29 [00:01<00:00, 20.78it/s]


Starting epoch 39/87


100%|██████████| 29/29 [00:01<00:00, 21.93it/s]


Starting epoch 40/87


100%|██████████| 29/29 [00:01<00:00, 20.25it/s]


Starting epoch 41/87


100%|██████████| 29/29 [00:01<00:00, 20.82it/s]


Starting epoch 42/87


100%|██████████| 29/29 [00:01<00:00, 22.19it/s]


Starting epoch 43/87


  3%|▎         | 1/29 [00:00<00:05,  5.30it/s]

-------------------------
| grad_norm  | 4.17e+03 |
| loss       | 2.96e+03 |
| loss_q0    | 1.93e+03 |
| loss_q1    | 1.96e+03 |
| loss_q2    | 2.2e+03  |
| loss_q3    | 5.31e+03 |
| param_norm | 39.4     |
| samples    | 1.25e+03 |
| step       | 1.25e+03 |
-------------------------
step[1250]: loss[2960.12324]
saving model...


 21%|██        | 6/29 [00:00<00:01, 14.12it/s]

Skipping evaluation for now.


100%|██████████| 29/29 [00:01<00:00, 18.87it/s]


Starting epoch 44/87


100%|██████████| 29/29 [00:01<00:00, 21.47it/s]


Starting epoch 45/87


100%|██████████| 29/29 [00:01<00:00, 20.83it/s]


Starting epoch 46/87


100%|██████████| 29/29 [00:01<00:00, 21.27it/s]


Starting epoch 47/87


100%|██████████| 29/29 [00:01<00:00, 19.60it/s]


Starting epoch 48/87


100%|██████████| 29/29 [00:01<00:00, 22.21it/s]


Starting epoch 49/87


100%|██████████| 29/29 [00:01<00:00, 19.92it/s]


Starting epoch 50/87


100%|██████████| 29/29 [00:01<00:00, 21.42it/s]


Starting epoch 51/87


 66%|██████▌   | 19/29 [00:00<00:00, 23.51it/s]

-------------------------
| grad_norm  | 4.46e+03 |
| loss       | 2.88e+03 |
| loss_q0    | 1.74e+03 |
| loss_q1    | 1.76e+03 |
| loss_q2    | 2.03e+03 |
| loss_q3    | 5.46e+03 |
| param_norm | 39.4     |
| samples    | 1.5e+03  |
| step       | 1.5e+03  |
-------------------------
step[1500]: loss[2881.65144]
saving model...


 86%|████████▌ | 25/29 [00:01<00:00, 21.70it/s]

Skipping evaluation for now.


100%|██████████| 29/29 [00:01<00:00, 20.88it/s]


Starting epoch 52/87


100%|██████████| 29/29 [00:01<00:00, 20.62it/s]


Starting epoch 53/87


100%|██████████| 29/29 [00:01<00:00, 20.91it/s]


Starting epoch 54/87


100%|██████████| 29/29 [00:01<00:00, 21.42it/s]


Starting epoch 55/87


100%|██████████| 29/29 [00:01<00:00, 19.93it/s]


Starting epoch 56/87


100%|██████████| 29/29 [00:01<00:00, 21.84it/s]


Starting epoch 57/87


100%|██████████| 29/29 [00:01<00:00, 20.31it/s]


Starting epoch 58/87


100%|██████████| 29/29 [00:01<00:00, 19.95it/s]


Starting epoch 59/87


100%|██████████| 29/29 [00:01<00:00, 21.40it/s]


Starting epoch 60/87


 31%|███       | 9/29 [00:00<00:01, 18.91it/s]

-------------------------
| grad_norm  | 3.94e+03 |
| loss       | 2.45e+03 |
| loss_q0    | 1.57e+03 |
| loss_q1    | 1.6e+03  |
| loss_q2    | 1.84e+03 |
| loss_q3    | 5.34e+03 |
| param_norm | 39.5     |
| samples    | 1.75e+03 |
| step       | 1.75e+03 |
-------------------------
step[1750]: loss[2446.06800]
saving model...


 48%|████▊     | 14/29 [00:00<00:00, 19.42it/s]

Skipping evaluation for now.


100%|██████████| 29/29 [00:01<00:00, 19.28it/s]


Starting epoch 61/87


100%|██████████| 29/29 [00:01<00:00, 21.38it/s]


Starting epoch 62/87


100%|██████████| 29/29 [00:01<00:00, 20.01it/s]


Starting epoch 63/87


100%|██████████| 29/29 [00:01<00:00, 22.01it/s]


Starting epoch 64/87


100%|██████████| 29/29 [00:01<00:00, 20.58it/s]


Starting epoch 65/87


100%|██████████| 29/29 [00:01<00:00, 22.45it/s]


Starting epoch 66/87


100%|██████████| 29/29 [00:01<00:00, 20.57it/s]


Starting epoch 67/87


100%|██████████| 29/29 [00:01<00:00, 21.28it/s]


Starting epoch 68/87


 93%|█████████▎| 27/29 [00:01<00:00, 22.97it/s]

-------------------------
| grad_norm  | 3.8e+03  |
| loss       | 2.18e+03 |
| loss_q0    | 1.4e+03  |
| loss_q1    | 1.43e+03 |
| loss_q2    | 1.67e+03 |
| loss_q3    | 4.87e+03 |
| param_norm | 39.6     |
| samples    | 2e+03    |
| step       | 2e+03    |
-------------------------
step[2000]: loss[2182.12722]
saving model...


100%|██████████| 29/29 [00:01<00:00, 20.06it/s]


Skipping evaluation for now.
Starting epoch 69/87


100%|██████████| 29/29 [00:01<00:00, 21.10it/s]


Starting epoch 70/87


100%|██████████| 29/29 [00:01<00:00, 20.54it/s]


Starting epoch 71/87


100%|██████████| 29/29 [00:01<00:00, 19.36it/s]


Starting epoch 72/87


100%|██████████| 29/29 [00:01<00:00, 21.90it/s]


Starting epoch 73/87


100%|██████████| 29/29 [00:01<00:00, 19.92it/s]


Starting epoch 74/87


100%|██████████| 29/29 [00:01<00:00, 20.35it/s]


Starting epoch 75/87


100%|██████████| 29/29 [00:01<00:00, 21.80it/s]


Starting epoch 76/87


100%|██████████| 29/29 [00:01<00:00, 20.40it/s]


Starting epoch 77/87


 55%|█████▌    | 16/29 [00:00<00:00, 23.98it/s]

-------------------------
| grad_norm  | 4.74e+03 |
| loss       | 2.61e+03 |
| loss_q0    | 1.28e+03 |
| loss_q1    | 1.31e+03 |
| loss_q2    | 1.5e+03  |
| loss_q3    | 5.91e+03 |
| param_norm | 39.6     |
| samples    | 2.25e+03 |
| step       | 2.25e+03 |
-------------------------
step[2250]: loss[2611.15398]
saving model...


 66%|██████▌   | 19/29 [00:00<00:00, 21.48it/s]

Skipping evaluation for now.


100%|██████████| 29/29 [00:01<00:00, 21.15it/s]


Starting epoch 78/87


100%|██████████| 29/29 [00:01<00:00, 19.04it/s]


Starting epoch 79/87


100%|██████████| 29/29 [00:01<00:00, 20.73it/s]


Starting epoch 80/87


100%|██████████| 29/29 [00:01<00:00, 18.81it/s]


Starting epoch 81/87


100%|██████████| 29/29 [00:01<00:00, 19.04it/s]


Starting epoch 82/87


100%|██████████| 29/29 [00:01<00:00, 19.57it/s]


Starting epoch 83/87


100%|██████████| 29/29 [00:01<00:00, 19.23it/s]


Starting epoch 84/87


100%|██████████| 29/29 [00:01<00:00, 19.93it/s]


Starting epoch 85/87


100%|██████████| 29/29 [00:01<00:00, 21.63it/s]


Starting epoch 86/87


 14%|█▍        | 4/29 [00:00<00:01, 13.40it/s]

-------------------------
| grad_norm  | 4.18e+03 |
| loss       | 2.16e+03 |
| loss_q0    | 1.16e+03 |
| loss_q1    | 1.2e+03  |
| loss_q2    | 1.45e+03 |
| loss_q3    | 5.01e+03 |
| param_norm | 39.7     |
| samples    | 2.5e+03  |
| step       | 2.5e+03  |
-------------------------
step[2500]: loss[2160.85123]
saving model...


 34%|███▍      | 10/29 [00:00<00:01, 16.76it/s]

Skipping evaluation for now.


100%|██████████| 29/29 [00:01<00:00, 18.51it/s]

saving model...
Skipping evaluation for now.



