# 0. 数据读取

In [1]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image

In [2]:
batch_size = 16

In [3]:
transform = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Resize([64,64]),
          transforms.Normalize(
              (0.5,0.5,0.5), 
              (0.5,0.5,0.5), 
          )
         ]
) 

In [4]:
dataset = torchvision.datasets.CIFAR10(root='F:/Datasets/',
                                     #train=True,
                                     transform=transform,
                                     download=True)

# 数据加载器
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)

Files already downloaded and verified


In [32]:
data_ier = iter(data_loader)

In [5]:
for i in data_loader:
    break

i[0].shape

torch.Size([16, 3, 64, 64])

In [6]:
save_image(i[0],"show.jpg")

# 1. 模型定义

In [7]:
from improved_diffusion.unet import UNetModel

In [8]:
def create_model(
    image_size,
    num_channels,
    num_res_blocks,
    learn_sigma,
    class_cond,
    use_checkpoint,
    attention_resolutions,
    num_heads,
    num_heads_upsample,
    use_scale_shift_norm,
    dropout,
):
    if image_size == 256:
        channel_mult = (1, 1, 2, 2, 4, 4)
    elif image_size == 64:
        channel_mult = (1, 2, 3, 4)
    elif image_size == 32:
        channel_mult = (1, 2, 2, 2)
    else:
        raise ValueError(f"unsupported image size: {image_size}")

    attention_ds = []
    if attention_resolutions != "" and attention_resolutions != []:
        for res in attention_resolutions.split(","):
            attention_ds.append(image_size // int(res))

    return UNetModel(
        in_channels=3,
        model_channels=num_channels,
        out_channels=(3 if not learn_sigma else 6),
        num_res_blocks=num_res_blocks,
        attention_resolutions=tuple(attention_ds),
        dropout=dropout,
        channel_mult=channel_mult,
        num_classes=(NUM_CLASSES if class_cond else None),
        use_checkpoint=use_checkpoint,
        num_heads=num_heads,
        num_heads_upsample=num_heads_upsample,
        use_scale_shift_norm=use_scale_shift_norm,
    )

In [9]:
# 定义一个小的模型玩一玩
unet = create_model(
    image_size=32,
    num_channels=128,
    num_res_blocks=1,
    learn_sigma=False,
    class_cond=False,
    use_checkpoint=False,
    attention_resolutions="",#这里是在哪几层添加Attention，为了降低计算量，我们取消这个
    num_heads=2,# 上面已经取消了，这个没什么用
    num_heads_upsample=-1,
    use_scale_shift_norm=True, # 在ResBlock里用到的参数
    dropout=0.0,
)

In [10]:
unet

UNetModel(
  (time_embed): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
  )
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=512, out_features=256, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0.0, inplace=False)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (skip_conne

# 2. Diffusion 定义

In [11]:
from improved_diffusion import gaussian_diffusion as gd
from improved_diffusion.resample import create_named_schedule_sampler

In [12]:
def create_gaussian_diffusion(
    *,
    steps=1000,
    learn_sigma=False,
    sigma_small=False,
    noise_schedule="linear",
    use_kl=False,
    predict_xstart=False,
    rescale_timesteps=False,
    rescale_learned_sigmas=False,
):
    betas = gd.get_named_beta_schedule(noise_schedule, steps)
    if use_kl:
        loss_type = gd.LossType.RESCALED_KL
    elif rescale_learned_sigmas:
        loss_type = gd.LossType.RESCALED_MSE
    else:
        loss_type = gd.LossType.MSE

    return gd.GaussianDiffusion(
        betas=betas,
        model_mean_type=(
            gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
        ),
        model_var_type=(
            (
                gd.ModelVarType.FIXED_LARGE
                if not sigma_small
                else gd.ModelVarType.FIXED_SMALL
            )
            if not learn_sigma
            else gd.ModelVarType.LEARNED_RANGE
        ),
        loss_type=loss_type,
        rescale_timesteps=rescale_timesteps,
    )

In [13]:
diffusion = create_gaussian_diffusion(
    steps=1000,
    learn_sigma=False,
    sigma_small=False,
    noise_schedule="linear",
    use_kl=False,
    predict_xstart=False,
    rescale_timesteps=False,
    rescale_learned_sigmas=True,
)

In [14]:
schedule_sampler = create_named_schedule_sampler("uniform", diffusion)

### 3.1 测试

In [15]:
diffusion.training_losses(
    model=unet,
    x_start=i[0],
    t = torch.randint(0,1000,(16,))
)

{'mse': tensor([0.9908, 0.9853, 0.9956, 0.9963, 1.0163, 0.9856, 1.0126, 1.0150, 1.0246,
         1.0407, 0.9959, 0.9857, 0.9793, 1.0265, 0.9994, 0.9940],
        grad_fn=<MeanBackward1>),
 'loss': tensor([0.9908, 0.9853, 0.9956, 0.9963, 1.0163, 0.9856, 1.0126, 1.0150, 1.0246,
         1.0407, 0.9959, 0.9857, 0.9793, 1.0265, 0.9994, 0.9940],
        grad_fn=<MeanBackward1>)}

# 4. TrainLoop 自定义版本

In [16]:
import copy
import os

import blobfile as bf
import numpy as np
import torch as th

from torch.optim import AdamW

from improved_diffusion.fp16_util import (
    make_master_params,
    master_params_to_model_params,
    model_grads_to_master_grads,
    unflatten_master_params,
    zero_grad,
)

from improved_diffusion import logger
from improved_diffusion.nn import update_ema
from improved_diffusion.resample import LossAwareSampler, UniformSampler

In [17]:
# For ImageNet experiments, this was a good default value.
# We found that the lg_loss_scale quickly climbed to
# 20-21 within the first ~1K steps of training.
INITIAL_LOG_LOSS_SCALE = 20.0

In [18]:
# 照抄不动
def parse_resume_step_from_filename(filename):
    """
    Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
    checkpoint's number of steps.
    """
    split = filename.split("model")
    if len(split) < 2:
        return 0
    split1 = split[-1].split(".")[0]
    try:
        return int(split1)
    except ValueError:
        return 0


def get_blob_logdir():
    return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir())


def find_resume_checkpoint():
    # On your infrastructure, you may want to override this to automatically
    # discover the latest checkpoint on your blob storage, etc.
    return None


def find_ema_checkpoint(main_checkpoint, step, rate):
    if main_checkpoint is None:
        return None
    filename = f"ema_{rate}_{(step):06d}.pt"
    path = bf.join(bf.dirname(main_checkpoint), filename)
    if bf.exists(path):
        return path
    return None


def log_loss_dict(diffusion, ts, losses):
    for key, values in losses.items():
        logger.logkv_mean(key, values.mean().item())
        # Log the quantiles (four quartiles, in particular).
        for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
            quartile = int(4 * sub_t / diffusion.num_timesteps)
            logger.logkv_mean(f"{key}_q{quartile}", sub_loss)

In [19]:
# 小改一下，移除dist相关东西
def load_state_dict(path, **kwargs):
    with bf.BlobFile(path, "rb") as f:
        data = f.read()
    return th.load(io.BytesIO(data), **kwargs)

def dev():
    """
    Get the device to use for torch.distributed.
    """
    if th.cuda.is_available():
        return th.device("cuda")
    return th.device("cpu")

In [40]:
# 大改，删掉所有Dist相关的，并取消了条件生成
class TrainLoop:
    def __init__(
        self,
        *,
        model,
        diffusion,
        data,
        batch_size,
        microbatch,
        lr,
        ema_rate,
        log_interval,
        save_interval,
        resume_checkpoint,
        use_fp16=False,
        fp16_scale_growth=1e-3,
        schedule_sampler=None,
        weight_decay=0.0,
        lr_anneal_steps=0,
    ):
        # 尽可能保持了原始参数
        self.model = model
        self.diffusion = diffusion
        self.data = data
        self.batch_size = batch_size
        self.microbatch = microbatch if microbatch > 0 else batch_size
        self.lr = lr
        self.ema_rate = (
            [ema_rate]
            if isinstance(ema_rate, float)
            else [float(x) for x in ema_rate.split(",")]
        )
        self.log_interval = log_interval
        self.save_interval = save_interval
        self.resume_checkpoint = resume_checkpoint
        self.use_fp16 = use_fp16
        self.fp16_scale_growth = fp16_scale_growth
        self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
        self.weight_decay = weight_decay
        self.lr_anneal_steps = lr_anneal_steps

        self.step = 0
        self.resume_step = 0
        self.global_batch = self.batch_size
        
        self.model_params = list(self.model.parameters())
        self.master_params = self.model_params
        self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
        self.sync_cuda = th.cuda.is_available()
        
        self._load_and_sync_parameters()
        
        if self.use_fp16:
            self._setup_fp16()
        
        self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
        if self.resume_step:
            self._load_optimizer_state()
            # Model was resumed, either due to a restart or a checkpoint
            # being specified at the command line.
            self.ema_params = [
                self._load_ema_parameters(rate) for rate in self.ema_rate
            ]
        else:
            self.ema_params = [
                copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
            ]
    def run_loop(self):
        while (
            not self.lr_anneal_steps
            or self.step + self.resume_step < self.lr_anneal_steps
        ):
            batch, cond = next(self.data)
            cond = {}
            self.run_step(batch, cond)
            if self.step % self.log_interval == 0:
                logger.dumpkvs()
            if self.step % self.save_interval == 0:
                self.save()
                # Run for a finite amount of time in integration tests.
                if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
                    return
            self.step += 1
        # Save the last checkpoint if it wasn't already saved.
        if (self.step - 1) % self.save_interval != 0:
            self.save()
            
    def run_step(self, batch, cond):
        self.forward_backward(batch, cond)
        if self.use_fp16:
            self.optimize_fp16()
        else:
            self.optimize_normal()
        self.log_step()
    
    def forward_backward(self, batch, cond):
        zero_grad(self.model_params)
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch].to(dev())
            # 这里取消了条件生成
            if cond is not None and cond != {}:
                micro_cond = {
                    k: v[i : i + self.microbatch].to(dev())
                    for k, v in cond.items()
                }
            else:
                micro_cond = None
            t, weights = self.schedule_sampler.sample(micro.shape[0], dev())
            
            # 这里取消了一大堆原始代码中为了协调分布式的一堆东西，直接计算Loss
            losses = self.diffusion.training_losses(
                self.model,micro,
                t,
                model_kwargs=micro_cond
            )

            if isinstance(self.schedule_sampler, LossAwareSampler):
                self.schedule_sampler.update_with_local_losses(
                    t, losses["loss"].detach()
                )

            loss = (losses["loss"] * weights).mean()
            log_loss_dict(
                self.diffusion, t, {k: v * weights for k, v in losses.items()}
            )
            if self.use_fp16:
                loss_scale = 2 ** self.lg_loss_scale
                (loss * loss_scale).backward()
            else:
                loss.backward()
    
    def optimize_fp16(self):
        if any(not th.isfinite(p.grad).all() for p in self.model_params):
            self.lg_loss_scale -= 1
            logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
            return

        model_grads_to_master_grads(self.model_params, self.master_params)
        self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
        self._log_grad_norm()
        self._anneal_lr()
        self.opt.step()
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.master_params, rate=rate)
        master_params_to_model_params(self.model_params, self.master_params)
        self.lg_loss_scale += self.fp16_scale_growth

    def optimize_normal(self):
        self._log_grad_norm()
        self._anneal_lr()
        self.opt.step()
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.master_params, rate=rate)
            
    def _log_grad_norm(self):
        sqsum = 0.0
        for p in self.master_params:
            sqsum += (p.grad ** 2).sum().item()
        logger.logkv_mean("grad_norm", np.sqrt(sqsum))

    def _anneal_lr(self):
        if not self.lr_anneal_steps:
            return
        frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
        lr = self.lr * (1 - frac_done)
        for param_group in self.opt.param_groups:
            param_group["lr"] = lr
    
    def _load_and_sync_parameters(self):
        resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint

        if resume_checkpoint:
            self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
            logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
            self.model.load_state_dict(
                load_state_dict(
                    resume_checkpoint, map_location=dev()
                )
            )
    
    def save(self):
        def save_checkpoint(rate, params):
            state_dict = self._master_params_to_state_dict(params)
            logger.log(f"saving model {rate}...")
            if not rate:
                filename = f"model{(self.step+self.resume_step):06d}.pt"
            else:
                filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
            with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
                th.save(state_dict, f)

        save_checkpoint(0, self.master_params)
        for rate, params in zip(self.ema_rate, self.ema_params):
            save_checkpoint(rate, params)


        with bf.BlobFile(
            bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
            "wb",
        ) as f:
            th.save(self.opt.state_dict(), f)

    
    def log_step(self):
        logger.logkv("step", self.step + self.resume_step)
        logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
        if self.use_fp16:
            logger.logkv("lg_loss_scale", self.lg_loss_scale)

    
    def _setup_fp16(self):
        self.master_params = make_master_params(self.model_params)
        self.model.convert_to_fp16()
    
    def _load_optimizer_state(self):
        main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
        opt_checkpoint = bf.join(
            bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
        )
        if bf.exists(opt_checkpoint):
            logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
            state_dict = load_state_dict(
                opt_checkpoint, map_location=dev()
            )
            self.opt.load_state_dict(state_dict)
    
    def _load_ema_parameters(self, rate):
        ema_params = copy.deepcopy(self.master_params)

        main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
        ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
        if ema_checkpoint:
            logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
            state_dict = load_state_dict(
                ema_checkpoint, map_location=dist_util.dev()
            )
            ema_params = self._state_dict_to_master_params(state_dict)
        return ema_params
    
    def _state_dict_to_master_params(self, state_dict):
        params = [state_dict[name] for name, _ in self.model.named_parameters()]
        if self.use_fp16:
            return make_master_params(params)
        else:
            return params
    
    def _master_params_to_state_dict(self, master_params):
        if self.use_fp16:
            master_params = unflatten_master_params(
                self.model.parameters(), master_params
            )
        state_dict = self.model.state_dict()
        for i, (name, _value) in enumerate(self.model.named_parameters()):
            assert name in state_dict
            state_dict[name] = master_params[i]
        return state_dict

    def _state_dict_to_master_params(self, state_dict):
        params = [state_dict[name] for name, _ in self.model.named_parameters()]
        if self.use_fp16:
            return make_master_params(params)
        else:
            return params

In [43]:
unet.to(dev())

UNetModel(
  (time_embed): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
  )
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=512, out_features=256, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0.0, inplace=False)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (skip_conne

In [44]:
trainer = TrainLoop(
    model=unet,
    diffusion=diffusion,
    data=data_ier,
    batch_size=4,
    microbatch=-1,
    lr=1e-4,
    ema_rate="0.9999",
    log_interval=10,
    save_interval=5000,
    resume_checkpoint="",
    use_fp16=False,
    fp16_scale_growth=1e-3,
    schedule_sampler=schedule_sampler,
    weight_decay=0.0,
    lr_anneal_steps=0,
)

In [None]:
trainer.run_loop()

Logging to C:\Users\WANGYI~1\AppData\Local\Temp\openai-2022-09-10-17-15-07-378194
------------------------
| grad_norm | 13.1     |
| loss      | 1        |
| loss_q0   | 1.01     |
| loss_q1   | 1        |
| loss_q2   | 1.01     |
| loss_q3   | 0.991    |
| mse       | 1        |
| mse_q0    | 1.01     |
| mse_q1    | 1        |
| mse_q2    | 1.01     |
| mse_q3    | 0.991    |
| samples   | 4        |
| step      | 0        |
------------------------
saving model 0...
saving model 0.9999...
------------------------
| grad_norm | 13.2     |
| loss      | 0.904    |
| loss_q0   | 0.924    |
| loss_q1   | 0.888    |
| loss_q2   | 0.901    |
| loss_q3   | 0.903    |
| mse       | 0.904    |
| mse_q0    | 0.924    |
| mse_q1    | 0.888    |
| mse_q2    | 0.901    |
| mse_q3    | 0.903    |
| samples   | 44       |
| step      | 10       |
------------------------
------------------------
| grad_norm | 11.9     |
| loss      | 0.711    |
| loss_q0   | 0.747    |
| loss_q1   | 0.703    |
| 

------------------------
| grad_norm | 0.954    |
| loss      | 0.0182   |
| loss_q0   | 0.0571   |
| loss_q1   | 0.0149   |
| loss_q2   | 0.00608  |
| loss_q3   | 0.00483  |
| mse       | 0.0182   |
| mse_q0    | 0.0571   |
| mse_q1    | 0.0149   |
| mse_q2    | 0.00608  |
| mse_q3    | 0.00483  |
| samples   | 884      |
| step      | 220      |
------------------------
------------------------
| grad_norm | 1.2      |
| loss      | 0.0277   |
| loss_q0   | 0.0753   |
| loss_q1   | 0.0142   |
| loss_q2   | 0.00624  |
| loss_q3   | 0.00418  |
| mse       | 0.0277   |
| mse_q0    | 0.0753   |
| mse_q1    | 0.0142   |
| mse_q2    | 0.00624  |
| mse_q3    | 0.00418  |
| samples   | 924      |
| step      | 230      |
------------------------
------------------------
| grad_norm | 1.07     |
| loss      | 0.0219   |
| loss_q0   | 0.0674   |
| loss_q1   | 0.0139   |
| loss_q2   | 0.00608  |
| loss_q3   | 0.00441  |
| mse       | 0.0219   |
| mse_q0    | 0.0674   |
| mse_q1    | 0.0139   |


------------------------
| grad_norm | 0.754    |
| loss      | 0.0188   |
| loss_q0   | 0.0533   |
| loss_q1   | 0.0115   |
| loss_q2   | 0.00402  |
| loss_q3   | 0.00214  |
| mse       | 0.0188   |
| mse_q0    | 0.0533   |
| mse_q1    | 0.0115   |
| mse_q2    | 0.00402  |
| mse_q3    | 0.00214  |
| samples   | 1.76e+03 |
| step      | 440      |
------------------------
------------------------
| grad_norm | 0.808    |
| loss      | 0.0167   |
| loss_q0   | 0.0461   |
| loss_q1   | 0.0115   |
| loss_q2   | 0.00447  |
| loss_q3   | 0.00224  |
| mse       | 0.0167   |
| mse_q0    | 0.0461   |
| mse_q1    | 0.0115   |
| mse_q2    | 0.00447  |
| mse_q3    | 0.00224  |
| samples   | 1.8e+03  |
| step      | 450      |
------------------------
------------------------
| grad_norm | 0.809    |
| loss      | 0.0177   |
| loss_q0   | 0.0508   |
| loss_q1   | 0.0117   |
| loss_q2   | 0.00397  |
| loss_q3   | 0.00214  |
| mse       | 0.0177   |
| mse_q0    | 0.0508   |
| mse_q1    | 0.0117   |


------------------------
| grad_norm | 0.562    |
| loss      | 0.0142   |
| loss_q0   | 0.0457   |
| loss_q1   | 0.0113   |
| loss_q2   | 0.00337  |
| loss_q3   | 0.00159  |
| mse       | 0.0142   |
| mse_q0    | 0.0457   |
| mse_q1    | 0.0113   |
| mse_q2    | 0.00337  |
| mse_q3    | 0.00159  |
| samples   | 2.64e+03 |
| step      | 660      |
------------------------
------------------------
| grad_norm | 0.611    |
| loss      | 0.0133   |
| loss_q0   | 0.0363   |
| loss_q1   | 0.0114   |
| loss_q2   | 0.00353  |
| loss_q3   | 0.00173  |
| mse       | 0.0133   |
| mse_q0    | 0.0363   |
| mse_q1    | 0.0114   |
| mse_q2    | 0.00353  |
| mse_q3    | 0.00173  |
| samples   | 2.68e+03 |
| step      | 670      |
------------------------
------------------------
| grad_norm | 0.627    |
| loss      | 0.0125   |
| loss_q0   | 0.0417   |
| loss_q1   | 0.0106   |
| loss_q2   | 0.00353  |
| loss_q3   | 0.00163  |
| mse       | 0.0125   |
| mse_q0    | 0.0417   |
| mse_q1    | 0.0106   |


------------------------
| grad_norm | 0.421    |
| loss      | 0.0128   |
| loss_q0   | 0.0417   |
| loss_q1   | 0.0094   |
| loss_q2   | 0.00328  |
| loss_q3   | 0.00139  |
| mse       | 0.0128   |
| mse_q0    | 0.0417   |
| mse_q1    | 0.0094   |
| mse_q2    | 0.00328  |
| mse_q3    | 0.00139  |
| samples   | 3.52e+03 |
| step      | 880      |
------------------------
------------------------
| grad_norm | 0.412    |
| loss      | 0.0133   |
| loss_q0   | 0.0414   |
| loss_q1   | 0.0103   |
| loss_q2   | 0.00331  |
| loss_q3   | 0.00139  |
| mse       | 0.0133   |
| mse_q0    | 0.0414   |
| mse_q1    | 0.0103   |
| mse_q2    | 0.00331  |
| mse_q3    | 0.00139  |
| samples   | 3.56e+03 |
| step      | 890      |
------------------------
------------------------
| grad_norm | 0.469    |
| loss      | 0.013    |
| loss_q0   | 0.0345   |
| loss_q1   | 0.0107   |
| loss_q2   | 0.00277  |
| loss_q3   | 0.0013   |
| mse       | 0.013    |
| mse_q0    | 0.0345   |
| mse_q1    | 0.0107   |


------------------------
| grad_norm | 0.636    |
| loss      | 0.0165   |
| loss_q0   | 0.0484   |
| loss_q1   | 0.0107   |
| loss_q2   | 0.00318  |
| loss_q3   | 0.0014   |
| mse       | 0.0165   |
| mse_q0    | 0.0484   |
| mse_q1    | 0.0107   |
| mse_q2    | 0.00318  |
| mse_q3    | 0.0014   |
| samples   | 4.4e+03  |
| step      | 1.1e+03  |
------------------------
------------------------
| grad_norm | 0.703    |
| loss      | 0.0219   |
| loss_q0   | 0.0606   |
| loss_q1   | 0.0101   |
| loss_q2   | 0.00277  |
| loss_q3   | 0.00128  |
| mse       | 0.0219   |
| mse_q0    | 0.0606   |
| mse_q1    | 0.0101   |
| mse_q2    | 0.00277  |
| mse_q3    | 0.00128  |
| samples   | 4.44e+03 |
| step      | 1.11e+03 |
------------------------
------------------------
| grad_norm | 0.396    |
| loss      | 0.0139   |
| loss_q0   | 0.0394   |
| loss_q1   | 0.0095   |
| loss_q2   | 0.00305  |
| loss_q3   | 0.00127  |
| mse       | 0.0139   |
| mse_q0    | 0.0394   |
| mse_q1    | 0.0095   |


------------------------
| grad_norm | 0.562    |
| loss      | 0.0169   |
| loss_q0   | 0.048    |
| loss_q1   | 0.01     |
| loss_q2   | 0.00318  |
| loss_q3   | 0.00116  |
| mse       | 0.0169   |
| mse_q0    | 0.048    |
| mse_q1    | 0.01     |
| mse_q2    | 0.00318  |
| mse_q3    | 0.00116  |
| samples   | 5.28e+03 |
| step      | 1.32e+03 |
------------------------
------------------------
| grad_norm | 0.592    |
| loss      | 0.0129   |
| loss_q0   | 0.0359   |
| loss_q1   | 0.00971  |
| loss_q2   | 0.00272  |
| loss_q3   | 0.00123  |
| mse       | 0.0129   |
| mse_q0    | 0.0359   |
| mse_q1    | 0.00971  |
| mse_q2    | 0.00272  |
| mse_q3    | 0.00123  |
| samples   | 5.32e+03 |
| step      | 1.33e+03 |
------------------------
------------------------
| grad_norm | 0.499    |
| loss      | 0.0171   |
| loss_q0   | 0.0519   |
| loss_q1   | 0.0106   |
| loss_q2   | 0.00282  |
| loss_q3   | 0.00117  |
| mse       | 0.0171   |
| mse_q0    | 0.0519   |
| mse_q1    | 0.0106   |


------------------------
| grad_norm | 0.588    |
| loss      | 0.0221   |
| loss_q0   | 0.0683   |
| loss_q1   | 0.0097   |
| loss_q2   | 0.00284  |
| loss_q3   | 0.00102  |
| mse       | 0.0221   |
| mse_q0    | 0.0683   |
| mse_q1    | 0.0097   |
| mse_q2    | 0.00284  |
| mse_q3    | 0.00102  |
| samples   | 6.16e+03 |
| step      | 1.54e+03 |
------------------------
------------------------
| grad_norm | 0.505    |
| loss      | 0.0142   |
| loss_q0   | 0.0428   |
| loss_q1   | 0.00969  |
| loss_q2   | 0.00258  |
| loss_q3   | 0.00118  |
| mse       | 0.0142   |
| mse_q0    | 0.0428   |
| mse_q1    | 0.00969  |
| mse_q2    | 0.00258  |
| mse_q3    | 0.00118  |
| samples   | 6.2e+03  |
| step      | 1.55e+03 |
------------------------
------------------------
| grad_norm | 0.486    |
| loss      | 0.0166   |
| loss_q0   | 0.04     |
| loss_q1   | 0.00964  |
| loss_q2   | 0.00328  |
| loss_q3   | 0.00121  |
| mse       | 0.0166   |
| mse_q0    | 0.04     |
| mse_q1    | 0.00964  |


------------------------
| grad_norm | 0.442    |
| loss      | 0.0129   |
| loss_q0   | 0.0355   |
| loss_q1   | 0.00865  |
| loss_q2   | 0.00277  |
| loss_q3   | 0.00091  |
| mse       | 0.0129   |
| mse_q0    | 0.0355   |
| mse_q1    | 0.00865  |
| mse_q2    | 0.00277  |
| mse_q3    | 0.00091  |
| samples   | 7.04e+03 |
| step      | 1.76e+03 |
------------------------
------------------------
| grad_norm | 0.456    |
| loss      | 0.015    |
| loss_q0   | 0.0488   |
| loss_q1   | 0.00928  |
| loss_q2   | 0.00278  |
| loss_q3   | 0.00106  |
| mse       | 0.015    |
| mse_q0    | 0.0488   |
| mse_q1    | 0.00928  |
| mse_q2    | 0.00278  |
| mse_q3    | 0.00106  |
| samples   | 7.08e+03 |
| step      | 1.77e+03 |
------------------------
------------------------
| grad_norm | 0.388    |
| loss      | 0.0134   |
| loss_q0   | 0.0384   |
| loss_q1   | 0.00979  |
| loss_q2   | 0.00263  |
| loss_q3   | 0.000953 |
| mse       | 0.0134   |
| mse_q0    | 0.0384   |
| mse_q1    | 0.00979  |


------------------------
| grad_norm | 0.451    |
| loss      | 0.0116   |
| loss_q0   | 0.0292   |
| loss_q1   | 0.00916  |
| loss_q2   | 0.0025   |
| loss_q3   | 0.000998 |
| mse       | 0.0116   |
| mse_q0    | 0.0292   |
| mse_q1    | 0.00916  |
| mse_q2    | 0.0025   |
| mse_q3    | 0.000998 |
| samples   | 7.92e+03 |
| step      | 1.98e+03 |
------------------------
------------------------
| grad_norm | 0.331    |
| loss      | 0.0103   |
| loss_q0   | 0.033    |
| loss_q1   | 0.00891  |
| loss_q2   | 0.00231  |
| loss_q3   | 0.000823 |
| mse       | 0.0103   |
| mse_q0    | 0.033    |
| mse_q1    | 0.00891  |
| mse_q2    | 0.00231  |
| mse_q3    | 0.000823 |
| samples   | 7.96e+03 |
| step      | 1.99e+03 |
------------------------
------------------------
| grad_norm | 0.393    |
| loss      | 0.0108   |
| loss_q0   | 0.0308   |
| loss_q1   | 0.0106   |
| loss_q2   | 0.00236  |
| loss_q3   | 0.000867 |
| mse       | 0.0108   |
| mse_q0    | 0.0308   |
| mse_q1    | 0.0106   |


------------------------
| grad_norm | 0.318    |
| loss      | 0.0127   |
| loss_q0   | 0.0462   |
| loss_q1   | 0.00937  |
| loss_q2   | 0.00247  |
| loss_q3   | 0.000862 |
| mse       | 0.0127   |
| mse_q0    | 0.0462   |
| mse_q1    | 0.00937  |
| mse_q2    | 0.00247  |
| mse_q3    | 0.000862 |
| samples   | 8.8e+03  |
| step      | 2.2e+03  |
------------------------
------------------------
| grad_norm | 0.261    |
| loss      | 0.0137   |
| loss_q0   | 0.0375   |
| loss_q1   | 0.00922  |
| loss_q2   | 0.00232  |
| loss_q3   | 0.000796 |
| mse       | 0.0137   |
| mse_q0    | 0.0375   |
| mse_q1    | 0.00922  |
| mse_q2    | 0.00232  |
| mse_q3    | 0.000796 |
| samples   | 8.84e+03 |
| step      | 2.21e+03 |
------------------------
------------------------
| grad_norm | 0.359    |
| loss      | 0.0117   |
| loss_q0   | 0.0385   |
| loss_q1   | 0.00783  |
| loss_q2   | 0.00273  |
| loss_q3   | 0.000848 |
| mse       | 0.0117   |
| mse_q0    | 0.0385   |
| mse_q1    | 0.00783  |


------------------------
| grad_norm | 0.356    |
| loss      | 0.0143   |
| loss_q0   | 0.0425   |
| loss_q1   | 0.00775  |
| loss_q2   | 0.00254  |
| loss_q3   | 0.000819 |
| mse       | 0.0143   |
| mse_q0    | 0.0425   |
| mse_q1    | 0.00775  |
| mse_q2    | 0.00254  |
| mse_q3    | 0.000819 |
| samples   | 9.68e+03 |
| step      | 2.42e+03 |
------------------------
------------------------
| grad_norm | 0.372    |
| loss      | 0.0123   |
| loss_q0   | 0.0297   |
| loss_q1   | 0.00943  |
| loss_q2   | 0.00289  |
| loss_q3   | 0.000757 |
| mse       | 0.0123   |
| mse_q0    | 0.0297   |
| mse_q1    | 0.00943  |
| mse_q2    | 0.00289  |
| mse_q3    | 0.000757 |
| samples   | 9.72e+03 |
| step      | 2.43e+03 |
------------------------
------------------------
| grad_norm | 0.259    |
| loss      | 0.0108   |
| loss_q0   | 0.0323   |
| loss_q1   | 0.00914  |
| loss_q2   | 0.00256  |
| loss_q3   | 0.000812 |
| mse       | 0.0108   |
| mse_q0    | 0.0323   |
| mse_q1    | 0.00914  |


------------------------
| grad_norm | 0.246    |
| loss      | 0.0104   |
| loss_q0   | 0.0381   |
| loss_q1   | 0.00832  |
| loss_q2   | 0.00211  |
| loss_q3   | 0.000741 |
| mse       | 0.0104   |
| mse_q0    | 0.0381   |
| mse_q1    | 0.00832  |
| mse_q2    | 0.00211  |
| mse_q3    | 0.000741 |
| samples   | 1.06e+04 |
| step      | 2.64e+03 |
------------------------
------------------------
| grad_norm | 0.343    |
| loss      | 0.01     |
| loss_q0   | 0.0303   |
| loss_q1   | 0.00978  |
| loss_q2   | 0.00248  |
| loss_q3   | 0.000735 |
| mse       | 0.01     |
| mse_q0    | 0.0303   |
| mse_q1    | 0.00978  |
| mse_q2    | 0.00248  |
| mse_q3    | 0.000735 |
| samples   | 1.06e+04 |
| step      | 2.65e+03 |
------------------------
------------------------
| grad_norm | 0.304    |
| loss      | 0.0112   |
| loss_q0   | 0.0344   |
| loss_q1   | 0.00839  |
| loss_q2   | 0.00243  |
| loss_q3   | 0.000746 |
| mse       | 0.0112   |
| mse_q0    | 0.0344   |
| mse_q1    | 0.00839  |


------------------------
| grad_norm | 0.351    |
| loss      | 0.0111   |
| loss_q0   | 0.0364   |
| loss_q1   | 0.00933  |
| loss_q2   | 0.00232  |
| loss_q3   | 0.000752 |
| mse       | 0.0111   |
| mse_q0    | 0.0364   |
| mse_q1    | 0.00933  |
| mse_q2    | 0.00232  |
| mse_q3    | 0.000752 |
| samples   | 1.14e+04 |
| step      | 2.86e+03 |
------------------------
------------------------
| grad_norm | 0.415    |
| loss      | 0.0151   |
| loss_q0   | 0.0468   |
| loss_q1   | 0.0103   |
| loss_q2   | 0.00246  |
| loss_q3   | 0.000721 |
| mse       | 0.0151   |
| mse_q0    | 0.0468   |
| mse_q1    | 0.0103   |
| mse_q2    | 0.00246  |
| mse_q3    | 0.000721 |
| samples   | 1.15e+04 |
| step      | 2.87e+03 |
------------------------
------------------------
| grad_norm | 0.411    |
| loss      | 0.0122   |
| loss_q0   | 0.0405   |
| loss_q1   | 0.00945  |
| loss_q2   | 0.0026   |
| loss_q3   | 0.000876 |
| mse       | 0.0122   |
| mse_q0    | 0.0405   |
| mse_q1    | 0.00945  |
