<a href="https://colab.research.google.com/github/parseny/Generative-Models-2024/blob/main/Assignment%203/3_ddpm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Generative Models
***

**Autumn 2024**

## Assignment 3


In this assignment we implement DDPM - Denoising Diffusion Probabilistic Models (2020)

In simple terms, we get an image from data and add noise step by step. Then We train a model to predict that noise at each step and use the model to generate images.

The following definitions and derivations show how this works. For details please refer to the paper https://arxiv.org/abs/2006.11239

## Forward Process

The forward process adds noise to the data $x_0 \sim q(x_0)$, for $T$ timesteps.

\begin{align}
q(x_t | x_{t-1}) = \mathcal{N}\big(x_t; \sqrt{1-  \beta_t} x_{t-1}, \beta_t \mathbf{I}\big) \\
q(x_{1:T} | x_0) = \prod_{t = 1}^{T} q(x_t | x_{t-1})
\end{align}

where $\beta_1, \dots, \beta_T$ is the variance schedule.

We can sample $x_t$ at any timestep $t$ with,

\begin{align}
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
\end{align}

where $\alpha_t = 1 - \beta_t$ and $\bar\alpha_t = \prod_{s=1}^t \alpha_s$

## Reverse Process

The reverse process removes noise starting at $p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
for $T$ time steps.

\begin{align}
p_\theta(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
 \mu_\theta(x_t, t), \Sigma_\theta(x_t, t)\big) \\
p_\theta(x_{0:T}) &= p_\theta(x_T) \prod_{t = 1}^{T} p_\theta(x_{t-1} | x_t) \\
p_\theta(x_0) &= \int p_\theta(x_{0:T}) dx_{1:T}
\end{align}

$\theta$ are the parameters we train.

## Loss

We optimize the ELBO (from Jenson's inequality) on the negative log likelihood.

\begin{align}
\mathbb{E}[-\log p_\theta(x_0)]
 &\le \mathbb{E}_q [ -\log \frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)} ] \\
 &=L
\end{align}

The loss can be rewritten as  follows.

\begin{align}
L
 &= \mathbb{E}_q [ -\log \frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)} ] \\
 &= \mathbb{E}_q [ -\log p(x_T) - \sum_{t=1}^T \log \frac{p_\theta(x_{t-1}|x_t)}{q(x_t|x_{t-1})} ] \\
 &= \mathbb{E}_q [
  -\log \frac{p(x_T)}{q(x_T|x_0)}
  -\sum_{t=2}^T \log \frac{p_\theta(x_{t-1}|x_t)}{q(x_{t-1}|x_t,x_0)}
  -\log p_\theta(x_0|x_1)] \\
 &= \mathbb{E}_q [
   D_{KL}(q(x_T|x_0) \Vert p(x_T))
  +\sum_{t=2}^T D_{KL}(q(x_{t-1}|x_t,x_0) \Vert p_\theta(x_{t-1}|x_t))
  -\log p_\theta(x_0|x_1)]
\end{align}

$D_{KL}(q(x_T|x_0) \Vert p(x_T))$ is constant since we keep $\beta_1, \dots, \beta_T$ constant.

### Computing $L_{t-1} = D_{KL}(q(x_{t-1}|x_t,x_0) \Vert p_\theta(x_{t-1}|x_t))$

The forward process posterior conditioned by $x_0$ is,

\begin{align}
q(x_{t-1}|x_t, x_0) &= \mathcal{N} \Big(x_{t-1}; \tilde\mu_t(x_t, x_0), \tilde\beta_t \mathbf{I} \Big) \\
\tilde\mu_t(x_t, x_0) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
                         + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
\tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t
\end{align}

The paper sets $\Sigma_\theta(x_t, t) = \sigma_t^2 \mathbf{I}$ where $\sigma_t^2$ is set to constants
$\beta_t$ or $\tilde\beta_t$.

Then,
$$p_\theta(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \sigma_t^2 \mathbf{I} \big)$$

For given noise $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ using $q(x_t|x_0)$

\begin{align}
x_t(x_0, \epsilon) &= \sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon \\
x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} \Big(x_t(x_0, \epsilon) -  \sqrt{1-\bar\alpha_t}\epsilon\Big)
\end{align}

This gives,

\begin{align}
L_{t-1}
 &= D_{KL}(q(x_{t-1}|x_t,x_0) \Vert p_\theta(x_{t-1}|x_t)) \\
 &= \mathbb{E}_q \Bigg[ \frac{1}{2\sigma_t^2}
 \Big \Vert \tilde\mu(x_t, x_0) - \mu_\theta(x_t, t) \Big \Vert^2 \Bigg] \\
 &= \mathbb{E}_{x_0, \epsilon} \Bigg[ \frac{1}{2\sigma_t^2}
  \bigg\Vert \frac{1}{\sqrt{\alpha_t}} \Big(
  x_t(x_0, \epsilon) - \frac{\beta_t}{\sqrt{1 - \bar\alpha_t}} \epsilon
  \Big) - \mu_\theta(x_t(x_0, \epsilon), t) \bigg\Vert^2 \Bigg] \\
\end{align}

Re-parameterizing with a model to predict noise

\begin{align}
\mu_\theta(x_t, t) &= \tilde\mu \bigg(x_t,
  \frac{1}{\sqrt{\bar\alpha_t}} \Big(x_t -
   \sqrt{1-\bar\alpha_t}\epsilon_\theta(x_t, t) \Big) \bigg) \\
  &= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
  \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(x_t, t) \Big)
\end{align}

where $\epsilon_\theta$ is a learned function that predicts $\epsilon$ given $(x_t, t)$.

This gives,

\begin{align}
L_{t-1}
&= \mathbb{E}_{x_0, \epsilon} \Bigg[ \frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}
  \Big\Vert
  \epsilon - \epsilon_\theta(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
  \Big\Vert^2 \Bigg]
\end{align}

That is, we are training to predict the noise.

### Simplified loss

$$L_{\text{simple}}(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
\epsilon - \epsilon_\theta(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
\bigg\Vert^2 \Bigg]$$

This minimizes $-\log p_\theta(x_0|x_1)$ when $t=1$ and $L_{t-1}$ for $t\gt1$ discarding the
weighting in $L_{t-1}$. Discarding the weights $\frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}$
increase the weight given to higher $t$ (which have higher noise levels), therefore increasing the sample quality.

This file implements the loss calculation and a basic sampling method that we use to generate images during
training.

$\epsilon_\theta(x_t, t)$ is a UNet model.

In [7]:
from typing import List, Tuple, Optional

import torch
import torch.nn.functional as F
import torch.utils.data
from torch import nn
import torchvision
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

In [8]:
# Load file unet.py to the filesystem of colab
from unet import UNet

In [9]:
def gather(consts: torch.Tensor, t: torch.Tensor):
    """Gather consts for t and reshape to feature map shape"""
    c = consts.gather(-1, t)
    return c.reshape(-1, 1, 1, 1)

### Explanations to the code
[1] $\bar\alpha_t = \prod_{s=1}^t \alpha_s$

[2] \begin{align}
        q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
        \end{align}

[3] gather $\alpha_t$ and compute $\sqrt{\bar\alpha_t} x_0$

[4] $(1-\bar\alpha_t) \mathbf{I}$

[5] \begin{align}
        q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
        \end{align}

[6] $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$

[7] \begin{align}
        p_\theta(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
        \mu_\theta(x_t, t), \sigma_t^2 \mathbf{I} \big) \\
        \mu_\theta(x_t, t)
          &= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
            \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(x_t, t) \Big)
        \end{align}

[8] gather $\bar\alpha_t$

[9] $\frac{\beta}{\sqrt{1-\bar\alpha_t}}$

[10] $$\frac{1}{\sqrt{\alpha_t}} \Big(x_t -
              \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(x_t, t) \Big)$$

[11] $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$

[12] $$L_{\text{simple}}(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
        \epsilon - \epsilon_\theta(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
        \bigg\Vert^2 \Bigg]$$

[13] $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$

[14] Get $\epsilon_\theta(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)$

In [10]:
beta = torch.linspace(0.0001, 0.02, 10)
alpha = 1 - beta
alpha_bar = torch.cumprod(alpha, dim=0)  # Используем torch.cumprod для кумулятивного произведения
print("Alpha_bar (cumulative product of alpha):", alpha_bar)


gather(alpha_bar, torch.tensor([1,2,3]))

Alpha_bar (cumulative product of alpha): tensor([0.9999, 0.9976, 0.9931, 0.9864, 0.9776, 0.9667, 0.9537, 0.9389, 0.9222,
        0.9037])


tensor([[[[0.9976]]],


        [[[0.9931]]],


        [[[0.9864]]]])

In [14]:
class DenoiseDiffusion:
    """
    ## Denoise Diffusion
    """

    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
        """
        * eps_model - epsilon_theta(x_t, t) model
        * n_steps - t
        * device - the device to place constants on
        """
        super().__init__()
        self.eps_model = eps_model

        # Create beta_1 ... beta_T linearly increasing variance schedule
        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)

        # alpha_t = 1 - beta_t
        self.alpha = 1 - self.beta
        # [1]
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        # T
        self.n_steps = n_steps
        # sigma^2 = beta
        self.sigma2 = self.beta

    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get q(x_t|x_0) distribution

        [2]
        """
        # [3]
        mean = gather(self.alpha_bar, t) ** 0.5 * x0
        # [4]
        var = 1 - gather(self.alpha_bar, t)
        return mean, var

    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
        """
        Sample from q(x_t|x_0)

        [5]
        """

        # [6]
        if eps is None:
            eps = torch.randn_like(x0)

        # get q(x_t|x_0)
        mean, var = self.q_xt_x0(x0, t)
        # Sample from q(x_t|x_0)
        return mean + (var ** 0.5) * eps

    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
        """
        Sample from p_theta(x_{t-1}|x_t)

        [7]
        """

        # epsilon_theta(x_t, t)
        eps_theta = self.eps_model(xt, t)
        print(f"epsilon_theta(x_t, t)\n{eps_theta}")
        # [8]
        alpha_bar = gather(self.alpha_bar, t)
        print(f"alpha_bar\n{alpha_bar}")
        # alpha_t
        alpha = gather(self.alpha, t)
        print(f"alpha\n{alpha}")
        # [9]
        beta = gather(self.beta, t)
        eps_coef = beta / (1 - alpha_bar) ** 0.5
        print(f"eps_coef\n{eps_coef}")
        # [10]
        mean = (1 / alpha_bar ** 0.5) * (xt - eps_coef * eps_theta)
        print(f"mean\n{mean}")

        # sigma^2
        var = gather(self.sigma2, t)

        # [11]
        eps = eps = torch.randn_like(xt)
        print(f"eps\n{eps}")

        # Sample
        return mean + (var ** .5) * eps

    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
        """
        Simplified Loss

        [12]
        """
        # Get batch size
        batch_size = x0.size(0)
        print(f"batch_size {batch_size}")
        # Get random t for each sample in the batch
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)

        # [13]
        if noise is None:
            noise = torch.randn_like(x0)

        # Sample x_t for q(x_t|x_0)
        xt = self.q_sample(x0, t, noise)
        print(f"xt\n{xt}")
        # [14]
        eps_theta = self.eps_model(xt, t)
        print(f"eps_theta\n{eps_theta}")
        # MSE loss
        return torch.mean((noise - eps_theta) ** 2)

### Explanations to the code
[1] $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$

[2] Sample from $p_\theta(x_{t-1}|x_t)$



In [15]:
class MNISTDataset(torchvision.datasets.MNIST):
    def __init__(self, image_size):
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(image_size),
            torchvision.transforms.ToTensor(),
        ])

        super().__init__("data", train=True, download=True, transform=transform)

    def __getitem__(self, item):
        return super().__getitem__(item)[0]

def plot_samples(tensor):
    # Assuming you have a tensor of size torch.Size([16, 1, 32, 32])
    # Convert the tensor to a numpy array
    images = tensor.numpy()

    # Reshape the images to be of size (16, 32, 32)
    images = np.reshape(images, (16, 32, 32))

    # Create a figure with a grid of subplots
    fig, axes = plt.subplots(nrows=4, ncols=4)

    # Iterate over the images and plot them on the subplots
    for i, ax in enumerate(axes.flatten()):
        ax.imshow(images[i], cmap='gray')
        ax.axis('off')

    # Show the plot
    plt.show()


class Configs:
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # U-Net model for epsilon_theta(x_t, t)
    eps_model: UNet
    # DDPM algorithm
    diffusion: DenoiseDiffusion

    # Number of channels in the image. 3 for RGB.
    image_channels: int = 1
    # Image size
    image_size: int = 32
    # Number of channels in the initial feature map
    n_channels: int = 64
    # The list of channel numbers at each resolution.
    # The number of channels is `channel_multipliers[i] * n_channels`
    channel_multipliers: List[int] = [1, 2, 2, 4]
    # The list of booleans that indicate whether to use attention at each resolution
    is_attention: List[int] = [False, False, False, True]

    # Number of time steps T
    n_steps: int = 1_000
    # Batch size
    batch_size: int = 64
    # Number of samples to generate
    n_samples: int = 16
    # Learning rate
    learning_rate: float = 2e-5

    # Number of training epochs
    epochs: int = 5

    # Dataset
    dataset: torch.utils.data.Dataset = MNISTDataset(image_size)
    # Dataloader
    data_loader: torch.utils.data.DataLoader

    # Adam optimizer
    optimizer: torch.optim.Adam

    def init(self):
        # Create epsilon_theta(x_t, t) model
        self.eps_model = UNet(
            image_channels=self.image_channels,
            n_channels=self.n_channels,
            ch_mults=self.channel_multipliers,
            is_attn=self.is_attention,
        ).to(self.device)

        # Create DDPM class
        self.diffusion = DenoiseDiffusion(
            eps_model=self.eps_model,
            n_steps=self.n_steps,
            device=self.device,
        )

        # Create dataloader
        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
        # Create optimizer
        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)

    def sample(self):
        with torch.no_grad():
            # [1]
            x = torch.randn(self.n_samples, self.image_channels, self.image_size, self.image_size, device=self.device)
            print(x)
            # Remove noise for T steps
            progress_bar = tqdm(range(self.n_steps))
            for t_ in progress_bar:
                progress_bar.set_description(f"Sampling")
                # t
                t = self.n_steps - t_ - 1
                # [2]
                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
            print(x)
            print("-"*50)

            # Log samples
            plot_samples(x.detach().cpu())

    def train(self, epoch):
        # Iterate through the dataset
        progress_bar = tqdm(self.data_loader)
        for data in progress_bar:
            # Increment global step
            progress_bar.set_description(f"Epoch {epoch + 1}")
            # Move data to device
            data = data.to(self.device)

            # Make the gradients zero
            self.optimizer.zero_grad()

            # Calculate loss
            loss = self.diffusion.loss(data)
            # Compute gradients
            loss.backward()

            # Take an optimization step
            self.optimizer.step()

            # Track the loss
            progress_bar.set_postfix(loss=loss.detach().cpu().numpy())

    def run(self):
        for epoch in range(self.epochs):
            # Train the model
            self.train(epoch)
            # Sample some images
            self.sample()


# Create configurations
configs = Configs()

# Initialize
configs.init()

# Start and run the training loop
configs.run()

Epoch 1:   0%|          | 0/938 [00:00<?, ?it/s]

batch_size 64
xt
tensor([[[[-2.2426e-01,  1.4322e+00, -1.5438e-01,  ...,  1.7611e+00,
           -5.2675e-01, -1.3608e+00],
          [-9.5763e-02,  2.7397e-01, -3.5343e-01,  ..., -5.1428e-01,
            4.3963e-02,  1.0126e+00],
          [ 5.7745e-01, -5.2418e-01, -9.4116e-01,  ...,  8.8333e-01,
            1.2845e+00,  6.4045e-01],
          ...,
          [-1.0916e+00,  3.6844e-01, -1.2685e+00,  ...,  9.1202e-01,
           -3.7307e-01, -1.0829e+00],
          [ 3.5043e-01,  1.6573e+00,  1.0553e+00,  ..., -5.6981e-01,
           -5.7773e-01,  6.5816e-02],
          [-4.4103e-01, -5.7662e-01, -1.0651e-02,  ...,  9.9537e-01,
           -1.3923e+00, -1.7097e+00]]],


        [[[ 4.7867e-02,  6.9165e-01, -6.0698e-01,  ...,  3.3133e-01,
            8.1156e-01,  7.8016e-01],
          [-1.4943e-01,  8.1341e-02, -1.8356e-01,  ..., -3.1091e-01,
            2.0309e-01, -2.2664e-03],
          [-3.4213e-01,  3.1049e-01, -8.0175e-01,  ..., -1.2250e+00,
            1.0830e-01,  1.3493e-02],
 

Epoch 1:   0%|          | 1/938 [00:00<08:09,  1.91it/s, loss=1.0394685]

batch_size 64
xt
tensor([[[[ 3.6344e-01,  2.5900e-01, -1.6759e+00,  ..., -4.8209e-01,
            6.7624e-01,  1.0985e+00],
          [ 3.1010e-01,  8.7285e-01,  1.2693e+00,  ...,  6.8192e-02,
            7.5356e-01,  8.3679e-01],
          [ 2.6263e-01, -5.9613e-01, -2.6634e-01,  ...,  1.2717e+00,
            6.5781e-01, -8.4107e-01],
          ...,
          [-2.7947e-01,  6.4195e-01,  6.8330e-01,  ..., -1.5055e+00,
            6.3110e-01, -7.3850e-01],
          [-1.4872e-01, -4.7804e-01, -2.3663e-01,  ..., -5.9556e-01,
           -6.9530e-01,  6.8199e-01],
          [-1.2049e-01,  6.5579e-01, -5.3643e-01,  ...,  1.0683e+00,
            2.1613e+00, -4.3853e-01]]],


        [[[-8.0363e-01,  3.7641e-01, -4.0223e-01,  ...,  2.0537e-01,
           -1.0750e-01,  3.3387e-01],
          [-3.5284e-01, -3.0426e-01, -4.5934e-01,  ...,  1.3248e-01,
            1.0814e-01,  4.9610e-01],
          [-1.5360e-01,  2.0262e-01,  7.6547e-03,  ...,  1.4681e-02,
            4.7041e-01,  1.1862e-01],
 

Epoch 1:   0%|          | 2/938 [00:00<07:18,  2.14it/s, loss=0.9842684]

batch_size 64
xt
tensor([[[[ 2.6940,  0.2090, -2.0314,  ..., -1.0018,  1.2537,  0.4521],
          [ 0.3682, -0.4824, -2.6811,  ...,  0.8426,  0.0367,  0.3118],
          [ 1.2252,  1.0343, -0.8349,  ...,  1.8402, -0.4314, -1.3110],
          ...,
          [-1.8916,  0.5493, -1.8286,  ...,  0.1842,  1.2633,  0.0923],
          [-1.8855,  0.6036, -1.2936,  ...,  0.8970, -0.0740, -1.4990],
          [-1.4704, -0.1505,  0.7119,  ...,  0.6861,  1.0115, -0.9312]]],


        [[[ 0.4152, -0.2863, -0.6052,  ..., -0.5210, -0.7002,  0.3642],
          [-0.4608, -0.1600,  0.8092,  ...,  0.3399, -0.7800,  0.8625],
          [ 0.0801,  0.0045,  0.1500,  ..., -0.5677, -0.5785,  0.3109],
          ...,
          [ 0.2863,  0.2728,  0.5753,  ...,  0.3835, -0.7812, -1.2966],
          [ 0.2338,  0.6538,  1.0696,  ...,  0.3935, -0.4229, -0.5172],
          [ 0.1373, -0.7222,  0.8668,  ...,  0.5065, -0.7946, -0.3300]]],


        [[[-2.0169,  0.3894, -0.3609,  ...,  0.5776, -2.0707,  1.2284],
         

Epoch 1:   0%|          | 3/938 [00:01<07:02,  2.21it/s, loss=0.9295436]

batch_size 64
xt
tensor([[[[ 0.8621, -0.7455, -1.2855,  ...,  0.2353,  0.4333,  0.2583],
          [-0.8402, -0.4870, -0.2057,  ..., -1.1178,  2.7666,  0.7999],
          [ 0.3000, -0.3230, -1.0932,  ..., -0.8508,  0.3652,  0.1955],
          ...,
          [-0.5112,  1.0945, -0.5406,  ...,  0.5862, -0.6479, -0.8261],
          [-0.7046,  1.4025,  1.6366,  ...,  2.5048,  1.3969, -0.0763],
          [-0.5085, -0.0245, -0.7795,  ..., -1.5083,  0.6285,  0.8989]]],


        [[[-0.1712,  0.1083,  0.8297,  ...,  0.0889,  0.2576,  0.0938],
          [-0.1207, -0.5720,  0.0252,  ...,  0.0842, -0.0786,  0.4568],
          [ 0.0502, -0.2771, -0.2586,  ..., -0.0735, -0.0300, -0.2690],
          ...,
          [ 0.2455, -0.0405, -0.0549,  ...,  0.1833,  0.5753,  0.4399],
          [-0.4197, -0.2786, -0.0681,  ..., -0.2370,  0.0655,  0.5383],
          [ 0.0170, -0.0097, -0.2837,  ...,  0.4375, -0.1999, -0.1579]]],


        [[[ 0.0226, -0.1339, -0.1443,  ...,  0.0887, -0.0048, -0.2080],
         

Epoch 1:   0%|          | 4/938 [00:01<06:55,  2.25it/s, loss=0.91172576]

batch_size 64
xt
tensor([[[[-4.6338e-01, -8.9962e-01, -8.2244e-01,  ..., -7.0805e-01,
           -1.8824e-01,  5.7618e-01],
          [ 4.8284e-01,  2.0008e+00, -2.4519e-02,  ...,  1.0695e+00,
           -1.0179e+00, -6.4348e-01],
          [-8.4323e-01, -7.9959e-01, -7.0586e-01,  ..., -3.6586e-01,
            1.6256e+00, -6.4298e-01],
          ...,
          [ 2.1528e+00, -1.0468e+00,  1.7610e+00,  ..., -1.0872e+00,
           -1.7947e-01, -7.7815e-01],
          [-4.5685e-01,  6.6454e-01,  1.8624e+00,  ...,  1.2488e+00,
            1.2502e-01,  1.8380e-01],
          [-7.8788e-01, -4.5078e-01,  2.8029e-01,  ..., -1.2933e-01,
           -3.8768e-02, -5.1793e-01]]],


        [[[-2.9747e-01,  9.2171e-03, -7.8759e-01,  ...,  1.1810e-02,
            1.4823e+00, -3.4865e-01],
          [ 7.7174e-02,  1.1581e+00,  1.0491e-01,  ...,  1.8008e-01,
            7.2130e-01, -4.2778e-01],
          [-8.5390e-01, -1.7621e-01, -3.9859e-01,  ..., -8.9568e-01,
            1.2956e+00, -8.5277e-01],
 

Epoch 1:   1%|          | 5/938 [00:02<06:50,  2.27it/s, loss=0.84696823]

batch_size 64
xt
tensor([[[[-0.7736,  0.5204, -0.0990,  ..., -0.8854, -0.5168,  0.4297],
          [ 0.7379,  0.7155,  0.4901,  ..., -1.1718, -0.8560,  0.2602],
          [-0.8039, -0.2680,  0.7904,  ..., -0.1825,  0.5211, -0.3084],
          ...,
          [ 0.0151, -0.1241,  0.8548,  ...,  0.2232,  0.2901, -0.3352],
          [ 0.9046, -0.9106, -0.5263,  ..., -0.2787,  1.0451, -0.6503],
          [ 0.3165,  0.9082, -0.3544,  ...,  0.1232,  0.8430,  0.4468]]],


        [[[-0.2553,  0.7354, -0.0852,  ...,  1.3760, -0.0276, -0.9568],
          [ 0.1553,  0.0124, -0.9565,  ...,  1.0701,  0.9970,  0.2960],
          [-0.3140,  0.0372,  0.1011,  ...,  1.1655, -0.0109, -1.5414],
          ...,
          [-1.0913, -1.0783,  1.3387,  ..., -1.5100,  0.9218,  0.9771],
          [-0.5789, -0.1501,  0.8834,  ..., -0.4991, -0.8102, -1.0408],
          [-0.0775,  0.3284,  1.2078,  ..., -0.0088,  0.5522, -0.7352]]],


        [[[ 0.0230,  1.7369, -0.0579,  ...,  0.1145, -0.0212, -0.0840],
         

Epoch 1:   1%|          | 6/938 [00:02<06:47,  2.29it/s, loss=0.817495]

batch_size 64
xt
tensor([[[[-1.7108e+00,  6.8022e-01, -1.3559e+00,  ...,  6.0584e-01,
            7.7435e-01,  1.9239e+00],
          [ 8.4985e-01,  4.0696e-01,  5.0140e-01,  ...,  2.7567e-01,
           -1.2256e+00, -1.2550e+00],
          [ 8.1479e-01, -3.2772e+00,  7.6833e-01,  ..., -5.3927e-01,
            1.1350e+00, -1.0203e+00],
          ...,
          [-3.4271e-01,  3.7840e-01, -3.8941e-01,  ..., -4.8492e-01,
            3.5774e-02,  2.0319e-01],
          [-2.7569e-02, -2.8736e-01,  6.8498e-01,  ..., -7.9063e-01,
           -5.1613e-01, -1.0423e+00],
          [ 1.3258e+00,  2.5149e-01,  1.6700e-02,  ..., -3.2135e-01,
            2.3679e-01,  1.8058e+00]]],


        [[[-9.7616e-01, -2.1087e-02,  3.7466e-01,  ..., -9.4839e-03,
           -7.0518e-01, -4.8590e-01],
          [-2.4544e+00, -1.9689e+00,  1.1783e+00,  ...,  1.0650e+00,
            1.7498e-01,  2.4346e-01],
          [-4.5060e-01, -1.0141e+00, -1.7449e+00,  ..., -1.5300e-01,
            3.0043e-01,  8.3511e-02],
 

Epoch 1:   1%|          | 7/938 [00:03<06:47,  2.29it/s, loss=0.77358234]

batch_size 64
xt
tensor([[[[-0.5220, -0.1766, -0.2914,  ...,  0.4645, -0.3422,  0.4598],
          [-0.9290,  0.0456,  0.4609,  ...,  0.0826,  0.4936,  0.3437],
          [ 0.4812,  0.0691,  0.3336,  ...,  0.5363,  0.5120, -0.2059],
          ...,
          [ 0.0728, -0.0664,  0.1183,  ..., -0.0824,  0.0426,  0.3988],
          [-0.0412,  0.2536,  0.1552,  ..., -0.0058,  0.6041, -0.1021],
          [ 0.1205, -0.1565, -0.3164,  ...,  0.8213, -0.4171, -0.0296]]],


        [[[-0.2328, -0.8058,  0.1625,  ...,  1.0889,  1.4779, -1.8448],
          [-0.4646, -0.4601, -2.1791,  ..., -0.9569, -1.0696, -0.3621],
          [ 1.3021, -0.0104, -0.3962,  ...,  0.6976, -0.0905,  1.6775],
          ...,
          [-0.2872, -0.3832,  0.9870,  ...,  0.4609, -1.0050,  1.2507],
          [ 0.9573,  0.8239,  1.3713,  ...,  0.7523,  0.3970, -0.8323],
          [-1.7041,  0.8339,  0.1501,  ..., -1.3613,  1.1638,  3.2976]]],


        [[[-0.1216,  0.2049, -0.4836,  ...,  0.2910,  0.3374, -2.1444],
         

Epoch 1:   1%|          | 8/938 [00:03<06:46,  2.29it/s, loss=0.73624814]

batch_size 64
xt
tensor([[[[ 1.8442e-01, -1.5026e+00, -4.2877e-01,  ...,  8.0732e-01,
           -4.1574e-01, -3.9699e-01],
          [-5.8152e-01, -8.1071e-01,  1.2437e+00,  ...,  6.0089e-01,
           -3.7945e-01, -1.4415e+00],
          [ 5.0967e-01,  2.2010e-01,  1.5769e+00,  ...,  1.0597e-01,
           -5.4041e-02,  2.0298e-01],
          ...,
          [ 5.7152e-01,  1.0058e+00,  3.5531e-01,  ..., -1.4188e-01,
           -5.2626e-01,  1.6259e-02],
          [ 3.0738e-01, -2.3569e-01, -3.1676e-01,  ..., -2.7570e-01,
           -2.3739e-01, -3.0407e-01],
          [-8.4709e-01,  1.4998e+00, -1.4063e+00,  ..., -3.2398e-01,
            1.0748e+00, -1.0765e+00]]],


        [[[-6.2991e-01,  1.0000e+00,  8.8584e-01,  ...,  7.4717e-01,
           -1.4518e-01, -1.0548e+00],
          [-2.5572e-01,  1.2280e+00,  1.9048e+00,  ..., -1.1369e+00,
           -4.9600e-01,  8.2069e-01],
          [-7.8582e-01,  2.4621e-01, -7.2344e-01,  ...,  1.2562e-01,
            8.7832e-01, -1.6379e+00],
 

Epoch 1:   1%|          | 9/938 [00:04<06:45,  2.29it/s, loss=0.6779096]

batch_size 64
xt
tensor([[[[ 6.2053e-01, -1.3710e+00,  1.3849e+00,  ...,  1.0961e+00,
            2.4933e+00, -1.8861e+00],
          [ 3.3640e-01, -1.1419e+00, -1.2423e+00,  ..., -5.5799e-01,
            3.2217e-01,  6.8686e-01],
          [ 9.9445e-02,  1.8300e+00,  1.6176e+00,  ..., -1.4338e+00,
            6.5866e-01,  7.9069e-01],
          ...,
          [ 8.4515e-01, -6.5539e-01,  1.7936e-01,  ...,  8.9620e-01,
           -1.2181e+00, -4.3759e-01],
          [-5.2563e-01,  2.8417e-01,  1.2428e+00,  ..., -4.5616e-01,
           -1.9408e-01, -5.8741e-01],
          [ 7.8055e-01, -2.9657e-01, -4.9045e-02,  ...,  8.5137e-01,
            3.7353e-01, -1.3955e+00]]],


        [[[ 9.4063e-01, -1.9680e-01, -1.2778e+00,  ...,  9.2755e-01,
            6.0920e-01,  1.4862e+00],
          [ 1.3000e+00,  2.2720e+00, -1.6966e-03,  ..., -6.6386e-01,
           -1.5870e+00,  6.4003e-01],
          [-1.2309e+00, -1.4285e+00, -3.7095e-01,  ..., -1.8629e+00,
            3.0483e-02, -5.6675e-01],
 

Epoch 1:   1%|          | 10/938 [00:04<06:44,  2.29it/s, loss=0.6339325]

batch_size 64
xt
tensor([[[[ 1.9335,  0.7926,  0.3674,  ..., -0.0799,  0.2140,  0.7269],
          [ 1.6455,  0.6304,  0.6377,  ...,  0.0697, -0.9734,  1.3721],
          [ 0.2588,  2.2794, -0.9304,  ...,  1.0110,  0.3363, -0.6975],
          ...,
          [ 0.1685,  0.3511,  2.0532,  ..., -1.1488,  0.5721, -0.2271],
          [ 1.4488, -1.0460,  0.8596,  ...,  2.5442,  0.7639,  1.5538],
          [-0.8433,  0.8845,  0.7325,  ...,  0.6078,  0.7716, -1.3170]]],


        [[[-1.9490, -2.1687,  0.2607,  ...,  0.2087,  1.2030, -0.0303],
          [ 1.3468,  0.2093, -0.7418,  ...,  0.3687, -1.1210,  0.4822],
          [ 0.4076,  0.1925,  0.2491,  ..., -0.0913, -0.3046, -0.8860],
          ...,
          [-0.4357,  1.0498, -0.4735,  ..., -0.9016,  0.1692,  1.1958],
          [ 0.1867,  2.6851, -0.7276,  ...,  0.8909,  0.8243,  0.9993],
          [-2.1498, -0.3016,  0.0152,  ..., -0.7276,  0.0828,  0.7195]]],


        [[[-0.8246, -0.3717,  0.3966,  ...,  0.0825,  1.4114, -2.2626],
         

Epoch 1:   1%|          | 11/938 [00:04<06:46,  2.28it/s, loss=0.6250288]

batch_size 64
xt
tensor([[[[ 4.4013e-01,  2.6758e-01,  1.0002e-01,  ...,  1.4816e+00,
           -7.6475e-02, -2.5709e-01],
          [ 8.1377e-01,  2.1810e+00, -1.9457e+00,  ..., -1.5537e+00,
           -6.4489e-01, -2.8593e+00],
          [-1.5902e-01,  2.5814e-01, -5.0004e-01,  ...,  1.4138e+00,
           -6.9475e-01, -1.2736e+00],
          ...,
          [ 9.6846e-01, -1.6781e-01, -1.5927e+00,  ..., -1.2213e+00,
            1.1631e+00, -7.2447e-01],
          [-5.9173e-02, -5.1935e-01, -8.4014e-01,  ..., -6.3825e-01,
            5.0084e-02, -4.8499e-01],
          [-3.5601e-01, -1.3252e+00, -9.7674e-01,  ..., -1.6695e-01,
           -1.2662e+00,  3.5038e-01]]],


        [[[ 1.0238e+00,  1.6353e-01,  9.8301e-01,  ...,  2.9345e-01,
           -2.6519e-01,  1.1304e-01],
          [ 4.9113e-01, -4.2492e-01,  2.9328e-01,  ...,  2.9452e-01,
            1.6993e+00,  4.9776e-01],
          [ 4.5898e-01,  8.4486e-01, -6.6362e-01,  ...,  4.2599e-02,
            9.3393e-02,  1.1555e+00],
 

Epoch 1:   1%|▏         | 12/938 [00:05<06:44,  2.29it/s, loss=0.5932648]

batch_size 64
xt
tensor([[[[ 0.2926,  0.2898,  0.6088,  ...,  0.4753, -0.8912,  1.5510],
          [ 0.1678, -0.6348, -0.5849,  ...,  0.6428,  1.3202,  1.5068],
          [-2.1482, -1.3552,  0.3000,  ..., -0.2445,  0.7967, -0.1476],
          ...,
          [ 1.1957, -0.7570,  0.6150,  ...,  0.0920, -0.1058, -0.6121],
          [-1.4983, -0.9957, -1.6204,  ...,  1.1183, -0.4831, -1.2442],
          [-0.1875, -0.4845, -0.4167,  ..., -1.1761,  0.2671, -0.2578]]],


        [[[-0.1545, -0.0759, -0.0455,  ...,  0.0856, -0.3782, -0.2092],
          [ 0.2690, -0.5657,  0.0822,  ...,  0.0417, -0.1343, -0.1281],
          [-0.2018,  0.4696, -0.4732,  ...,  0.0414,  0.0211, -0.5022],
          ...,
          [-0.1011,  0.0895, -0.2503,  ...,  0.1304, -0.1326,  0.1145],
          [-0.5326, -0.1364, -0.1088,  ...,  0.2903, -0.4845,  0.1886],
          [ 0.1269, -0.1613,  0.7822,  ..., -0.0768, -0.8410,  0.3711]]],


        [[[-0.1765, -0.9643,  0.4775,  ...,  0.3813, -0.2132, -1.6060],
         

Epoch 1:   1%|▏         | 13/938 [00:05<06:45,  2.28it/s, loss=0.56735486]

batch_size 64
xt
tensor([[[[-0.4706,  0.2320,  1.0896,  ...,  0.1411, -0.8396,  0.3234],
          [-1.1080,  0.5537, -0.7066,  ..., -0.9826, -0.3839, -1.1817],
          [ 0.8776, -0.3682, -0.3140,  ...,  0.7473, -0.2650, -0.3276],
          ...,
          [-0.4091, -0.5802, -0.7312,  ..., -0.3270, -0.8619, -0.9562],
          [ 0.2049,  0.0165, -0.0641,  ...,  0.5063, -0.2482,  1.0824],
          [ 0.2301, -1.4028,  0.5376,  ...,  2.8630,  0.3951,  0.4061]]],


        [[[-1.0078,  0.2776, -1.1656,  ..., -0.4209,  0.3642, -1.6260],
          [ 1.7578,  0.9540,  0.9850,  ...,  0.5711, -0.1852,  1.5002],
          [-0.0885, -0.9838,  0.5671,  ..., -0.6668, -1.0054,  0.3833],
          ...,
          [ 1.2520, -0.1061, -1.4557,  ...,  0.6301, -0.4641, -0.1971],
          [-0.9521,  2.1000,  1.0231,  ..., -0.2991,  1.0140, -0.2817],
          [ 2.9435,  0.7549,  0.0802,  ..., -0.9399, -1.4039,  0.7492]]],


        [[[-0.4198, -0.6881, -0.3376,  ...,  1.9633, -0.8430,  0.9501],
         

Epoch 1:   1%|▏         | 14/938 [00:06<06:43,  2.29it/s, loss=0.52494305]

batch_size 64
xt
tensor([[[[ 9.2957e-01,  2.7616e-01,  7.8286e-01,  ..., -1.9425e-03,
           -6.8126e-01,  1.1411e+00],
          [-4.5549e-01,  4.8030e-01, -5.5938e-01,  ...,  1.7935e-01,
           -8.1029e-01,  2.0524e+00],
          [-2.2162e-01, -7.2767e-01, -8.8907e-02,  ...,  1.1219e+00,
            4.6980e-02, -1.2403e+00],
          ...,
          [ 1.2089e+00,  5.2453e-01, -5.5615e-02,  ..., -2.6462e-01,
            1.0652e+00,  5.7635e-02],
          [ 1.3543e+00,  3.5571e-01, -4.5802e-01,  ..., -5.3633e-01,
           -5.3571e-01,  1.7349e+00],
          [ 1.8319e+00,  1.5033e+00, -1.4279e+00,  ...,  1.5929e-01,
            1.0927e+00,  8.0513e-01]]],


        [[[ 2.9912e-02, -4.0218e-02,  6.3454e-02,  ..., -3.0094e-02,
            6.9846e-02,  4.3688e-02],
          [-3.8309e-02,  3.7434e-02, -1.4330e-03,  ..., -7.7560e-02,
            1.7625e-01, -3.2220e-02],
          [ 9.0567e-02,  1.3376e-02, -1.7957e-02,  ..., -2.4721e-02,
           -6.0872e-02,  2.9167e-02],
 

Epoch 1:   2%|▏         | 15/938 [00:06<06:43,  2.29it/s, loss=0.5216646]

batch_size 64
xt
tensor([[[[-0.6509,  0.1181,  0.7729,  ..., -0.3753,  0.4840, -1.3178],
          [-0.0040, -0.7271, -0.3223,  ..., -0.8084,  0.3901, -0.3844],
          [ 0.2880, -0.8459,  0.1647,  ..., -0.2756,  0.4909,  0.8137],
          ...,
          [-1.1096,  1.4321,  0.4420,  ..., -0.1505, -0.9919,  0.1900],
          [ 0.1010,  0.2804, -1.2773,  ..., -0.7572,  2.0245, -1.2149],
          [-1.6470, -0.0361,  0.3348,  ...,  0.6757, -0.9167,  0.0853]]],


        [[[-0.3430,  1.2554, -0.2426,  ..., -0.3487, -2.7295,  1.3324],
          [ 0.5732, -1.4536,  1.1347,  ...,  0.1782, -0.8564, -0.0200],
          [ 0.9763, -0.4239,  0.8848,  ...,  0.0907, -1.4350,  0.1900],
          ...,
          [ 0.2539, -0.8748,  0.1629,  ...,  0.9383,  0.8381, -1.2087],
          [ 1.0499, -0.6500,  0.5763,  ...,  0.3376,  1.1465, -0.2314],
          [ 0.4049, -0.2654, -1.3411,  ...,  0.0442,  1.1895,  0.7413]]],


        [[[ 0.0862,  0.5290,  1.1010,  ..., -0.0985,  0.3842, -0.8132],
         

Epoch 1:   2%|▏         | 16/938 [00:07<06:43,  2.28it/s, loss=0.48344404]

batch_size 64
xt
tensor([[[[ 3.0503e-01,  1.4946e+00, -4.1225e-01,  ...,  2.2456e+00,
           -6.0749e-01,  9.4013e-01],
          [-1.5443e-01, -1.4679e+00,  1.8894e+00,  ...,  9.2147e-01,
            4.4991e-01,  2.1423e-01],
          [ 7.1702e-01, -8.1578e-01, -3.7135e-01,  ...,  7.0720e-01,
           -8.9522e-01, -9.8337e-01],
          ...,
          [ 9.5434e-01,  1.6368e-02, -2.0863e-01,  ...,  6.2087e-01,
            1.7563e+00,  6.1268e-01],
          [ 2.3609e+00, -8.3954e-01, -1.1331e+00,  ..., -9.6501e-01,
            5.8814e-02, -1.5716e+00],
          [ 6.3922e-02, -7.0817e-01, -5.7774e-01,  ...,  9.7802e-01,
            4.5947e-01, -1.2450e+00]]],


        [[[ 2.3098e+00, -1.8216e-01,  2.2283e-01,  ..., -8.5967e-01,
            1.1885e+00, -1.4247e+00],
          [ 6.9836e-01,  1.4227e+00, -3.7541e-01,  ..., -9.6781e-01,
           -1.4211e+00, -5.8189e-01],
          [ 6.9754e-01,  1.0030e+00,  1.8592e-01,  ..., -6.2950e-01,
            9.4997e-01,  1.2601e-01],
 

Epoch 1:   2%|▏         | 17/938 [00:07<06:44,  2.28it/s, loss=0.4551342]

batch_size 64
xt
tensor([[[[-1.1419e+00,  3.0491e-01,  6.3233e-01,  ...,  5.5298e-01,
            1.3594e+00, -2.9308e-01],
          [ 1.6598e-02,  5.5600e-01,  5.1346e-01,  ...,  1.4350e+00,
            6.3522e-01,  2.7818e-02],
          [ 9.3760e-01, -1.8729e+00,  2.0816e+00,  ..., -1.5512e+00,
            1.4546e+00,  2.6863e-01],
          ...,
          [ 1.5559e+00, -6.1939e-01, -1.3613e-01,  ..., -2.7589e-01,
            1.7533e-01, -7.4700e-01],
          [-1.2429e+00, -1.6337e-01, -3.4240e-01,  ..., -1.3310e+00,
           -1.2049e+00,  1.0332e+00],
          [-1.4864e-01, -1.4242e+00,  7.4951e-01,  ...,  4.5000e-01,
            6.2565e-02,  1.1274e+00]]],


        [[[-1.1083e-02,  8.2167e-01,  6.0709e-01,  ..., -1.0625e+00,
           -7.7167e-01,  9.1760e-01],
          [ 1.4466e+00,  1.8727e+00,  1.0018e+00,  ..., -2.1679e-01,
           -5.6902e-01, -5.9703e-01],
          [-6.2541e-01,  4.5644e-01, -4.2766e-01,  ..., -7.5972e-01,
           -1.0840e+00,  1.1546e+00],
 

Epoch 1:   2%|▏         | 18/938 [00:07<06:43,  2.28it/s, loss=0.45177886]

batch_size 64
xt
tensor([[[[ 0.1535, -0.4002,  0.2820,  ...,  0.1270, -0.2329, -0.0840],
          [ 0.1087, -0.2167,  0.3793,  ..., -0.1739, -0.0162,  0.3364],
          [-0.1500,  0.4704, -0.0852,  ..., -0.0760, -0.0528,  0.4072],
          ...,
          [ 0.0760, -0.0453,  0.3941,  ..., -0.1389,  0.0068,  0.3199],
          [-0.4250, -0.1919,  0.0730,  ..., -0.1713, -0.1071,  0.1145],
          [-0.0240,  0.5013,  0.2930,  ..., -0.2563, -0.1813, -0.2219]]],


        [[[-0.0187, -0.6797,  0.6446,  ..., -0.6249, -0.2824,  0.1232],
          [-0.7708,  1.1049,  0.0631,  ..., -0.1290, -1.1046, -0.2397],
          [ 0.2318, -1.0120, -0.7692,  ..., -0.7918,  1.8457, -0.6985],
          ...,
          [-0.5542, -0.8031, -0.8031,  ...,  0.1356,  1.1102,  0.5065],
          [-1.1439,  0.3274,  0.1905,  ..., -0.3258, -0.3211,  0.8008],
          [ 1.0227,  0.2598, -0.5600,  ..., -0.4707,  0.5221, -0.6480]]],


        [[[-0.2010,  2.2759, -0.4153,  ..., -0.3569, -0.2332,  0.6687],
         

Epoch 1:   2%|▏         | 19/938 [00:08<06:43,  2.28it/s, loss=0.38767987]

batch_size 64
xt
tensor([[[[-4.6876e-01,  2.7140e-01, -1.7543e+00,  ..., -1.0134e+00,
            4.7944e-01, -3.9600e-01],
          [ 3.6264e-02, -4.1258e-01, -3.3910e-01,  ...,  1.6415e+00,
           -5.2149e-01,  2.1778e-01],
          [ 8.4260e-01,  1.3352e+00, -4.7951e-02,  ..., -4.0792e-01,
            2.6452e-01,  3.8486e-01],
          ...,
          [ 7.9010e-02,  1.9888e+00, -1.4566e-01,  ...,  7.8003e-01,
           -1.2391e+00,  1.8702e+00],
          [-7.0815e-01, -2.6393e-01,  2.0116e-02,  ..., -1.7839e-01,
            1.3582e-01,  5.4139e-02],
          [ 1.3655e-01, -8.2511e-01,  1.9208e+00,  ..., -8.5395e-01,
           -8.9398e-01,  1.6580e+00]]],


        [[[-4.1503e-03, -3.5823e-02,  6.1318e-02,  ..., -4.3603e-02,
            6.6110e-02,  8.1450e-02],
          [-7.1401e-03,  3.8532e-02,  2.0786e-02,  ..., -1.9130e-02,
           -4.5576e-02, -4.6264e-02],
          [ 1.9544e-02,  5.1082e-02, -1.4661e-02,  ..., -2.9528e-02,
           -2.7222e-02, -3.0496e-02],
 

Epoch 1:   2%|▏         | 20/938 [00:08<06:47,  2.25it/s, loss=0.3753125]

batch_size 64
xt
tensor([[[[-2.8555e-01, -5.8461e-01, -6.3026e-01,  ...,  4.5765e-02,
           -1.2750e-03, -7.2346e-01],
          [-6.4807e-01,  6.7702e-01, -2.6329e-01,  ..., -6.5347e-01,
           -7.0946e-01, -7.3153e-01],
          [ 2.5992e-01, -1.5007e-02, -3.2779e-01,  ..., -3.3607e-01,
           -1.1098e+00,  1.5729e-02],
          ...,
          [ 2.2075e-01,  8.2330e-02, -1.0733e-01,  ..., -3.1140e-01,
           -7.1942e-01, -5.6305e-01],
          [-1.3245e-01, -7.7529e-02,  5.4025e-01,  ...,  3.3066e-01,
           -5.5211e-01, -3.7786e-01],
          [-3.3419e-01, -1.0456e+00, -6.8173e-01,  ..., -6.5332e-02,
           -1.1128e-01, -1.7380e-01]]],


        [[[-5.2519e-01, -1.4690e+00,  3.0692e-02,  ..., -4.8701e-01,
           -6.5554e-01, -1.0732e+00],
          [ 5.2549e-01, -7.9494e-01,  1.8722e+00,  ...,  6.3911e-01,
            8.1649e-01, -3.7886e-01],
          [ 4.8596e-01, -4.4823e-01,  6.2589e-01,  ..., -9.2237e-03,
           -6.8668e-01, -6.0927e-01],
 

Epoch 1:   2%|▏         | 21/938 [00:09<06:49,  2.24it/s, loss=0.37358695]

batch_size 64
xt
tensor([[[[-1.7272e-01,  5.6909e-01, -7.7000e-01,  ..., -1.5380e+00,
           -7.1788e-01,  9.8149e-01],
          [ 1.5755e+00,  1.6601e+00,  6.8737e-01,  ..., -5.8186e-01,
            1.7485e-01,  5.8667e-01],
          [ 9.9091e-01,  1.6209e+00,  3.1070e-02,  ...,  9.2906e-03,
           -2.8526e-01, -6.9781e-01],
          ...,
          [ 1.6808e+00, -1.0637e+00,  1.1610e+00,  ...,  1.2707e+00,
            5.8284e-01,  7.7196e-02],
          [ 3.5353e-01,  9.9049e-01, -4.1732e-01,  ..., -1.3356e+00,
           -2.9198e-01,  2.0604e-01],
          [-1.1209e-01,  1.1076e+00, -1.5328e-01,  ...,  1.4367e-01,
            8.1905e-01, -1.3682e+00]]],


        [[[-2.5210e-01,  3.9575e-01,  7.4242e-02,  ..., -2.1409e-01,
            2.1156e-01,  6.2468e-03],
          [ 2.6014e-01, -8.0379e-02, -4.8460e-02,  ...,  9.2144e-02,
           -7.0635e-02,  8.3338e-02],
          [-1.0523e-01, -8.4204e-02,  3.2499e-02,  ...,  7.3139e-02,
            1.1411e-01, -3.9389e-01],
 

Epoch 1:   2%|▏         | 22/938 [00:09<06:50,  2.23it/s, loss=0.38419914]

batch_size 64
xt
tensor([[[[-0.2930, -0.6163, -0.2473,  ..., -0.1415, -0.4261, -0.0146],
          [ 0.1226,  0.2911,  0.0840,  ...,  0.3413,  0.3700, -0.3555],
          [ 0.3784,  0.7343,  0.0581,  ...,  0.2098, -0.6009,  0.0681],
          ...,
          [-0.1506, -0.0686, -0.0948,  ..., -0.1235, -0.0800, -0.5938],
          [-0.2771, -0.0336,  0.0996,  ...,  0.2826, -0.4845, -0.1351],
          [ 0.5683, -0.3755, -0.2084,  ...,  0.4042, -0.3272,  0.0349]]],


        [[[ 0.4033, -0.2916, -0.8406,  ...,  1.0861,  0.3386,  0.5288],
          [ 0.0567,  0.1020, -0.2330,  ...,  0.2491,  0.5199, -0.5913],
          [-1.1943,  0.3765,  0.5740,  ...,  0.5389,  1.0873,  1.1644],
          ...,
          [ 1.1302, -0.3234,  0.3079,  ..., -0.6146, -1.3506,  0.6290],
          [-0.3777,  0.3735,  0.4341,  ..., -0.7857, -1.1215, -0.3883],
          [ 0.3093,  1.3664,  0.2416,  ...,  1.1188,  0.1549, -0.0134]]],


        [[[-0.1966,  0.0351,  0.0385,  ..., -2.1480,  1.3854, -0.6083],
         

Epoch 1:   2%|▏         | 23/938 [00:10<06:50,  2.23it/s, loss=0.33860916]

batch_size 64
xt
tensor([[[[ 1.1580e+00, -4.7975e-02,  2.2327e+00,  ...,  9.2377e-01,
           -1.5707e+00, -1.4301e+00],
          [-1.2205e+00, -1.2358e-01,  1.0413e-01,  ...,  8.9254e-02,
            3.6285e-01,  6.7036e-02],
          [-2.6339e-02,  8.0673e-01, -7.6725e-01,  ..., -1.1066e+00,
            1.3541e-02, -5.3202e-01],
          ...,
          [ 8.4616e-02,  2.0923e-01, -1.2817e+00,  ..., -3.1400e-01,
            1.0915e+00, -9.1136e-01],
          [-3.7284e-01,  1.2429e+00,  8.5432e-01,  ..., -1.0784e-02,
            1.5403e+00,  2.2174e+00],
          [ 2.2007e-01,  8.9879e-01,  9.9822e-01,  ...,  6.1671e-01,
           -7.0171e-01,  1.8063e+00]]],


        [[[-1.4131e+00,  2.6009e-02,  3.9378e-01,  ...,  1.1569e+00,
           -3.7412e-02, -1.5233e+00],
          [ 6.3768e-01,  3.9093e-01,  1.8190e-01,  ..., -1.2177e+00,
           -5.3801e-01,  1.0252e+00],
          [ 4.7907e-01, -7.9468e-01,  2.0009e-01,  ...,  2.1059e+00,
           -1.3310e-01, -7.6996e-01],
 

Epoch 1:   3%|▎         | 24/938 [00:10<06:50,  2.23it/s, loss=0.32563418]

batch_size 64
xt
tensor([[[[-1.7817, -1.3500, -3.1615,  ...,  2.6790,  0.9888, -0.3722],
          [ 1.3878,  0.6765, -0.1451,  ..., -3.0861, -1.7837,  0.0845],
          [-0.2181, -0.2960, -1.4475,  ...,  0.3782, -0.1824, -0.1989],
          ...,
          [-0.7151,  0.7264, -0.7611,  ...,  2.4564,  0.3913,  0.6291],
          [-0.7571, -0.3567, -1.3624,  ..., -0.2571, -1.0709,  1.3083],
          [-0.0658, -1.4763,  0.5874,  ..., -1.2456,  0.7047,  1.3608]]],


        [[[-0.0914, -1.6665,  0.2405,  ..., -0.8526,  0.3385, -1.0823],
          [-0.7291,  0.9334, -0.6494,  ...,  2.1323, -1.2603,  1.0823],
          [-0.0147,  0.4355,  0.1964,  ...,  1.1132,  0.0829, -0.2112],
          ...,
          [-0.7036, -1.2432,  1.1986,  ...,  2.0968,  1.1384, -0.0354],
          [ 0.9638, -0.2553, -2.0035,  ...,  1.7101,  1.3913, -0.4648],
          [-2.2884,  0.5996, -0.7445,  ...,  0.1414,  0.9050,  2.0423]]],


        [[[-0.5821,  0.8728, -0.7304,  ...,  1.0005, -0.8743,  0.3975],
         

Epoch 1:   3%|▎         | 25/938 [00:11<06:51,  2.22it/s, loss=0.3157868]

batch_size 64
xt
tensor([[[[-0.8980,  0.2941,  0.6953,  ..., -0.0707,  1.2123, -1.4297],
          [-0.3127, -0.7441,  1.1713,  ..., -1.9805,  0.7912,  1.7578],
          [ 2.8021,  3.0768, -0.8754,  ..., -0.1365,  0.2132, -0.0736],
          ...,
          [-1.2698, -1.2428, -0.5794,  ...,  0.7141, -1.6455, -1.2988],
          [-0.3013, -1.3792,  0.7712,  ..., -1.4764, -0.7256,  0.2174],
          [ 0.0112, -0.3655, -0.8004,  ..., -0.7312, -1.1666,  1.8519]]],


        [[[-0.0336, -0.3581,  0.3708,  ..., -1.4733,  1.1400,  0.5769],
          [ 0.0502, -1.2985, -0.6740,  ..., -0.1634,  0.5749, -0.8164],
          [ 0.0267,  0.3001,  0.3355,  ...,  1.1969,  1.0034,  0.8470],
          ...,
          [-1.0068, -0.5790, -2.0151,  ...,  0.8511, -0.2341,  0.1313],
          [-0.5440, -0.6852, -0.3591,  ...,  3.1717, -0.7083, -0.1011],
          [ 0.9585, -1.0520, -0.7865,  ...,  0.9746, -0.3078, -0.0785]]],


        [[[ 1.3889, -0.5545,  0.3240,  ..., -0.6470,  0.6465, -0.1318],
         

Epoch 1:   3%|▎         | 26/938 [00:11<06:47,  2.24it/s, loss=0.3114695]

batch_size 64
xt
tensor([[[[ 8.6083e-01,  7.2304e-01,  2.3927e-01,  ..., -3.6877e-01,
            8.7266e-01, -2.0938e-01],
          [ 9.8677e-01, -7.1775e-01,  3.7880e-01,  ..., -5.6015e-01,
            2.9746e-01,  2.9247e-01],
          [-3.5480e+00,  1.3551e-01,  1.8990e-01,  ..., -1.6992e+00,
           -1.0435e+00,  1.9609e+00],
          ...,
          [-1.5084e+00,  4.1072e-01, -4.9741e-02,  ...,  1.0325e+00,
           -1.2883e+00, -2.9901e-01],
          [-1.3250e+00, -1.0210e-01, -1.1085e+00,  ..., -1.3782e-02,
           -1.0542e+00,  6.7445e-01],
          [ 7.2246e-01,  5.4169e-01,  1.7122e+00,  ...,  4.3101e-01,
            1.0065e+00, -8.6580e-01]]],


        [[[ 6.2430e-01,  4.9198e-01,  8.7972e-01,  ...,  1.9500e-01,
           -1.2227e+00,  2.5073e-01],
          [ 1.2842e-01,  5.0086e-01,  1.7109e-01,  ..., -2.9719e-01,
           -3.3638e-01, -5.6187e-02],
          [ 1.0369e+00, -1.0853e+00,  4.2407e-01,  ..., -1.2386e+00,
            7.4392e-01, -4.1157e-02],
 

Epoch 1:   3%|▎         | 27/938 [00:11<06:45,  2.25it/s, loss=0.30675554]

batch_size 64
xt
tensor([[[[-0.6309, -0.8932,  0.4915,  ..., -0.6297,  0.4208,  1.0657],
          [ 0.2753,  0.5094,  0.7455,  ..., -0.0716,  0.2508,  0.8621],
          [ 2.2649,  1.4224,  1.2083,  ...,  0.0619,  0.0261, -0.2950],
          ...,
          [-0.7463, -2.8418,  1.3157,  ..., -0.0826, -1.4402,  0.3511],
          [-1.0934,  1.0285,  1.0442,  ...,  1.5476, -0.2716,  0.0440],
          [-0.3785, -1.0732, -0.8142,  ..., -1.2226,  2.5734, -0.6547]]],


        [[[ 0.5557, -0.0852,  1.3993,  ..., -0.4122,  1.0567,  0.1671],
          [ 1.1302, -0.0823, -0.9650,  ...,  0.6922, -0.1457,  1.6140],
          [ 0.3526,  0.7570,  0.5362,  ...,  0.2619,  0.7542, -2.7338],
          ...,
          [ 0.4404,  0.3666,  1.2472,  ..., -1.8202,  1.0451,  0.3842],
          [ 0.3045,  1.0596, -0.2564,  ...,  2.6061, -0.4100, -1.4095],
          [-0.4239, -0.7999,  0.7267,  ..., -0.5107, -0.4845, -2.2993]]],


        [[[ 0.6666,  1.6663,  1.1401,  ...,  0.3454, -0.2810, -0.0921],
         

Epoch 1:   3%|▎         | 28/938 [00:12<06:43,  2.25it/s, loss=0.28381217]

batch_size 64
xt
tensor([[[[-4.4198e-01, -8.5665e-01, -2.3688e-01,  ...,  3.7351e-02,
           -4.9816e-01, -4.7237e-01],
          [-4.7403e-01,  2.7662e-01, -7.1171e-01,  ...,  4.0212e-02,
            4.4197e-01,  6.1439e-01],
          [-3.6909e-01,  6.5473e-01, -3.7221e-01,  ...,  1.0273e+00,
            6.7932e-01,  1.6677e-01],
          ...,
          [ 9.9112e-01,  5.9140e-02, -5.8714e-02,  ...,  7.1829e-01,
            7.4232e-01,  4.8378e-01],
          [-1.3563e-01, -5.1711e-01,  6.0240e-01,  ..., -5.7835e-01,
           -1.0300e-01,  2.0790e-01],
          [-1.9776e-01, -6.1203e-01,  1.0342e+00,  ...,  2.1918e-01,
            3.2407e-01,  3.1104e-02]]],


        [[[ 6.2440e-01,  9.3409e-01,  1.0924e+00,  ...,  7.5384e-01,
            5.9178e-01,  2.3692e-01],
          [ 4.2502e-01,  2.7041e-01, -1.0137e+00,  ..., -5.8263e-01,
            5.8169e-01, -7.6341e-01],
          [ 2.3085e-01,  1.1248e-01, -7.9620e-01,  ...,  9.3352e-01,
            1.9209e+00,  1.3764e+00],
 

Epoch 1:   3%|▎         | 29/938 [00:12<06:43,  2.25it/s, loss=0.22157463]

batch_size 64
xt
tensor([[[[ 2.5070e-01,  1.7398e-01,  4.9856e-01,  ..., -5.7081e-01,
           -7.2030e-01, -5.4310e-02],
          [ 1.1872e+00, -3.0078e-01, -3.4789e-01,  ..., -6.5798e-01,
            5.7805e-01,  4.6427e-01],
          [-6.3516e-01, -6.8977e-01, -9.4323e-01,  ...,  1.5579e+00,
           -8.6770e-01,  3.4558e-01],
          ...,
          [-4.0094e-01,  1.9499e+00,  6.3133e-01,  ..., -2.3009e+00,
            9.1694e-01,  2.2438e-01],
          [ 4.2792e-01, -1.6600e+00,  4.5541e-01,  ..., -4.7650e-01,
            5.8142e-01,  8.3373e-01],
          [ 1.5362e+00,  9.7805e-01, -6.4721e-01,  ...,  1.1667e+00,
           -9.9498e-01, -2.3602e+00]]],


        [[[ 2.2655e-01, -1.3396e-01, -2.0035e-01,  ...,  3.5132e-02,
            1.6872e-01, -8.3434e-02],
          [ 1.2825e-02, -1.5787e-01, -2.7445e-01,  ..., -1.0390e-01,
            2.3059e-02, -5.4716e-03],
          [-4.1898e-02, -1.0209e-01, -2.2030e-01,  ..., -8.1970e-02,
           -8.9110e-02,  1.6361e-02],
 

Epoch 1:   3%|▎         | 30/938 [00:13<06:42,  2.26it/s, loss=0.23642462]

batch_size 64
xt
tensor([[[[ 5.1841e-01,  1.5778e-01,  3.8648e-01,  ..., -1.9706e-01,
           -4.4896e-01, -1.3333e+00],
          [-2.6320e-01, -8.0434e-01,  2.3660e-01,  ...,  2.9630e-01,
           -6.7823e-02,  2.4959e-01],
          [ 1.1615e+00, -3.1143e-01,  3.7137e-01,  ..., -2.0262e-01,
            6.6252e-01, -6.3509e-02],
          ...,
          [ 8.9269e-01, -7.7015e-02, -3.1826e-01,  ..., -4.9585e-01,
            5.4731e-02,  3.3630e-01],
          [ 2.8544e-01,  1.7892e-01,  4.4316e-01,  ..., -7.2038e-01,
            1.3129e-01, -7.3345e-02],
          [ 2.6592e-01, -7.9412e-01,  4.6906e-01,  ..., -5.3999e-02,
            3.0783e-01, -1.1604e-01]]],


        [[[-6.3579e-01, -1.1267e-02, -7.9362e-02,  ..., -1.2321e+00,
            6.5215e-01,  7.8956e-01],
          [-1.1571e+00,  1.2311e+00,  1.0616e+00,  ...,  1.7394e+00,
           -5.7355e-01, -6.8268e-01],
          [-6.2122e-01,  1.2348e+00,  7.4948e-01,  ...,  1.2769e+00,
           -1.1052e-02,  2.4410e-01],
 

Epoch 1:   3%|▎         | 31/938 [00:13<06:42,  2.25it/s, loss=0.22088155]

batch_size 64
xt
tensor([[[[ 0.1184, -0.2421,  0.1584,  ..., -0.1943,  0.1378, -0.3020],
          [-0.1759, -0.0646,  0.1051,  ..., -0.0905,  0.2416,  0.2919],
          [-0.2954,  0.0951, -0.0357,  ..., -0.2872,  0.3034, -0.3042],
          ...,
          [ 0.3011, -0.0043, -0.0923,  ...,  0.3258,  0.4825, -0.6007],
          [-0.2358, -0.1561,  0.1580,  ...,  0.3702,  0.1222, -0.1310],
          [ 0.5020,  0.1695,  0.1549,  ..., -0.2231,  0.2219, -0.8553]]],


        [[[-1.8279, -0.0769, -1.2472,  ..., -0.8862,  0.3655, -0.5269],
          [-1.1606, -2.0846, -0.2761,  ...,  0.5587, -0.9940,  0.7065],
          [-1.0124, -0.9097,  1.2255,  ...,  0.8903, -0.7052, -1.4093],
          ...,
          [ 2.8524,  0.2497, -0.2454,  ..., -0.5188, -1.5399, -0.1621],
          [ 1.0478, -0.2945,  1.1114,  ...,  0.9764,  0.5104, -0.5579],
          [-1.1774, -0.1119,  1.5941,  ...,  0.3562, -0.2551,  0.3147]]],


        [[[-0.3693, -1.5592,  0.8970,  ..., -0.9551, -1.7835, -0.5925],
         

Epoch 1:   3%|▎         | 31/938 [00:13<06:48,  2.22it/s, loss=0.22088155]


KeyboardInterrupt: 