In [None]:
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------

import os, random, copy, yaml, pickle
from time import time, sleep
from tqdm import tqdm

import numpy as np
import scipy
import pandas as pd
import matplotlib.pyplot as plt
import pyarrow.parquet as pq

# torch imports
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import random_split, IterableDataset

import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.utils import homophily
from torch_geometric.loader import DataLoader
from torch_geometric.nn import EdgeConv, knn_graph
from torch_scatter import \
    scatter_add, scatter_mean, scatter_max, scatter_min

import pytorch_lightning as pl

In [None]:
# -----------------------------------------------------------------------------
# Helper function
# -----------------------------------------------------------------------------

import logging
import logging.config
import os, re, gc, psutil

def get_logger(name, msg):
    """
    :param name: string
    :param msg: DEBUG, INFO, WARNING, ERROR
    :return: Logger() instance
    """
    level = {"DEBUG": logging.DEBUG,
             "INFO": logging.INFO,
             "WARNING": logging.WARNING,
             "ERROR": logging.ERROR}
    logging.basicConfig(level=level[msg],
                        format="== %(name)s == %(asctime)s %(levelname)s:\t%(message)s",
                        datefmt="%H:%M:%S")
    logger = logging.getLogger(name)

    return logger


def walk_dir(dirname, batch_ids):
    files = dict()
    pattern = r"_(\d+)\.parquet"

    if batch_ids is None:
        batch_ids = list()
        for base, _, names in os.walk(dirname):
            for name in names:
                match = re.findall(pattern, name)
                batch_ids.append(int(match[0]))
                files[int(match[0])] = os.path.join(base, name)
        return files, batch_ids
    
    for base, _, names in os.walk(dirname):
        selected_files = dict()
        for name in names:
            match = re.findall(pattern, name)
            if int(match[0]) in batch_ids:
                selected_files[int(match[0])] = os.path.join(base, name)
        files.update(selected_files)
    return files, batch_ids


def memory_check(logger, msg=""):
    gc.collect()
    logger.debug(f"memory usage {psutil.virtual_memory().used / 1024**3:.2f} "
                f"of {psutil.virtual_memory().total / 1024**3:.2f} GB {msg}")


In [None]:
# -----------------------------------------------------------------------------
# Basic settings
# -----------------------------------------------------------------------------

# basic
BATCH_SIZE = 200
BATCHES_TEST = list(range(11, 16)) #####################
EVENTS_PER_FILE = 200_000

# paths
BASE_PATH = "/root/autodl-tmp/kaggle/" ######################
PATH = os.path.join(BASE_PATH, "icecube-neutrinos-in-deep-ice")
MODEL_PATH = os.path.join(BASE_PATH, "input", "ice-cube-model")
OUTPUT_PATH = os.path.join(BASE_PATH, "working")
TEST_PATH = os.path.join(PATH, "test")
META_PATH = os.path.join(OUTPUT_PATH, "test_meta")

# files
FILES_TEST, BATCHES_TEST = walk_dir(TEST_PATH, BATCHES_TEST)
FILE_META = os.path.join(PATH, "test_meta.parquet")
FILE_SENSOR_GEO = os.path.join(PATH, "sensor_geometry.csv")
FILE_GNN = os.path.join(MODEL_PATH, "finetuned.ckpt")
FILE_BDT = os.path.join(MODEL_PATH, "BDT_clf.sklearn")
FILE_PARAM = os.path.join(MODEL_PATH, "parameters_local.yaml")

In [None]:
# -----------------------------------------------------------------------------
# Split meta file
# -----------------------------------------------------------------------------

# if not os.path.exists(META_PATH): 
#     os.mkdir(META_PATH)

# meta_test = pd.read_parquet(FILE_META)

# for i, df in meta_test.groupby("batch_id"):
#     print(f"processing {i} -> {df.shape}")
#     splitted = os.path.join(META_PATH, f"meta_{i}.parquet")
#     df.to_parquet(splitted)

In [None]:
# -----------------------------------------------------------------------------
# Some Logging
# -----------------------------------------------------------------------------
LOGGER = get_logger("IceCube", "DEBUG")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LOGGER.info(f"using {DEVICE}")
LOGGER.info(f"{len(FILES_TEST)} files for testing")
memory_check(LOGGER)

LOGGER.info(f"GNN model:{FILE_GNN}")
LOGGER.info(f"BDT model:{FILE_BDT}")

# BEST_FIT
BEST_FIT_VALUES = None
with open(FILE_PARAM, "r") as f:
    BEST_FIT_VALUES = yaml.full_load(f)
LOGGER.info(f"best fit values:{BEST_FIT_VALUES}")

In [None]:
# -----------------------------------------------------------------------------
# Dataset
# -----------------------------------------------------------------------------

# sensor geometry
def prepare_sensors(scale=None):
    sensors = pd.read_csv(FILE_SENSOR_GEO).astype({
        "sensor_id": np.int16,
        "x": np.float32,
        "y": np.float32,
        "z": np.float32,
    })

    if scale is not None and isinstance(scale, float):
        sensors["x"] *= scale
        sensors["y"] *= scale
        sensors["z"] *= scale

    return sensors


def angle_to_xyz(angles_b):
    az, zen = angles_b.t()
    x = torch.cos(az) * torch.sin(zen)
    y = torch.sin(az) * torch.sin(zen)
    z = torch.cos(zen)
    return torch.stack([x, y, z], dim=1)


def xyz_to_angle(xyz_b):
    x, y, z = xyz_b.t()
    az = torch.arccos(x / torch.sqrt(x**2 + y**2)) * torch.sign(y)
    zen = torch.arccos(z / torch.sqrt(x**2 + y**2 + z**2))
    return torch.stack([az, zen], dim=1)


def angular_error(xyz_pred_b, xyz_true_b):
    return torch.arccos(torch.clip_(torch.sum(xyz_pred_b * xyz_true_b, dim=1), -1, 1))


def angles2vector(df):
    df["nx"] = np.sin(df.zenith) * np.cos(df.azimuth)
    df["ny"] = np.sin(df.zenith) * np.sin(df.azimuth)
    df["nz"] = np.cos(df.zenith) 
    return df


def vector2angles(n, eps=1e-8):
    n = n / (np.linalg.norm(n, axis=1, keepdims=True) + eps)    
    azimuth = np.arctan2( n[:,1],  n[:,0])    
    azimuth[azimuth < 0] += 2*np.pi
    zenith = np.arccos( n[:,2].clip(-1,1) )                                
    return np.concatenate([azimuth[:, np.newaxis], zenith[:, np.newaxis]], axis=1)


def series2tensor(series, set_device=None):
    ret = torch.from_numpy(series.values).float()
    if set_device is not None:
        return ret.to(DEVICE)
    return ret


def solve_linear(xw, yw, zw, xxw, yyw, zzw, xyw, yzw, zxw):
    A = torch.tensor([
        [xxw, xyw, xw],
        [xyw, yyw, yw],
        [xw,  yw,  1 ],
    ])
    b = torch.tensor([
        zxw, yzw, zw
    ])
    try:
        coeff = torch.linalg.solve(A, b)
        return coeff
    except Exception:
        LOGGER.debug("linear system not solvable")
        return torch.zeros((3, ))


def plane_fit(df, k=0, kt=0, kq=0, fun=None, eps=1e-8):
    z_avg = series2tensor(df.z_avg)
    t = series2tensor(df.time)
    c = series2tensor(df.charge)
    x = series2tensor(df.x)
    y = series2tensor(df.y)
    z = series2tensor(df.z)

    # weighted by ...
    w = torch.exp(-k * torch.square(z - z_avg)) \
        * torch.exp(-kt * t) \
        * torch.pow(c, kq)

    # weighted values
    xw = (x*w); xxw = (x*x*w); xyw = (x*y*w)
    yw = (y*w); yyw = (y*y*w); yzw = (y*z*w)
    zw = (z*w); zzw = (z*z*w); zxw = (z*x*w)  

    xw = torch.sum(xw); xxw = torch.sum(xxw); xyw = torch.sum(xyw) 
    yw = torch.sum(yw); yyw = torch.sum(yyw); yzw = torch.sum(yzw) 
    zw = torch.sum(zw); zzw = torch.sum(zzw); zxw = torch.sum(zxw) 
    sumw = torch.sum(w); sumc = torch.sum(w*c); dt = torch.median(t)

    sumw += eps
    xw /= sumw; xxw /= sumw; xyw /= sumw
    yw /= sumw; yyw /= sumw; yzw /= sumw
    zw /= sumw; zzw /= sumw; zxw /= sumw

    coeff = solve_linear(xw, yw, zw, xxw, yyw, zzw, xyw, yzw, zxw)
    error = torch.sum((z - coeff[0] * x - coeff[1] * y - coeff[2]))
    error *= 1e3
    hits = w.shape[0]
    unique_x = torch.unique(x).shape[0]
    unique_y = torch.unique(y).shape[0]
    unique_z = torch.unique(z).shape[0]

    ret = torch.tensor([[coeff[0], coeff[1], -1, torch.square(error), hits, sumc, dt, unique_x, unique_y, unique_z]])
    ret[:, :3] /= torch.sqrt(coeff[0]**2 + coeff[1]**2 + 1)

    return ret


def prepare_df_for_plane(df):
    df = df.reset_index(drop=True)

    # remove auxiliary
    df = df[~df.auxiliary]

    df.charge = df.charge.astype(np.float32)
    df.charge = np.clip(df.charge, 0, 4)
    t_min = np.min(df.time)
    df.time = ((df.time - t_min) * 0.299792458e-3).astype(np.float32)
    df.x *= 1e-3; df.y *= 1e-3; df.z *= 1e-3
    
    df["qz"] = df.charge * df.z
    centre = df.groupby(["x", "y"]).agg(
        qsum = ("charge", np.sum),
        qzsum = ("qz", np.sum),
    )

    centre["z_avg"] = centre.qzsum / centre.qsum
    df = pd.merge(df, centre[["z_avg"]], on=["x", "y"])

    return df[["z_avg", "time", "charge", "x", "y", "z"]]


# Dataset
class IceCube(IterableDataset):
    def __init__(
        self, parquet_dir, meta_dir, chunk_ids,
        batch_size=200, max_pulses=200
    ):
        self.parquet_dir = parquet_dir
        self.meta_dir = meta_dir
        self.chunk_ids = chunk_ids
        self.batch_size = batch_size
        self.max_pulses = max_pulses


    def __iter__(self):
        chunk_ids = self.chunk_ids

        # Sensor data
        sensor = prepare_sensors()

        # Read each chunk and meta iteratively into memory and build mini-batch
        for c, chunk_id in enumerate(chunk_ids):
            data = pd.read_parquet(os.path.join(self.parquet_dir, f"batch_{chunk_id}.parquet"))
            meta = pd.read_parquet(os.path.join(self.meta_dir, f"meta_{chunk_id}.parquet"))


            # Take all event_ids and split them into batches
            eids = meta["event_id"].tolist()
            eids_batches = [
                eids[i : i + self.batch_size]
                for i in range(0, len(eids), self.batch_size)
            ]

            for batch_eids in eids_batches:
                batch = []

                # For each sample, extract features
                for eid in batch_eids:
                    df = data.loc[eid]
                    df = pd.merge(df, sensor, on="sensor_id")
                    # sampling of pulses if number exceeds maximum
                    if len(df) > self.max_pulses:
                        df_pass = df[~df.auxiliary]
                        df_fail = df[df.auxiliary]
                        if len(df_pass) >= self.max_pulses:
                            df = df_pass.sample(self.max_pulses)
                        else:
                            df_fail = df_fail.sample(self.max_pulses - len(df_pass))
                            df = pd.concat([df_fail, df_pass])

                    df.sort_values(["time"], inplace=True)

                    t = series2tensor(df.time)
                    c = series2tensor(df.charge)
                    a = series2tensor(df.auxiliary)
                    x = series2tensor(df.x)
                    y = series2tensor(df.y)
                    z = series2tensor(df.z)
                    feat = torch.stack([x, y, z, t, c, a], dim=1)


                    batch_data = Data(x=feat, n_pulses=len(feat), eid=torch.tensor([eid]).long())
                    coeff = plane_fit(prepare_df_for_plane(df), **BEST_FIT_VALUES)
                    setattr(batch_data, "plane", coeff)

                    batch.append(batch_data)

                yield Batch.from_data_list(batch)

            del data
            del meta
            gc.collect()

In [None]:
# -----------------------------------------------------------------------------
# GraphNet GNN model
# -----------------------------------------------------------------------------

class MLP(nn.Sequential):
    def __init__(self, feats):
        layers = []
        for i in range(1, len(feats)):
            layers.append(nn.Linear(feats[i - 1], feats[i]))
            layers.append(nn.LeakyReLU())
        super().__init__(*layers)


class Model(pl.LightningModule):
    def __init__(
        self, max_lr=1e-3, 
        num_warmup_step=1_000,
        remaining_step=1_000,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.conv0 = EdgeConv(MLP([17 * 2, 128, 256]), aggr="add")
        self.conv1 = EdgeConv(MLP([512, 336, 256]), aggr="add")
        self.conv2 = EdgeConv(MLP([512, 336, 256]), aggr="add")
        self.conv3 = EdgeConv(MLP([512, 336, 256]), aggr="add")
        self.post = MLP([1024 + 17, 336, 256])
        self.readout = MLP([768, 128])
        self.pred = nn.Linear(128, 3)

    def forward(self, data: Batch):
        vert_feat = data.x
        batch = data.batch

        # x, y, z, t, c, a
        # 0  1  2  3  4  5
        vert_feat[:, 0] /= 500.0  # x
        vert_feat[:, 1] /= 500.0  # y
        vert_feat[:, 2] /= 500.0  # z
        vert_feat[:, 3] = (vert_feat[:, 3] - 1.0e04) / 3.0e4  # time
        vert_feat[:, 4] = torch.log10(vert_feat[:, 4]) / 3.0  # charge

        edge_index = knn_graph(vert_feat[:, :3], 8, batch)

        # Construct global features
        hx = homophily(edge_index, vert_feat[:, 0], batch).reshape(-1, 1)
        hy = homophily(edge_index, vert_feat[:, 1], batch).reshape(-1, 1)
        hz = homophily(edge_index, vert_feat[:, 2], batch).reshape(-1, 1)
        ht = homophily(edge_index, vert_feat[:, 3], batch).reshape(-1, 1)
        means = scatter_mean(vert_feat, batch, dim=0)
        n_p = torch.log10(data.n_pulses).reshape(-1, 1)
        global_feats = torch.cat([means, hx, hy, hz, ht, n_p], dim=1)  # [B, 11]

        # Distribute global_feats to each vertex
        _, cnts = torch.unique_consecutive(batch, return_counts=True)
        global_feats = torch.repeat_interleave(global_feats, cnts, dim=0)
        vert_feat = torch.cat((vert_feat, global_feats), dim=1)

        # Convolutions
        feats = [vert_feat]
        # Conv 0
        vert_feat = self.conv0(vert_feat, edge_index)
        feats.append(vert_feat)
        # Conv 1
        edge_index = knn_graph(vert_feat[:, :3], k=8, batch=batch)
        vert_feat = self.conv1(vert_feat, edge_index)
        feats.append(vert_feat)
        # Conv 2
        edge_index = knn_graph(vert_feat[:, :3], k=8, batch=batch)
        vert_feat = self.conv2(vert_feat, edge_index)
        feats.append(vert_feat)
        # Conv 3
        edge_index = knn_graph(vert_feat[:, :3], k=8, batch=batch)
        vert_feat = self.conv3(vert_feat, edge_index)
        feats.append(vert_feat)

        # Postprocessing
        post_inp = torch.cat(feats, dim=1)
        post_out = self.post(post_inp)

        # Readout
        readout_inp = torch.cat(
            [
                scatter_min(post_out, batch, dim=0)[0],
                scatter_max(post_out, batch, dim=0)[0],
                scatter_mean(post_out, batch, dim=0),
            ],
            dim=1,
        )
        readout_out = self.readout(readout_inp)

        # Predict
        pred = self.pred(readout_out)
        kappa = pred.norm(dim=1, p=2) + 1e-8
        pred_x = pred[:, 0] / kappa
        pred_y = pred[:, 1] / kappa
        pred_z = pred[:, 2] / kappa
        pred = torch.stack([pred_x, pred_y, pred_z, kappa], dim=1)

        return pred

    def train_or_valid_step(self, data, prefix):
        pred_xyzk = self.forward(data)  # [B, 4]
        true_xyz = data.gt.view(-1, 3)  # [B, 3]
        loss = VonMisesFisher3DLoss()(pred_xyzk, true_xyz).mean()
        error = angular_error(pred_xyzk[:, :3], true_xyz).mean()
        self.log(f"loss-{prefix}", loss, batch_size=len(true_xyz), 
            on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log(f"error-{prefix}", error, batch_size=len(true_xyz), 
            on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss

    def training_step(self, data, _):
        return self.train_or_valid_step(data, "train")

    def validation_step(self, data, _):
        return self.train_or_valid_step(data, "valid")
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.max_lr)
        scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[
                torch.optim.lr_scheduler.LinearLR(
                    optimizer, 1e-2, 1, self.hparams.num_warmup_step
                ),
                torch.optim.lr_scheduler.LinearLR(
                    optimizer, 1, 1e-3, self.hparams.remaining_step
                ),
            ],
            milestones=[self.hparams.num_warmup_step],
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
            },
        }


In [None]:
# -----------------------------------------------------------------------------
# GNN prediction
# -----------------------------------------------------------------------------

def predict_gnn(model):

    test_set = IceCube(TEST_PATH, META_PATH, BATCHES_TEST, batch_size=BATCH_SIZE)
    test_loader = DataLoader(test_set, batch_size=1)

    pred = None
    with torch.no_grad():
        for i, data in tqdm(enumerate(test_loader)):
            xyzk = model(data.to(DEVICE))
            x = np.concatenate([
                xyzk[:, :3].cpu(), xyz_to_angle(xyzk[:, :3]).cpu(), data.plane.cpu()
            ], axis=1)
            pred = x if pred is None else np.concatenate([pred, x]) 
    
    pred = pd.DataFrame(pred, 
        columns=["x", "y", "z", "azimuth", "zenith",
                 "ex", "ey", "ez", "fit_error", "hits", "sumc", "dt", 
                 "unique_x", "unique_y", "unique_z"])
    
    pred["azimuth"] = np.remainder(pred["azimuth"], 2 * np.pi)
    pred["zenith"] = np.remainder(pred["zenith"], 2 * np.pi)

    return pred


model = Model.load_from_checkpoint(FILE_GNN)
LOGGER.info(f"loaded {FILE_GNN}")

model.eval()
model.freeze()
model.to(DEVICE)

reco_df = predict_gnn(model)

In [None]:
# -----------------------------------------------------------------------------
# GNN + Plane fit prediction
# -----------------------------------------------------------------------------

n_hat = reco_df[["x", "y", "z"]].to_numpy()
e = reco_df[["ex", "ey", "ez"]].to_numpy()
xe = np.sum(n_hat * e, axis=1)

proj = n_hat - xe[:, np.newaxis] * e
proj /= (np.linalg.norm(proj, axis=1, keepdims=True) + 1e-8)

In [None]:
# -----------------------------------------------------------------------------
# GNN + Plane fit + BDT prediction
# -----------------------------------------------------------------------------

# reco_df inputs
reco = reco_df[["fit_error", "sumc", "hits", "zenith", "ez", "dt", "unique_x", "unique_z"]].to_numpy()
reco[:, 0] = np.log10(reco[:, 0] + 1e-8)
reco[:, 1] = np.log10(reco[:, 1] + 1e-8)

# load the model and predict
LOGGER.info("Loading BDT model...")
clf = pickle.load(open(FILE_BDT, 'rb'))
LOGGER.info("Predicting...")
X = np.concatenate([reco, np.abs(xe[:, np.newaxis])], axis=1)
y_hat = clf.predict(X)[:, np.newaxis]
y_hat[0][0] = True

gnn = reco_df[["azimuth", "zenith"]].to_numpy()
fit = vector2angles(proj)
LOGGER.info(f"GNN prediction\n{gnn}")
LOGGER.info(f"Plane fit prediction\n{fit}")
LOGGER.info(f"BDT prediction\n{(gnn * ~y_hat) + fit * y_hat}")