In [14]:
import h5py
import torch
import random

device = torch.device('cuda:6' if torch.cuda.is_available() else 'cpu')

In [2]:
# with h5py.File('data/KdV_train_1024_default.h5', 'r') as f:
#     traj_train = torch.tensor(f['train']['pde_140-256'][:], dtype=torch.float32)
# with h5py.File('data/KdV_valid_1024_default.h5', 'r') as f:
#     traj_valid = torch.tensor(f['valid']['pde_140-256'][:], dtype=torch.float32)
# with h5py.File('data/KdV_test_4096_default.h5', 'r') as f:
#     traj_test = torch.tensor(f['test']['pde_140-256'][:], dtype=torch.float32)

class args:
    equation = 'KdV'

class Traj_dataset:
    traj_train = None
    traj_valid = None
    traj_test = None

with h5py.File(f'data/{args.equation}_train_1024_default.h5', 'r') as f:
    Traj_dataset.traj_train = torch.tensor(f['train']['pde_140-256'][:], dtype=torch.float32)[:, :131]
with h5py.File(f'data/{args.equation}_valid_1024_default.h5', 'r') as f:
    Traj_dataset.traj_valid = torch.tensor(f['valid']['pde_140-256'][:], dtype=torch.float32)[:, :131]
with h5py.File(f'data/{args.equation}_test_4096_default.h5', 'r') as f:
    Traj_dataset.traj_test = torch.tensor(f['test']['pde_140-256'][:], dtype=torch.float32)[:, :131]

In [3]:
from typing import Any, Dict, List, Tuple
import torch.nn.functional as F

class LpLoss(object):
    def __init__(self, d=2, p=2, size_average=True, reduction=True):
        super(LpLoss, self).__init__()

        #Dimension and Lp-norm type are postive
        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.reduction = reduction
        self.size_average = size_average

    def abs(self, x, y):
        num_examples = x.size()[0]

        #Assume uniform mesh
        h = 1.0 / (x.size()[1] - 1.0) if x.size()[1] > 1 else 1.0

        all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(all_norms)
            else:
                return torch.sum(all_norms)

        return all_norms

    def rel(self, x, y):
        num_examples = x.size()[0]

        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms/y_norms)
            else:
                return torch.sum(diff_norms/y_norms)

        return diff_norms/y_norms

    def __call__(self, x, y):
        return self.rel(x, y)

def compute_metrics(y, y_pred, d=1) :
    L2_func = LpLoss(d=d, p=2, reduction=False)
    if y.shape != y_pred.shape :
        raise NotImplementedError
    l2 = L2_func.abs(y, y_pred) # [bs]
    relative_l2 = L2_func.rel(y, y_pred) # [bs]
    mse = F.mse_loss(y_pred, y, reduction='none') # [bs]
    mse = mse.mean(dim=tuple(range(1, mse.ndim)))
    return l2, relative_l2, mse


In [8]:
epochs = 500
lr = 0.001
batch_size = 32

In [9]:
import torch
import numpy as np
from neuralop.models import FNO
from tqdm import tqdm

from acquisition.acquirers import select


def experiment_direct(initial_datasize=256, batch_acquire=32, num_acquire=1, device='cpu', **cfg):

    def train(X_train, Y_train):
        model = FNO(n_modes=(256, ), hidden_channels=64,
                        in_channels=1, out_channels=1)

        model = model.to(device)

        model.train()

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
        criterion = torch.nn.MSELoss()

        dataset = torch.utils.data.TensorDataset(X_train, Y_train)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for epoch in range(epochs):
            for x, y in dataloader:
                # x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                y_pred = model(x)
                loss = criterion(y_pred, y)
                loss.backward()
                optimizer.step()
            scheduler.step()
        
        return model

    def test(model):
        X_test = Traj_dataset.traj_test[:,0,:].unsqueeze(1).to(device)
        Y_test = Traj_dataset.traj_test[:,-1,:].unsqueeze(1).to(device)

        testset = torch.utils.data.TensorDataset(X_test, Y_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

        model.eval()
        
        Y_test_pred = []
        with torch.no_grad():
            for x, y in testloader:
                # x, y = x.to(device), y.to(device)
                y_pred = model(x)
                Y_test_pred.append(y_pred)
            Y_test_pred = torch.cat(Y_test_pred, dim=0).to(Y_test.device)
        
        metrics = compute_metrics(Y_test, Y_test_pred, d=1)
        return metrics

    def select_var(ensemble, X_pool, batch_acquire):
        for model in ensemble:
            model.eval()
        with torch.no_grad():
            y_preds = [model(X_pool) for model in ensemble]
            y_preds = torch.stack(y_preds, dim=0) # [ensemble_size, pool_size, ...]
            y_preds_var = y_preds.var(dim=0) # [pool_size, ...]
            y_preds_var = y_preds_var.mean(dim=tuple(range(1, y_preds_var.ndim))) # [pool_size]
        new_idxs = torch.argsort(y_preds_var, descending=True)[:batch_acquire]
        return new_idxs

    ensemble_size = cfg.get('ensemble_size', 5)

    results = {'datasize': [], 'rel_l2': []}

    X = Traj_dataset.traj_train[:,0].unsqueeze(1).to(device)
    Y = Traj_dataset.traj_train[:,-1].unsqueeze(1).to(device)

    train_idxs = torch.arange(initial_datasize, device=device)
    pool_idxs = torch.arange(initial_datasize, X.shape[0], device=device)

    X_train = X[train_idxs]
    Y_train = Y[train_idxs]

    X_pool = X[pool_idxs]

    ensemble = [train(X_train, Y_train) for _ in tqdm(range(ensemble_size))]

    results['datasize'].append(train_idxs.shape[0])
    rel_l2_list = [test(model)[1].mean().item() for model in ensemble]
    results['rel_l2'].append(torch.mean(torch.tensor(rel_l2_list)).item())
    print(f'{results["datasize"][-1]}, {results["rel_l2"][-1]}')

    for i in range(num_acquire):
        new_idxs = select(ensemble, X_train, X_pool, batch_acquire, selection_method='greedy', acquisition_function='variance', device=device)
        # new_idxs = select_var(ensemble, X_pool, batch_acquire)

        new_idxs = new_idxs.to(device)
        print(new_idxs)
        # print(f'{len(new_idxs)=}')
        logical_new_idxs = torch.zeros(pool_idxs.shape[-1], dtype=torch.bool, device=device)
        logical_new_idxs[new_idxs] = True
        train_idxs = torch.cat([train_idxs, pool_idxs[logical_new_idxs]], dim=-1)
        pool_idxs = pool_idxs[~logical_new_idxs]

        X_train = X[train_idxs]
        Y_train = Y[train_idxs]

        X_pool = X[pool_idxs]

        ensemble = [train(X_train, Y_train) for _ in tqdm(range(ensemble_size))]

        results['datasize'].append(train_idxs.shape[0])
        rel_l2_list = [test(model)[1].mean().item() for model in ensemble]
        results['rel_l2'].append(torch.mean(torch.tensor(rel_l2_list)).item())
        print(f'{results["datasize"][-1]}, {results["rel_l2"][-1]}')
    
    return results

results = experiment_direct(initial_datasize=256, batch_acquire=32, num_acquire=1, device=device)

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [02:00<00:00, 24.18s/it]


256, 0.9775699377059937


100%|██████████| 32/32 [02:04<00:00,  3.89s/it]


tensor([ 41, 428, 759, 313, 612, 174, 330, 495, 250,  74, 634, 681, 751, 616,
        141, 667, 456, 158, 130, 724, 121, 548,  24, 342, 172, 318,  43, 363,
        209, 734, 524,  11], device='cuda:6')


100%|██████████| 5/5 [02:14<00:00, 26.91s/it]


288, 0.9684969186782837


In [10]:
import torch
import numpy as np
from neuralop.models import FNO
from tqdm import tqdm

from acquisition.acquirers import select


def experiment_direct(initial_datasize=256, batch_acquire=32, num_acquire=1, device='cpu', **cfg):

    def train(X_train, Y_train):
        model = FNO(n_modes=(256, ), hidden_channels=64,
                        in_channels=1, out_channels=1)

        model = model.to(device)

        model.train()

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
        criterion = torch.nn.MSELoss()

        dataset = torch.utils.data.TensorDataset(X_train, Y_train)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for epoch in range(epochs):
            for x, y in dataloader:
                # x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                y_pred = model(x)
                loss = criterion(y_pred, y)
                loss.backward()
                optimizer.step()
            scheduler.step()
        
        return model

    def test(model):
        X_test = Traj_dataset.traj_test[:,0,:].unsqueeze(1).to(device)
        Y_test = Traj_dataset.traj_test[:,-1,:].unsqueeze(1).to(device)

        testset = torch.utils.data.TensorDataset(X_test, Y_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

        model.eval()
        
        Y_test_pred = []
        with torch.no_grad():
            for x, y in testloader:
                # x, y = x.to(device), y.to(device)
                y_pred = model(x)
                Y_test_pred.append(y_pred)
            Y_test_pred = torch.cat(Y_test_pred, dim=0).to(Y_test.device)
        
        metrics = compute_metrics(Y_test, Y_test_pred, d=1)
        return metrics

    def select_var(ensemble, X_pool, batch_acquire):
        for model in ensemble:
            model.eval()
        with torch.no_grad():
            y_preds = [model(X_pool) for model in ensemble]
            y_preds = torch.stack(y_preds, dim=0) # [ensemble_size, pool_size, ...]
            y_preds_var = y_preds.var(dim=0) # [pool_size, ...]
            y_preds_var = y_preds_var.mean(dim=tuple(range(1, y_preds_var.ndim))) # [pool_size]
        new_idxs = torch.argsort(y_preds_var, descending=True)[:batch_acquire]
        return new_idxs

    ensemble_size = cfg.get('ensemble_size', 5)

    results = {'datasize': [], 'rel_l2': []}

    X = Traj_dataset.traj_train[:,0].unsqueeze(1).to(device)
    Y = Traj_dataset.traj_train[:,-1].unsqueeze(1).to(device)

    train_idxs = torch.arange(initial_datasize, device=device)
    pool_idxs = torch.arange(initial_datasize, X.shape[0], device=device)

    X_train = X[train_idxs]
    Y_train = Y[train_idxs]

    X_pool = X[pool_idxs]

    ensemble = [train(X_train, Y_train) for _ in tqdm(range(ensemble_size))]

    results['datasize'].append(train_idxs.shape[0])
    rel_l2_list = [test(model)[1].mean().item() for model in ensemble]
    results['rel_l2'].append(torch.mean(torch.tensor(rel_l2_list)).item())
    print(f'{results["datasize"][-1]}, {results["rel_l2"][-1]}')

    for i in range(num_acquire):
        new_idxs = select(ensemble, X_train, X_pool, batch_acquire, selection_method='random', acquisition_function='variance', device=device)
        # new_idxs = select_var(ensemble, X_pool, batch_acquire)

        new_idxs = new_idxs.to(device)
        print(new_idxs)
        # print(f'{len(new_idxs)=}')
        logical_new_idxs = torch.zeros(pool_idxs.shape[-1], dtype=torch.bool, device=device)
        logical_new_idxs[new_idxs] = True
        train_idxs = torch.cat([train_idxs, pool_idxs[logical_new_idxs]], dim=-1)
        pool_idxs = pool_idxs[~logical_new_idxs]

        X_train = X[train_idxs]
        Y_train = Y[train_idxs]

        X_pool = X[pool_idxs]

        ensemble = [train(X_train, Y_train) for _ in tqdm(range(ensemble_size))]

        results['datasize'].append(train_idxs.shape[0])
        rel_l2_list = [test(model)[1].mean().item() for model in ensemble]
        results['rel_l2'].append(torch.mean(torch.tensor(rel_l2_list)).item())
        print(f'{results["datasize"][-1]}, {results["rel_l2"][-1]}')
    
    return results

results = experiment_direct(initial_datasize=256, batch_acquire=32, num_acquire=1, device=device)

100%|██████████| 5/5 [01:57<00:00, 23.53s/it]


256, 0.9730780720710754
tensor([677, 171, 487, 145, 227, 233, 666, 275, 544, 129, 195, 564, 292, 186,
        760,  95, 594,  67, 571, 229, 588, 574, 272, 526, 484, 757, 766, 416,
        236, 475,  18, 518], device='cuda:6')


100%|██████████| 5/5 [02:15<00:00, 27.13s/it]


288, 0.9202383160591125


In [11]:
import torch
import numpy as np
from neuralop.models import FNO
from tqdm import tqdm

from acquisition.acquirers import select


def experiment_direct(initial_datasize=256, batch_acquire=32, num_acquire=1, device='cpu', **cfg):

    def train(X_train, Y_train):
        model = FNO(n_modes=(256, ), hidden_channels=64,
                        in_channels=1, out_channels=1)

        model = model.to(device)

        model.train()

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
        criterion = torch.nn.MSELoss()

        dataset = torch.utils.data.TensorDataset(X_train, Y_train)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for epoch in range(epochs):
            for x, y in dataloader:
                # x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                y_pred = model(x)
                loss = criterion(y_pred, y)
                loss.backward()
                optimizer.step()
            scheduler.step()
        
        return model

    def test(model):
        X_test = Traj_dataset.traj_test[:,0,:].unsqueeze(1).to(device)
        Y_test = Traj_dataset.traj_test[:,-1,:].unsqueeze(1).to(device)

        testset = torch.utils.data.TensorDataset(X_test, Y_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

        model.eval()
        
        Y_test_pred = []
        with torch.no_grad():
            for x, y in testloader:
                # x, y = x.to(device), y.to(device)
                y_pred = model(x)
                Y_test_pred.append(y_pred)
            Y_test_pred = torch.cat(Y_test_pred, dim=0).to(Y_test.device)
        
        metrics = compute_metrics(Y_test, Y_test_pred, d=1)
        return metrics

    def select_var(ensemble, X_pool, batch_acquire):
        for model in ensemble:
            model.eval()
        with torch.no_grad():
            y_preds = [model(X_pool) for model in ensemble]
            y_preds = torch.stack(y_preds, dim=0) # [ensemble_size, pool_size, ...]
            y_preds_var = y_preds.var(dim=0) # [pool_size, ...]
            y_preds_var = y_preds_var.mean(dim=tuple(range(1, y_preds_var.ndim))) # [pool_size]
        new_idxs = torch.argsort(y_preds_var, descending=True)[:batch_acquire]
        return new_idxs

    ensemble_size = cfg.get('ensemble_size', 5)

    results = {'datasize': [], 'rel_l2': []}

    X = Traj_dataset.traj_train[:,0].unsqueeze(1).to(device)
    Y = Traj_dataset.traj_train[:,-1].unsqueeze(1).to(device)

    train_idxs = torch.arange(initial_datasize, device=device)
    pool_idxs = torch.arange(initial_datasize, X.shape[0], device=device)

    X_train = X[train_idxs]
    Y_train = Y[train_idxs]

    X_pool = X[pool_idxs]

    ensemble = [train(X_train, Y_train) for _ in tqdm(range(ensemble_size))]

    results['datasize'].append(train_idxs.shape[0])
    rel_l2_list = [test(model)[1].mean().item() for model in ensemble]
    results['rel_l2'].append(torch.mean(torch.tensor(rel_l2_list)).item())
    print(f'{results["datasize"][-1]}, {results["rel_l2"][-1]}')

    for i in range(num_acquire):
        new_idxs = select(ensemble, X_train, X_pool, batch_acquire, selection_method='greedy', acquisition_function='bait', device=device)
        # new_idxs = select_var(ensemble, X_pool, batch_acquire)

        new_idxs = new_idxs.to(device)
        print(new_idxs)
        # print(f'{len(new_idxs)=}')
        logical_new_idxs = torch.zeros(pool_idxs.shape[-1], dtype=torch.bool, device=device)
        logical_new_idxs[new_idxs] = True
        train_idxs = torch.cat([train_idxs, pool_idxs[logical_new_idxs]], dim=-1)
        pool_idxs = pool_idxs[~logical_new_idxs]

        X_train = X[train_idxs]
        Y_train = Y[train_idxs]

        X_pool = X[pool_idxs]

        ensemble = [train(X_train, Y_train) for _ in tqdm(range(ensemble_size))]

        results['datasize'].append(train_idxs.shape[0])
        rel_l2_list = [test(model)[1].mean().item() for model in ensemble]
        results['rel_l2'].append(torch.mean(torch.tensor(rel_l2_list)).item())
        print(f'{results["datasize"][-1]}, {results["rel_l2"][-1]}')
    
    return results

results = experiment_direct(initial_datasize=256, batch_acquire=32, num_acquire=1, device=device)

100%|██████████| 5/5 [02:01<00:00, 24.30s/it]


256, 0.974908709526062


100%|██████████| 32/32 [02:13<00:00,  4.17s/it]


tensor([ 41, 428, 759, 612, 330, 634, 250, 751, 313, 681, 174,  74, 616, 141,
        342, 172, 158, 130, 495,  11, 667, 121, 724,  43, 481, 524, 456, 548,
        318, 326, 413, 170], device='cuda:6')


100%|██████████| 5/5 [02:21<00:00, 28.29s/it]


288, 0.9574222564697266


In [12]:
import torch
import numpy as np
from neuralop.models import FNO
from tqdm import tqdm

from acquisition.acquirers import select


def experiment_direct(initial_datasize=256, batch_acquire=32, num_acquire=1, device='cpu', **cfg):

    def train(X_train, Y_train):
        model = FNO(n_modes=(256, ), hidden_channels=64,
                        in_channels=1, out_channels=1)

        model = model.to(device)

        model.train()

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
        criterion = torch.nn.MSELoss()

        dataset = torch.utils.data.TensorDataset(X_train, Y_train)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for epoch in range(epochs):
            for x, y in dataloader:
                # x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                y_pred = model(x)
                loss = criterion(y_pred, y)
                loss.backward()
                optimizer.step()
            scheduler.step()
        
        return model

    def test(model):
        X_test = Traj_dataset.traj_test[:,0,:].unsqueeze(1).to(device)
        Y_test = Traj_dataset.traj_test[:,-1,:].unsqueeze(1).to(device)

        testset = torch.utils.data.TensorDataset(X_test, Y_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

        model.eval()
        
        Y_test_pred = []
        with torch.no_grad():
            for x, y in testloader:
                # x, y = x.to(device), y.to(device)
                y_pred = model(x)
                Y_test_pred.append(y_pred)
            Y_test_pred = torch.cat(Y_test_pred, dim=0).to(Y_test.device)
        
        metrics = compute_metrics(Y_test, Y_test_pred, d=1)
        return metrics

    def select_var(ensemble, X_pool, batch_acquire):
        for model in ensemble:
            model.eval()
        with torch.no_grad():
            y_preds = [model(X_pool) for model in ensemble]
            y_preds = torch.stack(y_preds, dim=0) # [ensemble_size, pool_size, ...]
            y_preds_var = y_preds.var(dim=0) # [pool_size, ...]
            y_preds_var = y_preds_var.mean(dim=tuple(range(1, y_preds_var.ndim))) # [pool_size]
        new_idxs = torch.argsort(y_preds_var, descending=True)[:batch_acquire]
        return new_idxs

    ensemble_size = cfg.get('ensemble_size', 5)

    results = {'datasize': [], 'rel_l2': []}

    X = Traj_dataset.traj_train[:,0].unsqueeze(1).to(device)
    Y = Traj_dataset.traj_train[:,-1].unsqueeze(1).to(device)

    train_idxs = torch.arange(initial_datasize, device=device)
    pool_idxs = torch.arange(initial_datasize, X.shape[0], device=device)

    X_train = X[train_idxs]
    Y_train = Y[train_idxs]

    X_pool = X[pool_idxs]

    ensemble = [train(X_train, Y_train) for _ in tqdm(range(ensemble_size))]

    results['datasize'].append(train_idxs.shape[0])
    rel_l2_list = [test(model)[1].mean().item() for model in ensemble]
    results['rel_l2'].append(torch.mean(torch.tensor(rel_l2_list)).item())
    print(f'{results["datasize"][-1]}, {results["rel_l2"][-1]}')

    for i in range(num_acquire):
        new_idxs = select(ensemble, X_train, X_pool, batch_acquire, selection_method='greedy', acquisition_function='entropy', device=device)
        # new_idxs = select_var(ensemble, X_pool, batch_acquire)

        new_idxs = new_idxs.to(device)
        print(new_idxs)
        # print(f'{len(new_idxs)=}')
        logical_new_idxs = torch.zeros(pool_idxs.shape[-1], dtype=torch.bool, device=device)
        logical_new_idxs[new_idxs] = True
        train_idxs = torch.cat([train_idxs, pool_idxs[logical_new_idxs]], dim=-1)
        pool_idxs = pool_idxs[~logical_new_idxs]

        X_train = X[train_idxs]
        Y_train = Y[train_idxs]

        X_pool = X[pool_idxs]

        ensemble = [train(X_train, Y_train) for _ in tqdm(range(ensemble_size))]

        results['datasize'].append(train_idxs.shape[0])
        rel_l2_list = [test(model)[1].mean().item() for model in ensemble]
        results['rel_l2'].append(torch.mean(torch.tensor(rel_l2_list)).item())
        print(f'{results["datasize"][-1]}, {results["rel_l2"][-1]}')
    
    return results

results = experiment_direct(initial_datasize=256, batch_acquire=32, num_acquire=1, device=device)

100%|██████████| 5/5 [02:01<00:00, 24.25s/it]


256, 0.9778550863265991


100%|██████████| 32/32 [02:03<00:00,  3.86s/it]


tensor([ 41, 428, 250, 759, 330, 612, 313, 681, 751, 634,  74, 616, 174, 495,
        141, 121,  11, 158, 172, 724,  24, 456, 318, 667, 548,  43, 130, 363,
        690, 524, 209, 734], device='cuda:6')


100%|██████████| 5/5 [02:17<00:00, 27.53s/it]


288, 0.9533389210700989


In [13]:
import torch
import numpy as np
from neuralop.models import FNO
from tqdm import tqdm

from acquisition.acquirers import select


def experiment_direct(initial_datasize=256, batch_acquire=32, num_acquire=1, device='cpu', **cfg):

    def train(X_train, Y_train):
        model = FNO(n_modes=(256, ), hidden_channels=64,
                        in_channels=1, out_channels=1)

        model = model.to(device)

        model.train()

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
        criterion = torch.nn.MSELoss()

        dataset = torch.utils.data.TensorDataset(X_train, Y_train)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for epoch in range(epochs):
            for x, y in dataloader:
                # x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                y_pred = model(x)
                loss = criterion(y_pred, y)
                loss.backward()
                optimizer.step()
            scheduler.step()
        
        return model

    def test(model):
        X_test = Traj_dataset.traj_test[:,0,:].unsqueeze(1).to(device)
        Y_test = Traj_dataset.traj_test[:,-1,:].unsqueeze(1).to(device)

        testset = torch.utils.data.TensorDataset(X_test, Y_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

        model.eval()
        
        Y_test_pred = []
        with torch.no_grad():
            for x, y in testloader:
                # x, y = x.to(device), y.to(device)
                y_pred = model(x)
                Y_test_pred.append(y_pred)
            Y_test_pred = torch.cat(Y_test_pred, dim=0).to(Y_test.device)
        
        metrics = compute_metrics(Y_test, Y_test_pred, d=1)
        return metrics

    def select_var(ensemble, X_pool, batch_acquire):
        for model in ensemble:
            model.eval()
        with torch.no_grad():
            y_preds = [model(X_pool) for model in ensemble]
            y_preds = torch.stack(y_preds, dim=0) # [ensemble_size, pool_size, ...]
            y_preds_var = y_preds.var(dim=0) # [pool_size, ...]
            y_preds_var = y_preds_var.mean(dim=tuple(range(1, y_preds_var.ndim))) # [pool_size]
        new_idxs = torch.argsort(y_preds_var, descending=True)[:batch_acquire]
        return new_idxs

    ensemble_size = cfg.get('ensemble_size', 5)

    results = {'datasize': [], 'rel_l2': []}

    X = Traj_dataset.traj_train[:,0].unsqueeze(1).to(device)
    Y = Traj_dataset.traj_train[:,-1].unsqueeze(1).to(device)

    train_idxs = torch.arange(initial_datasize, device=device)
    pool_idxs = torch.arange(initial_datasize, X.shape[0], device=device)

    X_train = X[train_idxs]
    Y_train = Y[train_idxs]

    X_pool = X[pool_idxs]

    ensemble = [train(X_train, Y_train) for _ in tqdm(range(ensemble_size))]

    results['datasize'].append(train_idxs.shape[0])
    rel_l2_list = [test(model)[1].mean().item() for model in ensemble]
    results['rel_l2'].append(torch.mean(torch.tensor(rel_l2_list)).item())
    print(f'{results["datasize"][-1]}, {results["rel_l2"][-1]}')

    for i in range(num_acquire):
        new_idxs = select(ensemble, X_train, X_pool, batch_acquire, selection_method='lcmd', acquisition_function='variance', device=device)
        # new_idxs = select_var(ensemble, X_pool, batch_acquire)

        new_idxs = new_idxs.to(device)
        print(new_idxs)
        # print(f'{len(new_idxs)=}')
        logical_new_idxs = torch.zeros(pool_idxs.shape[-1], dtype=torch.bool, device=device)
        logical_new_idxs[new_idxs] = True
        train_idxs = torch.cat([train_idxs, pool_idxs[logical_new_idxs]], dim=-1)
        pool_idxs = pool_idxs[~logical_new_idxs]

        X_train = X[train_idxs]
        Y_train = Y[train_idxs]

        X_pool = X[pool_idxs]

        ensemble = [train(X_train, Y_train) for _ in tqdm(range(ensemble_size))]

        results['datasize'].append(train_idxs.shape[0])
        rel_l2_list = [test(model)[1].mean().item() for model in ensemble]
        results['rel_l2'].append(torch.mean(torch.tensor(rel_l2_list)).item())
        print(f'{results["datasize"][-1]}, {results["rel_l2"][-1]}')
    
    return results

results = experiment_direct(initial_datasize=256, batch_acquire=32, num_acquire=1, device=device)

100%|██████████| 5/5 [02:00<00:00, 24.10s/it]


256, 0.9731435775756836
tensor([ 43, 428,  58, 612, 634, 391, 413, 342, 551, 753, 616, 121, 174, 396,
        188, 734, 170, 667, 250, 541, 587,  74, 219, 751, 172, 592, 377, 283,
        749,  66, 313,  11], device='cuda:6')


100%|██████████| 5/5 [02:14<00:00, 26.94s/it]


288, 0.9483126401901245
