# Setup #

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
import os
import matplotlib.pyplot as plt
import torch
from torch_geometric.data import DataLoader
from torch.utils.data import ConcatDataset
from pathlib import Path as Path
from dataset import *
from utils import *
from cpt import *
from functools import partial
from tqdm import tqdm
import wandb
tqdm = partial(tqdm, position=0, leave=True)
import matplotlib.pylab as pylab
%matplotlib inline
%load_ext autoreload
%autoreload 2

USE_GPU = True
dtype = torch.float32 # we will be using float
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print('using device:', device)

from seaborn import heatmap
import seaborn as sns
sns.set_style("white")
sns.axes_style("white")

# Load datasets

In [None]:
dataset_names = ['data/final/ttH']
max_num_output = 2
detector_x = True
detector_y = True
to_torch = get_to_torch(max_num_output, detector_x=detector_x, detector_y=detector_y)
to_torch_test = get_to_torch(max_num_output, detector_x=detector_x, detector_y=detector_y, test=True)
Ds = [ConcatDataset([LMDBDataset(f'{dataset_name}/data_{i}', transform=to_torch, use_cache=False, readahead=False) for i in range(0, 10)]) for dataset_name in dataset_names]
Ds = [split_dataset(D, name=ds_name, max_train_event=None, max_test_event=None) for D, ds_name in zip(Ds, dataset_names)]
D_train = ConcatDataset([D[0] for D in Ds])
D_val = ConcatDataset([D[1] for D in Ds])
D_test = ConcatDataset([D[2] for D in Ds])
batch_size = 256
train_loader = DataLoader(D_train, batch_size, num_workers=0, shuffle=True)
val_loader = DataLoader(D_val, batch_size, num_workers=0) if D_val else None
test_loader = DataLoader(D_test, batch_size, num_workers=0) if D_test else None

# Train model #

In [None]:
max_num_output = 2 # assert that there are at most ${max_num_output} tops
base_dir = './trained'
arch = 'CovariantTopFormer'
model_params = {
    'geometric': True,
    'break_eta_covariance': False,
    'in_dim': 9,
    'out_dim': 4,
    'max_num_output': max_num_output,
    'hidden_dim': 256,
    'num_convs': (6, 6),
    'heads': 4,
    'mass': 173,
    'match_scale_factor': torch.FloatTensor([0, 1, 1, 0]), # used for matching dR = (dy, dphi)    
    'p_norm': 2, # used in matching and loss
    'beta': 0.8, # loss weight for predicting number of tops
    'dropout': 0.,
    'schedule_lr': False,
    'use_gpu': USE_GPU,
    'uniform_attention': False,
}
tag = 'TEST' # CHANGE THIS
output_dir = f"{arch}_{model_params['num_convs']}_{model_params['hidden_dim']}_{tag}"
output_dir = os.path.join(base_dir, output_dir)
model_params['output_dir'] = output_dir
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

model = eval(arch)(**model_params).to(device)

In [None]:
# uncomment line below to load from most recent checkpoint
# model.load('most_recent_epoch_model.pt') 
with wandb.init(project='CPT', name=tag, config=model_params, resume=False):
  model.train_model(100, train_loader, val_loader, None)

# Run inference #

In [None]:
batch_size = 1024
max_num_batch = 100 # set a large value (1e7) to run on all test data
version_name = 'FIXME' # choose a name for saving result to disk

test_loader = DataLoader(D_test, batch_size, num_workers=4, prefetch_factor=1, follow_batch=['x_in']) if D_test else None
test_result = model.run_inference(test_loader, max_num_batch=max_num_batch, force_correct_num_pred=True, version=version_name)

# Make table and plots #

In [None]:
from make_stats import run as make_plots_and_tables
### DON'T CHANGE ###
bins = 2 # bins - 1 = actual number of bins...
entries_per_bin = 100000000 # for binning the result in phase space. set a large number to skip this.
### DON'T CHANGE ###

version_name = 'FIXME' 
test_result = torch.load(f'{output_dir}/test_result_{version_name}.pt')
make_plots_and_tables(test_result, f'{output_dir}/test_result_{version_name}', max_num_output, bins, entries_per_bin) 
# once the above run, can find a summary table named result.csv inside the output directory (2nd arg).