## Pushforward

In [1]:
import h5py
import torch
import numpy as np
import matplotlib.pyplot as plt

# with h5py.File(f'data_large/Burgers_train_100000_default.h5', 'r') as f:
#     # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:10000, :131], dtype=torch.float32, device=cfg.device)
#     traj = torch.tensor(f['train']['pde_140-256'][:1000, :131], dtype=torch.float32)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
import hydra
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm

hydra.initialize(config_path="cfg_time_batch", version_base=None)
cfg = hydra.compose(config_name="config", overrides=["task=KdV", "nt=131", "initial_datasize=64"])

from utils import set_seed
set_seed(100)

In [3]:
from itertools import islice
from utils import torch_expand
from eval_utils import compute_metrics

with h5py.File(cfg.dataset.train_path, 'r') as f:
    # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:10000, :131], dtype=torch.float32, device=cfg.device)
    traj = torch.tensor(f['train']['pde_140-256'][:1000, :131], dtype=torch.float32)


mean = traj[:32].mean()
std = traj[:32].std()
print(f'Mean: {mean}, Std: {std}')
traj = (traj - mean) / std


# max = traj.max()
# min = traj.min()
# print(f'Max: {max}, Min: {min}')
# traj = (traj - (max+min/2)) / (max - min)

from neuralop.models import FNO

nt = cfg.nt
ensemble_size = cfg.ensemble_size
num_acquire = cfg.num_acquire
device = cfg.device
epochs = cfg.train.epochs
lr = cfg.train.lr
batch_size = cfg.train.batch_size
initial_datasize = cfg.initial_datasize

def train(Y, train_nts, **kwargs):
    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    inputs = []
    outputs = []
    for b in range(Y.shape[0]):
        for t in range(train_nts[b].item()-1):
            inputs.append(Y[b,t])
            outputs.append(Y[b, t+1])
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)

    dataset = torch.utils.data.TensorDataset(inputs, outputs)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model.train()
    for epoch in tqdm(range(epochs)):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        for x, y in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        # wandb.log({f'train/loss_{acquire_step}': total_loss})
    return model



def train_gaussian(Y, train_nts, noise_std=0.01, **kwargs):
    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    inputs = []
    outputs = []
    for b in range(Y.shape[0]):
        for t in range(train_nts[b].item()-1):
            inputs.append(Y[b,t])
            outputs.append(Y[b, t+1])
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)

    dataset = torch.utils.data.TensorDataset(inputs, outputs)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model.train()
    for epoch in tqdm(range(epochs)):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        for x, y in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                x = x + noise_std * torch.randn_like(x)
            
            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        # wandb.log({f'train/loss_{acquire_step}': total_loss})
    return model

def evaluate():
    test_Y = traj[initial_datasize:1000, 0::timestep]

    model.eval()

    preds = []
    preds.append(test_Y[:,0:1])

    with torch.no_grad():
        for t in range(nt-1):
            X = preds[-1].to(device)
            pred = model(X) # batch x 1 x 256
            preds.append(pred.cpu())

    preds = torch.cat(preds, dim=1) # batch x nt x 256


    metrics = compute_metrics(test_Y, preds, d=2, device=device, reduction=True)

    print(metrics)

timestep = (traj.shape[1] - 1) // (nt - 1) # 10
Y = traj[:,0::timestep]
train_nts = torch.ones(Y.shape[0], device=device, dtype=torch.int64)
train_nts[:initial_datasize] = nt

# model = train(Y, train_nts)
# evaluate()


Mean: 1.1197198723778001e-10, Std: 0.6562381386756897


In [5]:


def train_pushforward(Y, train_nts, unrolling, **kwargs):
    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    datasize = (train_nts-1).sum().item()
    average_iter = datasize // batch_size
    

    inputs = []
    outputs = []
    unroll = []
    for b in range(Y.shape[0]):
        for t in range(train_nts[b].item()-1):
            for r in range(unrolling):
                if t+(r+1) < train_nts[b].item():
                    inputs.append(Y[b,t])
                    outputs.append(Y[b, t+(r+1)])
                    unroll.append(r)
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)
    unroll = torch.tensor(unroll)

    dataset_0 = torch.utils.data.TensorDataset(inputs[unroll==0], outputs[unroll==0], unroll[unroll==0])
    dataloader_0 = torch.utils.data.DataLoader(dataset_0, batch_size=batch_size, shuffle=True)

    dataset_1 = torch.utils.data.TensorDataset(inputs, outputs, unroll)
    dataloader_1 = torch.utils.data.DataLoader(dataset_1, batch_size=batch_size, shuffle=True)

    model.train()
    for epoch in tqdm(range(epochs)):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        # for x, y, unroll in islice(dataloader, 0, average_iter):
        dataloader = dataloader_0 if epoch < 10 else dataloader_1
        for x, y, unroll in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            for _ in range(unroll.max()):
                with torch.no_grad():
                    x[unroll > 0] = model(x[unroll > 0])
                    unroll[unroll > 0] -= 1

            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        # wandb.log({f'train/loss_{acquire_step}': total_loss})
    return model

model = train_pushforward(Y, train_nts, 1)
evaluate()


100%|██████████| 100/100 [08:20<00:00,  5.00s/it]


(tensor(0.3103), tensor(0.1765), tensor(6190.8936))


In [6]:


def train_pushforward(Y, train_nts, unrolling, **kwargs):
    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    datasize = (train_nts-1).sum().item()
    average_iter = datasize // batch_size
    

    inputs = []
    outputs = []
    unroll = []
    for b in range(Y.shape[0]):
        for t in range(train_nts[b].item()-1):
            for r in range(unrolling):
                if t+(r+1) < train_nts[b].item():
                    inputs.append(Y[b,t])
                    outputs.append(Y[b, t+(r+1)])
                    unroll.append(r)
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)
    unroll = torch.tensor(unroll)

    dataset_0 = torch.utils.data.TensorDataset(inputs[unroll==0], outputs[unroll==0], unroll[unroll==0])
    dataloader_0 = torch.utils.data.DataLoader(dataset_0, batch_size=batch_size, shuffle=True)

    dataset_1 = torch.utils.data.TensorDataset(inputs, outputs, unroll)
    dataloader_1 = torch.utils.data.DataLoader(dataset_1, batch_size=batch_size, shuffle=True)

    model.train()
    for epoch in tqdm(range(epochs)):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        # for x, y, unroll in islice(dataloader, 0, average_iter):
        dataloader = dataloader_0 if epoch < 3 else dataloader_1
        for x, y, unroll in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            for _ in range(unroll.max()):
                with torch.no_grad():
                    x[unroll > 0] = model(x[unroll > 0])
                    unroll[unroll > 0] -= 1

            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        # wandb.log({f'train/loss_{acquire_step}': total_loss})
    return model

model = train_pushforward(Y, train_nts, 1)
evaluate()


100%|██████████| 100/100 [08:29<00:00,  5.09s/it]


(tensor(0.3062), tensor(0.1744), tensor(5336.4575))


In [7]:


def train_pushforward(Y, train_nts, unrolling, **kwargs):
    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    datasize = (train_nts-1).sum().item()
    average_iter = datasize // batch_size
    

    inputs = []
    outputs = []
    unroll = []
    for b in range(Y.shape[0]):
        for t in range(train_nts[b].item()-1):
            for r in range(unrolling):
                if t+(r+1) < train_nts[b].item():
                    inputs.append(Y[b,t])
                    outputs.append(Y[b, t+(r+1)])
                    unroll.append(r)
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)
    unroll = torch.tensor(unroll)

    dataset_0 = torch.utils.data.TensorDataset(inputs[unroll==0], outputs[unroll==0], unroll[unroll==0])
    dataloader_0 = torch.utils.data.DataLoader(dataset_0, batch_size=batch_size, shuffle=True)

    dataset_1 = torch.utils.data.TensorDataset(inputs, outputs, unroll)
    dataloader_1 = torch.utils.data.DataLoader(dataset_1, batch_size=batch_size, shuffle=True)

    model.train()
    for epoch in tqdm(range(epochs)):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        # for x, y, unroll in islice(dataloader, 0, average_iter):
        dataloader = dataloader_0 if epoch < 50 else dataloader_1
        for x, y, unroll in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            for _ in range(unroll.max()):
                with torch.no_grad():
                    x[unroll > 0] = model(x[unroll > 0])
                    unroll[unroll > 0] -= 1

            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        # wandb.log({f'train/loss_{acquire_step}': total_loss})
    return model

model = train_pushforward(Y, train_nts, 1)
evaluate()


100%|██████████| 100/100 [08:25<00:00,  5.06s/it]


(tensor(0.3025), tensor(0.1737), tensor(5067.5610))


In [4]:

def train_pushforward(Y, train_nts, unrolling, **kwargs):
    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    datasize = (train_nts-1).sum().item()
    average_iter = datasize // batch_size
    

    inputs = []
    outputs = []
    unroll = []
    for b in range(Y.shape[0]):
        for t in range(train_nts[b].item()-1):
            for r in range(unrolling):
                if t+(r+1) < train_nts[b].item():
                    inputs.append(Y[b,t])
                    outputs.append(Y[b, t+(r+1)])
                    unroll.append(r)
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)
    unroll = torch.tensor(unroll)

    dataloaders = []
    for i in range(unrolling+1):
        dataset_i = torch.utils.data.TensorDataset(inputs[unroll<=i], outputs[unroll<=i], unroll[unroll<=i])
        dataloader_i = torch.utils.data.DataLoader(dataset_i, batch_size=batch_size, shuffle=True)
        dataloaders.append(dataloader_i)

    model.train()
    for epoch in tqdm(range(epochs)):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        # for x, y, unroll in islice(dataloader, 0, average_iter):
        dataloader = dataloaders[min(epoch, unrolling)]
        for x, y, unroll in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            for _ in range(unroll.max()):
                with torch.no_grad():
                    x[unroll > 0] = model(x[unroll > 0])
                    unroll[unroll > 0] -= 1

            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        # wandb.log({f'train/loss_{acquire_step}': total_loss})
    return model

model = train_pushforward(Y, train_nts, 2)
evaluate()


100%|██████████| 100/100 [19:27<00:00, 11.67s/it]


(tensor(0.3321), tensor(0.1858), tensor(6400.8657))


In [5]:

def train_pushforward(Y, train_nts, unrolling, **kwargs):
    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    datasize = (train_nts-1).sum().item()
    average_iter = datasize // batch_size
    

    inputs = []
    outputs = []
    unroll = []
    for b in range(Y.shape[0]):
        for t in range(train_nts[b].item()-1):
            for r in range(unrolling):
                if t+(r+1) < train_nts[b].item():
                    inputs.append(Y[b,t])
                    outputs.append(Y[b, t+(r+1)])
                    unroll.append(r)
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)
    unroll = torch.tensor(unroll)

    dataloaders = []
    for i in range(unrolling+1):
        dataset_i = torch.utils.data.TensorDataset(inputs[unroll<=i], outputs[unroll<=i], unroll[unroll<=i])
        dataloader_i = torch.utils.data.DataLoader(dataset_i, batch_size=batch_size, shuffle=True)
        dataloaders.append(dataloader_i)

    model.train()
    for epoch in tqdm(range(epochs)):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        # for x, y, unroll in islice(dataloader, 0, average_iter):
        dataloader = dataloaders[min(epoch, unrolling)]
        for x, y, unroll in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            for _ in range(unroll.max()):
                with torch.no_grad():
                    x[unroll > 0] = model(x[unroll > 0])
                    unroll[unroll > 0] -= 1

            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        # wandb.log({f'train/loss_{acquire_step}': total_loss})
    return model

model = train_pushforward(Y, train_nts, 3)
evaluate()


100%|██████████| 100/100 [31:43<00:00, 19.03s/it]


(tensor(0.4109), tensor(0.2330), tensor(8768.4072))


In [6]:

def train_pushforward(Y, train_nts, unrolling, **kwargs):
    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    datasize = (train_nts-1).sum().item()
    average_iter = datasize // batch_size
    

    inputs = []
    outputs = []
    unroll = []
    for b in range(Y.shape[0]):
        for t in range(train_nts[b].item()-1):
            for r in range(unrolling):
                if t+(r+1) < train_nts[b].item():
                    inputs.append(Y[b,t])
                    outputs.append(Y[b, t+(r+1)])
                    unroll.append(r)
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)
    unroll = torch.tensor(unroll)

    dataloaders = []
    for i in range(unrolling+1):
        dataset_i = torch.utils.data.TensorDataset(inputs[unroll<=i], outputs[unroll<=i], unroll[unroll<=i])
        dataloader_i = torch.utils.data.DataLoader(dataset_i, batch_size=batch_size, shuffle=True)
        dataloaders.append(dataloader_i)

    model.train()
    for epoch in tqdm(range(epochs)):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        # for x, y, unroll in islice(dataloader, 0, average_iter):
        dataloader = dataloaders[min(epoch, unrolling)]
        for x, y, unroll in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            for _ in range(unroll.max()):
                with torch.no_grad():
                    x[unroll > 0] = model(x[unroll > 0])
                    unroll[unroll > 0] -= 1

            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        # wandb.log({f'train/loss_{acquire_step}': total_loss})
    return model

model = train_pushforward(Y, train_nts, 4)
evaluate()


100%|██████████| 100/100 [49:57<00:00, 29.98s/it]


(tensor(0.5240), tensor(0.3072), tensor(11924.1826))


In [7]:

def train_pushforward(Y, train_nts, unrolling, **kwargs):
    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    datasize = (train_nts-1).sum().item()
    average_iter = datasize // batch_size
    

    inputs = []
    outputs = []
    unroll = []
    for b in range(Y.shape[0]):
        for t in range(train_nts[b].item()-1):
            for r in range(unrolling):
                if t+(r+1) < train_nts[b].item():
                    inputs.append(Y[b,t])
                    outputs.append(Y[b, t+(r+1)])
                    unroll.append(r)
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)
    unroll = torch.tensor(unroll)

    dataloaders = []
    for i in range(unrolling+1):
        dataset_i = torch.utils.data.TensorDataset(inputs[unroll<=i], outputs[unroll<=i], unroll[unroll<=i])
        dataloader_i = torch.utils.data.DataLoader(dataset_i, batch_size=batch_size, shuffle=True)
        dataloaders.append(dataloader_i)

    model.train()
    for epoch in tqdm(range(epochs)):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        # for x, y, unroll in islice(dataloader, 0, average_iter):
        dataloader = dataloaders[min(epoch, unrolling)]
        for x, y, unroll in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            for _ in range(unroll.max()):
                with torch.no_grad():
                    x[unroll > 0] = model(x[unroll > 0])
                    unroll[unroll > 0] -= 1

            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        # wandb.log({f'train/loss_{acquire_step}': total_loss})
    return model

model = train_pushforward(Y, train_nts, 5)
evaluate()


100%|██████████| 100/100 [1:01:26<00:00, 36.86s/it]


(tensor(0.9729), tensor(0.7041), tensor(24612.1680))


In [8]:

def train_pushforward(Y, train_nts, unrolling, **kwargs):
    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    datasize = (train_nts-1).sum().item()
    average_iter = datasize // batch_size
    

    inputs = []
    outputs = []
    unroll = []
    for b in range(Y.shape[0]):
        for t in range(train_nts[b].item()-1):
            for r in range(unrolling):
                if t+(r+1) < train_nts[b].item():
                    inputs.append(Y[b,t])
                    outputs.append(Y[b, t+(r+1)])
                    unroll.append(r)
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)
    unroll = torch.tensor(unroll)

    dataloaders = []
    for i in range(unrolling+1):
        dataset_i = torch.utils.data.TensorDataset(inputs[unroll<=i], outputs[unroll<=i], unroll[unroll<=i])
        dataloader_i = torch.utils.data.DataLoader(dataset_i, batch_size=batch_size, shuffle=True)
        dataloaders.append(dataloader_i)

    model.train()
    for epoch in tqdm(range(epochs)):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        # for x, y, unroll in islice(dataloader, 0, average_iter):
        dataloader = dataloaders[min(epoch, unrolling)]
        for x, y, unroll in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            for _ in range(unroll.max()):
                with torch.no_grad():
                    x[unroll > 0] = model(x[unroll > 0])
                    unroll[unroll > 0] -= 1

            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        # wandb.log({f'train/loss_{acquire_step}': total_loss})
    return model

model = train_pushforward(Y, train_nts, 10)
evaluate()


 43%|████▎     | 43/100 [1:07:17<1:42:09, 107.53s/it]

In [None]:
a
def train_pushforward(Y, train_nts, unrolling, **kwargs):
    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    datasize = (train_nts-1).sum().item()
    average_iter = datasize // batch_size
    

    inputs = []
    outputs = []
    unroll = []
    for b in range(Y.shape[0]):
        for t in range(train_nts[b].item()-1):
            for r in range(unrolling):
                if t+(r+1) < train_nts[b].item():
                    inputs.append(Y[b,t])
                    outputs.append(Y[b, t+(r+1)])
                    unroll.append(r)
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)
    unroll = torch.tensor(unroll)

    dataloaders = []
    for i in range(unrolling+1):
        dataset_i = torch.utils.data.TensorDataset(inputs[unroll<=i], outputs[unroll<=i], unroll[unroll<=i])
        dataloader_i = torch.utils.data.DataLoader(dataset_i, batch_size=batch_size, shuffle=True)
        dataloaders.append(dataloader_i)

    model.train()
    for epoch in tqdm(range(epochs)):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        # for x, y, unroll in islice(dataloader, 0, average_iter):
        dataloader = dataloaders[min(epoch, unrolling)]
        for x, y, unroll in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            for _ in range(unroll.max()):
                with torch.no_grad():
                    x[unroll > 0] = model(x[unroll > 0])
                    unroll[unroll > 0] -= 1

            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        # wandb.log({f'train/loss_{acquire_step}': total_loss})
    return model

model = train_pushforward(Y, train_nts, 20)
evaluate()
