In [1]:
from tempfile import NamedTemporaryFile
import numpy as np
import uproot
import awkward as ak
import torch
from torch.utils.data import DataLoader
from saja.dataset import JetPartonAssignmentDataset

np.random.seed(1337)
torch.manual_seed(1337)

# Generate toy dataset

In [2]:
num_events = 1000
min_num_jets = 6
mode_num_jets = 8
max_num_jets = 16

treepath = 'tree'
data_branches = ['pt', 'eta', 'phi', 'mass']
target_branch = 'target'

In [3]:
file = NamedTemporaryFile()

In [4]:
root_file = uproot.writing.recreate(file)
num_jets = np.random.poisson(lam=mode_num_jets, size=(num_events, )).clip(min_num_jets, max_num_jets)

# generate fake jets
branches = {each: ak.Array([np.random.randn(each) for each in num_jets]) for each in data_branches}
# jet-parton matching information
branches[target_branch] = ak.Array([np.random.randint(0, 4, each) for each in num_jets])

root_file[treepath] = branches
del branches

root_file[treepath].show()
# do not cloese `root_file`. closing `root_file` also closes `file`.

name                 | typename                 | interpretation                
---------------------+--------------------------+-------------------------------
npt                  | int32_t                  | AsDtype('>i4')
pt                   | double[]                 | AsJagged(AsDtype('>f8'))
neta                 | int32_t                  | AsDtype('>i4')
eta                  | double[]                 | AsJagged(AsDtype('>f8'))
nphi                 | int32_t                  | AsDtype('>i4')
phi                  | double[]                 | AsJagged(AsDtype('>f8'))
nmass                | int32_t                  | AsDtype('>i4')
mass                 | double[]                 | AsJagged(AsDtype('>f8'))
ntarget              | int32_t                  | AsDtype('>i4')
target               | int64_t[]                | AsJagged(AsDtype('>i8'))


# Read toy dataset

In [5]:
dataset = JetPartonAssignmentDataset(
    path=file.name,
    treepath=treepath,
    data_branches=data_branches,
    target_branch=target_branch)

Total = 1000, Processed: 1000 (100.00 %): : 0it [00:00, ?it/s]


The dataset consists of a list of tuples, each corresponding to an event. Each event has a different number of events.

In [6]:
print(f'{len(dataset._examples)=}')
dataset._examples[:3]

len(dataset._examples)=1000


[(tensor([[-0.6776,  0.5942,  0.7488,  0.0574],
          [-1.4115,  0.1182, -0.3210, -0.1245],
          [ 0.7120, -0.6514, -0.5886,  0.3052],
          [-1.2425,  0.9872,  1.7782,  0.4447],
          [ 2.0905,  0.2118,  1.3702, -1.2335],
          [ 0.2330,  0.4097, -0.3645,  0.3047]]),
  tensor([2, 0, 3, 3, 1, 1])),
 (tensor([[-0.0620, -1.4998,  3.8013, -0.3047],
          [-1.3476,  0.2066,  1.3505, -1.9128],
          [ 1.7316,  0.0252, -0.2951,  0.5309],
          [ 0.0986,  0.4875, -1.2266, -0.4525],
          [-0.4580, -1.1613, -0.8565, -0.0262],
          [-0.1684, -0.1613,  0.9772,  1.9362],
          [ 0.2739,  0.0898, -0.3537,  0.0393],
          [-0.4079,  0.4976, -0.0230, -0.2425],
          [-0.2222, -0.5419, -0.6185, -1.7311],
          [-0.0950, -0.2303, -0.3976,  2.2172]]),
  tensor([2, 2, 0, 1, 2, 3, 2, 0, 0, 1])),
 (tensor([[ 0.3444, -0.2417, -1.5746, -1.0111],
          [-0.4233,  0.7770, -0.4817, -0.7660],
          [ 0.0049, -0.9113, -0.6292,  0.1322],
          

In [7]:
for idx in range(3):
    input, target = dataset[idx]
    print(f'{idx=}: {input.shape=}, {target.shape=}')
    print(f'{input=}')
    print(f'{target=}')
    print()

idx=0: input.shape=torch.Size([6, 4]), target.shape=torch.Size([6])
input=tensor([[-0.6776,  0.5942,  0.7488,  0.0574],
        [-1.4115,  0.1182, -0.3210, -0.1245],
        [ 0.7120, -0.6514, -0.5886,  0.3052],
        [-1.2425,  0.9872,  1.7782,  0.4447],
        [ 2.0905,  0.2118,  1.3702, -1.2335],
        [ 0.2330,  0.4097, -0.3645,  0.3047]])
target=tensor([2, 0, 3, 3, 1, 1])

idx=1: input.shape=torch.Size([10, 4]), target.shape=torch.Size([10])
input=tensor([[-0.0620, -1.4998,  3.8013, -0.3047],
        [-1.3476,  0.2066,  1.3505, -1.9128],
        [ 1.7316,  0.0252, -0.2951,  0.5309],
        [ 0.0986,  0.4875, -1.2266, -0.4525],
        [-0.4580, -1.1613, -0.8565, -0.0262],
        [-0.1684, -0.1613,  0.9772,  1.9362],
        [ 0.2739,  0.0898, -0.3537,  0.0393],
        [-0.4079,  0.4976, -0.0230, -0.2425],
        [-0.2222, -0.5419, -0.6185, -1.7311],
        [-0.0950, -0.2303, -0.3976,  2.2172]])
target=tensor([2, 2, 0, 1, 2, 3, 2, 0, 0, 1])

idx=2: input.shape=torch.Size(

# Collate events with different numbers of jets

Since PyTorch's default collate function can not process tensors with different sizes, we need to use the `collate` classmethod of `JetPartonAssignmentDataset`. `JetPartonAssignmentDataset.collate` takes a list of events and pads them with zeros so that they have the same length. `collate` also creates a mask indicating which rows are real jets and not zero pads.

In [8]:
batch = dataset.collate([dataset[idx] for idx in range(10)])

In [9]:
for idx in range(3):
    print(f'* {idx=}')
    print(f'    * data (jets): {batch.data[idx]}')
    print(f'    * target (jet-parton matching): {batch.target[idx]}')
    print(f'    * mask: {batch.mask[idx]}')
    print()

* idx=0
    * data (jets): tensor([[-0.6776,  0.5942,  0.7488,  0.0574],
        [-1.4115,  0.1182, -0.3210, -0.1245],
        [ 0.7120, -0.6514, -0.5886,  0.3052],
        [-1.2425,  0.9872,  1.7782,  0.4447],
        [ 2.0905,  0.2118,  1.3702, -1.2335],
        [ 0.2330,  0.4097, -0.3645,  0.3047],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])
    * target (jet-parton matching): tensor([2, 0, 3, 3, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    * mask: tensor([ True,  True,  True,  True,  True,  True, False, False, False, False,
        False, False, False, False, False])

* idx=1
    * data (jets): tensor([[-0.0620, -1.4998,  3.

we can also pass `collate` to `DataLaoder`. Each batch has a different length.

In [10]:
data_loader = DataLoader(dataset, batch_size=32, collate_fn=dataset.collate)
for idx, batch in enumerate(data_loader):
    print(f'{batch.data.shape=}')
    
    if idx > 10:
        break

batch.data.shape=torch.Size([32, 15, 4])
batch.data.shape=torch.Size([32, 13, 4])
batch.data.shape=torch.Size([32, 16, 4])
batch.data.shape=torch.Size([32, 16, 4])
batch.data.shape=torch.Size([32, 16, 4])
batch.data.shape=torch.Size([32, 12, 4])
batch.data.shape=torch.Size([32, 14, 4])
batch.data.shape=torch.Size([32, 12, 4])
batch.data.shape=torch.Size([32, 14, 4])
batch.data.shape=torch.Size([32, 12, 4])
batch.data.shape=torch.Size([32, 16, 4])
batch.data.shape=torch.Size([32, 13, 4])
