In [13]:
from typing import List, Union, Optional

import torch
from torch import Tensor
import numpy as np

import logging
from os.path import exists


# TODO: allow for loading all three jet types together


class JetNet(torch.utils.data.Dataset):
    """
    PyTorch ``torch.utils.data.Dataset`` class for the JetNet dataset.

    Features, in order: ``[eta, phi, pt, mask]``.

    Will produce an iteratable of either the dataset alone, of shape
    ``[num_jets, num_particles, num_features]``, or a tuple of the dataset and jet-level features
    of each jet. Currently only the number of (non-zero-padded) particles per jet is available as
    a jet feature.

    If pt or hdf5 files are not found in the ``data_dir`` directory then:
    If ``num_particles <= 30``, JetNet is downloaded from https://zenodo.org/record/6302454;
    Else, JetNet150 is downloaded from https://zenodo.org/record/6302240

    Args:
        jet_type (str): 'g' (gluon), 't' (top quarks), or 'q' (light quarks).
        data_dir (str): directory which contains (or in which to download) dataset.
          Defaults to "./" i.e. the working directory.
        download (bool): download the dataset, even if the hdf5 file exists already.
          Defaults to False.
        num_particles (int): number of particles to use, has to be less than or equal to 150.
          Defaults to 30.
        normalize (bool): normalize features for training or not, using parameters defined below.
          Defaults to True.
        feature_norms (Union[float, List[float]]): max value to scale each feature to.
          Can either be a single float for all features, or a list of length ``num_features``.
          Defaults to 1.0.
        feature_shifts (Union[float, List[float]]): after scaling, value to shift feature by.
          Can either be a single float for all features, or a list of length ``num_features``.
          Defaults to 0.0.
        use_mask (bool): Defaults to True.
        train (bool): whether for training or testing. Defaults to True.
        train_fraction (float): fraction of data to use as training - rest is for testing.
          Defaults to 0.7.
        num_pad_particles (int): how many out of ``num_particles`` should be zero-padded.
          Defaults to 0.
        use_num_particles_jet_feature (bool): Store the # of particles in each jet as a
          jet-level feature. *Only works if using mask* i.e. if ``use_mask=True``. Defaults to True.
        noise_padding (bool): instead of 0s, pad extra particles with Gaussian noise.
          Only works if using mask. Defaults to False.
    """

    _num_non_mask_features = 3

    # normalization used for ParticleNet training
    _fpnd_feature_maxes = [1.6211985349655151, 0.520724892616272, 0.8934717178344727, 1.0]
    _fpnd_feature_norms = 1.0
    _fpnd_feature_shifts = [0.0, 0.0, -0.5, 0.0]

    def __init__(
        self,
        jet_type: str,
        data_dir: str = "./",
        download: bool = False,
        num_particles: int = 30,
        normalize: bool = True,
        feature_norms: List[float] = [1.0, 1.0, 1.0, 1.0],
        feature_shifts: List[float] = [0.0, 0.0, -0.5, -0.5],
        use_mask: bool = True,
        num_pad_particles: int = 0,
        use_num_particles_jet_feature: bool = True,
        noise_padding: bool = False,
    ):

        assert jet_type in ["top", "qcd"], "Invalid jet type"
        dataset_type=input("Enter test for testing, train for training and val for validation"),
        if dataset_type=="train": 
         data_fraction = ' '
        
        elif dataset_type=='test':
              data_fraction = ' '
        else:
            data_fraction = ' '
        
        self.feature_norms = feature_norms
        self.feature_shifts = feature_shifts
        self.use_mask = use_mask
        # in the future there'll be more jet features such as jet pT and eta
        self.use_jet_features = use_num_particles_jet_feature and self.use_mask


        # Use JetNet150 if ``num_particles`` > 30
        pt_file = f"{data_dir}/{jet_type}{'200'}.pt"

        if not exists(pt_file) or download:
            self.download_and_convert_to_pt(data_dir, jet_type)

        logging.info("Loading dataset")
        dataset = self.load_dataset(pt_file, num_particles, num_pad_particles, use_mask)
        self.num_particles = num_particles if num_particles > 0 else dataset.shape[1]

        if self.use_jet_features:
            jet_features = self.get_jet_features(dataset, use_num_particles_jet_feature)

        logging.info(f"Loaded dataset {dataset.shape = }")
     
        tcut = int(len(dataset) * train_fraction)

        self.data = dataset[:tcut] if train else dataset[tcut:]
        if self.use_jet_features:
            self.jet_features = jet_features[:tcut] if train else jet_features[tcut:]

        logging.info("Dataset processed")

    def download_and_convert_to_pt(self, data_dir: str, jet_type: str):
        """
        Download jet dataset and convert and save to pytorch tensor.

        Args:
            data_dir (str): directory in which to save file.
            jet_type (str): jet type to download, out of ``['g', 't', 'q']``.
            use_150 (bool): download JetNet150 or JetNet. Defaults to False.

        """
        import os

        os.system(f"mkdir -p {data_dir}")
        hdf5_file = f"{data_dir}/{jet_type}{'200'}.hdf5"

        if not exists(hdf5_file):
            logging.info(f"Downloading {jet_type} jets hdf5")
            self.download(jet_type, hdf5_file)

        logging.info(f"Converting {jet_type} jets hdf5 to pt")
        self.hdf5_to_pt(data_dir, jet_type, hdf5_file)

    def download(self, jet_type: str, hdf5_file: str):
        """
        Downloads the ``jet_type`` jet hdf5 from Zenodo and saves it as ``hdf5_file``.

        Args:
            jet_type (str): jet type to download, out of ``['g', 't', 'q']``.
            hdf5_file (str): path to save hdf5 file.
            use_150 (bool): download JetNet150 or JetNet. Defaults to False.

        """
        import requests
        import sys

        record_id = 2603256 
        records_url = f"https://zenodo.org/api/records/{record_id}"
        r = requests.get(records_url).json()
        key = f"{jet_type}{'200' }.hdf5"

        # finding the url for the particular jet type dataset
        file_url = next(item for item in r["files"] if item["key"] == key)["links"]["self"]
        logging.info(f"{file_url = }")

        # modified from https://sumit-ghosh.com/articles/python-download-progress-bar/
        with open(hdf5_file, "wb") as f:
            response = requests.get(file_url, stream=True)
            total = response.headers.get("content-length")

            if total is None:
                f.write(response.content)
            else:
                downloaded = 0
                total = int(total)

                print("Downloading dataset")
                for data in response.iter_content(chunk_size=max(int(total / 1000), 1024 * 1024)):
                    downloaded += len(data)
                    f.write(data)
                    done = int(50 * downloaded / total)
                    sys.stdout.write(
                        "\r[{}{}] {:.0f}%".format(
                            "█" * done, "." * (50 - done), float(downloaded / total) * 100
                        )
                    )
                    sys.stdout.flush()

        sys.stdout.write("\n")

    def hdf5_to_pt(self, data_dir: str, jet_type: str, hdf5_file: str):
        """
        Converts and saves downloaded hdf5 file to pytorch tensor.

        Args:
            data_dir (str): directory in which to save file.
            jet_type (str): jet type to download, out of ``['g', 't', 'q']``.
            hdf5_file (str): path to hdf5 file.
            use_150 (bool): download JetNet150 or JetNet. Defaults to False.

        """
        import h5py

        pt_file = f"{data_dir}/{jet_type}{'200'}.pt"

        with h5py.File(hdf5_file, "r") as f:
            torch.save(Tensor(np.array(f["particle_features"])), pt_file)

    def load_dataset(
        self, pt_file: str, num_particles: int, num_pad_particles: int = 0, use_mask: bool = True
    ) -> Tensor:
        """
        Load the dataset, optionally padding the particles.

        Args:
            pt_file (str): path to dataset .pt file.
            num_particles (int): number of particles per jet to load
              (has to be less than the number per jet in the dataset).
            num_pad_particles (int): out of ``num_particles`` how many are to be zero-padded.
              Defaults to 0.
            use_mask (bool): keep or remove the mask feature. Defaults to True.

        Returns:
            Tensor: dataset tensor of shape ``[num_jets, num_particles, num_features]``.

        """
        dataset = torch.load(pt_file).float()

        # only retain up to ``num_particles``,
        # subtracting ``num_pad_particles`` since they will be padded below
        if 0 < num_particles - num_pad_particles < dataset.shape[1]:
            dataset = dataset[:, : num_particles - num_pad_particles, :]

        # pad with ``num_pad_particles`` particles
        if num_pad_particles > 0:
            dataset = torch.nn.functional.pad(dataset, (0, 0, 0, num_pad_particles), "constant", 0)

        if not use_mask:
            # remove mask feature from dataset if not needed
            dataset = dataset[:, :, : self._num_non_mask_features]

        return dataset

    def get_jet_features(self, dataset: Tensor, use_num_particles_jet_feature: bool) -> Tensor:
        """
        Returns jet-level features. `Will be expanded to include jet pT and eta.`

        Args:
            dataset (Tensor):  dataset tensor of shape [N, num_particles, num_features],
              where the last feature is the mask.
            use_num_particles_jet_feature (bool): `Currently does nothing,
              in the future such bools will specify which jet features to use`.

        Returns:
            Tensor: jet features tensor of shape [N, num_jet_features].

        """
        jet_num_particles = (torch.sum(dataset[:, :, -1], dim=1) / self.num_particles).unsqueeze(1)
        logging.debug("{num_particles = }")
        return jet_num_particles

    @classmethod

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.jet_features[idx] if self.use_jet_features else self.data[idx]
