<a href="https://colab.research.google.com/github/mkoko22/AMP_Generation_using_LDM/blob/main/generation/Phase_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import wandb
from google.colab import drive

drive.mount('/content/drive')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

diff_config = {
    "latent_dim": 64,
    "n_steps": 500,
    "lr": 0.0001,
    "batch_size": 512,
    "epochs": 100,
    "gamma": 0.99,
    "device": device
}


def get_paper_params(n_steps, device):
    betas = torch.linspace(0.0001, 0.02, n_steps).to(device)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    return betas, alphas, alphas_cumprod

betas, alphas, alphas_cumprod = get_paper_params(diff_config["n_steps"], device)


class BERTDiffusion(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        # Three fully connected networks
        self.time_mlp = nn.Embedding(diff_config["n_steps"], 512)
        self.input_proj = nn.Linear(latent_dim, 512)
        self.output_proj = nn.Linear(512, latent_dim)

        # Embedding layers
        self.pos_embed = nn.Parameter(torch.randn(1, 1, 512))

        # BERT encoding layer
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=512,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            batch_first=True
        )
        self.bert_encoding_layer = nn.TransformerEncoder(encoder_layers, num_layers=4)

        # Normalization and Dropout
        self.layer_norm = nn.LayerNorm(512)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x, t):
        # FC Processing
        x_feat = self.input_proj(x).unsqueeze(1)
        t_emb = self.time_mlp(t).unsqueeze(1)

        # Feature Addition: Input + Time + Position
        h = x_feat + t_emb + self.pos_embed

        # Normalization and Dropout
        h = self.layer_norm(h)
        h = self.dropout(h)

        # BERT Backbone
        h = self.bert_encoding_layer(h)

        # Final FC Processing
        return self.output_proj(h.squeeze(1))

train_latents = torch.load("/content/drive/MyDrive/AMP-Generation/data/latent_uncond_train.pth")
val_latents = torch.load("/content/drive/MyDrive/AMP-Generation/data/latent_uncond_val.pth")

train_loader = DataLoader(TensorDataset(train_latents), batch_size=diff_config["batch_size"], shuffle=True)
val_loader = DataLoader(TensorDataset(val_latents), batch_size=diff_config["batch_size"])

model = BERTDiffusion(diff_config["latent_dim"]).to(device)
optimizer = optim.AdamW(model.parameters(), lr=diff_config["lr"])
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=diff_config["gamma"])
criterion = nn.MSELoss()

checkpoint_path = None
start_epoch = 0

if checkpoint_path is not None and os.path.exists(checkpoint_path):
    print(f"--- Loading Checkpoint: {checkpoint_path} ---")
    ckpt = torch.load(checkpoint_path, map_location=device)
    if isinstance(ckpt, dict) and 'model_state_dict' in ckpt:
        model.load_state_dict(ckpt['model_state_dict'])
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        start_epoch = ckpt['epoch']
    else:
        model.load_state_dict(ckpt)
        start_epoch = 50
    print(f"Resuming at Epoch {start_epoch}")

wandb.init(project="AMP-Latent-Diffusion-Full-Paper", config=diff_config, resume="allow")


@torch.no_grad()
def generate_samples(model, num_samples=512):
    model.eval()
    xt = torch.randn(num_samples, diff_config["latent_dim"]).to(device)

    for t_idx in reversed(range(diff_config["n_steps"])):
        t_batch = torch.full((num_samples,), t_idx, device=device, dtype=torch.long)

        x0_pred = model(xt, t_batch)

        if t_idx > 0:
            a_bar = alphas_cumprod[t_idx]
            a_bar_prev = alphas_cumprod[t_idx - 1]
            beta_t = betas[t_idx]
            alpha_t = alphas[t_idx]
            mu = ((torch.sqrt(a_bar_prev) * beta_t) / (1 - a_bar)) * x0_pred + \
                 ((torch.sqrt(alpha_t) * (1 - a_bar_prev)) / (1 - a_bar)) * xt

            var = ((1 - a_bar_prev) / (1 - a_bar)) * beta_t
            xt = mu + torch.sqrt(var) * torch.randn_like(xt)
        else:
            xt = x0_pred

    return xt


for epoch in range(start_epoch, diff_config["epochs"]):
    model.train()
    train_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        x0 = batch[0].to(device)
        t = torch.randint(0, diff_config["n_steps"], (x0.shape[0],), device=device)

        epsilon = torch.randn_like(x0)
        a_bar = alphas_cumprod[t].view(-1, 1)
        xt = torch.sqrt(a_bar) * x0 + torch.sqrt(1 - a_bar) * epsilon

        x0_pred = model(xt, t)
        loss = criterion(x0_pred, x0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    scheduler.step()


    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            x0_v = batch[0].to(device)
            t_v = torch.randint(0, diff_config["n_steps"], (x0_v.shape[0],), device=device)
            a_bar_v = alphas_cumprod[t_v].view(-1, 1)
            xt_v = torch.sqrt(a_bar_v) * x0_v + torch.sqrt(1 - a_bar_v) * torch.randn_like(x0_v)
            v_pred = model(xt_v, t_v)
            val_loss += criterion(v_pred, x0_v).item()

    avg_train = train_loss / len(train_loader)
    avg_val = val_loss / len(val_loader)

    wandb.log({
        "train_loss": avg_train,
        "val_loss": avg_val,
        "learning_rate": scheduler.get_last_lr()[0],
        "epoch": epoch + 1
    })

    print(f"Epoch {epoch+1} | Loss: {avg_train:.6f} | Val: {avg_val:.6f} | LR: {scheduler.get_last_lr()[0]:.6f}")


    if (epoch + 1) % 5 == 0:
        save_path = f"/content/drive/MyDrive/AMP-Generation/checkpoints/diffusion_paper_final_ep{epoch+1}.pth"
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_val,
        }, save_path)

wandb.finish()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mtvani22[0m ([33mAMP-Generation[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1: 100%|██████████| 892/892 [00:30<00:00, 28.78it/s]


Epoch 1 | Loss: 1.590168 | Val: 1.335600 | LR: 0.000099


Epoch 2: 100%|██████████| 892/892 [00:30<00:00, 29.32it/s]


Epoch 2 | Loss: 1.383861 | Val: 1.273472 | LR: 0.000098


Epoch 3: 100%|██████████| 892/892 [00:30<00:00, 29.13it/s]


Epoch 3 | Loss: 1.341043 | Val: 1.260573 | LR: 0.000097


Epoch 4: 100%|██████████| 892/892 [00:30<00:00, 29.11it/s]


Epoch 4 | Loss: 1.316444 | Val: 1.250381 | LR: 0.000096


Epoch 5: 100%|██████████| 892/892 [00:30<00:00, 29.17it/s]


Epoch 5 | Loss: 1.299712 | Val: 1.242009 | LR: 0.000095


Epoch 6: 100%|██████████| 892/892 [00:30<00:00, 28.93it/s]


Epoch 6 | Loss: 1.288795 | Val: 1.227174 | LR: 0.000094


Epoch 7: 100%|██████████| 892/892 [00:30<00:00, 29.07it/s]


Epoch 7 | Loss: 1.282215 | Val: 1.212868 | LR: 0.000093


Epoch 8: 100%|██████████| 892/892 [00:30<00:00, 29.14it/s]


Epoch 8 | Loss: 1.274317 | Val: 1.218134 | LR: 0.000092


Epoch 9: 100%|██████████| 892/892 [00:30<00:00, 29.21it/s]


Epoch 9 | Loss: 1.269448 | Val: 1.215447 | LR: 0.000091


Epoch 10: 100%|██████████| 892/892 [00:30<00:00, 28.99it/s]


Epoch 10 | Loss: 1.264168 | Val: 1.216370 | LR: 0.000090


Epoch 11: 100%|██████████| 892/892 [00:30<00:00, 29.21it/s]


Epoch 11 | Loss: 1.259741 | Val: 1.213639 | LR: 0.000090


Epoch 12: 100%|██████████| 892/892 [00:30<00:00, 29.21it/s]


Epoch 12 | Loss: 1.258318 | Val: 1.204791 | LR: 0.000089


Epoch 13: 100%|██████████| 892/892 [00:30<00:00, 29.21it/s]


Epoch 13 | Loss: 1.250329 | Val: 1.201869 | LR: 0.000088


Epoch 14: 100%|██████████| 892/892 [00:30<00:00, 29.12it/s]


Epoch 14 | Loss: 1.252489 | Val: 1.209403 | LR: 0.000087


Epoch 15: 100%|██████████| 892/892 [00:30<00:00, 29.12it/s]


Epoch 15 | Loss: 1.247773 | Val: 1.210218 | LR: 0.000086


Epoch 16: 100%|██████████| 892/892 [00:30<00:00, 28.79it/s]


Epoch 16 | Loss: 1.246590 | Val: 1.211509 | LR: 0.000085


Epoch 17: 100%|██████████| 892/892 [00:30<00:00, 29.24it/s]


Epoch 17 | Loss: 1.244488 | Val: 1.213680 | LR: 0.000084


Epoch 18: 100%|██████████| 892/892 [00:30<00:00, 29.12it/s]


Epoch 18 | Loss: 1.246029 | Val: 1.197468 | LR: 0.000083


Epoch 19: 100%|██████████| 892/892 [00:30<00:00, 29.15it/s]


Epoch 19 | Loss: 1.240106 | Val: 1.193400 | LR: 0.000083


Epoch 20: 100%|██████████| 892/892 [00:30<00:00, 29.10it/s]


Epoch 20 | Loss: 1.241619 | Val: 1.205098 | LR: 0.000082


Epoch 21: 100%|██████████| 892/892 [00:31<00:00, 28.74it/s]


Epoch 21 | Loss: 1.238208 | Val: 1.210457 | LR: 0.000081


Epoch 22: 100%|██████████| 892/892 [00:30<00:00, 29.14it/s]


Epoch 22 | Loss: 1.237922 | Val: 1.193507 | LR: 0.000080


Epoch 23: 100%|██████████| 892/892 [00:30<00:00, 29.15it/s]


Epoch 23 | Loss: 1.239139 | Val: 1.200585 | LR: 0.000079


Epoch 24: 100%|██████████| 892/892 [00:30<00:00, 29.17it/s]


Epoch 24 | Loss: 1.237134 | Val: 1.198729 | LR: 0.000079


Epoch 25: 100%|██████████| 892/892 [00:30<00:00, 29.07it/s]


Epoch 25 | Loss: 1.235024 | Val: 1.202413 | LR: 0.000078


Epoch 26: 100%|██████████| 892/892 [00:31<00:00, 28.72it/s]


Epoch 26 | Loss: 1.230760 | Val: 1.212338 | LR: 0.000077


Epoch 27: 100%|██████████| 892/892 [00:30<00:00, 29.18it/s]


Epoch 27 | Loss: 1.233744 | Val: 1.192159 | LR: 0.000076


Epoch 28: 100%|██████████| 892/892 [00:30<00:00, 29.14it/s]


Epoch 28 | Loss: 1.231489 | Val: 1.200116 | LR: 0.000075


Epoch 29: 100%|██████████| 892/892 [00:30<00:00, 29.08it/s]


Epoch 29 | Loss: 1.232834 | Val: 1.204593 | LR: 0.000075


Epoch 30: 100%|██████████| 892/892 [00:30<00:00, 29.22it/s]


Epoch 30 | Loss: 1.229411 | Val: 1.198373 | LR: 0.000074


Epoch 31: 100%|██████████| 892/892 [00:30<00:00, 28.95it/s]


Epoch 31 | Loss: 1.228420 | Val: 1.197571 | LR: 0.000073


Epoch 32: 100%|██████████| 892/892 [00:30<00:00, 29.10it/s]


Epoch 32 | Loss: 1.229894 | Val: 1.214724 | LR: 0.000072


Epoch 33: 100%|██████████| 892/892 [00:30<00:00, 29.22it/s]


Epoch 33 | Loss: 1.224527 | Val: 1.200153 | LR: 0.000072


Epoch 34: 100%|██████████| 892/892 [00:30<00:00, 29.19it/s]


Epoch 34 | Loss: 1.227150 | Val: 1.189345 | LR: 0.000071


Epoch 35: 100%|██████████| 892/892 [00:30<00:00, 29.16it/s]


Epoch 35 | Loss: 1.227853 | Val: 1.189725 | LR: 0.000070


Epoch 36: 100%|██████████| 892/892 [00:30<00:00, 28.85it/s]


Epoch 36 | Loss: 1.223941 | Val: 1.197539 | LR: 0.000070


Epoch 37: 100%|██████████| 892/892 [00:30<00:00, 29.22it/s]


Epoch 37 | Loss: 1.225388 | Val: 1.206625 | LR: 0.000069


Epoch 38: 100%|██████████| 892/892 [00:30<00:00, 29.24it/s]


Epoch 38 | Loss: 1.229422 | Val: 1.194855 | LR: 0.000068


Epoch 39: 100%|██████████| 892/892 [00:30<00:00, 29.19it/s]


Epoch 39 | Loss: 1.226607 | Val: 1.193610 | LR: 0.000068


Epoch 40: 100%|██████████| 892/892 [00:30<00:00, 29.23it/s]


Epoch 40 | Loss: 1.226256 | Val: 1.189902 | LR: 0.000067


Epoch 41: 100%|██████████| 892/892 [00:30<00:00, 28.98it/s]


Epoch 41 | Loss: 1.226022 | Val: 1.196896 | LR: 0.000066


Epoch 42: 100%|██████████| 892/892 [00:30<00:00, 29.28it/s]


Epoch 42 | Loss: 1.224958 | Val: 1.192921 | LR: 0.000066


Epoch 43: 100%|██████████| 892/892 [00:30<00:00, 29.12it/s]


Epoch 43 | Loss: 1.223093 | Val: 1.192423 | LR: 0.000065


Epoch 44: 100%|██████████| 892/892 [00:30<00:00, 29.24it/s]


Epoch 44 | Loss: 1.222859 | Val: 1.205527 | LR: 0.000064


Epoch 45: 100%|██████████| 892/892 [00:30<00:00, 29.24it/s]


Epoch 45 | Loss: 1.223750 | Val: 1.202102 | LR: 0.000064


Epoch 46: 100%|██████████| 892/892 [00:30<00:00, 29.09it/s]


Epoch 46 | Loss: 1.223325 | Val: 1.190545 | LR: 0.000063


Epoch 47: 100%|██████████| 892/892 [00:30<00:00, 29.17it/s]


Epoch 47 | Loss: 1.220697 | Val: 1.188741 | LR: 0.000062


Epoch 48: 100%|██████████| 892/892 [00:30<00:00, 29.03it/s]


Epoch 48 | Loss: 1.220397 | Val: 1.191617 | LR: 0.000062


Epoch 49: 100%|██████████| 892/892 [00:30<00:00, 29.20it/s]


Epoch 49 | Loss: 1.220018 | Val: 1.190576 | LR: 0.000061


Epoch 50: 100%|██████████| 892/892 [00:30<00:00, 29.16it/s]


Epoch 50 | Loss: 1.217561 | Val: 1.189931 | LR: 0.000061


Epoch 51: 100%|██████████| 892/892 [00:30<00:00, 29.13it/s]


Epoch 51 | Loss: 1.220100 | Val: 1.184729 | LR: 0.000060


Epoch 52: 100%|██████████| 892/892 [00:30<00:00, 29.10it/s]


Epoch 52 | Loss: 1.221158 | Val: 1.186331 | LR: 0.000059


Epoch 53: 100%|██████████| 892/892 [00:30<00:00, 29.22it/s]


Epoch 53 | Loss: 1.218844 | Val: 1.193002 | LR: 0.000059


Epoch 54: 100%|██████████| 892/892 [00:30<00:00, 29.18it/s]


Epoch 54 | Loss: 1.218723 | Val: 1.200842 | LR: 0.000058


Epoch 55: 100%|██████████| 892/892 [00:30<00:00, 29.24it/s]


Epoch 55 | Loss: 1.219781 | Val: 1.195293 | LR: 0.000058


Epoch 56: 100%|██████████| 892/892 [00:30<00:00, 29.01it/s]


Epoch 56 | Loss: 1.218922 | Val: 1.193942 | LR: 0.000057


Epoch 57: 100%|██████████| 892/892 [00:30<00:00, 29.12it/s]


Epoch 57 | Loss: 1.219968 | Val: 1.190979 | LR: 0.000056


Epoch 58: 100%|██████████| 892/892 [00:30<00:00, 29.24it/s]


Epoch 58 | Loss: 1.218881 | Val: 1.192285 | LR: 0.000056


Epoch 59: 100%|██████████| 892/892 [00:30<00:00, 29.24it/s]


Epoch 59 | Loss: 1.218429 | Val: 1.199720 | LR: 0.000055


Epoch 60: 100%|██████████| 892/892 [00:30<00:00, 29.16it/s]


Epoch 60 | Loss: 1.217005 | Val: 1.182746 | LR: 0.000055


Epoch 61: 100%|██████████| 892/892 [00:30<00:00, 28.90it/s]


Epoch 61 | Loss: 1.215297 | Val: 1.190921 | LR: 0.000054


Epoch 62: 100%|██████████| 892/892 [00:30<00:00, 29.27it/s]


Epoch 62 | Loss: 1.216231 | Val: 1.194819 | LR: 0.000054


Epoch 63: 100%|██████████| 892/892 [00:30<00:00, 29.25it/s]


Epoch 63 | Loss: 1.215023 | Val: 1.192305 | LR: 0.000053


Epoch 64: 100%|██████████| 892/892 [00:30<00:00, 29.20it/s]


Epoch 64 | Loss: 1.216119 | Val: 1.188431 | LR: 0.000053


Epoch 65: 100%|██████████| 892/892 [00:30<00:00, 29.25it/s]


Epoch 65 | Loss: 1.213942 | Val: 1.189169 | LR: 0.000052


Epoch 66: 100%|██████████| 892/892 [00:30<00:00, 28.97it/s]


Epoch 66 | Loss: 1.217462 | Val: 1.186184 | LR: 0.000052


Epoch 67: 100%|██████████| 892/892 [00:30<00:00, 29.24it/s]


Epoch 67 | Loss: 1.217120 | Val: 1.175888 | LR: 0.000051


Epoch 68: 100%|██████████| 892/892 [00:30<00:00, 29.12it/s]


Epoch 68 | Loss: 1.215895 | Val: 1.187073 | LR: 0.000050


Epoch 69: 100%|██████████| 892/892 [00:30<00:00, 29.24it/s]


Epoch 69 | Loss: 1.216642 | Val: 1.194863 | LR: 0.000050


Epoch 70: 100%|██████████| 892/892 [00:30<00:00, 29.21it/s]


Epoch 70 | Loss: 1.213929 | Val: 1.193103 | LR: 0.000049


Epoch 71: 100%|██████████| 892/892 [00:30<00:00, 28.84it/s]


Epoch 71 | Loss: 1.215580 | Val: 1.181505 | LR: 0.000049


Epoch 72: 100%|██████████| 892/892 [00:30<00:00, 29.22it/s]


Epoch 72 | Loss: 1.215783 | Val: 1.199764 | LR: 0.000048


Epoch 73: 100%|██████████| 892/892 [00:30<00:00, 29.22it/s]


Epoch 73 | Loss: 1.215100 | Val: 1.187707 | LR: 0.000048


Epoch 74: 100%|██████████| 892/892 [00:30<00:00, 29.23it/s]


Epoch 74 | Loss: 1.214365 | Val: 1.192067 | LR: 0.000048


Epoch 75: 100%|██████████| 892/892 [00:30<00:00, 29.10it/s]


Epoch 75 | Loss: 1.212438 | Val: 1.182851 | LR: 0.000047


Epoch 76: 100%|██████████| 892/892 [00:30<00:00, 28.88it/s]


Epoch 76 | Loss: 1.213646 | Val: 1.187753 | LR: 0.000047


Epoch 77: 100%|██████████| 892/892 [00:30<00:00, 29.26it/s]


Epoch 77 | Loss: 1.213135 | Val: 1.196573 | LR: 0.000046


Epoch 78: 100%|██████████| 892/892 [00:30<00:00, 29.22it/s]


Epoch 78 | Loss: 1.214365 | Val: 1.200424 | LR: 0.000046


Epoch 79: 100%|██████████| 892/892 [00:30<00:00, 29.18it/s]


Epoch 79 | Loss: 1.213120 | Val: 1.190120 | LR: 0.000045


Epoch 80: 100%|██████████| 892/892 [00:30<00:00, 29.31it/s]


Epoch 80 | Loss: 1.212652 | Val: 1.190268 | LR: 0.000045


Epoch 81: 100%|██████████| 892/892 [00:30<00:00, 29.03it/s]


Epoch 81 | Loss: 1.213359 | Val: 1.188295 | LR: 0.000044


Epoch 82: 100%|██████████| 892/892 [00:30<00:00, 29.13it/s]


Epoch 82 | Loss: 1.214936 | Val: 1.194127 | LR: 0.000044


Epoch 83: 100%|██████████| 892/892 [00:30<00:00, 29.22it/s]


Epoch 83 | Loss: 1.210658 | Val: 1.181614 | LR: 0.000043


Epoch 84: 100%|██████████| 892/892 [00:30<00:00, 29.24it/s]


Epoch 84 | Loss: 1.212745 | Val: 1.190595 | LR: 0.000043


Epoch 85: 100%|██████████| 892/892 [00:30<00:00, 29.18it/s]


Epoch 85 | Loss: 1.210459 | Val: 1.194729 | LR: 0.000043


Epoch 86: 100%|██████████| 892/892 [00:31<00:00, 28.62it/s]


Epoch 86 | Loss: 1.212757 | Val: 1.195976 | LR: 0.000042


Epoch 87: 100%|██████████| 892/892 [00:30<00:00, 28.99it/s]


Epoch 87 | Loss: 1.210088 | Val: 1.187220 | LR: 0.000042


Epoch 88: 100%|██████████| 892/892 [00:30<00:00, 29.08it/s]


Epoch 88 | Loss: 1.213235 | Val: 1.173456 | LR: 0.000041


Epoch 89: 100%|██████████| 892/892 [00:30<00:00, 28.99it/s]


Epoch 89 | Loss: 1.212570 | Val: 1.184217 | LR: 0.000041


Epoch 90: 100%|██████████| 892/892 [00:30<00:00, 29.00it/s]


Epoch 90 | Loss: 1.212962 | Val: 1.184105 | LR: 0.000040


Epoch 91: 100%|██████████| 892/892 [00:31<00:00, 28.72it/s]


Epoch 91 | Loss: 1.213886 | Val: 1.186185 | LR: 0.000040


Epoch 92: 100%|██████████| 892/892 [00:30<00:00, 29.08it/s]


Epoch 92 | Loss: 1.212179 | Val: 1.185874 | LR: 0.000040


Epoch 93: 100%|██████████| 892/892 [00:30<00:00, 28.86it/s]


Epoch 93 | Loss: 1.211259 | Val: 1.188422 | LR: 0.000039


Epoch 94: 100%|██████████| 892/892 [00:30<00:00, 28.91it/s]


Epoch 94 | Loss: 1.215013 | Val: 1.186665 | LR: 0.000039


Epoch 95: 100%|██████████| 892/892 [00:30<00:00, 28.80it/s]


Epoch 95 | Loss: 1.209037 | Val: 1.183390 | LR: 0.000038


Epoch 96: 100%|██████████| 892/892 [00:31<00:00, 28.54it/s]


Epoch 96 | Loss: 1.208278 | Val: 1.185150 | LR: 0.000038


Epoch 97: 100%|██████████| 892/892 [00:30<00:00, 28.97it/s]


Epoch 97 | Loss: 1.209459 | Val: 1.181749 | LR: 0.000038


Epoch 98: 100%|██████████| 892/892 [00:30<00:00, 28.92it/s]


Epoch 98 | Loss: 1.211965 | Val: 1.183165 | LR: 0.000037


Epoch 99: 100%|██████████| 892/892 [00:30<00:00, 29.00it/s]


Epoch 99 | Loss: 1.212304 | Val: 1.187907 | LR: 0.000037


Epoch 100: 100%|██████████| 892/892 [00:30<00:00, 28.84it/s]


Epoch 100 | Loss: 1.211141 | Val: 1.188702 | LR: 0.000037


0,1
epoch,▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇█████
learning_rate,█████▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁
train_loss,█▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▇▅▄▃▃▃▃▄▃▃▃▂▃▂▃▂▂▂▂▃▂▂▃▃▂▁▃▂▂▃▂▃▁▃▂▂▂▂▂

0,1
epoch,100.0
learning_rate,4e-05
train_loss,1.21114
val_loss,1.1887
