## Normalization

In [1]:
import h5py
import torch
import numpy as np
import matplotlib.pyplot as plt

# with h5py.File(f'data_large/Burgers_train_100000_default.h5', 'r') as f:
#     # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:10000, :131], dtype=torch.float32, device=cfg.device)
#     traj = torch.tensor(f['train']['pde_140-256'][:1000, :131], dtype=torch.float32)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
import hydra
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm

hydra.initialize(config_path="cfg_time_batch", version_base=None)
cfg = hydra.compose(config_name="config", overrides=["task=KdV", "nt=131", "initial_datasize=64", "model.n_modes=[256]"])

from utils import set_seed
set_seed(100)

In [3]:
def get_gradient_norm(model):
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    return total_norm ** 0.5

In [4]:
with h5py.File(cfg.dataset.train_path, 'r') as f:
    # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:10000, :131], dtype=torch.float32, device=cfg.device)
    traj = torch.tensor(f['train']['pde_140-256'][:1000, :131], dtype=torch.float32)

mean = 0
std = 1
print(f'Mean: {mean}, Std: {std}')
traj = (traj - mean) / std

from neuralop.models import FNO

unrolling = cfg.train.unrolling
nt = cfg.nt
ensemble_size = cfg.ensemble_size
num_acquire = cfg.num_acquire
device = cfg.device
epochs = cfg.train.epochs
lr = cfg.train.lr
batch_size = cfg.train.batch_size
initial_datasize = cfg.initial_datasize



def train(Y, train_nts, **kwargs):
    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    inputs = []
    outputs = []
    for b in range(Y.shape[0]):
        for t in range(train_nts[b].item()-1):
            inputs.append(Y[b,t])
            outputs.append(Y[b, t+1])
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)

    dataset = torch.utils.data.TensorDataset(inputs, outputs)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model.train()
    for epoch in tqdm(range(epochs)):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        max_grad_norm = 0
        for x, y in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()

            current_grad_norm = get_gradient_norm(model)
            max_grad_norm = max(max_grad_norm, current_grad_norm)

            optimizer.step()
            total_loss += loss.item()
        print(f'Max grad norm: {max_grad_norm}')
        scheduler.step()
        # wandb.log({f'train/loss_{acquire_step}': total_loss})
    return model

timestep = (traj.shape[1] - 1) // (nt - 1) # 10
Y = traj[:,0::timestep]
train_nts = torch.ones(Y.shape[0], device=device, dtype=torch.int64)
train_nts[:initial_datasize] = nt

model = train(Y, train_nts)

from utils import torch_expand
from eval_utils import compute_metrics

test_Y = traj[initial_datasize:1000, 0::timestep]

model.eval()

preds = []
preds.append(test_Y[:,0:1])

with torch.no_grad():
    for t in range(nt-1):
        X = preds[-1].to(device)
        pred = model(X) # batch x 1 x 256
        preds.append(pred.cpu())

preds = torch.cat(preds, dim=1) # batch x nt x 256


metrics = compute_metrics(test_Y, preds, d=2, device=device, reduction=True)

print(metrics)

for scale in [1.0]:
    print(scale)

    with h5py.File(cfg.dataset.train_path, 'r') as f:
        # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:10000, :131], dtype=torch.float32, device=cfg.device)
        traj = torch.tensor(f['train']['pde_140-256'][:1000, :131], dtype=torch.float32)

    mean = traj[:32].mean()
    std = traj[:32].std()
    print(f'Mean: {mean}, Std: {std}')
    traj = (traj - mean) / std * scale

    from neuralop.models import FNO

    unrolling = cfg.train.unrolling
    nt = cfg.nt
    ensemble_size = cfg.ensemble_size
    num_acquire = cfg.num_acquire
    device = cfg.device
    epochs = cfg.train.epochs
    lr = cfg.train.lr
    batch_size = cfg.train.batch_size
    initial_datasize = cfg.initial_datasize


    def train(Y, train_nts, **kwargs):
        model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                    in_channels=1, out_channels=1)

        model = model.to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
        criterion = torch.nn.MSELoss()

        inputs = []
        outputs = []
        for b in range(Y.shape[0]):
            for t in range(train_nts[b].item()-1):
                inputs.append(Y[b,t])
                outputs.append(Y[b, t+1])
        inputs = torch.stack(inputs, dim=0).unsqueeze(1)
        outputs = torch.stack(outputs, dim=0).unsqueeze(1)

        dataset = torch.utils.data.TensorDataset(inputs, outputs)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        model.train()
        for epoch in tqdm(range(epochs)):
            model.train()
            # max_unrolling = epoch if epoch <= unrolling else unrolling
            # unrolling_list = [r for r in range(max_unrolling + 1)]

            max_grad_norm = 0
            total_loss = 0
            for x, y in dataloader:
                optimizer.zero_grad()
                x, y = x.to(device), y.to(device)
                
                pred = model(x)
                loss = criterion(pred, y)

                # loss = torch.sqrt(loss)
                loss.backward()

                current_grad_norm = get_gradient_norm(model)
                max_grad_norm = max(max_grad_norm, current_grad_norm)

                optimizer.step()
                total_loss += loss.item()
            print(f'Max grad norm: {max_grad_norm}')
            scheduler.step()
            # wandb.log({f'train/loss_{acquire_step}': total_loss})
        return model
    
    timestep = (traj.shape[1] - 1) // (nt - 1) # 10
    Y = traj[:,0::timestep]
    train_nts = torch.ones(Y.shape[0], device=device, dtype=torch.int64)
    train_nts[:initial_datasize] = nt

    model = train(Y, train_nts)

    from utils import torch_expand
    from eval_utils import compute_metrics

    test_Y = traj[initial_datasize:1000, 0::timestep]

    model.eval()

    preds = []
    preds.append(test_Y[:,0:1])

    with torch.no_grad():
        for t in range(nt-1):
            X = preds[-1].to(device)
            pred = model(X) # batch x 1 x 256
            preds.append(pred.cpu())

    preds = torch.cat(preds, dim=1) # batch x nt x 256


    metrics = compute_metrics(test_Y, preds, d=2, device=device, reduction=True)

    print(metrics)

# print('max-min scaling')

# with h5py.File(cfg.dataset.train_path, 'r') as f:
#     # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:10000, :131], dtype=torch.float32, device=cfg.device)
#     traj = torch.tensor(f['train']['pde_140-256'][:1000, :131], dtype=torch.float32)

# max = traj.max()
# min = traj.min()
# print(f'Max: {max}, Min: {min}')
# traj = (traj - (max+min/2)) / (max - min)

# from neuralop.models import FNO

# unrolling = cfg.train.unrolling
# nt = cfg.nt
# ensemble_size = cfg.ensemble_size
# num_acquire = cfg.num_acquire
# device = cfg.device
# epochs = cfg.train.epochs
# lr = cfg.train.lr
# batch_size = cfg.train.batch_size
# initial_datasize = cfg.initial_datasize

# def train(Y, train_nts, **kwargs):
#     model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
#                 in_channels=1, out_channels=1)

#     model = model.to(device)

#     optimizer = torch.optim.Adam(model.parameters(), lr=lr)
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
#     criterion = torch.nn.MSELoss()

#     inputs = []
#     outputs = []
#     for b in range(Y.shape[0]):
#         for t in range(train_nts[b].item()-1):
#             inputs.append(Y[b,t])
#             outputs.append(Y[b, t+1])
#     inputs = torch.stack(inputs, dim=0).unsqueeze(1)
#     outputs = torch.stack(outputs, dim=0).unsqueeze(1)

#     dataset = torch.utils.data.TensorDataset(inputs, outputs)
#     dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

#     model.train()
#     for epoch in tqdm(range(epochs)):
#         model.train()
#         # max_unrolling = epoch if epoch <= unrolling else unrolling
#         # unrolling_list = [r for r in range(max_unrolling + 1)]

#         total_loss = 0
#         for x, y in dataloader:
#             optimizer.zero_grad()
#             x, y = x.to(device), y.to(device)
            
#             pred = model(x)
#             loss = criterion(pred, y)

#             # loss = torch.sqrt(loss)
#             loss.backward()
#             optimizer.step()
#             total_loss += loss.item()
#         scheduler.step()
#         # wandb.log({f'train/loss_{acquire_step}': total_loss})
#     return model

# timestep = (traj.shape[1] - 1) // (nt - 1) # 10
# Y = traj[:,0::timestep]
# train_nts = torch.ones(Y.shape[0], device=device, dtype=torch.int64)
# train_nts[:initial_datasize] = nt

# model = train(Y, train_nts)

# from utils import torch_expand
# from eval_utils import compute_metrics

# test_Y = traj[initial_datasize:1000, 0::timestep]

# model.eval()

# preds = []
# preds.append(test_Y[:,0:1])

# with torch.no_grad():
#     for t in range(nt-1):
#         X = preds[-1].to(device)
#         pred = model(X) # batch x 1 x 256
#         preds.append(pred.cpu())

# preds = torch.cat(preds, dim=1) # batch x nt x 256


# metrics = compute_metrics(test_Y, preds, d=2, device=device, reduction=True)

# print(metrics)

Mean: 0, Std: 1


  1%|          | 1/100 [00:08<14:18,  8.67s/it]

Max grad norm: 1.9372691859292854


  2%|▏         | 2/100 [00:13<10:38,  6.51s/it]

Max grad norm: 0.01978146744480674


  3%|▎         | 3/100 [00:18<09:29,  5.87s/it]

Max grad norm: 0.014984136894620944


  4%|▍         | 4/100 [00:23<08:49,  5.51s/it]

Max grad norm: 0.05591404830886082


  5%|▌         | 5/100 [00:28<08:26,  5.33s/it]

Max grad norm: 0.05423826927941066


  6%|▌         | 6/100 [00:33<08:04,  5.15s/it]

Max grad norm: 0.214407998725462


  7%|▋         | 7/100 [00:38<07:56,  5.13s/it]

Max grad norm: 0.06653102346369237


  8%|▊         | 8/100 [00:43<07:53,  5.15s/it]

Max grad norm: 0.48311620471868616


  9%|▉         | 9/100 [00:48<07:46,  5.12s/it]

Max grad norm: 0.045839499819597535


 10%|█         | 10/100 [00:53<07:38,  5.09s/it]

Max grad norm: 0.042863223718599686


 11%|█         | 11/100 [00:59<07:33,  5.10s/it]

Max grad norm: 0.031956676202047646


 12%|█▏        | 12/100 [01:03<07:21,  5.02s/it]

Max grad norm: 0.07561787700143595


 13%|█▎        | 13/100 [01:08<07:18,  5.04s/it]

Max grad norm: 0.135049171632864


 14%|█▍        | 14/100 [01:13<07:12,  5.03s/it]

Max grad norm: 0.0888101835038506


 15%|█▌        | 15/100 [01:19<07:08,  5.05s/it]

Max grad norm: 0.16501739697052753


 16%|█▌        | 16/100 [01:23<07:00,  5.01s/it]

Max grad norm: 0.04480373683599405


 17%|█▋        | 17/100 [01:29<06:58,  5.04s/it]

Max grad norm: 0.08452963510006609


 18%|█▊        | 18/100 [01:34<06:51,  5.02s/it]

Max grad norm: 0.27993157279962977


 19%|█▉        | 19/100 [01:39<06:46,  5.02s/it]

Max grad norm: 0.030856712817134497


 20%|██        | 20/100 [01:44<06:44,  5.06s/it]

Max grad norm: 0.37339402292769763


 21%|██        | 21/100 [01:49<06:39,  5.06s/it]

Max grad norm: 0.06765344037179168


 22%|██▏       | 22/100 [01:54<06:34,  5.05s/it]

Max grad norm: 0.019594754519276465


 23%|██▎       | 23/100 [01:59<06:30,  5.07s/it]

Max grad norm: 0.017967959338338394


 24%|██▍       | 24/100 [02:04<06:22,  5.04s/it]

Max grad norm: 0.03556798735702599


 25%|██▌       | 25/100 [02:09<06:16,  5.02s/it]

Max grad norm: 0.15029430050443215


 26%|██▌       | 26/100 [02:14<06:05,  4.93s/it]

Max grad norm: 0.030058526453031375


 27%|██▋       | 27/100 [02:19<06:08,  5.04s/it]

Max grad norm: 0.11841083543583052


 28%|██▊       | 28/100 [02:24<06:04,  5.07s/it]

Max grad norm: 0.29175408345230486


 29%|██▉       | 29/100 [02:29<06:02,  5.11s/it]

Max grad norm: 0.02367720015832825


 30%|███       | 30/100 [02:34<05:56,  5.10s/it]

Max grad norm: 0.032469536884729436


 31%|███       | 31/100 [02:39<05:52,  5.11s/it]

Max grad norm: 0.05292426155414894


 32%|███▏      | 32/100 [02:45<05:50,  5.16s/it]

Max grad norm: 0.072271883157224


 33%|███▎      | 33/100 [02:50<05:45,  5.15s/it]

Max grad norm: 0.06223802248204214


 34%|███▍      | 34/100 [02:55<05:40,  5.16s/it]

Max grad norm: 0.056187681992179876


 35%|███▌      | 35/100 [03:00<05:33,  5.13s/it]

Max grad norm: 0.12317293371787347


 36%|███▌      | 36/100 [03:05<05:21,  5.02s/it]

Max grad norm: 0.06251316852164805


 37%|███▋      | 37/100 [03:19<08:07,  7.74s/it]

Max grad norm: 0.02441159267209756


 38%|███▊      | 38/100 [03:35<10:38, 10.30s/it]

Max grad norm: 0.06557666537413327


 39%|███▉      | 39/100 [03:45<10:20, 10.17s/it]

Max grad norm: 0.04791944108551242


 40%|████      | 40/100 [03:53<09:29,  9.50s/it]

Max grad norm: 0.10340184293260146


 41%|████      | 41/100 [04:00<08:44,  8.89s/it]

Max grad norm: 0.029807421947951215


 42%|████▏     | 42/100 [04:08<08:13,  8.51s/it]

Max grad norm: 0.02564537751695607


 43%|████▎     | 43/100 [04:17<08:18,  8.75s/it]

Max grad norm: 0.07585774947957385


 44%|████▍     | 44/100 [04:29<09:01,  9.67s/it]

Max grad norm: 0.026209226932724122


 45%|████▌     | 45/100 [04:34<07:33,  8.25s/it]

Max grad norm: 0.08056509358099304


 46%|████▌     | 46/100 [04:39<06:29,  7.21s/it]

Max grad norm: 0.025071134990912493


 47%|████▋     | 47/100 [04:44<05:45,  6.52s/it]

Max grad norm: 0.0844296026778149


 48%|████▊     | 48/100 [04:50<05:26,  6.28s/it]

Max grad norm: 0.0565445702516758


 49%|████▉     | 49/100 [04:55<04:59,  5.87s/it]

Max grad norm: 0.015854818184651297


 50%|█████     | 50/100 [05:00<04:43,  5.67s/it]

Max grad norm: 0.030553735138948282


 51%|█████     | 51/100 [05:06<04:48,  5.89s/it]

Max grad norm: 0.024250711123398636


 52%|█████▏    | 52/100 [05:12<04:49,  6.04s/it]

Max grad norm: 0.07918268519686504


 53%|█████▎    | 53/100 [05:18<04:35,  5.87s/it]

Max grad norm: 0.01513739905273436


 54%|█████▍    | 54/100 [05:24<04:28,  5.83s/it]

Max grad norm: 0.016715269024706925


 55%|█████▌    | 55/100 [05:29<04:19,  5.77s/it]

Max grad norm: 0.015077278676680655


 56%|█████▌    | 56/100 [05:35<04:10,  5.70s/it]

Max grad norm: 0.08487657438263696


 57%|█████▋    | 57/100 [05:41<04:05,  5.71s/it]

Max grad norm: 0.007269667390216463


 58%|█████▊    | 58/100 [05:46<03:58,  5.68s/it]

Max grad norm: 0.033310663809657554


 59%|█████▉    | 59/100 [05:52<03:50,  5.63s/it]

Max grad norm: 0.04218976527179012


 60%|██████    | 60/100 [05:57<03:44,  5.61s/it]

Max grad norm: 0.02387364565617268


 61%|██████    | 61/100 [06:03<03:37,  5.59s/it]

Max grad norm: 0.01578277226295186


 62%|██████▏   | 62/100 [06:08<03:32,  5.59s/it]

Max grad norm: 0.06137846642674985


 63%|██████▎   | 63/100 [06:14<03:26,  5.59s/it]

Max grad norm: 0.01485570430023141


 64%|██████▍   | 64/100 [06:20<03:21,  5.60s/it]

Max grad norm: 0.00860958616391531


 65%|██████▌   | 65/100 [06:25<03:15,  5.59s/it]

Max grad norm: 0.026104374761376233


 66%|██████▌   | 66/100 [06:31<03:08,  5.55s/it]

Max grad norm: 0.04175770431744081


 67%|██████▋   | 67/100 [06:36<03:04,  5.58s/it]

Max grad norm: 0.035786214656860625


 68%|██████▊   | 68/100 [06:42<02:57,  5.56s/it]

Max grad norm: 0.008364922512046544


 69%|██████▉   | 69/100 [06:47<02:52,  5.56s/it]

Max grad norm: 0.019579685496106267


 70%|███████   | 70/100 [06:53<02:46,  5.55s/it]

Max grad norm: 0.00848364562331844


 71%|███████   | 71/100 [06:59<02:43,  5.64s/it]

Max grad norm: 0.011675281646013565


 72%|███████▏  | 72/100 [07:04<02:38,  5.65s/it]

Max grad norm: 0.009690139984520825


 73%|███████▎  | 73/100 [07:10<02:31,  5.63s/it]

Max grad norm: 0.019400625968165223


 74%|███████▍  | 74/100 [07:16<02:25,  5.61s/it]

Max grad norm: 0.02562038583822973


 75%|███████▌  | 75/100 [07:23<02:34,  6.19s/it]

Max grad norm: 0.004684328217505104


 76%|███████▌  | 76/100 [07:31<02:40,  6.71s/it]

Max grad norm: 0.0051345978886867824


 77%|███████▋  | 77/100 [07:38<02:39,  6.93s/it]

Max grad norm: 0.008267909147896005


 78%|███████▊  | 78/100 [07:45<02:32,  6.93s/it]

Max grad norm: 0.006767265093877148


 79%|███████▉  | 79/100 [07:50<02:11,  6.28s/it]

Max grad norm: 0.007409679919744401


 80%|████████  | 80/100 [07:55<01:57,  5.85s/it]

Max grad norm: 0.013434458206322624


 81%|████████  | 81/100 [08:00<01:46,  5.63s/it]

Max grad norm: 0.008594680375619137


 82%|████████▏ | 82/100 [08:05<01:39,  5.51s/it]

Max grad norm: 0.0033638477128760114


 83%|████████▎ | 83/100 [08:10<01:30,  5.34s/it]

Max grad norm: 0.0026143614659517337


 84%|████████▍ | 84/100 [08:16<01:28,  5.56s/it]

Max grad norm: 0.005132322380641839


 85%|████████▌ | 85/100 [08:35<02:24,  9.61s/it]

Max grad norm: 0.0033263937989476584


 86%|████████▌ | 86/100 [08:54<02:54, 12.44s/it]

Max grad norm: 0.001837289739669713


 87%|████████▋ | 87/100 [09:14<03:09, 14.54s/it]

Max grad norm: 0.002112024898393931


 88%|████████▊ | 88/100 [09:26<02:44, 13.72s/it]

Max grad norm: 0.002735176515699323


 89%|████████▉ | 89/100 [09:31<02:03, 11.18s/it]

Max grad norm: 0.001607826539267177


 90%|█████████ | 90/100 [09:36<01:33,  9.32s/it]

Max grad norm: 0.0016242750191573494


 91%|█████████ | 91/100 [09:41<01:12,  8.05s/it]

Max grad norm: 0.0013669851218710584


 92%|█████████▏| 92/100 [09:46<00:57,  7.16s/it]

Max grad norm: 0.0014540473715854642


 93%|█████████▎| 93/100 [09:51<00:45,  6.50s/it]

Max grad norm: 0.001132437593780564


 94%|█████████▍| 94/100 [09:56<00:36,  6.08s/it]

Max grad norm: 0.0007836810047572255


 95%|█████████▌| 95/100 [10:01<00:28,  5.74s/it]

Max grad norm: 0.0007409457363985331


 96%|█████████▌| 96/100 [10:06<00:22,  5.56s/it]

Max grad norm: 0.0007060764841936179


 97%|█████████▋| 97/100 [10:11<00:16,  5.44s/it]

Max grad norm: 0.0005933232733243088


 98%|█████████▊| 98/100 [10:16<00:10,  5.21s/it]

Max grad norm: 0.00043793743390348704


 99%|█████████▉| 99/100 [10:21<00:05,  5.13s/it]

Max grad norm: 0.00042571819294049746


100%|██████████| 100/100 [10:26<00:00,  6.27s/it]

Max grad norm: 0.0003505846233313657





(tensor(0.1872), tensor(0.1613), tensor(2295.6729))
1.0
Mean: 1.1197198723778001e-10, Std: 0.6562381386756897


  1%|          | 1/100 [00:05<08:55,  5.40s/it]

Max grad norm: 4.256157914510783


  2%|▏         | 2/100 [00:10<08:32,  5.23s/it]

Max grad norm: 0.04887448605252577


  3%|▎         | 3/100 [00:15<08:11,  5.07s/it]

Max grad norm: 0.04748922185507854


  4%|▍         | 4/100 [00:24<10:53,  6.81s/it]

Max grad norm: 0.08701323253373265


  5%|▌         | 5/100 [00:42<16:42, 10.56s/it]

Max grad norm: 0.29616439623860086


  6%|▌         | 6/100 [00:56<18:37, 11.89s/it]

Max grad norm: 0.4530262841811554


  7%|▋         | 7/100 [01:04<16:38, 10.74s/it]

Max grad norm: 0.10379023430463237


  8%|▊         | 8/100 [01:12<14:58,  9.76s/it]

Max grad norm: 1.1280209147383236


  9%|▉         | 9/100 [01:20<13:44,  9.07s/it]

Max grad norm: 0.08108474282219884


 10%|█         | 10/100 [01:27<12:50,  8.56s/it]

Max grad norm: 0.07616670913054344


 11%|█         | 11/100 [01:36<12:57,  8.74s/it]

Max grad norm: 0.18936788709435357


 12%|█▏        | 12/100 [01:41<11:09,  7.60s/it]

Max grad norm: 0.35742962742152273


 13%|█▎        | 13/100 [01:46<09:49,  6.77s/it]

Max grad norm: 0.3602948899508684


 14%|█▍        | 14/100 [01:51<08:45,  6.12s/it]

Max grad norm: 0.13409822968748958


 15%|█▌        | 15/100 [01:56<08:08,  5.74s/it]

Max grad norm: 0.4444538119444055


 16%|█▌        | 16/100 [02:00<07:37,  5.45s/it]

Max grad norm: 0.10036447678068007


 17%|█▋        | 17/100 [02:07<08:12,  5.93s/it]

Max grad norm: 0.45236968414130735


 18%|█▊        | 18/100 [02:13<07:56,  5.81s/it]

Max grad norm: 0.4490207449460348


 19%|█▉        | 19/100 [02:18<07:34,  5.61s/it]

Max grad norm: 0.07897630425881827


 20%|██        | 20/100 [02:23<07:18,  5.48s/it]

Max grad norm: 0.40705674912011125


 21%|██        | 21/100 [02:28<07:02,  5.35s/it]

Max grad norm: 0.8284339245988158


 22%|██▏       | 22/100 [02:33<06:53,  5.30s/it]

Max grad norm: 0.20115067092627473


 23%|██▎       | 23/100 [02:38<06:41,  5.21s/it]

Max grad norm: 0.06819468787028066


 24%|██▍       | 24/100 [02:44<06:32,  5.17s/it]

Max grad norm: 0.07366759952124406


 25%|██▌       | 25/100 [02:49<06:28,  5.19s/it]

Max grad norm: 0.06144652717207586


 26%|██▌       | 26/100 [02:54<06:21,  5.16s/it]

Max grad norm: 0.4629589064613563


 27%|██▋       | 27/100 [02:59<06:10,  5.08s/it]

Max grad norm: 0.4610233129800665


 28%|██▊       | 28/100 [03:04<06:08,  5.11s/it]

Max grad norm: 0.07784615933959181


 29%|██▉       | 29/100 [03:09<05:57,  5.03s/it]

Max grad norm: 0.05764960874934168


 30%|███       | 30/100 [03:14<05:54,  5.06s/it]

Max grad norm: 0.38827768675407354


 31%|███       | 31/100 [03:19<05:48,  5.05s/it]

Max grad norm: 0.060656654655682894


 32%|███▏      | 32/100 [03:24<05:47,  5.11s/it]

Max grad norm: 0.04516944593862201


 33%|███▎      | 33/100 [03:29<05:39,  5.06s/it]

Max grad norm: 0.08388497115492616


 34%|███▍      | 34/100 [03:34<05:35,  5.09s/it]

Max grad norm: 0.12709243068420967


 35%|███▌      | 35/100 [03:39<05:32,  5.11s/it]

Max grad norm: 0.43460483684064893


 36%|███▌      | 36/100 [03:45<05:29,  5.14s/it]

Max grad norm: 0.06473773531478028


 37%|███▋      | 37/100 [03:50<05:26,  5.18s/it]

Max grad norm: 0.04759032826321938


 38%|███▊      | 38/100 [03:55<05:17,  5.12s/it]

Max grad norm: 0.06084918105309003


 39%|███▉      | 39/100 [04:06<07:07,  7.01s/it]

Max grad norm: 0.1510607614453836


 40%|████      | 40/100 [04:19<08:38,  8.64s/it]

Max grad norm: 0.16525704288220147


 41%|████      | 41/100 [04:28<08:34,  8.73s/it]

Max grad norm: 0.10628991108345426


 42%|████▏     | 42/100 [04:41<09:51, 10.20s/it]

Max grad norm: 0.21619512376600447


 43%|████▎     | 43/100 [04:49<08:54,  9.37s/it]

Max grad norm: 0.05065167697459052


 44%|████▍     | 44/100 [04:54<07:31,  8.06s/it]

Max grad norm: 0.041640625736975925


 45%|████▌     | 45/100 [04:59<06:34,  7.17s/it]

Max grad norm: 0.24138328203161577


 46%|████▌     | 46/100 [05:04<05:55,  6.59s/it]

Max grad norm: 0.10743696896283357


 47%|████▋     | 47/100 [05:09<05:25,  6.15s/it]

Max grad norm: 0.05670951186746934


 48%|████▊     | 48/100 [05:14<05:03,  5.84s/it]

Max grad norm: 0.12525440909074828


 49%|████▉     | 49/100 [05:20<04:48,  5.65s/it]

Max grad norm: 0.10584748344604546


 50%|█████     | 50/100 [05:25<04:35,  5.51s/it]

Max grad norm: 0.2562857377116871


 51%|█████     | 51/100 [05:30<04:23,  5.37s/it]

Max grad norm: 0.1343118261254936


 52%|█████▏    | 52/100 [05:35<04:14,  5.31s/it]

Max grad norm: 0.03258662965451515


 53%|█████▎    | 53/100 [05:40<04:08,  5.28s/it]

Max grad norm: 0.047811346424837406


 54%|█████▍    | 54/100 [05:45<04:00,  5.22s/it]

Max grad norm: 0.09637252012310453


 55%|█████▌    | 55/100 [05:50<03:53,  5.19s/it]

Max grad norm: 0.08664162161834389


 56%|█████▌    | 56/100 [05:56<03:48,  5.18s/it]

Max grad norm: 0.13171159053935624


 57%|█████▋    | 57/100 [06:01<03:41,  5.15s/it]

Max grad norm: 0.11761097239210437


 58%|█████▊    | 58/100 [06:05<03:32,  5.07s/it]

Max grad norm: 0.032710600989275265


 59%|█████▉    | 59/100 [06:10<03:26,  5.03s/it]

Max grad norm: 0.09813424884250205


 60%|██████    | 60/100 [06:16<03:22,  5.06s/it]

Max grad norm: 0.07628071482700584


 61%|██████    | 61/100 [06:21<03:16,  5.05s/it]

Max grad norm: 0.06114265948541582


 62%|██████▏   | 62/100 [06:26<03:14,  5.12s/it]

Max grad norm: 0.05608935165583979


 63%|██████▎   | 63/100 [06:31<03:09,  5.13s/it]

Max grad norm: 0.03612451196551061


 64%|██████▍   | 64/100 [06:36<03:04,  5.12s/it]

Max grad norm: 0.082937061415636


 65%|██████▌   | 65/100 [06:41<03:00,  5.15s/it]

Max grad norm: 0.04739477992967422


 66%|██████▌   | 66/100 [06:47<02:55,  5.16s/it]

Max grad norm: 0.03435327030802991


 67%|██████▋   | 67/100 [06:52<02:50,  5.16s/it]

Max grad norm: 0.05405415086905847


 68%|██████▊   | 68/100 [06:57<02:44,  5.15s/it]

Max grad norm: 0.06478615914186633


 69%|██████▉   | 69/100 [07:02<02:42,  5.25s/it]

Max grad norm: 0.012307267136977995


 70%|███████   | 70/100 [07:17<04:04,  8.13s/it]

Max grad norm: 0.04069130571516368


 71%|███████   | 71/100 [07:25<03:56,  8.15s/it]

Max grad norm: 0.036127104522375324


 72%|███████▏  | 72/100 [07:36<04:11,  8.99s/it]

Max grad norm: 0.04743086507143578


 73%|███████▎  | 73/100 [07:48<04:21,  9.68s/it]

Max grad norm: 0.054633215360061646


 74%|███████▍  | 74/100 [07:53<03:35,  8.29s/it]

Max grad norm: 0.021738305569319688


 75%|███████▌  | 75/100 [07:58<03:04,  7.40s/it]

Max grad norm: 0.00848381334778691


 76%|███████▌  | 76/100 [08:03<02:40,  6.69s/it]

Max grad norm: 0.040986104248217435


 77%|███████▋  | 77/100 [08:08<02:22,  6.20s/it]

Max grad norm: 0.015998665900479467


 78%|███████▊  | 78/100 [08:13<02:07,  5.81s/it]

Max grad norm: 0.02036039789441381


 79%|███████▉  | 79/100 [08:18<01:57,  5.60s/it]

Max grad norm: 0.014133519510734046


 80%|████████  | 80/100 [08:23<01:48,  5.44s/it]

Max grad norm: 0.03798546728793014


 81%|████████  | 81/100 [08:28<01:41,  5.33s/it]

Max grad norm: 0.007864420957177641


 82%|████████▏ | 82/100 [08:33<01:34,  5.27s/it]

Max grad norm: 0.0072016260998511465


 83%|████████▎ | 83/100 [08:38<01:28,  5.20s/it]

Max grad norm: 0.005742670259316333


 84%|████████▍ | 84/100 [08:44<01:23,  5.20s/it]

Max grad norm: 0.007852563873017315


 85%|████████▌ | 85/100 [08:49<01:17,  5.19s/it]

Max grad norm: 0.010744418844424794


 86%|████████▌ | 86/100 [08:54<01:12,  5.21s/it]

Max grad norm: 0.0051002879000330474


 87%|████████▋ | 87/100 [08:59<01:06,  5.12s/it]

Max grad norm: 0.004964074669155993


 88%|████████▊ | 88/100 [09:04<01:00,  5.05s/it]

Max grad norm: 0.0071591086645177715


 89%|████████▉ | 89/100 [09:09<00:55,  5.03s/it]

Max grad norm: 0.005255975712672149


 90%|█████████ | 90/100 [09:14<00:50,  5.01s/it]

Max grad norm: 0.004734929264998897


 91%|█████████ | 91/100 [09:19<00:45,  5.04s/it]

Max grad norm: 0.0041228381865023814


 92%|█████████▏| 92/100 [09:24<00:40,  5.11s/it]

Max grad norm: 0.0030328419874379084


 93%|█████████▎| 93/100 [09:29<00:35,  5.09s/it]

Max grad norm: 0.0028291909194047967


 94%|█████████▍| 94/100 [09:34<00:30,  5.13s/it]

Max grad norm: 0.0028004234499870943


 95%|█████████▌| 95/100 [09:39<00:25,  5.06s/it]

Max grad norm: 0.0018222410894377463


 96%|█████████▌| 96/100 [09:44<00:20,  5.04s/it]

Max grad norm: 0.0018118090309597475


 97%|█████████▋| 97/100 [09:52<00:17,  5.94s/it]

Max grad norm: 0.0014224401410093538


 98%|█████████▊| 98/100 [09:58<00:11,  5.74s/it]

Max grad norm: 0.0010543434166877114


 99%|█████████▉| 99/100 [10:03<00:05,  5.56s/it]

Max grad norm: 0.0009245817269142255


100%|██████████| 100/100 [10:08<00:00,  6.08s/it]

Max grad norm: 0.0006725323826233204





(tensor(0.3570), tensor(0.2066), tensor(7039.2114))
