### Comparing same model for every time step vs different models

In [2]:
import h5py
import torch
import numpy as np
from neuralop.models import FNO
from tqdm import tqdm
import random

import argparse
import time

from eval_utils import compute_metrics

from utils import set_seed, flatten_configdict
from acquisition.acquirers import select, select_time

from omegaconf import OmegaConf
import hydra
import wandb


class Traj_dataset:
    traj_train = None
    traj_valid = None
    traj_test = None

class cfg:
    equation = 'KdV'


device = 'cuda'
nt = 14

lr = 1e-3
epochs = 100
batch_size = 32

unrolling = 0
initial_time_steps = 1664

print('Loading training data...')
with h5py.File(f'data_large/{cfg.equation}_train_100000_default.h5', 'r') as f:
    # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:10000, :131], dtype=torch.float32, device=cfg.device)
    Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:1000, :131], dtype=torch.float32)
    # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:100, :131], dtype=torch.float32, device=cfg.device)
# print('Loading validation data...')
# with h5py.File(f'data_large/{cfg.equation}_valid_1024_default.h5', 'r') as f:
#     Traj_dataset.traj_valid = torch.tensor(f['valid']['pde_140-256'][:, :131], dtype=torch.float32)
print('Loading test data...')
with h5py.File(f'data_large/{cfg.equation}_test_100000_default.h5', 'r') as f:
    # Traj_dataset.traj_test = torch.tensor(f['test']['pde_140-256'][:, :131], dtype=torch.float32)
    Traj_dataset.traj_test = torch.tensor(f['test']['pde_140-256'][:10000, :131], dtype=torch.float32)


def train(Y, train_nts, **kwargs):
    assert unrolling == 0

    acquire_step = kwargs.get('acquire_step', 0)

    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    inputs = []
    outputs = []
    for b in range(Y.shape[0]):
        for t in range(train_nts[b].item()-1):
            inputs.append(Y[b,t])
            outputs.append(Y[b, t+1])
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)

    dataset = torch.utils.data.TensorDataset(inputs, outputs)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model.train()
    for epoch in range(epochs):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        for x, y in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
    return model

def test(model):
    X_test = Traj_dataset.traj_test[:,0,:].unsqueeze(1).to(device)
    Y_test = Traj_dataset.traj_test[:,-1,:].unsqueeze(1).to(device)

    testset = torch.utils.data.TensorDataset(X_test, Y_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    model.eval()

    Y_test_pred = []
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            Y_test_pred.append(model(x))
        Y_test_pred = torch.cat(Y_test_pred, dim=0).to(device)
    
    metrics = compute_metrics(Y_test, Y_test_pred, d=1)

    return metrics

def test_trajectory(model):
    X_test = Traj_dataset.traj_test[:,0].unsqueeze(1).to(device)
    Y_test = Traj_dataset.traj_test[:,timestep::timestep].to(device)

    testset = torch.utils.data.TensorDataset(X_test, Y_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    model.eval()

    Y_test_pred = []
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            # print(y_pred.shape, y.shape)
            assert y_pred.shape == y.shape
            Y_test_pred.append(y_pred)
        Y_test_pred = torch.cat(Y_test_pred, dim=0).to(device)
    
    metrics = compute_metrics(Y_test, Y_test_pred, d=2)

    return metrics

class direct_model(torch.nn.Module):
    def __init__(self, model, unrolling):
        super().__init__()
        self.model = model
        self.unrolling = unrolling
    def forward(self, x):
        for _ in range(self.unrolling):
            x = self.model(x)
        return x
    
class trajectory_model(torch.nn.Module):
    def __init__(self, model, unrolling):
        super().__init__()
        self.model = model
        self.unrolling = unrolling
    def forward(self, x):
        trajectory = []
        for _ in range(self.unrolling):
            x = self.model(x)
            trajectory.append(x)
        return torch.cat(trajectory, dim=1) # [cfg.train.batch_size, unrolling, nx]

timestep = (Traj_dataset.traj_train.shape[1] - 1) // (nt - 1) # 10
assert timestep == 10 # hardcoded for now (130/ (14-1) = 10)

X = Traj_dataset.traj_train[:,0].unsqueeze(1).to(device)
Y = Traj_dataset.traj_train[:,0::timestep].to(device)

train_nts = torch.ones(X.shape[0], device=device, dtype=torch.int64)
# values are between 1 and 14, inclusive
# 1 means only initial data, 14 means all data

train_nts[:initial_time_steps//(nt-1)] = nt
if initial_time_steps % (nt-1) != 0:
    train_nts[initial_time_steps//(nt-1)] = initial_time_steps % (nt-1)

# train_idxs = torch.arange(initial_datasize, device=device)
# pool_idxs = torch.arange(initial_datasize, X.shape[0], device=device)

# X_train = X[train_idxs]
# Y_train = Y[train_idxs]

# X_pool = X[pool_idxs]

model = train(Y, train_nts, acquire_step=0)

results = {'datasize': [], 'l2': [], 'rel_l2': [], 'mse': []}

results['datasize'].append((train_nts-1).sum().item())
# rel_l2_list = [test(direct_model(model, nt-1))[1].mean().item() for model in ensemble]
metrics_list = torch.stack(test_trajectory(trajectory_model(model, nt-1))) # [3, datasize]
results['l2'].append(metrics_list[0, :].mean().item())
results['rel_l2'].append(metrics_list[1, :].mean().item())
results['mse'].append(metrics_list[2, :].mean().item())
print(f'Datasize: {results["datasize"][-1]}, L2: {results["l2"][-1]}, Rel_l2: {results["rel_l2"][-1]}, MSE: {results["mse"][-1]}')


Loading training data...
Loading test data...
Datasize: 1664, L2: 0.5233489871025085, Rel_l2: 0.12047380208969116, MSE: 0.08471187204122543


In [9]:
import h5py
import torch
import numpy as np
from neuralop.models import FNO
from tqdm import tqdm
import random

import argparse
import time

from eval_utils import compute_metrics

from utils import set_seed, flatten_configdict
from acquisition.acquirers import select, select_time

from omegaconf import OmegaConf
import hydra
import wandb


class Traj_dataset:
    traj_train = None
    traj_valid = None
    traj_test = None

class cfg:
    equation = 'KdV'


device = 'cuda'
nt = 14

lr = 1e-3
epochs = 100
batch_size = 32

unrolling = 0
initial_time_steps = 1664

print('Loading training data...')
with h5py.File(f'data_large/{cfg.equation}_train_100000_default.h5', 'r') as f:
    # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:10000, :131], dtype=torch.float32, device=cfg.device)
    Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:1000, :131], dtype=torch.float32)
    # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:100, :131], dtype=torch.float32, device=cfg.device)
# print('Loading validation data...')
# with h5py.File(f'data_large/{cfg.equation}_valid_1024_default.h5', 'r') as f:
#     Traj_dataset.traj_valid = torch.tensor(f['valid']['pde_140-256'][:, :131], dtype=torch.float32)
print('Loading test data...')
with h5py.File(f'data_large/{cfg.equation}_test_100000_default.h5', 'r') as f:
    # Traj_dataset.traj_test = torch.tensor(f['test']['pde_140-256'][:, :131], dtype=torch.float32)
    Traj_dataset.traj_test = torch.tensor(f['test']['pde_140-256'][:10000, :131], dtype=torch.float32)


def train(Y, train_nts, **kwargs):
    assert unrolling == 0

    acquire_step = kwargs.get('acquire_step', 0)

    models = [FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1) for _ in range(nt-1)]

    for t, model in tqdm(enumerate(models), total=nt-1):
        model.to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
        criterion = torch.nn.MSELoss()

        inputs = []
        outputs = []
        for b in range(Y.shape[0]):
            if t < train_nts[b].item()-1:
                inputs.append(Y[b,t])
                outputs.append(Y[b, t+1])
        inputs = torch.stack(inputs, dim=0).unsqueeze(1)
        outputs = torch.stack(outputs, dim=0).unsqueeze(1)

        dataset = torch.utils.data.TensorDataset(inputs, outputs)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for epoch in range(epochs):
            model.train()
            # max_unrolling = epoch if epoch <= unrolling else unrolling
            # unrolling_list = [r for r in range(max_unrolling + 1)]

            total_loss = 0
            for x, y in dataloader:
                optimizer.zero_grad()
                x, y = x.to(device), y.to(device)
                
                pred = model(x)
                loss = criterion(pred, y)

                # loss = torch.sqrt(loss)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            scheduler.step()
            
    return models

def test_trajectory(model):
    X_test = Traj_dataset.traj_test[:,0].unsqueeze(1).to(device)
    Y_test = Traj_dataset.traj_test[:,timestep::timestep].to(device)

    testset = torch.utils.data.TensorDataset(X_test, Y_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    model.eval()

    Y_test_pred = []
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            # print(y_pred.shape, y.shape)
            assert y_pred.shape == y.shape
            Y_test_pred.append(y_pred)
        Y_test_pred = torch.cat(Y_test_pred, dim=0).to(device)
    
    metrics = compute_metrics(Y_test, Y_test_pred, d=2)

    return metrics

# class direct_model(torch.nn.Module):
#     def __init__(self, model, unrolling):
#         super().__init__()
#         self.model = model
#         self.unrolling = unrolling
#     def forward(self, x):
#         for _ in range(self.unrolling):
#             x = self.model(x)
#         return x
    
class trajectory_model(torch.nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = models
    def forward(self, x):
        trajectory = []
        for model in self.models:
            x = model(x)
            trajectory.append(x)
        return torch.cat(trajectory, dim=1) # [cfg.train.batch_size, unrolling, nx]
    def eval(self):
        for model in self.models:
            model.eval()

timestep = (Traj_dataset.traj_train.shape[1] - 1) // (nt - 1) # 10
assert timestep == 10 # hardcoded for now (130/ (14-1) = 10)

X = Traj_dataset.traj_train[:,0].unsqueeze(1).to(device)
Y = Traj_dataset.traj_train[:,0::timestep].to(device)

train_nts = torch.ones(X.shape[0], device=device, dtype=torch.int64)
# values are between 1 and 14, inclusive
# 1 means only initial data, 14 means all data

train_nts[:initial_time_steps//(nt-1)] = nt
if initial_time_steps % (nt-1) != 0:
    train_nts[initial_time_steps//(nt-1)] = initial_time_steps % (nt-1)

# train_idxs = torch.arange(initial_datasize, device=device)
# pool_idxs = torch.arange(initial_datasize, X.shape[0], device=device)

# X_train = X[train_idxs]
# Y_train = Y[train_idxs]

# X_pool = X[pool_idxs]

models = train(Y, train_nts, acquire_step=0)

results = {'datasize': [], 'l2': [], 'rel_l2': [], 'mse': []}

results['datasize'].append((train_nts-1).sum().item())
# rel_l2_list = [test(direct_model(model, nt-1))[1].mean().item() for model in ensemble]
metrics_list = torch.stack(test_trajectory(trajectory_model(models))) # [3, datasize]
results['l2'].append(metrics_list[0, :].mean().item())
results['rel_l2'].append(metrics_list[1, :].mean().item())
results['mse'].append(metrics_list[2, :].mean().item())
print(f'Datasize: {results["datasize"][-1]}, L2: {results["l2"][-1]}, Rel_l2: {results["rel_l2"][-1]}, MSE: {results["mse"][-1]}')


Loading training data...
Loading test data...


100%|██████████| 13/13 [01:44<00:00,  8.04s/it]


Datasize: 1664, L2: 1.867620825767517, Rel_l2: 0.5719999074935913, MSE: 0.29869818687438965


In [10]:
import h5py
import torch
import numpy as np
from neuralop.models import FNO
from tqdm import tqdm
import random

import argparse
import time

from eval_utils import compute_metrics

from utils import set_seed, flatten_configdict
from acquisition.acquirers import select, select_time

from omegaconf import OmegaConf
import hydra
import wandb


class Traj_dataset:
    traj_train = None
    traj_valid = None
    traj_test = None

class cfg:
    equation = 'KdV'


device = 'cuda'
nt = 14

lr = 1e-3
epochs = 100
batch_size = 32

unrolling = 0
initial_time_steps = 1664

print('Loading training data...')
with h5py.File(f'data_large/{cfg.equation}_train_100000_default.h5', 'r') as f:
    # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:10000, :131], dtype=torch.float32, device=cfg.device)
    Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:1000, :131], dtype=torch.float32)
    # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:100, :131], dtype=torch.float32, device=cfg.device)
# print('Loading validation data...')
# with h5py.File(f'data_large/{cfg.equation}_valid_1024_default.h5', 'r') as f:
#     Traj_dataset.traj_valid = torch.tensor(f['valid']['pde_140-256'][:, :131], dtype=torch.float32)
print('Loading test data...')
with h5py.File(f'data_large/{cfg.equation}_test_100000_default.h5', 'r') as f:
    # Traj_dataset.traj_test = torch.tensor(f['test']['pde_140-256'][:, :131], dtype=torch.float32)
    Traj_dataset.traj_test = torch.tensor(f['test']['pde_140-256'][:10000, :131], dtype=torch.float32)


def train(Y, train_nts, **kwargs):
    assert unrolling == 0

    acquire_step = kwargs.get('acquire_step', 0)

    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)
    
    model = direct_model(model, nt-1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    inputs = []
    outputs = []
    for b in range(Y.shape[0]):
        if train_nts[b].item() > 1:
            inputs.append(Y[b,0])
            outputs.append(Y[b, -1])
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)

    dataset = torch.utils.data.TensorDataset(inputs, outputs)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model.train()
    for epoch in range(epochs):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        for x, y in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
    return model

def test(model):
    X_test = Traj_dataset.traj_test[:,0,:].unsqueeze(1).to(device)
    Y_test = Traj_dataset.traj_test[:,-1,:].unsqueeze(1).to(device)

    testset = torch.utils.data.TensorDataset(X_test, Y_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    model.eval()

    Y_test_pred = []
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            Y_test_pred.append(model(x))
        Y_test_pred = torch.cat(Y_test_pred, dim=0).to(device)
    
    metrics = compute_metrics(Y_test, Y_test_pred, d=1)

    return metrics

def test_trajectory(model):
    X_test = Traj_dataset.traj_test[:,0].unsqueeze(1).to(device)
    Y_test = Traj_dataset.traj_test[:,timestep::timestep].to(device)

    testset = torch.utils.data.TensorDataset(X_test, Y_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    model.eval()

    Y_test_pred = []
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            # print(y_pred.shape, y.shape)
            assert y_pred.shape == y.shape
            Y_test_pred.append(y_pred)
        Y_test_pred = torch.cat(Y_test_pred, dim=0).to(device)
    
    metrics = compute_metrics(Y_test, Y_test_pred, d=2)

    return metrics

class direct_model(torch.nn.Module):
    def __init__(self, model, unrolling):
        super().__init__()
        self.model = model
        self.unrolling = unrolling
    def forward(self, x):
        for _ in range(self.unrolling):
            x = self.model(x)
        return x
    
class trajectory_model(torch.nn.Module):
    def __init__(self, model, unrolling):
        super().__init__()
        self.model = model
        self.unrolling = unrolling
    def forward(self, x):
        trajectory = []
        for _ in range(self.unrolling):
            x = self.model(x)
            trajectory.append(x)
        return torch.cat(trajectory, dim=1) # [cfg.train.batch_size, unrolling, nx]

timestep = (Traj_dataset.traj_train.shape[1] - 1) // (nt - 1) # 10
assert timestep == 10 # hardcoded for now (130/ (14-1) = 10)

X = Traj_dataset.traj_train[:,0].unsqueeze(1).to(device)
Y = Traj_dataset.traj_train[:,0::timestep].to(device)

train_nts = torch.ones(X.shape[0], device=device, dtype=torch.int64)
# values are between 1 and 14, inclusive
# 1 means only initial data, 14 means all data

train_nts[:initial_time_steps//(nt-1)] = nt
if initial_time_steps % (nt-1) != 0:
    train_nts[initial_time_steps//(nt-1)] = initial_time_steps % (nt-1)

# train_idxs = torch.arange(initial_datasize, device=device)
# pool_idxs = torch.arange(initial_datasize, X.shape[0], device=device)

# X_train = X[train_idxs]
# Y_train = Y[train_idxs]

# X_pool = X[pool_idxs]

model = train(Y, train_nts, acquire_step=0)

results = {'datasize': [], 'l2': [], 'rel_l2': [], 'mse': []}

results['datasize'].append((train_nts-1).sum().item())
# rel_l2_list = [test(direct_model(model, nt-1))[1].mean().item() for model in ensemble]
metrics_list = torch.stack(test_trajectory(trajectory_model(model, nt-1))) # [3, datasize]
results['l2'].append(metrics_list[0, :].mean().item())
results['rel_l2'].append(metrics_list[1, :].mean().item())
results['mse'].append(metrics_list[2, :].mean().item())
print(f'Datasize: {results["datasize"][-1]}, L2: {results["l2"][-1]}, Rel_l2: {results["rel_l2"][-1]}, MSE: {results["mse"][-1]}')


Loading training data...
Loading test data...
Datasize: 1664, L2: 2.91375994682312, Rel_l2: 90456912.0, MSE: 0.41966328024864197


In [None]:
import h5py
import torch
import numpy as np
from neuralop.models import FNO
from tqdm import tqdm
import random

import argparse
import time

from eval_utils import compute_metrics

from utils import set_seed, flatten_configdict
from acquisition.acquirers import select, select_time

from omegaconf import OmegaConf
import hydra
import wandb


class Traj_dataset:
    traj_train = None
    traj_valid = None
    traj_test = None

class cfg:
    equation = 'KdV'


device = 'cuda'
nt = 14

lr = 1e-3
epochs = 100
batch_size = 32

unrolling = 0
initial_time_steps = 1664

print('Loading training data...')
with h5py.File(f'data_large/{cfg.equation}_train_100000_default.h5', 'r') as f:
    # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:10000, :131], dtype=torch.float32, device=cfg.device)
    Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:1000, :131], dtype=torch.float32)
    # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:100, :131], dtype=torch.float32, device=cfg.device)
# print('Loading validation data...')
# with h5py.File(f'data_large/{cfg.equation}_valid_1024_default.h5', 'r') as f:
#     Traj_dataset.traj_valid = torch.tensor(f['valid']['pde_140-256'][:, :131], dtype=torch.float32)
print('Loading test data...')
with h5py.File(f'data_large/{cfg.equation}_test_100000_default.h5', 'r') as f:
    # Traj_dataset.traj_test = torch.tensor(f['test']['pde_140-256'][:, :131], dtype=torch.float32)
    Traj_dataset.traj_test = torch.tensor(f['test']['pde_140-256'][:10000, :131], dtype=torch.float32)


def train(Y, train_nts, **kwargs):
    assert unrolling == 0

    acquire_step = kwargs.get('acquire_step', 0)

    model = FNO(n_modes=cfg.model.n_modes, hidden_channels=64,
                in_channels=1, out_channels=1)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    inputs = []
    outputs = []
    for b in range(Y.shape[0]):
        if train_nts[b].item() > 1:
            inputs.append(Y[b,0])
            outputs.append(Y[b, -1])
    inputs = torch.stack(inputs, dim=0).unsqueeze(1)
    outputs = torch.stack(outputs, dim=0).unsqueeze(1)

    dataset = torch.utils.data.TensorDataset(inputs, outputs)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model.train()
    for epoch in range(epochs):
        model.train()
        # max_unrolling = epoch if epoch <= unrolling else unrolling
        # unrolling_list = [r for r in range(max_unrolling + 1)]

        total_loss = 0
        for x, y in dataloader:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            
            pred = model(x)
            loss = criterion(pred, y)

            # loss = torch.sqrt(loss)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
    return model

def test(model):
    X_test = Traj_dataset.traj_test[:,0,:].unsqueeze(1).to(device)
    Y_test = Traj_dataset.traj_test[:,-1,:].unsqueeze(1).to(device)

    testset = torch.utils.data.TensorDataset(X_test, Y_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    model.eval()

    Y_test_pred = []
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            Y_test_pred.append(model(x))
        Y_test_pred = torch.cat(Y_test_pred, dim=0).to(device)
    
    metrics = compute_metrics(Y_test, Y_test_pred, d=1)

    return metrics

def test_trajectory(model):
    X_test = Traj_dataset.traj_test[:,0].unsqueeze(1).to(device)
    Y_test = Traj_dataset.traj_test[:,timestep::timestep].to(device)

    testset = torch.utils.data.TensorDataset(X_test, Y_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    model.eval()

    Y_test_pred = []
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            # print(y_pred.shape, y.shape)
            assert y_pred.shape == y.shape
            Y_test_pred.append(y_pred)
        Y_test_pred = torch.cat(Y_test_pred, dim=0).to(device)
    
    metrics = compute_metrics(Y_test, Y_test_pred, d=2)

    return metrics

class direct_model(torch.nn.Module):
    def __init__(self, model, unrolling):
        super().__init__()
        self.model = model
        self.unrolling = unrolling
    def forward(self, x):
        for _ in range(self.unrolling):
            x = self.model(x)
        return x
    
class trajectory_model(torch.nn.Module):
    def __init__(self, model, unrolling):
        super().__init__()
        self.model = model
        self.unrolling = unrolling
    def forward(self, x):
        trajectory = []
        for _ in range(self.unrolling):
            x = self.model(x)
            trajectory.append(x)
        return torch.cat(trajectory, dim=1) # [cfg.train.batch_size, unrolling, nx]

timestep = (Traj_dataset.traj_train.shape[1] - 1) // (nt - 1) # 10
assert timestep == 10 # hardcoded for now (130/ (14-1) = 10)

X = Traj_dataset.traj_train[:,0].unsqueeze(1).to(device)
Y = Traj_dataset.traj_train[:,0::timestep].to(device)

train_nts = torch.ones(X.shape[0], device=device, dtype=torch.int64)
# values are between 1 and 14, inclusive
# 1 means only initial data, 14 means all data

train_nts[:initial_time_steps//(nt-1)] = nt
if initial_time_steps % (nt-1) != 0:
    train_nts[initial_time_steps//(nt-1)] = initial_time_steps % (nt-1)

# train_idxs = torch.arange(initial_datasize, device=device)
# pool_idxs = torch.arange(initial_datasize, X.shape[0], device=device)

# X_train = X[train_idxs]
# Y_train = Y[train_idxs]

# X_pool = X[pool_idxs]

model = train(Y, train_nts, acquire_step=0)

results = {'datasize': [], 'l2': [], 'rel_l2': [], 'mse': []}

results['datasize'].append((train_nts-1).sum().item())
# rel_l2_list = [test(direct_model(model, nt-1))[1].mean().item() for model in ensemble]
metrics_list = torch.stack(test_trajectory(trajectory_model(model, nt-1))) # [3, datasize]
results['l2'].append(metrics_list[0, :].mean().item())
results['rel_l2'].append(metrics_list[1, :].mean().item())
results['mse'].append(metrics_list[2, :].mean().item())
print(f'Datasize: {results["datasize"][-1]}, L2: {results["l2"][-1]}, Rel_l2: {results["rel_l2"][-1]}, MSE: {results["mse"][-1]}')
