TODO: Split data in train, validation, test 

In [20]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
data_dir = "../nuclear-fusion/data/preprocessed/"
# omega is the size of the time block that we give to the model
omega = 20
batch_size = 16

In [22]:
# Import dataset
from dataset import SimulationDataset
simulationdataset = SimulationDataset(data_dir, omega)

In [23]:
# Import sampler that uses Time Adjusted Sampling as described in Appendix B
from sampler import TimeAdjustedSampler
sampler = TimeAdjustedSampler(simulationdataset, batch_size=batch_size)

In [24]:
# Setup dataloader
import torch

def custom_collate(batch):
    X, Y = [], []

    for x,f,y in batch:
        X.append(torch.cat((x, f), dim=2))
        Y.append(y)

    X = torch.stack(X)
    Y = torch.stack(Y)

    return (X,Y)

from torch.utils.data import DataLoader 
dataloader = DataLoader(simulationdataset, batch_sampler=sampler, collate_fn=custom_collate)

In [25]:
for batch in dataloader:
    break

Batches are a list. Each list contains **batch_size (n)** amount of samples. \
So, `batches = [batch_1, batch_2, ...., batch_n]`. \
Each batch is a list of three items. `[X, F, Y]`. \
Where, 
- `X` is the state of the tokamak from time `t - w` to time `t`, i.e. `t-w:t`. Size: (`w`, 500, 6)
- `F` is the forcing from time `t-w` to time `t + w`, i.e. `t-w:t+w`. (2*`w`, 500, 6)
- `Y` is the state of the tokamak from time `t` to time `t + w`, i.e. `t:t+w`. Size: (`w`, 500, 6)

`w` is the size of the time blocks and the amount of timesteps we input in the model and expect the model to output.

**Sampling:** \
Based on `w` we sample the dataset. \
We first sample the length of the distribution we want to sample from. \
Then, we sample a simulation uniformly from all simulations with that given lenght of timesteps. \

There are more combinations of timeblocks in simulations with larger length. Therefore, the chance that a certain block gets selected given simulations with large numer of timesteps diminished for large number of timesteps - we have chosen to give higher chance to sample a simulation with longer timesteps as opposed to a simulation with fewer timesteps.

In [27]:
# X
print("X shape: ", batch[0].shape)

# Y
print("Y shape: ", batch[1].shape)

X shape:  torch.Size([16, 20, 500, 8])
Y shape:  torch.Size([16, 20, 500, 6])
