In [1]:
import h5py
import torch
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# with h5py.File('data/KdV_train_1024_default.h5', 'r') as f:
#     traj_train = torch.tensor(f['train']['pde_140-256'][:], dtype=torch.float32)
# with h5py.File('data/KdV_valid_1024_default.h5', 'r') as f:
#     traj_valid = torch.tensor(f['valid']['pde_140-256'][:], dtype=torch.float32)
# with h5py.File('data/KdV_test_4096_default.h5', 'r') as f:
#     traj_test = torch.tensor(f['test']['pde_140-256'][:], dtype=torch.float32)

class args:
    equation = 'KdV'

class Traj_dataset:
    traj_train = None
    traj_valid = None
    traj_test = None

with h5py.File(f'data/{args.equation}_train_1024_default.h5', 'r') as f:
    Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:], dtype=torch.float32)[:, :131]
with h5py.File(f'data/{args.equation}_valid_1024_default.h5', 'r') as f:
    Traj_dataset.traj_valid = torch.tensor(f['valid']['pde_140-256'][:], dtype=torch.float32)[:, :131]
with h5py.File(f'data/{args.equation}_test_4096_default.h5', 'r') as f:
    Traj_dataset.traj_test = torch.tensor(f['test']['pde_140-256'][:], dtype=torch.float32)[:, :131]

In [3]:
from typing import Any, Dict, List, Tuple
import torch.nn.functional as F

class LpLoss(object):
    def __init__(self, d=2, p=2, size_average=True, reduction=True):
        super(LpLoss, self).__init__()

        #Dimension and Lp-norm type are postive
        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.reduction = reduction
        self.size_average = size_average

    def abs(self, x, y):
        num_examples = x.size()[0]

        #Assume uniform mesh
        h = 1.0 / (x.size()[1] - 1.0) if x.size()[1] > 1 else 1.0

        all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(all_norms)
            else:
                return torch.sum(all_norms)

        return all_norms

    def rel(self, x, y):
        num_examples = x.size()[0]

        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms/y_norms)
            else:
                return torch.sum(diff_norms/y_norms)

        return diff_norms/y_norms

    def __call__(self, x, y):
        return self.rel(x, y)

def compute_metrics(y, y_pred, d=1) :
    L2_func = LpLoss(d=d, p=2, reduction=False)
    if y.shape != y_pred.shape :
        raise NotImplementedError
    l2 = L2_func.abs(y, y_pred) # [bs]
    relative_l2 = L2_func.rel(y, y_pred) # [bs]
    mse = F.mse_loss(y_pred, y, reduction='none') # [bs]
    mse = mse.mean(dim=tuple(range(1, mse.ndim)))
    return l2, relative_l2, mse


In [4]:
epochs = 1
lr = 0.001
batch_size = 32

In [5]:
import torch
import numpy as np
from neuralop.models import FNO
from tqdm import tqdm

from acquisition.acquirers import select

def train(X_train, Y_train):
    model = FNO(n_modes=(256, ), hidden_channels=64,
                    in_channels=1, out_channels=1)

    model = model.to(device)

    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = torch.nn.MSELoss()

    dataset = torch.utils.data.TensorDataset(X_train, Y_train)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        for x, y in dataloader:
            # x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            y_pred = model(x)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
        scheduler.step()
    
    return model

def test(model):
    X_test = Traj_dataset.traj_test[:,0,:].unsqueeze(1).to(device)
    Y_test = Traj_dataset.traj_test[:,-1,:].unsqueeze(1).to(device)

    testset = torch.utils.data.TensorDataset(X_test, Y_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    model.eval()
    
    Y_test_pred = []
    with torch.no_grad():
        for x, y in testloader:
            # x, y = x.to(device), y.to(device)
            y_pred = model(x)
            Y_test_pred.append(y_pred)
        Y_test_pred = torch.cat(Y_test_pred, dim=0).to(Y_test.device)
    
    metrics = compute_metrics(Y_test, Y_test_pred, d=1)
    return metrics


initial_datasize=256
batch_acquire=32
num_acquire=1
ensemble_size = 5

results = {'datasize': [], 'rel_l2': []}

X = Traj_dataset.traj_train[:,0].unsqueeze(1).to(device)
Y = Traj_dataset.traj_train[:,-1].unsqueeze(1).to(device)

train_idxs = torch.arange(initial_datasize, device=device)
pool_idxs = torch.arange(initial_datasize, X.shape[0], device=device)

X_train = X[train_idxs]
Y_train = Y[train_idxs]

X_pool = X[pool_idxs]

ensemble = [train(X_train, Y_train) for _ in tqdm(range(ensemble_size))]

results['datasize'].append(train_idxs.shape[0])
rel_l2_list = [test(model)[1].mean().item() for model in ensemble]
results['rel_l2'].append(torch.mean(torch.tensor(rel_l2_list)).item())
print(f'{results["datasize"][-1]}, {results["rel_l2"][-1]}')


100%|██████████| 5/5 [00:05<00:00,  1.17s/it]


256, 6.253170013427734


In [6]:
import torch.nn as nn

class MeanModel(nn.Module):
    def __init__(self, model):
        super(MeanModel, self).__init__()
        self.model = model
    def forward(self, x):
        y = self.model(x)
        y = torch.mean(y, dim=tuple(range(1, y.ndim)))
        y = y.unsqueeze(1)
        return y
    
Y_mean_train = Y_train.mean(dim=tuple(range(1, Y_train.ndim))).unsqueeze(1)


ensemble_mean = [MeanModel(model) for model in ensemble]

In [40]:
from bmdal_reg.bmdal.feature_data import TensorFeatureData
from bmdal_reg.bmdal.algorithms import select_batch

train_data = TensorFeatureData(X_train)
pool_data = TensorFeatureData(X_pool)
for bait_sigma in [1e-4, 1e-3, 1e-2, 1e-1, 1]:
    new_idxs, _ = select_batch(batch_size=32, models=ensemble_mean[:2], 
                               data={'train': train_data, 'pool': pool_data}, y_train=Y_mean_train,
                               selection_method='bait', sel_with_train=False, bait_sigma=bait_sigma,
                               base_kernel='predictions', kernel_transforms=[]) #[('rp', [5])])
    print(new_idxs)
# new_idxs, _ = select_batch(batch_size=32, models=ensemble_mean[:2], 
#                            data={'train': train_data, 'pool': pool_data}, y_train=Y_mean_train,
#                            selection_method='maxdiag', sel_with_train=False,
#                            base_kernel='predictions', kernel_transforms=[]) #[('rp', [5])])
# print(new_idxs)

print('---')

for maxdet_sigma in [1e-4, 1e-3, 1e-2, 1e-1, 1]:
    new_idxs, _ = select_batch(batch_size=32, models=ensemble_mean[:2], 
                            data={'train': train_data, 'pool': pool_data}, y_train=Y_mean_train,
                            selection_method='maxdet', sel_with_train=False, maxdet_sigma=maxdet_sigma,
                            base_kernel='predictions', kernel_transforms=[]) #[('rp', [5])])
    print(new_idxs)
# new_idxs, _ = select_batch(batch_size=32, models=ensemble_mean[:2], 
#                            data={'train': train_data, 'pool': pool_data}, y_train=Y_mean_train,
#                            selection_method='lcmd', sel_with_train=False,
#                            base_kernel='predictions', kernel_transforms=[]) #[('rp', [5])])
# print(new_idxs)


tensor([250, 612,  41,  24, 154, 233, 420, 592, 500, 644, 201, 184,  79, 681,
        244, 106,  63, 360, 397, 552, 571, 604, 547, 219, 731, 246, 171, 524,
        415, 323, 633, 725], device='cuda:0')
tensor([250,  41, 612, 121,  24, 681, 244, 753, 154, 428, 233, 600, 749, 500,
        313, 739, 441, 420, 592, 548, 662, 731, 219, 363, 342, 339, 482, 360,
        725, 524, 567, 558], device='cuda:0')
tensor([250,  41, 612, 121,  24, 681, 244, 753, 428, 154, 548, 749, 441, 233,
        600, 313, 549, 464, 739, 461, 308, 420, 647, 311, 397, 727, 592, 717,
        500, 662, 532,  74], device='cuda:0')
tensor([250,  41, 612, 121,  24, 681, 244, 753, 428, 154, 548, 749, 441, 233,
        600, 313, 549, 464, 739, 461, 308, 420, 647, 311, 397, 727, 592, 717,
        500, 662, 532,  74], device='cuda:0')
tensor([250,  41, 612, 121,  24, 681, 244, 753, 428, 154, 548, 749, 441, 233,
        600, 313, 549, 464, 739, 461, 308, 420, 647, 311, 397, 727, 592, 717,
        500, 662, 532,  74], device=

In [33]:
from bmdal_reg.bmdal.feature_data import TensorFeatureData
from bmdal_reg.bmdal.algorithms import select_batch

train_data = TensorFeatureData(X_train)
pool_data = TensorFeatureData(X_pool)
new_idxs, _ = select_batch(batch_size=50, models=ensemble[:2], 
                           data={'train': train_data, 'pool': pool_data}, y_train=Y_train,
                           selection_method='bait', sel_with_train=False, bait_sigma=1e-3,
                           base_kernel='predictions', kernel_transforms=[]) #[('rp', [5])])
print(new_idxs)


AssertionError: 