In [None]:
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------

import os, random, copy, yaml, pickle
from time import time, sleep
from tqdm import tqdm
from math import floor

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pyarrow.parquet as pq

# torch imports
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import random_split, IterableDataset

import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.utils import homophily
from torch_geometric.loader import DataLoader
from torch_geometric.nn import EdgeConv, knn_graph
from torch_scatter import \
    scatter_add, scatter_mean, scatter_max, scatter_min

import pytorch_lightning as pl

In [None]:
# -----------------------------------------------------------------------------
# Helper function
# -----------------------------------------------------------------------------

import logging
import logging.config
import os, re, gc, psutil

def get_logger(name, msg):
    """
    :param name: string
    :param msg: DEBUG, INFO, WARNING, ERROR
    :return: Logger() instance
    """
    level = {"DEBUG": logging.DEBUG,
             "INFO": logging.INFO,
             "WARNING": logging.WARNING,
             "ERROR": logging.ERROR}
    logging.basicConfig(level=level[msg],
                        format="== %(name)s == %(asctime)s %(levelname)s:\t%(message)s",
                        datefmt="%H:%M:%S")
    logger = logging.getLogger(name)

    return logger


def walk_dir(dirname, batch_ids):
    files = dict()
    pattern = r"_(\d+)\.parquet"

    if batch_ids is None:
        batch_ids = list()
        for base, _, names in os.walk(dirname):
            for name in names:
                match = re.findall(pattern, name)
                batch_ids.append(int(match[0]))
                files[int(match[0])] = os.path.join(base, name)
        return files, batch_ids
    
    for base, _, names in os.walk(dirname):
        selected_files = dict()
        for name in names:
            match = re.findall(pattern, name)
            if int(match[0]) in batch_ids:
                selected_files[int(match[0])] = os.path.join(base, name)
        files.update(selected_files)
    return files, batch_ids


def memory_check(logger, msg=""):
    gc.collect()
    logger.debug(f"memory usage {psutil.virtual_memory().used / 1024**3:.2f} "
                f"of {psutil.virtual_memory().total / 1024**3:.2f} GB {msg}")


In [None]:
# -----------------------------------------------------------------------------
# Basic settings
# -----------------------------------------------------------------------------

# basic
BATCH_SIZE = 200
BATCHES_TEST = [651] #####################
# BATCHES_TEST = list(range(651, 656)) #####################
EVENTS_PER_FILE = 200_000

# paths
BASE_PATH = "/root/autodl-tmp/kaggle/" ######################
PATH = os.path.join(BASE_PATH, "icecube-neutrinos-in-deep-ice")
MODEL_PATH = os.path.join(BASE_PATH, "input", "ice-cube-model")
OUTPUT_PATH = os.path.join(BASE_PATH, "working")
TEST_PATH = os.path.join(PATH, "test")
META_PATH = os.path.join(OUTPUT_PATH, "test_meta")

# files
FILES_TEST, BATCHES_TEST = walk_dir(TEST_PATH, BATCHES_TEST)
FILE_META = os.path.join(PATH, "test_meta.parquet")
FILE_SENSOR_GEO = os.path.join(PATH, "sensor_geometry.csv")
FILE_GNNPre = os.path.join(MODEL_PATH, "official-pretrained.pth")
FILE_GNN = os.path.join(MODEL_PATH, "finetuned.ckpt")
FILE_BDT = os.path.join(MODEL_PATH, "BDT_clf.Baseline.0414.sklearn")

In [None]:
# -----------------------------------------------------------------------------
# Split meta file
# -----------------------------------------------------------------------------

# if not os.path.exists(META_PATH): 
#     os.mkdir(META_PATH)

# meta_test = pd.read_parquet(FILE_META)

# for i, df in meta_test.groupby("batch_id"):
#     print(f"processing {i} -> {df.shape}")
#     splitted = os.path.join(META_PATH, f"meta_{i}.parquet")
#     df.to_parquet(splitted)

In [None]:
# -----------------------------------------------------------------------------
# Some Logging
# -----------------------------------------------------------------------------
LOGGER = get_logger("IceCube", "DEBUG")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LOGGER.info(f"using {DEVICE}")
LOGGER.info(f"{len(FILES_TEST)} files for testing")
memory_check(LOGGER)

LOGGER.info(f"GNN model:{FILE_GNN}")
LOGGER.info(f"BDT model:{FILE_BDT}")

In [None]:
# -----------------------------------------------------------------------------
# Dataset
# -----------------------------------------------------------------------------

# sensor geometry
def prepare_sensors(scale=None):
    sensors = pd.read_csv(FILE_SENSOR_GEO).astype({
        "sensor_id": np.int16,
        "x": np.float32,
        "y": np.float32,
        "z": np.float32,
    })

    if scale is not None and isinstance(scale, float):
        sensors["x"] *= scale
        sensors["y"] *= scale
        sensors["z"] *= scale

    return sensors


def angle_to_xyz(angles_b):
    az, zen = angles_b.t()
    x = torch.cos(az) * torch.sin(zen)
    y = torch.sin(az) * torch.sin(zen)
    z = torch.cos(zen)
    return torch.stack([x, y, z], dim=1)


def xyz_to_angle(xyz_b):
    x, y, z = xyz_b.t()
    az = torch.arccos(x / torch.sqrt(x**2 + y**2)) * torch.sign(y)
    zen = torch.arccos(z / torch.sqrt(x**2 + y**2 + z**2))
    return torch.stack([az, zen], dim=1)


def angular_error(xyz_pred_b, xyz_true_b):
    return torch.arccos(torch.clip_(torch.sum(xyz_pred_b * xyz_true_b, dim=1), -1, 1))


def angles2vector(df):
    df["nx"] = np.sin(df.zenith) * np.cos(df.azimuth)
    df["ny"] = np.sin(df.zenith) * np.sin(df.azimuth)
    df["nz"] = np.cos(df.zenith) 
    return df


def vector2angles(n, eps=1e-8):
    n = n / (np.linalg.norm(n, axis=1, keepdims=True) + eps)    
    azimuth = np.arctan2( n[:,1],  n[:,0])    
    azimuth[azimuth < 0] += 2*np.pi
    zenith = np.arccos( n[:,2].clip(-1,1) )                                
    return np.concatenate([azimuth[:, np.newaxis], zenith[:, np.newaxis]], axis=1)


def series2tensor(series, set_device=None):
    ret = torch.from_numpy(series.values).float()
    if set_device is not None:
        return ret.to(DEVICE)
    return ret



def solve_linear(xw, yw, zw, xxw, yyw, xyw, yzw, zxw):
    A = torch.tensor([
        [xxw, xyw, xw],
        [xyw, yyw, yw],
        [xw,  yw,  1 ],
    ])
    b = torch.tensor([
        zxw, yzw, zw
    ])
    try:
        coeff = torch.linalg.solve(A, b)
        return coeff
    except Exception:
        LOGGER.debug("linear system not solvable")
        return torch.zeros((3, ))


def feature_extraction(df, fun=None, eps=1e-8):                                           # list of variables
    # sort by time
    df.sort_values(["time"], inplace=True)

    t = series2tensor(df.time)
    c = series2tensor(df.charge)
    x = series2tensor(df.x)
    y = series2tensor(df.y)
    z = series2tensor(df.z)

    hits = t.numel()                                                                            # hits

    # weighted values
    Sx = torch.sum(x); Sxx = torch.sum(x*x); Sxy = torch.sum(x*y)
    Sy = torch.sum(y); Syy = torch.sum(y*y); Syz = torch.sum(y*z)
    Sz = torch.sum(z); Szx = torch.sum(z*x)

    # error of plane estimate
    coeff = solve_linear(Sx, Sy, Sz, Sxx, Syy, Sxy, Syz, Szx)
    error = torch.sum((z - coeff[0] * x - coeff[1] * y - coeff[2]))
    error = torch.square(error * 1e3)                                                           # error

    # plane norm vector
    norm_vec = torch.tensor([coeff[0], coeff[1], -1], dtype=torch.float)
    norm_vec /= torch.sqrt(coeff[0]**2 + coeff[1]**2 + 1)                                       # norm_vec -> (3, )

    # delta t -> median time
    dt = torch.quantile(t, torch.tensor([0.15, 0.50, 0.85], dtype=torch.float))                 # dt -> (3, )

    # charge centre (vector)
    sumq = torch.sum(c)                                                                         # sumq
    meanq = sumq / hits                                                                         # meanq
    qv = torch.tensor([torch.sum(x*c), torch.sum(y*c), torch.sum(z*c)], dtype=torch.float)
    qv /= sumq                                                                                  # qv -> (3, )

    # bright sensor ratio
    bratio = c[c > 5 * meanq].numel() / hits                                                    # bratio

    # grouping by time (remember to sort by time)
    n_groups = 4                                                                                # xyzt -> (16, )

    if hits > n_groups:
        sec_len = floor(hits / n_groups)
        remain_len = hits - (n_groups - 1) * sec_len
        xyzt = series2tensor(df[["x", "y", "z", "time"]])
        xyzt = torch.split(xyzt, [sec_len, sec_len, sec_len, remain_len])
        xyzt = torch.concat([xx.mean(axis=0) for xx in xyzt])
    else:
        xyzt = torch.zeros(n_groups * 4)
        _xxxx = list()
        for i in range(hits):
            _xxxx.append(x[i]); _xxxx.append(y[i]); _xxxx.append(z[i]); _xxxx.append(t[i])
        xyzt[: hits * 4] = torch.tensor(_xxxx, dtype=torch.float)

    # unique xyz
    unique = torch.tensor([_x.unique().numel() for _x in [x, y, z]], dtype=torch.float)         # unique -> (3, )

    # global features
    glob_feat = torch.tensor([hits, error, sumq, meanq, bratio, ], dtype=torch.float)

    return torch.concat([norm_vec, dt, qv, xyzt, unique, glob_feat]).unsqueeze(0)


def prepare_feature(df):
    df = df.reset_index(drop=True)
    # remove auxiliary
    df = df[~df.auxiliary]
    df.x *= 1e-3; df.y *= 1e-3; df.z *= 1e-3
    df.time -= np.min(df.time)
    return df[["time", "charge", "x", "y", "z"]]


# Dataset
class IceCube(IterableDataset):
    def __init__(
        self, parquet_dir, meta_dir, chunk_ids,
        batch_size=200, max_pulses=200, extra=False
    ):
        self.parquet_dir = parquet_dir
        self.meta_dir = meta_dir
        self.chunk_ids = chunk_ids
        self.batch_size = batch_size
        self.max_pulses = max_pulses
        self.extra = extra

    def __iter__(self):
        # Handle num_workers > 1 and multi-gpu
        is_dist = torch.distributed.is_initialized()
        world_size = torch.distributed.get_world_size() if is_dist else 1
        rank_id = torch.distributed.get_rank() if is_dist else 0

        info = torch.utils.data.get_worker_info()
        num_worker = info.num_workers if info else 1
        worker_id = info.id if info else 0

        num_replica = world_size * num_worker
        offset = rank_id * num_worker + worker_id
        chunk_ids = self.chunk_ids[offset::num_replica]

        # Sensor data
        sensor = prepare_sensors()

        # Read each chunk and meta iteratively into memory and build mini-batch
        for c, chunk_id in enumerate(chunk_ids):
            data = pd.read_parquet(os.path.join(self.parquet_dir, f"batch_{chunk_id}.parquet"))
            meta = pd.read_parquet(os.path.join(self.meta_dir, f"meta_{chunk_id}.parquet"))

            eids = meta["event_id"].tolist()
            eids_batches = [
                eids[i : i + self.batch_size]
                for i in range(0, len(eids), self.batch_size)
            ]

            for batch_eids in eids_batches:
                batch = []

                # For each sample, extract features
                for eid in batch_eids:
                    df = data.loc[eid]
                    df = pd.merge(df, sensor, on="sensor_id")
                    # sampling of pulses if number exceeds maximum
                    if len(df) > self.max_pulses:
                        df_pass = df[~df.auxiliary]
                        df_fail = df[df.auxiliary]
                        if len(df_pass) >= self.max_pulses:
                            df = df_pass.sample(self.max_pulses)
                        else:
                            df_fail = df_fail.sample(self.max_pulses - len(df_pass))
                            df = pd.concat([df_fail, df_pass])

                    df.sort_values(["time"], inplace=True)

                    t = series2tensor(df.time)
                    c = series2tensor(df.charge)
                    a = series2tensor(df.auxiliary)
                    x = series2tensor(df.x)
                    y = series2tensor(df.y)
                    z = series2tensor(df.z)
                        
                    feat = torch.stack([x, y, z, t, c, a], dim=1)

                    batch_data = Data(x=feat, n_pulses=len(feat), eid=torch.tensor([eid]).long())

                    if self.extra:
                        feats = feature_extraction(prepare_feature(df))
                        setattr(batch_data, "extra_feat", feats)

                    batch.append(batch_data)

                yield Batch.from_data_list(batch)

            del data
            del meta
            gc.collect()

In [None]:
# -----------------------------------------------------------------------------
# GraphNet GNN model
# -----------------------------------------------------------------------------

class MLP(nn.Sequential):
    def __init__(self, feats):
        layers = []
        for i in range(1, len(feats)):
            layers.append(nn.Linear(feats[i - 1], feats[i]))
            layers.append(nn.LeakyReLU())
        super().__init__(*layers)


class Model(pl.LightningModule):
    def __init__(
        self, max_lr=1e-3, 
        num_warmup_step=1_000,
        remaining_step=1_000,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.conv0 = EdgeConv(MLP([17 * 2, 128, 256]), aggr="add")
        self.conv1 = EdgeConv(MLP([512, 336, 256]), aggr="add")
        self.conv2 = EdgeConv(MLP([512, 336, 256]), aggr="add")
        self.conv3 = EdgeConv(MLP([512, 336, 256]), aggr="add")
        self.post = MLP([1024 + 17, 336, 256])
        self.readout = MLP([768, 128])
        self.pred = nn.Linear(128, 3)

    def forward(self, data: Batch):
        vert_feat = data.x
        batch = data.batch

        # x, y, z, t, c, a
        # 0  1  2  3  4  5
        vert_feat[:, 0] /= 500.0  # x
        vert_feat[:, 1] /= 500.0  # y
        vert_feat[:, 2] /= 500.0  # z
        vert_feat[:, 3] = (vert_feat[:, 3] - 1.0e04) / 3.0e4  # time
        vert_feat[:, 4] = torch.log10(vert_feat[:, 4]) / 3.0  # charge

        edge_index = knn_graph(vert_feat[:, :3], 8, batch)

        # Construct global features
        hx = homophily(edge_index, vert_feat[:, 0], batch).reshape(-1, 1)
        hy = homophily(edge_index, vert_feat[:, 1], batch).reshape(-1, 1)
        hz = homophily(edge_index, vert_feat[:, 2], batch).reshape(-1, 1)
        ht = homophily(edge_index, vert_feat[:, 3], batch).reshape(-1, 1)
        means = scatter_mean(vert_feat, batch, dim=0)
        n_p = torch.log10(data.n_pulses).reshape(-1, 1)
        global_feats = torch.cat([means, hx, hy, hz, ht, n_p], dim=1)  # [B, 11]

        # Distribute global_feats to each vertex
        _, cnts = torch.unique_consecutive(batch, return_counts=True)
        global_feats = torch.repeat_interleave(global_feats, cnts, dim=0)
        vert_feat = torch.cat((vert_feat, global_feats), dim=1)

        # Convolutions
        feats = [vert_feat]
        # Conv 0
        vert_feat = self.conv0(vert_feat, edge_index)
        feats.append(vert_feat)
        # Conv 1
        edge_index = knn_graph(vert_feat[:, :3], k=8, batch=batch)
        vert_feat = self.conv1(vert_feat, edge_index)
        feats.append(vert_feat)
        # Conv 2
        edge_index = knn_graph(vert_feat[:, :3], k=8, batch=batch)
        vert_feat = self.conv2(vert_feat, edge_index)
        feats.append(vert_feat)
        # Conv 3
        edge_index = knn_graph(vert_feat[:, :3], k=8, batch=batch)
        vert_feat = self.conv3(vert_feat, edge_index)
        feats.append(vert_feat)

        # Postprocessing
        post_inp = torch.cat(feats, dim=1)
        post_out = self.post(post_inp)

        # Readout
        readout_inp = torch.cat(
            [
                scatter_min(post_out, batch, dim=0)[0],
                scatter_max(post_out, batch, dim=0)[0],
                scatter_mean(post_out, batch, dim=0),
            ],
            dim=1,
        )
        readout_out = self.readout(readout_inp)

        # Predict
        pred = self.pred(readout_out)
        kappa = pred.norm(dim=1, p=2) + 1e-8
        pred_x = pred[:, 0] / kappa
        pred_y = pred[:, 1] / kappa
        pred_z = pred[:, 2] / kappa
        pred = torch.stack([pred_x, pred_y, pred_z, kappa], dim=1)

        return pred

    def train_or_valid_step(self, data, prefix):
        pred_xyzk = self.forward(data)  # [B, 4]
        true_xyz = data.gt.view(-1, 3)  # [B, 3]
        loss = VonMisesFisher3DLoss()(pred_xyzk, true_xyz).mean()
        error = angular_error(pred_xyzk[:, :3], true_xyz).mean()
        self.log(f"loss-{prefix}", loss, batch_size=len(true_xyz), 
            on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log(f"error-{prefix}", error, batch_size=len(true_xyz), 
            on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss

    def training_step(self, data, _):
        return self.train_or_valid_step(data, "train")

    def validation_step(self, data, _):
        return self.train_or_valid_step(data, "valid")
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.max_lr)
        scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[
                torch.optim.lr_scheduler.LinearLR(
                    optimizer, 1e-2, 1, self.hparams.num_warmup_step
                ),
                torch.optim.lr_scheduler.LinearLR(
                    optimizer, 1, 1e-3, self.hparams.remaining_step
                ),
            ],
            milestones=[self.hparams.num_warmup_step],
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
            },
        }


In [None]:
# -----------------------------------------------------------------------------
# GNN prediction
# -----------------------------------------------------------------------------

def predict_gnn(model):

    test_set = IceCube(TEST_PATH, META_PATH, BATCHES_TEST, batch_size=400, extra=True)
    test_loader = DataLoader(test_set, batch_size=1)

    pred = None
    eid = None
    
    with torch.no_grad():
        for i, data in tqdm(enumerate(test_loader)):
            pred_xyzk = model(data.to(DEVICE))
            angles = np.concatenate([
                # +-------------+------------------------------------+----------------------+
                # | x, y, z, kp |           azimuth, zenith          |    extra features    |
                # +-------------+------------------------------------+----------------------+
                pred_xyzk.cpu(), xyz_to_angle(pred_xyzk[:, :3]).cpu(), data.extra_feat.cpu()
                # +-------------+------------------------------------+----------------------+
            ], axis=1)
            pred = angles if pred is None else np.concatenate([pred, angles])
            eid = data.eid.cpu() if eid is None else np.concatenate([eid, data.eid.cpu()])
            
    col_xyzk = ["x", "y", "z", "kappa"]
    col_angles = ["azimuth", "zenith"]
    col_norm_vec = ["ex", "ey", "ez"]
    col_dt = ["dt_15", "dt_50", "dt_85"]
    col_qv = ["qx", "qy", "qz"]
    col_xyzt = [
        "x0", "y0", "z0", "t0",
        "x1", "y1", "z1", "t1",
        "x2", "y2", "z2", "t2",
        "x3", "y3", "z3", "t3", ]
    col_unique = ["uniq_x", "uniq_y", "uniq_z"]
    col_glob_feat = ["hits", "error", "sumq", "meanq", "bratio"]
    col_extra = col_norm_vec + col_dt + col_qv + \
        col_xyzt + col_unique + col_glob_feat
    
    res = pd.DataFrame(pred, columns=col_xyzk+col_angles+col_extra)
    
    res["azimuth"] = np.remainder(res["azimuth"], 2 * np.pi)
    res["zenith"] = np.remainder(res["zenith"], 2 * np.pi)
    res["event_id"] = eid

    return res


model = Model.load_from_checkpoint(FILE_GNN)
LOGGER.info(f"loaded {FILE_GNN}")

model.eval()
model.freeze()
model.to(DEVICE)

reco_df = predict_gnn(model)

In [None]:
# -----------------------------------------------------------------------------
# GNN + Plane projection prediction
# -----------------------------------------------------------------------------

n_hat = reco_df[["x", "y", "z"]].to_numpy()
e = reco_df[["ex", "ey", "ez"]].to_numpy()
xe = np.sum(n_hat * e, axis=1, keepdims=True)

proj = n_hat - xe * e
proj /= (np.linalg.norm(proj, axis=1, keepdims=True) + 1e-8)

In [None]:
# -----------------------------------------------------------------------------
# GNN + Plane projection + BDT prediction
# -----------------------------------------------------------------------------

# reco_df inputs
reco_df["error"] = np.log10(reco_df["error"] + 1e-6)
reco_df["sumq"] = np.log10(reco_df["sumq"] + 1e-3)
reco_df["dt_50"] = np.log10(reco_df["dt_50"] + 1e-3)
reco_df["dt_85"] = np.log10(reco_df["dt_85"] + 1e-3)
reco_df["kappa"] = np.log10(reco_df["kappa"] + 1e-3)
columns = ["kappa", "zenith", "error", "sumq", "qz", "dt_50", "dt_85", "ez"]
reco = reco_df[columns].to_numpy()
xe = np.arccos(xe)

# trajectory display
col_xyzt = [
    "x0", "y0", "z0", "t0",
    "x1", "y1", "z1", "t1",
    "x2", "y2", "z2", "t2",
    "x3", "y3", "z3", "t3", ]
traj = reco_df[col_xyzt].values
traj = traj.reshape(-1, 4, 4)

v1 = 1e3 * (traj[:, 1, :3] - traj[:, 0, :3]) / (traj[:, 1, 3] - traj[:, 0, 3] + 1)[:, np.newaxis]
v2 = 1e3 * (traj[:, 2, :3] - traj[:, 1, :3]) / (traj[:, 2, 3] - traj[:, 1, 3] + 1)[:, np.newaxis]
v3 = 1e3 * (traj[:, 3, :3] - traj[:, 2, :3]) / (traj[:, 3, 3] - traj[:, 2, 3] + 1)[:, np.newaxis]

v1scale = np.linalg.norm(v1, axis=1, keepdims=True) + 1e-1
v2scale = np.linalg.norm(v2, axis=1, keepdims=True) + 1e-1
v3scale = np.linalg.norm(v3, axis=1, keepdims=True) + 1e-1

ev1 = np.sum(-v1 * e / v1scale, axis=1, keepdims=True)
ev2 = np.sum(-v2 * e / v2scale, axis=1, keepdims=True)
ev3 = np.sum(-v3 * e / v3scale, axis=1, keepdims=True)

ev1 = np.arccos(ev1)
ev2 = np.arccos(ev2)
ev3 = np.arccos(ev3)

vv12 = np.sum(v1 * v2 / v1scale / v2scale, axis=1, keepdims=True)
vv23 = np.sum(v2 * v3 / v2scale / v3scale, axis=1, keepdims=True)
vv31 = np.sum(v3 * v1 / v3scale / v1scale, axis=1, keepdims=True)

vavg = np.log10(np.mean((v1scale, v2scale, v3scale), axis=0))
evvv = np.mean((ev1, ev2, ev3), axis=0)
vvvv = np.mean((vv12, vv23, vv31), axis=0)

pos = np.mean(traj[:, :, :3], axis=1)
xyzq = reco_df[["qx", "qy", "qz"]].to_numpy()
distq = pos - xyzq
distq = np.linalg.norm(distq, axis=1, keepdims=True) + 1e-3

# load the model and predict
LOGGER.info("Loading BDT model...")
clf = pickle.load(open(FILE_BDT, 'rb'))
LOGGER.info("Predicting...")
X = np.concatenate([reco, xe, ev1, ev2, ev3, vavg, evvv, vvvv, distq], axis=1)
X[np.isnan(X)] = 0
y_hat = clf.predict(X)[:, np.newaxis]

gnn = reco_df[["azimuth", "zenith"]].to_numpy()
fit = vector2angles(proj)
bdt = (gnn * ~y_hat) + fit * y_hat

LOGGER.info(f"GNN prediction\n{gnn}")
LOGGER.info(f"Plane prediction\n{fit}")
LOGGER.info(f"BDT prediction\n{bdt}")

submit_df = pd.DataFrame(bdt, columns=["azimuth", "zenith"])

submit_df["event_id"] = reco_df.event_id.values
submit_df = submit_df.set_index("event_id")
submit_df = submit_df.sort_values(["event_id"]) # sort by event_id
submit_df.to_csv(os.path.join(OUTPUT_PATH, "submission.csv"))
print(submit_df)

In [None]:
def get_target_angles(batches):
    res = None
    file = pq.ParquetFile("/root/autodl-tmp/kaggle/icecube-neutrinos-in-deep-ice/train_meta.parquet")
    tmp = set(copy.copy(batches))
    for b in file.iter_batches(batch_size=EVENTS_PER_FILE, columns=["event_id","batch_id","azimuth","zenith"]):    
        if len(tmp) == 0:
            break
        true_df = b.to_pandas()
        batch_id = true_df.batch_id[0]
        if batch_id in tmp:      
            true_df.event_id= true_df.event_id.astype(np.int64)      
            true_df.azimuth = true_df.azimuth.astype(np.float32)
            true_df.zenith  = true_df.zenith.astype(np.float32)    
            true_df = true_df[["event_id", "batch_id", "azimuth", "zenith"]]
            res =  true_df if res is None else pd.concat((res, true_df))            
            tmp.remove(batch_id)
    return res

# ground truth
true_df = get_target_angles(BATCHES_TEST)

In [None]:
def angular_dist_score(az_true, zen_true, az_pred, zen_pred):
    """
    calculate the MAE of the angular distance between two directions.
    The two vectors are first converted to cartesian unit vectors,
    and then their scalar product is computed, which is equal to
    the cosine of the angle between the two vectors. The inverse 
    cosine (arccos) thereof is then the angle between the two input vectors
    
    Parameters:
    -----------
    
    az_true : float (or array thereof)
        true azimuth value(s) in radian
    zen_true : float (or array thereof)
        true zenith value(s) in radian
    az_pred : float (or array thereof)
        predicted azimuth value(s) in radian
    zen_pred : float (or array thereof)
        predicted zenith value(s) in radian
    
    Returns:
    --------
    
    dist : float
        mean over the angular distance(s) in radian
    """
    
    if not (np.all(np.isfinite(az_true)) and
            np.all(np.isfinite(zen_true)) and
            np.all(np.isfinite(az_pred)) and
            np.all(np.isfinite(zen_pred))):
        raise ValueError("All arguments must be finite")
    
    # pre-compute all sine and cosine values
    sa1 = np.sin(az_true)
    ca1 = np.cos(az_true)
    sz1 = np.sin(zen_true)
    cz1 = np.cos(zen_true)
    
    sa2 = np.sin(az_pred)
    ca2 = np.cos(az_pred)
    sz2 = np.sin(zen_pred)
    cz2 = np.cos(zen_pred)
    
    # scalar product of the two cartesian vectors (x = sz*ca, y = sz*sa, z = cz)
    scalar_prod = sz1*sz2*(ca1*ca2 + sa1*sa2) + (cz1*cz2)
    
    # scalar product of two unit vectors is always between -1 and 1, this is against nummerical instability
    # that might otherwise occure from the finite precision of the sine and cosine functions
    scalar_prod =  np.clip(scalar_prod, -1, 1)
    
    # convert back to an angle (in radian)
    return np.average(np.abs(np.arccos(scalar_prod)))

In [None]:
angular_dist_score(true_df.azimuth.values, true_df.zenith.values, submit_df.azimuth.values, submit_df.zenith.values)