In [1]:
import torch 
import pandas as pd, numpy as np
import os
import sys
sys.path.append('/global/cfs/cdirs/m3443/usr/pmtuan/hadsim')
import yaml
from data.utils import *
from data.datamodule import CartesianDataModule

DATA_PATH = "/global/cfs/cdirs/m3443/usr/pmtuan/HadronicMCData/train_data_2_particles_processed/"
data_files = os.listdir(DATA_PATH)

In [2]:
hparams = {
    'n_particle': 1,
    'max_etot': 100000,
    'min_etot': 10000,
    'gen_hidden_activation': 'LeakyReLU',
    'dis_hidden_activation': 'LeakyReLU',
    'gen_output_activation': 'LeakyReLU',
    'dis_output_activation': 'Sigmoid',
    'gen_batchnorm': True,
    'dis_batchnorm': True,
    'gen_dropout_rate': 0.5,
    'dis_dropout_rate': 0.,
    'nb_gen_layer': 10,
    'nb_dis_layer': 10,
    'gen_lr': 0.001,
    'dis_lr': 0.001,
    
    'sort_by': 0,
    'batch_size': 1024,
    'input_dir': '/global/cfs/cdirs/m3443/usr/pmtuan/HadronicMCData/2_particle_fstate',
    'hidden':  128,
    
    'noise_dim': 4,
    'cond_dim': 1,
    'gen_in': 4,
    'gen_dim': 4,
    
}

In [4]:
p12, m12, outcome = read_data(os.path.join(DATA_PATH, data_files[0]))
outcome = sort_particle(outcome, 0)

In [21]:
import os
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split

import logging

class MCDataModule(LightningDataModule):
    def __init__(self, hparams) -> None:
        super().__init__()

        self.save_hyperparameters(hparams)
        self.trainset, self.valset, self.testset = None, None, None
    
    def convert_dataset(self, dataset):
        return torch.utils.data.TensorDataset(* dataset )

    def setup(self, stage='fit'):

        logging.info('Reading data')

        datasets = {
            name: os.path.join(self.hparams['input_dir'], name) for name in ['train', 'val', 'test']
        }

        if stage=='fit':

            self.trainset, self.valset = [
                [torch.tensor(ds) for ds in read_data([ os.path.join(datasets[name], f) for f in  os.listdir(datasets[name])], max_etot=self.hparams.get('max_etot'), prog_bar=True)   ]
                for name in ['train', 'val']
            ]
            self.trainset, self.valset = self.transform(self.trainset), self.transform(self.valset)
        
        if stage=='test':
            self.testset = [torch.tensor(ds) for ds in read_data([ os.path.join(datasets['test'], f) for f in  os.listdir(datasets['test'])], max_etot=self.hparams.get('max_etot'), prog_bar=True) ]
            self.testset = self.transform(self.testset)

    def train_dataloader(self):
        return DataLoader(
            self.convert_dataset(self.trainset), 
            batch_size=self.hparams['batch_size'])

    def val_dataloader(self):
        return DataLoader(
            self.convert_dataset(self.valset), 
            batch_size=self.hparams['batch_size'])

    def test_dataloader(self):
        return DataLoader(
            self.convert_dataset(self.testset), 
            batch_size=self.hparams['batch_size'])
    
    def transform(self, dataset):
        etot, m, outcome = dataset
        outcome = outcome[:, : self.hparams['n_particle'] ].reshape((outcome.shape[0], -1))
        outcome /= etot
        etot /= 1000.
        m /= 1000.
        return [etot, m, outcome]

In [3]:
data_module = CartesianDataModule(hparams)
data_module.setup()

100%|██████████| 68/68 [00:40<00:00,  1.67it/s]
100%|██████████| 33/33 [00:21<00:00,  1.56it/s]


In [5]:
for i in data_module.train_dataloader():
    print(i)
    break

[tensor([[20.0000],
        [70.0000],
        [50.0000],
        ...,
        [50.0000],
        [20.0000],
        [60.0000]], dtype=torch.float64), tensor([[1.0778],
        [1.0778],
        [1.0779],
        ...,
        [1.0778],
        [1.0779],
        [1.0778]], dtype=torch.float64), tensor([[ 5.0108e-01, -8.4955e-03,  5.0653e-03, -4.9878e-01],
        [ 5.0009e-01, -1.6370e-04,  2.3869e-03, -4.9990e-01],
        [ 5.0017e-01,  8.6553e-04, -2.2275e-03, -4.9981e-01],
        ...,
        [ 5.0017e-01, -3.1510e-03, -1.7474e-03, -4.9981e-01],
        [ 5.0108e-01, -1.7072e-02,  2.4805e-03, -4.9858e-01],
        [ 5.0012e-01, -1.5018e-03,  4.1806e-03, -4.9986e-01]],
       dtype=torch.float64)]
