# 第四节：进阶版 Transformer 实战

前面的章节中我们已经一步步实现了用 PyTorch 自带 Transformer 模块搭建基础的 Transformer 模型，在高能物理分析的分类任务中进行了实战。现在，我们将探索应如何按照高能物理数据的格式制作更好的 Transformer，取得更优越的性能。

这将使我们引出 Particle Transformer (ParT, [H. Qu *et al*. ICML 2022](https://arxiv.org/abs/2202.03772))，高能物理领域里目前较成熟的，处理 particle-format data 的 Transformer 结构。它也是2025年之前作喷注鉴别最好的模型。

本节分两部分：

 - 首先，我们指出加入 pairwise particle masses 的妙处（即 ParT 的精髓），并基于上节最后的例子进行修改，实现一个简单版本的ParT。这将主要用作教学案例。
 - 其次，我们提供一个直接调用 ParT 模型（以及此前更出名的 ParticleNet, ParticleNeXt 模型，都在[这个文件夹内](https://github.com/hqucms/weaver-core/tree/main/weaver/nn/model)）的代码样例。根据此前收集到的建议，这或许对大家后续真正使用这些模型有帮助：希望可以作为一个范本，方便大家进行修改以适配到自己的任务中。
 <!-- - 在本章的最后，我们加入一个处理 token 数目达到 o(50-100) 的复杂任务，喷注鉴别 (jet tagging)。我们将使用较早的 Top Landscape dataset ([G. Kasieczka *et al*. SciPost Phys. 7, 014 (2019)](https://arxiv.org/abs/1902.09914)) 进行举例，接入上一步制作的 ParticleNet / ParticleNeXt 接口进行训练。这也可以作为一个处理复杂任务的范本，对大家后续使用这些模型有帮助。 -->

**如果在 Google Colab 中使用，请首先运行下面的命令安装相关的python包。**

In [None]:
! pip install uproot==4.2.2 pandas==1.3.4 mplhep==0.3.12 weaver-core
! pip install pytorch_lightning

## 从 “plain Transformer” 到 Particle Transformer (ParT)

我们上节已经实现了“平淡”的Transformer：

 - 把每个物理对象（jet，轻子和MET）当做一个token，让它们在 Transformer block 中利用 attention 机制“相互交流”，传递信息，并通过 FFN 来更新自己的潜空间表示。
 - 我们在最开始定义了一个专门的可训练的 class token，让他和其它真实的 token 一起交流。最后进行分类时，用它的最后的潜空间向量得到分类的 logits。

虽然我们的task都是简单的分类任务，不足以体现出 Transformer 在模型参量增大、数据规模增大时强大的 scaling 能力，但是在工程学实践中，当处理及其复杂的输入变量和分类问题时，用现在的 Transformer 模型还有性能提升的空间。如何继续修改 Transformer 网络的设计，让它更“聪明”？一个工程学经验是，需要认识到我们处理的数据本身的内禀性质，想办法在网络设计中融入这种性质。

这里，我们把数据所具有的内禀属性——对称性，和归纳偏置联系起来。

对对称性的广义描述是，当数据经历某种形式的变换时，它的性质大概保持不变。比如（见下图）

 - 把图片中一个元素从左边移到右边，图片本身性质大概保持不变。
 - 把粒子构成的数据中任意两个粒子记录进行交换，数据本身性质保持严格不变（也即，粒子在数据中本身是无序的）。
 - 把粒子构成的数据整体进行 Lorentz 变换，其数据性质也大概保持不变。

<img src="figures/inherit-symm.png" alt="image" width=600/>

所以我们可以想办法把这些数据内禀的特征作为“知识”来教给网络。一种方法是，在设计时给网络一定的约束（给网络inductive bias），使得当数据经历这些变换时，网络本身的部分或全部神经元输出保持不变，这就增强了网络对这种变换的稳健性，让网络大概“知道了数据有这样的内禀属性，经历变换后其性质大概保持不变”。

因此，我们可以工程上解释为什么CNN在早期视觉领域非常有优势；粒子物理为什么从CNN/RNN网络转移到了满足粒子交换不变的GNN网络以获得更好性能。

对于最后一点，我们可以设计增强Lorentz对称性的网络来提升其性能（这在 [CL *et al*. PRD. 109, 056003 (2024)](https://arxiv.org/abs/2208.07814) 中有所讨论）。最简单的方法就是，给网络的输入增添很多Lorentz标量，因为这些标量是Lorentz不变量，能够增强网络的稳健性。对于包含 $N$ 粒子记录的数据，可以构造 $N^2$ 个粒子对的不变质量，它是所有可构造的洛伦兹标量的一组基底。因此，可以把这 $N^2$ 个初始变量经过 embed 之后（维度为 $(N,N,d)$），加入attention score中偏移它，从而影响attention weight。

以上可以认为是Particle Transformer的基本设计理念，以解释它性能的优势来源（ParT除了加入不变质量，还有额外3个特征输入，它们也具备一些变换的不变性；以及，ParT还做了大量的工程学优化，都有助于提升其性能）

<img src="figures/part-pairwise-feat.png" alt="image" width=600/>

我们简单地改造上一节最后做的Transformer，这里仅多做三件事：

 - 计算 $N^2$ 个粒子对的不变质量
 - 用一个额外的 embeder，将其嵌入到 $d_{\rm pair}=4$ 维的潜空间中 dim=$(N,N,d_{\rm pair}=4)$
 - 把它注入作为每层 Transformer block 的attention计算中，分别加到4个head的attention score上（各自是$(N,N)$的矩阵）


**如果在 Google Colab 中使用，请首先运行下面的命令下载所需数据集。**

In [None]:
! if [[ "$(hostname)" != *pku.edu.cn* && "$(hostname)" != *lxlogin* ]]; then \
    wget https://coli.web.cern.ch/coli/share/cmschina/ML/dihiggs_ntuples/hh2b2w.root; \
    wget https://coli.web.cern.ch/coli/share/cmschina/ML/dihiggs_ntuples/ttbar.root; \
fi

加载上一节中的 di-Higgs vs ttbar 的数据并对特征进行标准化：

In [1]:
# Load data

import os
# Determine the base directory and the remote git path
hostname = os.uname()[1]
if 'pku.edu.cn' in hostname: # on PKU cluster
    basedir = '/data/pubfs/pku_visitor/public_write/ML/dihiggs_ntuples/'
elif hostname.startswith('lxlogin'): # on IHEP lxlogin
    basedir = '/scratchfs/cms/licq/cmschina/ML/dihiggs_ntuples/'
else:
    basedir = '.'

import uproot
import pandas as pd

dihiggs = uproot.concatenate(f"{basedir}/hh2b2w.root:tree",library="pd")
ttbar = uproot.concatenate(f"{basedir}/ttbar.root:tree",library="pd")

df_raw = pd.concat([dihiggs, ttbar], axis=0)
df_raw['label'] = df_raw['is_sig'].astype(int) # 定义一个int类型的label，指示是sig还是bkg
df_raw

Unnamed: 0,event,is_sig,is_bkg,bjet1_pt,bjet1_eta,bjet1_phi,bjet1_eratio,bjet1_mass,bjet1_ncharged,bjet1_nneutrals,...,lep1_charge,lep1_type,lep2_pt,lep2_phi,lep2_eta,lep2_charge,lep2_type,met,met_phi,label
0,0,True,False,191.402328,-0.291164,-0.979666,0.289940,16.695002,12,6,...,1,0,17.508558,1.165262,-0.711694,-1,0,233.613846,2.300718,1
1,1,True,False,68.452507,-1.215415,2.387005,0.492361,15.253864,10,10,...,1,1,50.567219,-1.089323,0.409210,-1,1,102.682526,-0.004101,1
2,2,True,False,152.725433,0.449958,-1.470235,1.215402,12.211457,12,8,...,1,1,43.384056,0.584961,0.624317,-1,0,57.796436,0.935978,1
3,3,True,False,107.360390,0.124028,-1.175005,0.000000,6.060685,6,4,...,1,1,26.664654,2.175398,0.556468,-1,1,65.674774,1.417736,1
4,4,True,False,97.040932,2.160315,-2.940341,0.000000,10.475123,10,14,...,-1,1,29.825296,0.306789,0.310836,1,0,94.994255,-3.043913,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
149995,149995,False,True,169.980515,-0.147777,2.985398,0.209411,27.176857,12,8,...,-1,1,92.388741,-0.697610,0.483596,1,1,131.839859,-0.000835,0
149996,149996,False,True,47.205002,1.002193,-1.613720,0.000000,9.505911,10,8,...,-1,0,62.515564,0.693190,0.890577,1,0,75.465065,1.878048,0
149997,149997,False,True,63.940311,0.381228,0.494209,0.166411,8.920477,10,8,...,1,0,34.591877,2.737549,-0.180914,-1,1,59.409447,0.176429,0
149998,149998,False,True,89.012772,-0.562279,-0.638000,0.648749,13.602261,11,12,...,-1,1,20.882822,-3.140386,-1.086141,1,0,44.600636,-1.177096,0


In [2]:
# Initialize the StandardScaler
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()

input_columns = ['bjet1_pt', 'bjet1_eta', 'bjet1_phi',
       'bjet1_eratio', 'bjet1_mass', 'bjet1_ncharged', 'bjet1_nneutrals',
       'bjet2_pt', 'bjet2_eta', 'bjet2_phi', 'bjet2_eratio', 'bjet2_mass',
       'bjet2_ncharged', 'bjet2_nneutrals', 'lep1_pt', 'lep1_phi', 'lep1_eta',
       'lep1_charge', 'lep1_type', 'lep2_pt', 'lep2_phi', 'lep2_eta',
       'lep2_charge', 'lep2_type', 'met', 'met_phi']

# Fit and transform the DataFrame
df = pd.DataFrame(scaler.fit_transform(df_raw[input_columns]), columns=[c + '_trans' for c in input_columns], index=df_raw.index)
df = pd.concat([df, df_raw], axis=1)
df

Unnamed: 0,bjet1_pt_trans,bjet1_eta_trans,bjet1_phi_trans,bjet1_eratio_trans,bjet1_mass_trans,bjet1_ncharged_trans,bjet1_nneutrals_trans,bjet2_pt_trans,bjet2_eta_trans,bjet2_phi_trans,...,lep1_charge,lep1_type,lep2_pt,lep2_phi,lep2_eta,lep2_charge,lep2_type,met,met_phi,label
0,0.566009,-0.216940,-0.538931,-0.142020,0.165628,0.430743,-0.663944,-0.127634,-1.025722,-0.312706,...,1,0,17.508558,1.165262,-0.711694,-1,0,233.613846,2.300718,1
1,-0.692736,-0.913453,1.316571,-0.140433,0.032240,0.037705,0.213418,-0.177743,-0.332485,-1.051495,...,1,1,50.567219,-1.089323,0.409210,-1,1,102.682526,-0.004101,1
2,0.170040,0.341568,-0.809302,-0.134763,-0.249359,0.430743,-0.225263,-0.270861,0.921192,-1.382473,...,1,1,43.384056,0.584961,0.624317,-1,0,57.796436,0.935978,1
3,-0.294402,0.095947,-0.646589,-0.144293,-0.818660,-0.748372,-1.102625,-0.959167,1.114436,-0.866333,...,1,1,26.664654,2.175398,0.556468,-1,1,65.674774,1.417736,1
4,-0.400051,1.630488,-1.619534,-0.144293,-0.410070,0.037705,1.090781,-0.542826,1.505292,-0.814446,...,-1,1,29.825296,0.306789,0.310836,1,0,94.994255,-3.043913,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
149995,0.346695,-0.108884,1.646368,-0.142651,1.135805,0.430743,-0.225263,0.927557,0.108459,-1.172372,...,-1,1,92.388741,-0.697610,0.483596,1,1,131.839859,-0.000835,0
149996,-0.910265,0.757730,-0.888382,-0.144293,-0.499778,0.037705,-0.225263,-0.592341,2.543271,0.839254,...,-1,0,62.515564,0.693190,0.890577,1,0,75.465065,1.878048,0
149997,-0.738931,0.289773,0.273378,-0.142988,-0.553964,0.037705,-0.225263,-0.640201,-0.014030,-0.854679,...,1,0,34.591877,2.737549,-0.180914,-1,1,59.409447,0.176429,0
149998,-0.482243,-0.421251,-0.350626,-0.139206,-0.120629,0.234224,0.652100,-0.176126,-0.564892,0.026914,...,-1,1,20.882822,-3.140386,-1.086141,1,0,44.600636,-1.177096,0


我们首先制作新的 PyTorch dataset。我们考虑在从 Dataset 取数的时候，就已经可以拿到5个token（2个b-jet，2个lep，1个MET），获得它们各自的features（分别是7/5/2维）、vectors（4动量，都是4维）。

> 此外，还额外输出了points变量（在η-φ坐标的位置）。定义这个新的 PyTorch Dataset 有额外的好处，就是和 ParT 和 ParticleNet 的标准模型文件接受的输入一致了，可以方便我们在下面一个小节直接用这个 Dataset。这个 Dataset 输出的格式是 (points, features, vectors)

In [3]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch import nn
import torchmetrics

def _p4_from_ptetaphim(pt, eta, phi, mass):
    import vector
    vector.register_awkward()
    return vector.zip({'pt': pt, 'eta': eta, 'phi': phi, 'm': mass})

class DataFrameTokenizedDataset(Dataset):
    def __init__(self, dataframe, input_columns, input_columns_raw, target_column):
        print('Initializing DataFrameTokenizedDataset...')
        self.dataframe = dataframe
        self.targets = dataframe[target_column].values

        # With inputs, we define three sets of tokenized data: features, vectors, and points
        inputs = dataframe[input_columns].values
        inputs_raw = dataframe[input_columns_raw].values
        N = len(dataframe)

        # Define tokenized input "features" with dim (N, L=5, d=7)
        features = [
            inputs[:,0:7], # dim (N, 7)
            inputs[:,7:14],
            np.concatenate([inputs[:,14:19], np.zeros([N, 2])], axis=1),
            np.concatenate([inputs[:,19:24], np.zeros([N, 2])], axis=1),
            np.concatenate([inputs[:,24:26], np.zeros([N, 5])], axis=1),
        ]
        self.features = torch.tensor(np.stack(features, axis=1), dtype=torch.float32)
        print('  features:', self.features.shape)

        # Define tokenized "vectors" with dim (N, L=5, d=4), d=4 for (px, py, pz, energy)
        zeros = np.zeros(N, dtype=np.float32)
        pt = np.stack([inputs_raw[:,0], inputs_raw[:,7], inputs_raw[:,14], inputs_raw[:,19], inputs_raw[:,24]], axis=-1) # dim (N, 5)
        eta = np.stack([inputs_raw[:,1], inputs_raw[:,8], inputs_raw[:,16], inputs_raw[:,21], zeros], axis=-1)
        phi = np.stack([inputs_raw[:,2], inputs_raw[:,9], inputs_raw[:,15], inputs_raw[:,20], inputs_raw[:,25]], axis=-1)
        mass = np.stack([inputs_raw[:,4], inputs_raw[:,11], zeros, zeros, zeros], axis=-1)
        p4 = _p4_from_ptetaphim(pt, eta, phi, mass)
        self.vectors = torch.tensor(np.stack([p4.px, p4.py, p4.pz, p4.energy], axis=-1), dtype=torch.float32)
        print('  vectors:', self.vectors.shape)

        # Define tokenized "points" with dim (N, L=5, d=2), d=2 for (eta, phi)
        self.points = torch.tensor(np.stack([eta, phi], axis=-1), dtype=torch.float32)
        print('  points:', self.points.shape)

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        points = self.points[idx]
        features = self.features[idx]
        vectors = self.vectors[idx]        
        y = torch.tensor(self.targets[idx], dtype=torch.long) # 一维的label
        return (points, features, vectors), y

In [4]:
input_columns = ['bjet1_pt', 'bjet1_eta', 'bjet1_phi',
       'bjet1_eratio', 'bjet1_mass', 'bjet1_ncharged', 'bjet1_nneutrals',
       'bjet2_pt', 'bjet2_eta', 'bjet2_phi', 'bjet2_eratio', 'bjet2_mass',
       'bjet2_ncharged', 'bjet2_nneutrals', 'lep1_pt', 'lep1_phi', 'lep1_eta',
       'lep1_charge', 'lep1_type', 'lep2_pt', 'lep2_phi', 'lep2_eta',
       'lep2_charge', 'lep2_type', 'met', 'met_phi']
target_column = 'label'

# Split the data into training/validation/testing datasets following 80/10/10%
train_df, test_df = train_test_split(df, test_size=0.2)
val_df, test_df = train_test_split(test_df, test_size=0.5)

# Create datasets
train_dataset = DataFrameTokenizedDataset(train_df, [c + '_trans' for c in input_columns], input_columns, target_column)
val_dataset = DataFrameTokenizedDataset(val_df, [c + '_trans' for c in input_columns], input_columns, target_column)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1024)


Initializing DataFrameTokenizedDataset...
  features: torch.Size([240000, 5, 7])
  vectors: torch.Size([240000, 5, 4])
  points: torch.Size([240000, 5, 2])
Initializing DataFrameTokenizedDataset...
  features: torch.Size([30000, 5, 7])
  vectors: torch.Size([30000, 5, 4])
  points: torch.Size([30000, 5, 2])


基于这个 Dataset 输出的 `features`，可以做新的 Transformer 的 token embeder。

从 Dataset 中输出的 batched sample，`features` 应该具有的维度是 `(batch_size, num_input_token=5, input_size=7)`

In [5]:
class TokenEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embeddings = nn.ModuleDict()
        for token_name, in_dim in zip(['bjet', 'lep', 'met'], [7, 5, 2]):
            # embedding layers for three types of tokens: bjet, lep, met, with input vector dim 7, 5, 2
            self.token_embeddings[token_name] = nn.Sequential(
                nn.Linear(in_dim, config.hidden_size),
                nn.GELU(),
                nn.Linear(config.hidden_size, config.hidden_size),
            )

        self.layer_norm = nn.LayerNorm(config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # add a trainable class token
        self.class_token = nn.Parameter(torch.randn(1, config.hidden_size))

    def forward(self, features):
        # features have size (batch_size, num_input_token=5, input_size=7)
        feat_bjet1, feat_bjet2, feat_lep1, feat_lep2, feat_met = features.unbind(dim=1) # each with dim (batch_size, 7)
        embeddings = []
        embeddings.append(self.class_token.expand(features.size(0), -1)) # first token is the [CLS] token
        embeddings.append(self.token_embeddings['bjet'](feat_bjet1)) # bjet1 token embed by bjet embedding layer
        embeddings.append(self.token_embeddings['bjet'](feat_bjet2)) # bjet2 token embed by bjet embedding layer
        embeddings.append(self.token_embeddings['lep'](feat_lep1[:, :5])) # lep1 token embed by lep embedding layer
        embeddings.append(self.token_embeddings['lep'](feat_lep2[:, :5])) # lep2 token embed by lep embedding layer
        embeddings.append(self.token_embeddings['met'](feat_met[:, :2])) # met token embed by met embedding layer
        embeddings = torch.stack(embeddings, dim=1)
        # now embeddings has dim (batch_size, num_tokens=6, hidden_size)
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

然后是我们修改后的 Transformer。

 - 首先，对于基础的 Transformer block，我们使用 `nn.MultiheadAttention` 自带的功能，其正向计算可以额外接受一个 `attn_mask` 变量，它是 `(batch_size * num_heads, num_tokens, num_tokens)` 维度的，目的正是给 `num_heads` 个 attention score添加额外的偏置。

   > 参考 https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html

     - *`attn_mask`*: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
        $(L, S)$ or $(N\cdot\text{num\_heads}, L, S)$, where $N$  is the batch size,
        $L$ is the target sequence length, and $S$ is the source sequence length. A 2D mask will be
        broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
        Binary and float masks are supported. For a binary mask, a `True` value indicates that the
        corresponding position is not allowed to attend. For a float mask, the mask values will be added to
        the attention weight.

 - 然后，通过 `get_attention_bias(vectors)` 函数，可以利用每个输入 token 的4-动量，计算得到 $N^2$ 个粒子对4-动量和的不变质量。这里用到了 broadcast 的性质。
   > 实现方案：`to_m2(v.unsqueeze(-2) + v.unsqueeze(-3), eps=1e-8)`

    将它进行 embed 后，当做 `attn_mask` 输入给 Transformer block 即可。


In [10]:
class BasicTransformerWithAttnBiasBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, attn_dropout=0.1, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=attn_dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_hidden_dim),
            nn.GELU(),
            nn.Linear(ff_hidden_dim, embed_dim)
        )
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        # inputs: 
        #   x: (batch_size, num_tokens, embed_dim)
        #   attn_mask: (batch_size * num_heads, num_tokens, num_tokens)

        # Self-attention
        x = self.norm1(x) # the first layernorm
        attn_output, _ = self.attention(x, x, x, need_weights=False, attn_mask=attn_mask) # input Q, K, V are the same
        x = x + self.dropout(attn_output) # the first residual connection

        # Feed-forward
        x = self.norm2(x) # the second layernorm
        ff_output = self.ff(x)
        x = x + self.dropout(ff_output) # the second residual connection

        return x


def to_m2(x, eps=1e-8):
    m2 = x[..., 3:4].square() - x[..., :3].square().sum(dim=-1, keepdim=True)
    return m2.clamp(min=eps)

class BasicTransformerWithAttnBiasForClassification(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeddings = TokenEmbeddings(config) # the same Embeddings module as before
        self.blocks = nn.ModuleList(
            [BasicTransformerWithAttnBiasBlock(config.hidden_size, config.num_attention_heads, config.intermediate_size, attn_dropout=0., dropout=config.hidden_dropout_prob) 
             for _ in range(config.num_hidden_layers)]
            )
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.use_attention_bias = config.use_attention_bias

        if self.use_attention_bias:
            self.pair_embedding = nn.Sequential(
                nn.Linear(1, config.pair_hidden_size),
                nn.GELU(),
                nn.Linear(config.pair_hidden_size, config.num_attention_heads),
            )

    def get_attention_bias(self, vectors):
        # Calculate pairwise masses and embed them, from vectors: (batch_size, num_tokens, 4)

        v = torch.cat([torch.zeros_like(vectors[:, :1]), vectors], dim=1) # add zero vector for the [CLS] token
        # broadcast the vectors to all pairs
        # note: 
        #   v.unsqueeze(-2) has shape (batch_size, num_tokens, 1, 4)
        #   v.unsqueeze(-3) has shape (batch_size, 1, num_tokens, 4)
        lnm2 = torch.log(to_m2(v.unsqueeze(-2) + v.unsqueeze(-3), eps=1e-8) + 1) # dim: (batch_size, num_tokens, num_tokens, 1)

        output = self.pair_embedding(lnm2).permute(0, 3, 1, 2).reshape(-1, v.size(1), v.size(1)) # dim: (batch_size * num_attention_heads, num_tokens, num_tokens)
        return output

    def forward(self, points, features, vectors):
        x = self.embeddings(features)
        attn_mask = self.get_attention_bias(vectors) if self.use_attention_bias else None
        for block in self.blocks:
            x = block(x, attn_mask=attn_mask)
        x = x[:, 0, :] # select hidden state of [CLS] token
        x = self.dropout(x)
        x = self.classifier(x)
        return x

现在我们定义下这个模型，看下模型的实力！

In [11]:
from types import SimpleNamespace
config = SimpleNamespace(
    hidden_size = 16,  # Hidden size for embeddings and Transformer
    num_attention_heads = 4,  # Number of attention heads
    intermediate_size = 32,  # Feed forward intermediate size
    num_hidden_layers = 4,  # Number of Transformer layers
    hidden_dropout_prob = 0.,  # Dropout probability
    num_labels = 2,  # Number of output classes

    # new config for pairwise mass embedding
    use_attention_bias = True,  # Whether to use attention bias
    pair_hidden_size = 8,  # Hidden size for pairwise mass embedding
)

# Define the PyTorch Lightning model
class BasicTransformerWithAttnBias(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.mod = BasicTransformerWithAttnBiasForClassification(config)

        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=config.num_labels)
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=config.num_labels)

    def forward(self, *x):
        return self.mod(*x)

    def training_step(self, batch, batch_idx):
        x, y = batch # x: a tuple of (points, features, vectors), y: label
        logits = self(*x)
        loss = F.cross_entropy(logits, y)
        self.train_acc(logits, y)
        self.log('train_loss', loss)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True)
        self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'])
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch # x: a tuple of (points, features, vectors), y: label
        logits = self(*x)
        loss = F.cross_entropy(logits, y)
        self.val_acc(logits, y)
        self.log('val_loss', loss)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        
        def lr_lambda(epoch):
            if epoch < 0.7 * self.trainer.max_epochs:
                return 1.0
            else:
                decay_factor = (epoch - 0.7 * self.trainer.max_epochs) / (0.3 * self.trainer.max_epochs)
                return 0.01 ** decay_factor

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        return [optimizer], [scheduler]

# 定义model的实例
model_transformer_attnbias = BasicTransformerWithAttnBias(config)

# 打印下这个模型进行检查
print(model_transformer_attnbias)


BasicTransformerWithAttnBias(
  (mod): BasicTransformerWithAttnBiasForClassification(
    (embeddings): TokenEmbeddings(
      (token_embeddings): ModuleDict(
        (bjet): Sequential(
          (0): Linear(in_features=7, out_features=16, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=16, out_features=16, bias=True)
        )
        (lep): Sequential(
          (0): Linear(in_features=5, out_features=16, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=16, out_features=16, bias=True)
        )
        (met): Sequential(
          (0): Linear(in_features=2, out_features=16, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=16, out_features=16, bias=True)
        )
      )
      (layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (blocks): ModuleList(
      (0-3): 4 x BasicTransformerWithAttnBiasBlock(
    

In [12]:
# Trainer with a logger
trainer = pl.Trainer(max_epochs=50, logger=pl.loggers.TensorBoardLogger('tb_logs', name='simple_transformer_attnbias'))

# Fit the model
trainer.fit(model_transformer_attnbias, train_loader, val_loader)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name      | Type                                          | Params | Mode 
------------------------------------------------------------------------------------
0 | mod       | BasicTransformerWithAttnBiasForClassification | 10.1 K | train
1 | train_acc | MulticlassAccuracy                            | 0      | train
2 | val_acc   | MulticlassAccuracy                            | 0      | train
------------------------------------------------------------------------------------
10.1 K    Trainable params
0         Non-trainable params
10.1 K    T

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


通过观察 Tensorboard 上面的训练指标，可以看到进步非常显著！
这就是简单地把平淡的Transformer，沿着Particle Transformer方向稍作修改后带来的变化。

<img src="figures/tensorboard_output2.png" alt="image" width=900/>

## 接入 ParT / ParticleNet 模型

利用上面定义的 Dataset，我们可以直接接入标准的 ParT、ParticleNet 模型。原始模型可以在这里下载：
https://github.com/hqucms/weaver-core/tree/main/weaver/nn/model

**如果在 Google Colab 中使用，请首先运行下面的命令下载这些模型文件。**

In [None]:
! if [[ "$(hostname)" != *pku.edu.cn* && "$(hostname)" != *lxlogin* ]]; then \
    wget https://raw.githubusercontent.com/colizz/ml-tutorial/refs/heads/v2025-01-nku/models/ParticleTransformer.py; \
    wget https://raw.githubusercontent.com/colizz/ml-tutorial/refs/heads/v2025-01-nku/models/ParticleNet.py; \
    mkdir models && mv {ParticleTransformer.py,ParticleNet.py} models/; \
fi

我们可以直接从文件中加载模型，并初始化模型实例。

```python
model_part = ParticleTransformerTaggerMultiGroups(
    input_dims,
    num_classes=num_labels, 
    embed_dims=[32, 32],
    pair_embed_dims=[8, 8],
    num_heads=4,
    num_layers=4,
)
```

这里有几点说明：

 - 因为我们的token的特殊性，一共包含3组不同类型的token（`bjet`, `lep`, `met`），它们各自`features` 长度是不一样的（7/5/2）。因此，需要使用 ParT 中对于初始 token 分为多组，每组各自进行 embedding 的实现。其逻辑与我们上面所做是相同的。在 ParT 文件中，我们使用的是额外添加的 `ParticleTransformerTaggerMultiGroups` 的 Module，它支持任意多组 token 的情形。

 - `ParticleTransformerTaggerMultiGroups` 对于每一组 token 接受三种 input：`features`, `vectors`, `mask`，因此一共需要9个参量作为input。我们的例子中，每个事例都有固定的5个token。mask可以设置为全为1的向量。我们从刚才写的 Dataset 的输出开始，经过一个 `process_model_input(self, x)` 函数，可以得到ParT模型所需的输入特征。

 - 我们初始化的 ParT 是规模很小的，因为我们的任务并不复杂。可以尝试改变上面的模型初始化参数，来观察训练的效果。


In [5]:
# load the class ParticleNet in models/ParticleNet.py
import sys
sys.path.append("models/")

In [24]:
from ParticleTransformer import ParticleTransformerTaggerMultiGroups

# Define an model instance
# use the default parameters in the corresponding modules
input_dims = (7, 5, 2)
num_labels = 2

model_part = ParticleTransformerTaggerMultiGroups(
    input_dims,
    num_classes=num_labels, 
    embed_dims=[32, 32],
    pair_embed_dims=[8, 8],
    num_heads=4,
    num_layers=4,
)


import torch
import pytorch_lightning as pl
import torch.nn.functional as F
import torchmetrics

# Define the PyTorch Lightning module
class ParTLightningModel(pl.LightningModule):
    def __init__(self, model, start_lr=0.001):
        super().__init__()
        self.mod = model # model is an instance of ParticleNet, ParticleNeXt, or ParticleTransformer
        self.start_lr = start_lr

        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_labels)
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_labels)

    def forward(self, *x):
        return self.mod(*x)

    def process_model_input(self, x):
        _, features, vectors = x
        features = features.permute(0, 2, 1) # permute the features to (batch_size, num_features, num_tokens)
        vectors = vectors.permute(0, 2, 1)
        masks = torch.ones((features.size(0), 1, features.size(2)), dtype=torch.float32, device=features.device)

        # form separate groups of tokens
        feat_bjet, v_bjet, m_bjet = features[:,:,:2], vectors[:,:,:2], masks[:,:,:2]
        feat_lep, v_lep, m_lep = features[:,:5,2:4], vectors[:,:,2:4], masks[:,:,2:4]
        feat_met, v_met, m_met = features[:,:2,4:], vectors[:,:,4:], masks[:,:,4:]
        return feat_bjet, v_bjet, m_bjet, feat_lep, v_lep, m_lep, feat_met, v_met, m_met

    def training_step(self, batch, batch_idx):
        x, y = batch # x: a tuple of (points, features, vectors), y: label
        logits = self(*self.process_model_input(x))
        loss = F.cross_entropy(logits, y)
        self.train_acc(logits, y)
        self.log('train_loss', loss)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True)
        self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'])
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch # x: a tuple of (points, features, vectors), y: label
        logits = self(*self.process_model_input(x))
        loss = F.cross_entropy(logits, y)
        self.val_acc(logits, y)
        self.log('val_loss', loss)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.start_lr)
        
        def lr_lambda(epoch):
            if epoch < 0.7 * self.trainer.max_epochs:
                return 1.0
            else:
                decay_factor = (epoch - 0.7 * self.trainer.max_epochs) / (0.3 * self.trainer.max_epochs)
                return 0.01 ** decay_factor

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # LR scheduler定义为从70% epoch开始，指数衰减到1%
        return [optimizer], [scheduler]

# 定义model的实例
model_part_pl = ParTLightningModel(model_part)

# 打印下这个模型进行检查
print(model_part_pl)


ParTLightningModel(
  (mod): ParticleTransformerTaggerMultiGroups(
    (trimmers): ModuleList(
      (0-2): 3 x SequenceTrimmer()
    )
    (input_embeds): ModuleList(
      (0): Embed(
        (input_bn): BatchNorm1d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (embed): Sequential(
          (0): LayerNorm((7,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=7, out_features=32, bias=True)
          (2): GELU(approximate='none')
          (3): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (4): Linear(in_features=32, out_features=32, bias=True)
          (5): GELU(approximate='none')
        )
      )
      (1): Embed(
        (input_bn): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (embed): Sequential(
          (0): LayerNorm((5,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=5, out_features=32, bias=True)
          (2): GELU(approximate='none')
    

In [25]:
# Trainer with a logger
trainer = pl.Trainer(max_epochs=50, logger=pl.loggers.TensorBoardLogger('tb_logs', name='particle_transformer'))

# Fit the model
trainer.fit(model_part_pl, train_loader, val_loader)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name      | Type                                 | Params
-------------------------------------------------------------------
0 | mod       | ParticleTransformerTaggerMultiGroups | 82.7 K
1 | train_acc | MulticlassAccuracy                   | 0     
2 | val_acc   | MulticlassAccuracy                   | 0     
-------------------------------------------------------------------
82.7 K    Trainable params
0         Non-trainable params
82.7 K    Total params
0.331     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/olympus/licq/utils/miniconda3/envs/weaver/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=71` in the `DataLoader` to improve performance.
/home/olympus/licq/utils/miniconda3/envs/weaver/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=71` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

/home/olympus/licq/utils/miniconda3/envs/weaver/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


### 思考

> 训练这个标准的 ParT 模型，其性能如何？在 Tensorboard 上观察训练过程的指标。

最后，我们使用这个接口再接入标准的 ParticleNet 文件，方便之后可能在一些任务中使用到。

和上面的 ParT 相比，区别在于

 - 这里定义的 ParticleNet 模型，其初始化的参数有所不同。同样地，我们也使用了规模很小的 ParticleNet，可训练参数仅为 20.1k。

 - ParticleNet 模型对于每组 token 也接受三类input：`(points, features, mask)`。我们的新 Dataset 输出的 points 可以是为这里准备的。

 - ParticleNet 训练的起始learning rate相对较大。下面设置了 `start_lr=0.004`。

观察 ParticleNet 的训练结果，与 ParT 和之前所有的 Transformer模型比，有哪些区别？


In [18]:
from ParticleNet import ParticleNetTaggerMultiGroups

# Define an model instance
# use the default parameters in the corresponding modules
input_dims = (7, 5, 2)
num_labels = 2

model_pnet = ParticleNetTaggerMultiGroups(
    input_dims,
    num_classes=num_labels,
    embed_dim=16,
    conv_params=[(4, (16, 16, 16)), (4, (32, 32, 32))], # (kernel_size, (conv1_channels, conv2_channels, conv3_channels)). should have kernel_size <= num_tokens - 1
    fc_params=[(64, 0.1)],
)


import torch
import pytorch_lightning as pl
import torch.nn.functional as F
import torchmetrics

# Define the PyTorch Lightning module
class ParticleNetLightningModel(pl.LightningModule):
    def __init__(self, model, start_lr=0.001):
        super().__init__()
        self.mod = model # model is an instance of ParticleNet, ParticleNeXt, or ParticleTransformer
        self.start_lr = start_lr

        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_labels)
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_labels)

    def forward(self, *x):
        return self.mod(*x)

    def process_model_input(self, x):
        points, features, _ = x
        points = points.permute(0, 2, 1) # permute the features to (batch_size, num_features, num_tokens)
        features = features.permute(0, 2, 1)
        masks = torch.ones((features.size(0), 1, features.size(2)), dtype=torch.float32, device=features.device)

        # form separate groups of tokens
        point_bjet, feat_bjet, m_bjet = points[:,:,:2], features[:,:,:2], masks[:,:,:2]
        point_lep, feat_lep, m_lep = points[:,:,2:4], features[:,:5,2:4], masks[:,:,2:4]
        point_met, feat_met, m_met = points[:,:,4:], features[:,:2,4:], masks[:,:,4:]
        return point_bjet, feat_bjet, m_bjet, point_lep, feat_lep, m_lep, point_met, feat_met, m_met

    def training_step(self, batch, batch_idx):
        x, y = batch # x: a tuple of (points, features, vectors), y: label
        logits = self(*self.process_model_input(x))
        loss = F.cross_entropy(logits, y)
        self.train_acc(logits, y)
        self.log('train_loss', loss)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True)
        self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'])
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch # x: a tuple of (points, features, vectors), y: label
        logits = self(*self.process_model_input(x))
        loss = F.cross_entropy(logits, y)
        self.val_acc(logits, y)
        self.log('val_loss', loss)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.start_lr)
        
        def lr_lambda(epoch):
            if epoch < 0.7 * self.trainer.max_epochs:
                return 1.0
            else:
                decay_factor = (epoch - 0.7 * self.trainer.max_epochs) / (0.3 * self.trainer.max_epochs)
                return 0.01 ** decay_factor

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # LR scheduler定义为从70% epoch开始，指数衰减到1%
        return [optimizer], [scheduler]

# 定义model的实例
model_pnet_pl = ParticleNetLightningModel(model_pnet, start_lr=0.004)

# 打印下这个模型进行检查
print(model_pnet_pl)


ParticleNetLightningModel(
  (mod): ParticleNetTaggerMultiGroups(
    (convs): ModuleList(
      (0): FeatureConv(
        (conv): Sequential(
          (0): BatchNorm1d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): Conv1d(7, 16, kernel_size=(1,), stride=(1,), bias=False)
          (2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU()
        )
      )
      (1): FeatureConv(
        (conv): Sequential(
          (0): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): Conv1d(5, 16, kernel_size=(1,), stride=(1,), bias=False)
          (2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU()
        )
      )
      (2): FeatureConv(
        (conv): Sequential(
          (0): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): Conv1d(2, 16, kernel_size=(1,), stride=(1,

In [19]:
# Trainer with a logger
trainer = pl.Trainer(max_epochs=50, logger=pl.loggers.TensorBoardLogger('tb_logs', name='particlenet'))

# Fit the model
trainer.fit(model_pnet_pl, train_loader, val_loader)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name      | Type                         | Params
-----------------------------------------------------------
0 | mod       | ParticleNetTaggerMultiGroups | 20.1 K
1 | train_acc | MulticlassAccuracy           | 0     
2 | val_acc   | MulticlassAccuracy           | 0     
-----------------------------------------------------------
20.1 K    Trainable params
0         Non-trainable params
20.1 K    Total params
0.081     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/olympus/licq/utils/miniconda3/envs/weaver/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=71` in the `DataLoader` to improve performance.
/home/olympus/licq/utils/miniconda3/envs/weaver/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=71` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.
