In [1]:
import h5py
import torch
import random

device = torch.device('cuda:6' if torch.cuda.is_available() else 'cpu')

In [2]:
# with h5py.File('data/KdV_train_1024_default.h5', 'r') as f:
#     traj_train = torch.tensor(f['train']['pde_140-256'][:], dtype=torch.float32)
# with h5py.File('data/KdV_valid_1024_default.h5', 'r') as f:
#     traj_valid = torch.tensor(f['valid']['pde_140-256'][:], dtype=torch.float32)
# with h5py.File('data/KdV_test_4096_default.h5', 'r') as f:
#     traj_test = torch.tensor(f['test']['pde_140-256'][:], dtype=torch.float32)

class args:
    equation = 'KdV'

class Traj_dataset:
    traj_train = None
    traj_valid = None
    traj_test = None

print('Loading training data...')
with h5py.File(f'data_large/{args.equation}_train_100000_default.h5', 'r') as f:
    # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:10000, :131], dtype=torch.float32, device=cfg.device)
    Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:1000, :10], dtype=torch.float32)
    # Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:100, :131], dtype=torch.float32, device=cfg.device)
# print('Loading validation data...')
# with h5py.File(f'data_large/{cfg.equation}_valid_1024_default.h5', 'r') as f:
#     Traj_dataset.traj_valid = torch.tensor(f['valid']['pde_140-256'][:, :131], dtype=torch.float32)
print('Loading test data...')
with h5py.File(f'data_large/{args.equation}_test_100000_default.h5', 'r') as f:
    # Traj_dataset.traj_test = torch.tensor(f['test']['pde_140-256'][:, :131], dtype=torch.float32)
    Traj_dataset.traj_test = torch.tensor(f['test']['pde_140-256'][:10000, :10], dtype=torch.float32)

In [3]:
from typing import Any, Dict, List, Tuple
import torch.nn.functional as F

class LpLoss(object):
    def __init__(self, d=2, p=2, size_average=True, reduction=True):
        super(LpLoss, self).__init__()

        #Dimension and Lp-norm type are postive
        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.reduction = reduction
        self.size_average = size_average

    def abs(self, x, y):
        num_examples = x.size()[0]

        #Assume uniform mesh
        h = 1.0 / (x.size()[1] - 1.0) if x.size()[1] > 1 else 1.0

        all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(all_norms)
            else:
                return torch.sum(all_norms)

        return all_norms

    def rel(self, x, y):
        num_examples = x.size()[0]

        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms/y_norms)
            else:
                return torch.sum(diff_norms/y_norms)

        return diff_norms/y_norms

    def __call__(self, x, y):
        return self.rel(x, y)

def compute_metrics(y, y_pred, d=1) :
    L2_func = LpLoss(d=d, p=2, reduction=False)
    if y.shape != y_pred.shape :
        raise NotImplementedError
    l2 = L2_func.abs(y, y_pred) # [bs]
    relative_l2 = L2_func.rel(y, y_pred) # [bs]
    mse = F.mse_loss(y_pred, y, reduction='none') # [bs]
    mse = mse.mean(dim=tuple(range(1, mse.ndim)))
    return l2, relative_l2, mse


In [17]:
epochs = 500
lr = 0.001
batch_size = 32

In [18]:
import torch
import numpy as np
from neuralop.models import FNO
from tqdm import tqdm

from acquisition.acquirers import select


def experiment(initial_datasize=256, batch_acquire=32, num_acquire=1, device='cpu', **cfg):
    unrolling = cfg.get('unrolling', 1)
    nt = cfg.get('nt', 14)
    ensemble_size = cfg.get('ensemble_size', 5)

    def train(X_train, Y_train):
        model = FNO(n_modes=(256, ), hidden_channels=64,
                    in_channels=1, out_channels=1)

        model = model.to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
        criterion = torch.nn.MSELoss()

        dataset = torch.utils.data.TensorDataset(X_train, Y_train)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        model.train()
        for epoch in range(epochs):
            model.train()
            max_unrolling = epoch if epoch <= unrolling else unrolling
            unrolling_list = [r for r in range(max_unrolling + 1)]

            # Loop over every epoch as often as the number of timesteps in one trajectory.
            # Since the starting point is randomly drawn, this in expectation has every possible starting point/sample combination of the training data.
            # Therefore in expectation the whole available training information is covered.
            total_loss = 0
            for i in range(nt):
                for x, y in dataloader:
                    optimizer.zero_grad()
                    x, y = x.to(device), y.to(device) # y has shape [batch_size, nt, nx]

                    unrolled = random.choice(unrolling_list)
                    bs = x.shape[0]

                    steps = [t for t in range(0, nt - 1 - unrolled)]
                    random_steps = random.choices(steps, k=bs)
                    inputs = torch.stack([y[b, random_steps[b]] for b in range(bs)], dim=0).unsqueeze(1)
                    outputs = torch.stack([y[b, random_steps[b] + unrolled+1] for b in range(bs)], dim=0).unsqueeze(1)

                    # pushforward
                    with torch.no_grad():
                        model.eval()
                        for _ in range(unrolled):
                            inputs = model(inputs)
                        model.train()
                    
                    pred = model(inputs)
                    loss = criterion(pred, outputs)

                    # loss = torch.sqrt(loss)
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
            scheduler.step()
        return model

    def test(model):
        X_test = Traj_dataset.traj_test[:,0,:].unsqueeze(1).to(device)
        Y_test = Traj_dataset.traj_test[:,-1,:].unsqueeze(1).to(device)

        testset = torch.utils.data.TensorDataset(X_test, Y_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

        model.eval()
    
        Y_test_pred = []
        with torch.no_grad():
            for x, y in testloader:
                x, y = x.to(device), y.to(device)
                for _ in range(nt-1):
                    x = model(x)
                Y_test_pred.append(x)
            Y_test_pred = torch.cat(Y_test_pred, dim=0).to(Y_test.device)
        
        metrics = compute_metrics(Y_test, Y_test_pred, d=1)

        return metrics

    results = {'datasize': [], 'rel_l2': []}

    timestep = (Traj_dataset.traj_train.shape[1] - 1) // (nt - 1) # 10
    assert timestep == 10

    X = Traj_dataset.traj_train[:,0].unsqueeze(1).to(device)
    Y = Traj_dataset.traj_train[:,0::timestep].to(device)

    train_idxs = torch.arange(initial_datasize, device=device)
    pool_idxs = torch.arange(initial_datasize, X.shape[0], device=device)

    X_train = X[train_idxs]
    Y_train = Y[train_idxs]

    X_pool = X[pool_idxs]

    ensemble = [train(X_train, Y_train) for _ in tqdm(range(ensemble_size))]

    results['datasize'].append(train_idxs.shape[0])
    rel_l2_list = [test(model)[1].mean().item() for model in ensemble]
    results['rel_l2'].append(torch.mean(torch.tensor(rel_l2_list)).item())
    print(f'{results["datasize"][-1]}, {results["rel_l2"][-1]}')

    class unrolled_model(torch.nn.Module):
        def __init__(self, model, unrolling):
            super().__init__()
            self.model = model
            self.unrolling = unrolling
        def forward(self, x):
            for _ in range(self.unrolling):
                x = self.model(x)
            return x

    for i in range(num_acquire):
        unrolled_ensemble = [unrolled_model(model, nt-1) for model in ensemble]
        new_idxs = select(unrolled_ensemble, X_train, X_pool, batch_acquire, selection_method='random', acquisition_function='variance', device=device)
        # new_idxs = select_var(ensemble, X_pool, batch_acquire)

        new_idxs = new_idxs.to(device)
        print(new_idxs)
        # print(f'{len(new_idxs)=}')
        logical_new_idxs = torch.zeros(pool_idxs.shape[-1], dtype=torch.bool, device=device)
        logical_new_idxs[new_idxs] = True
        train_idxs = torch.cat([train_idxs, pool_idxs[logical_new_idxs]], dim=-1)
        pool_idxs = pool_idxs[~logical_new_idxs]

        X_train = X[train_idxs]
        Y_train = Y[train_idxs]

        X_pool = X[pool_idxs]

        ensemble = [train(X_train, Y_train) for _ in tqdm(range(ensemble_size))]

        results['datasize'].append(train_idxs.shape[0])
        rel_l2_list = [test(model)[1].mean().item() for model in ensemble]
        results['rel_l2'].append(torch.mean(torch.tensor(rel_l2_list)).item())
        print(f'{results["datasize"][-1]}, {results["rel_l2"][-1]}')
    
    return results

results = experiment(initial_datasize=256, batch_acquire=32, num_acquire=1, device=device, unrolling=0, nt=14, ensemble_size=5)

100%|██████████| 5/5 [32:54<00:00, 394.97s/it]


256, 0.0693989247083664
tensor([199, 677, 722, 450, 478, 697, 316, 215, 169, 161, 171, 703, 276, 373,
        427, 695, 692, 355, 714, 598, 412, 152, 683, 492, 278, 498, 516, 342,
        716,  32, 352, 749], device='cuda:6')


 40%|████      | 2/5 [14:28<21:44, 434.99s/it]