In [None]:
#!pip install dgl==1.1.1
!pip install dgl==2.0.0 -f https://data.dgl.ai/wheels/cu121/repo.html
#!pip install faiss_cpu
!pip install faiss-gpu-cu12

In [None]:
!python -c "import dgl;print(dgl.__version__)"

In [None]:
!mkdir ./p

In [None]:
!mkdir /kaggle/working/p/utils
!mkdir /kaggle/working/p/scripts
!mkdir /kaggle/working/p/models
!mkdir /kaggle/working/p/datasetml

In [None]:
!git clone https://github.com/yjxiong/clustering-benchmark.git p/clustering-benchmark

## CARTELLA UTILS

In [None]:
%%writefile ./p/utils/misc.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""

import json
import os
import pickle
import random
import time

import numpy as np


class TextColors:
    HEADER = "\033[35m"
    OKBLUE = "\033[34m"
    OKGREEN = "\033[32m"
    WARNING = "\033[33m"
    FATAL = "\033[31m"
    ENDC = "\033[0m"
    BOLD = "\033[1m"
    UNDERLINE = "\033[4m"


class Timer:
    def __init__(self, name="task", verbose=True):
        self.name = name
        self.verbose = verbose

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.verbose:
            print(
                "[Time] {} consumes {:.4f} s".format(
                    self.name, time.time() - self.start
                )
            )
        return exc_type is None


def set_random_seed(seed, cuda=False):
    import torch

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed_all(seed)


def l2norm(vec):
    vec /= np.linalg.norm(vec, axis=1).reshape(-1, 1)
    return vec


def is_l2norm(features, size):
    rand_i = random.choice(range(size))
    norm_ = np.dot(features[rand_i, :], features[rand_i, :])
    return abs(norm_ - 1) < 1e-6


def is_spmat_eq(a, b):
    return (a != b).nnz == 0


def aggregate(features, adj, times):
    dtype = features.dtype
    for i in range(times):
        features = adj * features
    return features.astype(dtype)


def mkdir_if_no_exists(path, subdirs=[""], is_folder=False):
    if path == "":
        return
    for sd in subdirs:
        if sd != "" or is_folder:
            d = os.path.dirname(os.path.join(path, sd))
        else:
            d = os.path.dirname(path)
        if not os.path.exists(d):
            os.makedirs(d)


def stop_iterating(
    current_l,
    total_l,
    early_stop,
    num_edges_add_this_level,
    num_edges_add_last_level,
    knn_k,
):
    # Stopping rule 1: run all levels
    if current_l == total_l - 1:
        return True
    # Stopping rule 2: no new edges
    if num_edges_add_this_level == 0:
        return True
    # Stopping rule 3: early stopping, two levels start to produce similar numbers of edges
    if (
        early_stop
        and float(num_edges_add_last_level) / num_edges_add_this_level
        < knn_k - 1
    ):
        return True
    return False

In [None]:
%%writefile ./p/utils/metrics.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""

from __future__ import division

import numpy as np
from sklearn.metrics import precision_score, recall_score
from sklearn.metrics.cluster import (
    contingency_matrix,
    normalized_mutual_info_score,
)

__all__ = ["pairwise", "bcubed", "nmi", "precision", "recall", "accuracy"]


def _check(gt_labels, pred_labels):
    if gt_labels.ndim != 1:
        raise ValueError(
            "gt_labels must be 1D: shape is %r" % (gt_labels.shape,)
        )
    if pred_labels.ndim != 1:
        raise ValueError(
            "pred_labels must be 1D: shape is %r" % (pred_labels.shape,)
        )
    if gt_labels.shape != pred_labels.shape:
        raise ValueError(
            "gt_labels and pred_labels must have same size, got %d and %d"
            % (gt_labels.shape[0], pred_labels.shape[0])
        )
    return gt_labels, pred_labels


def _get_lb2idxs(labels):
    lb2idxs = {}
    for idx, lb in enumerate(labels):
        if lb not in lb2idxs:
            lb2idxs[lb] = []
        lb2idxs[lb].append(idx)
    return lb2idxs


def _compute_fscore(pre, rec):
    return 2.0 * pre * rec / (pre + rec)


def fowlkes_mallows_score(gt_labels, pred_labels, sparse=True):
    """The original function is from `sklearn.metrics.fowlkes_mallows_score`.
    We output the pairwise precision, pairwise recall and F-measure,
    instead of calculating the geometry mean of precision and recall.
    """
    (n_samples,) = gt_labels.shape

    c = contingency_matrix(gt_labels, pred_labels, sparse=sparse)
    tk = np.dot(c.data, c.data) - n_samples
    pk = np.sum(np.asarray(c.sum(axis=0)).ravel() ** 2) - n_samples
    qk = np.sum(np.asarray(c.sum(axis=1)).ravel() ** 2) - n_samples

    avg_pre = tk / pk
    avg_rec = tk / qk
    fscore = _compute_fscore(avg_pre, avg_rec)

    return avg_pre, avg_rec, fscore


def pairwise(gt_labels, pred_labels, sparse=True):
    _check(gt_labels, pred_labels)
    return fowlkes_mallows_score(gt_labels, pred_labels, sparse)


def bcubed(gt_labels, pred_labels):
    _check(gt_labels, pred_labels)

    gt_lb2idxs = _get_lb2idxs(gt_labels)
    pred_lb2idxs = _get_lb2idxs(pred_labels)

    num_lbs = len(gt_lb2idxs)
    pre = np.zeros(num_lbs)
    rec = np.zeros(num_lbs)
    gt_num = np.zeros(num_lbs)

    for i, gt_idxs in enumerate(gt_lb2idxs.values()):
        all_pred_lbs = np.unique(pred_labels[gt_idxs])
        gt_num[i] = len(gt_idxs)
        for pred_lb in all_pred_lbs:
            pred_idxs = pred_lb2idxs[pred_lb]
            n = 1.0 * np.intersect1d(gt_idxs, pred_idxs).size
            pre[i] += n**2 / len(pred_idxs)
            rec[i] += n**2 / gt_num[i]

    gt_num = gt_num.sum()
    avg_pre = pre.sum() / gt_num
    avg_rec = rec.sum() / gt_num
    fscore = _compute_fscore(avg_pre, avg_rec)

    return avg_pre, avg_rec, fscore


def nmi(gt_labels, pred_labels):
    return normalized_mutual_info_score(pred_labels, gt_labels)


def precision(gt_labels, pred_labels):
    return precision_score(gt_labels, pred_labels)


def recall(gt_labels, pred_labels):
    return recall_score(gt_labels, pred_labels)


def accuracy(gt_labels, pred_labels):
    return np.mean(gt_labels == pred_labels)

In [None]:
%%writefile ./p/utils/knn.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""

import math
import multiprocessing as mp
import os

import numpy as np
from tqdm import tqdm
#from utils import Timer

from .faiss_search import faiss_search_knn

__all__ = [
    "knn_faiss",
    "knn_faiss_gpu",
    "fast_knns2spmat",
    "build_knns",
    "knns2ordered_nbrs",
]


def knns2ordered_nbrs(knns, sort=True):
    if isinstance(knns, list):
        knns = np.array(knns)
    nbrs = knns[:, 0, :].astype(np.int32)
    dists = knns[:, 1, :]
    if sort:
        # sort dists from low to high
        nb_idx = np.argsort(dists, axis=1)
        idxs = np.arange(nb_idx.shape[0]).reshape(-1, 1)
        dists = dists[idxs, nb_idx]
        nbrs = nbrs[idxs, nb_idx]
    return dists, nbrs


def fast_knns2spmat(knns, k, th_sim=0, use_sim=True, fill_value=None):
    # convert knns to symmetric sparse matrix
    from scipy.sparse import csr_matrix

    eps = 1e-5
    n = len(knns)
    if isinstance(knns, list):
        knns = np.array(knns)
    if len(knns.shape) == 2:
        # knns saved by hnsw has different shape
        n = len(knns)
        ndarr = np.ones([n, 2, k])
        ndarr[:, 0, :] = -1  # assign unknown dist to 1 and nbr to -1
        for i, (nbr, dist) in enumerate(knns):
            size = len(nbr)
            assert size == len(dist)
            ndarr[i, 0, :size] = nbr[:size]
            ndarr[i, 1, :size] = dist[:size]
        knns = ndarr
    nbrs = knns[:, 0, :]
    dists = knns[:, 1, :]
    assert (
        -eps <= dists.min() <= dists.max() <= 1 + eps
    ), "min: {}, max: {}".format(dists.min(), dists.max())
    if use_sim:
        sims = 1.0 - dists
    else:
        sims = dists
    if fill_value is not None:
        print("[fast_knns2spmat] edge fill value:", fill_value)
        sims.fill(fill_value)
    row, col = np.where(sims >= th_sim)
    # remove the self-loop
    idxs = np.where(row != nbrs[row, col])
    row = row[idxs]
    col = col[idxs]
    data = sims[row, col]
    col = nbrs[row, col]  # convert to absolute column
    assert len(row) == len(col) == len(data)
    spmat = csr_matrix((data, (row, col)), shape=(n, n))
    return spmat


def build_knns(feats, k, knn_method, dump=True):
    from utils import Timer

    with Timer("build index"):
        if knn_method == "faiss":
            index = knn_faiss(feats, k, omp_num_threads=None)
        elif knn_method == "faiss_gpu":
            index = knn_faiss_gpu(feats, k)
        else:
            raise KeyError(
                "Only support faiss and faiss_gpu currently ({}).".format(
                    knn_method
                )
            )
        knns = index.get_knns()
    return knns


class knn:
    def __init__(self, feats, k, index_path="", verbose=True):
        pass

    def filter_by_th(self, i):
        th_nbrs = []
        th_dists = []
        nbrs, dists = self.knns[i]
        for n, dist in zip(nbrs, dists):
            if 1 - dist < self.th:
                continue
            th_nbrs.append(n)
            th_dists.append(dist)
        th_nbrs = np.array(th_nbrs)
        th_dists = np.array(th_dists)
        return (th_nbrs, th_dists)

    def get_knns(self, th=None):
        from utils import Timer

        if th is None or th <= 0.0:
            return self.knns
        # TODO: optimize the filtering process by numpy
        # nproc = mp.cpu_count()
        nproc = 1
        with Timer(
            "filter edges by th {} (CPU={})".format(th, nproc), self.verbose
        ):
            self.th = th
            self.th_knns = []
            tot = len(self.knns)
            if nproc > 1:
                pool = mp.Pool(nproc)
                th_knns = list(
                    tqdm(pool.imap(self.filter_by_th, range(tot)), total=tot)
                )
                pool.close()
            else:
                th_knns = [self.filter_by_th(i) for i in range(tot)]
            return th_knns


class knn_faiss(knn):
    def __init__(
        self,
        feats,
        k,
        nprobe=128,
        omp_num_threads=None,
        rebuild_index=True,
        verbose=True,
        **kwargs
    ):
        import faiss
        from utils import Timer

        if omp_num_threads is not None:
            faiss.omp_set_num_threads(omp_num_threads)
        self.verbose = verbose
        with Timer("[faiss] build index", verbose):
            feats = feats.astype("float32")
            size, dim = feats.shape
            index = faiss.IndexFlatIP(dim)
            index.add(feats)
        with Timer("[faiss] query topk {}".format(k), verbose):
            sims, nbrs = index.search(feats, k=k)
            self.knns = [
                (
                    np.array(nbr, dtype=np.int32),
                    1 - np.array(sim, dtype=np.float32),
                )
                for nbr, sim in zip(nbrs, sims)
            ]


class knn_faiss_gpu(knn):
    def __init__(
        self,
        feats,
        k,
        nprobe=128,
        num_process=4,
        is_precise=True,
        sort=True,
        verbose=True,
        **kwargs
    ):
        from utils import Timer

        with Timer("[faiss_gpu] query topk {}".format(k), verbose):
            dists, nbrs = faiss_search_knn(
                feats,
                k=k,
                nprobe=nprobe,
                num_process=num_process,
                is_precise=is_precise,
                sort=sort,
                verbose=verbose,
            )

            self.knns = [
                (
                    np.array(nbr, dtype=np.int32),
                    np.array(dist, dtype=np.float32),
                )
                for nbr, dist in zip(nbrs, dists)
            ]

In [None]:
%%writefile ./p/utils/faiss_search.py
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""
import gc

from tqdm import tqdm

from .faiss_gpu import faiss_search_approx_knn

__all__ = ["faiss_search_knn"]


def precise_dist(feat, nbrs, num_process=4, sort=True, verbose=False):
    import torch

    feat_share = torch.from_numpy(feat).share_memory_()
    nbrs_share = torch.from_numpy(nbrs).share_memory_()
    dist_share = torch.zeros_like(nbrs_share).float().share_memory_()

    precise_dist_share_mem(
        feat_share,
        nbrs_share,
        dist_share,
        num_process=num_process,
        sort=sort,
        verbose=verbose,
    )

    del feat_share
    gc.collect()
    return dist_share.numpy(), nbrs_share.numpy()


def precise_dist_share_mem(
    feat,
    nbrs,
    dist,
    num_process=16,
    sort=True,
    process_unit=4000,
    verbose=False,
):
    from torch import multiprocessing as mp

    num, _ = feat.shape
    num_per_proc = int(num / num_process) + 1

    for pi in range(num_process):
        sid = pi * num_per_proc
        eid = min(sid + num_per_proc, num)

        kwargs = {
            "feat": feat,
            "nbrs": nbrs,
            "dist": dist,
            "sid": sid,
            "eid": eid,
            "sort": sort,
            "process_unit": process_unit,
            "verbose": verbose,
        }
        bmm(**kwargs)


def bmm(
    feat, nbrs, dist, sid, eid, sort=True, process_unit=4000, verbose=False
):
    import torch

    _, cols = dist.shape
    batch_sim = torch.zeros((eid - sid, cols), dtype=torch.float32)
    for s in tqdm(
        range(sid, eid, process_unit), desc="bmm", disable=not verbose
    ):
        e = min(eid, s + process_unit)
        query = feat[s:e].unsqueeze(1)
        gallery = feat[nbrs[s:e]].permute(0, 2, 1)
        batch_sim[s - sid : e - sid] = torch.clamp(
            torch.bmm(query, gallery).view(-1, cols), 0.0, 1.0
        )

    if sort:
        sort_unit = int(1e6)
        batch_nbr = nbrs[sid:eid]
        for s in range(0, batch_sim.shape[0], sort_unit):
            e = min(s + sort_unit, eid)
            batch_sim[s:e], indices = torch.sort(
                batch_sim[s:e], descending=True
            )
            batch_nbr[s:e] = torch.gather(batch_nbr[s:e], 1, indices)
        nbrs[sid:eid] = batch_nbr
    dist[sid:eid] = 1.0 - batch_sim


def faiss_search_knn(
    feat,
    k,
    nprobe=128,
    num_process=4,
    is_precise=True,
    sort=True,
    verbose=False,
):
    dists, nbrs = faiss_search_approx_knn(
        query=feat, target=feat, k=k, nprobe=nprobe, verbose=verbose
    )

    if is_precise:
        print("compute precise dist among k={} nearest neighbors".format(k))
        dists, nbrs = precise_dist(
            feat, nbrs, num_process=num_process, sort=sort, verbose=verbose
        )

    return dists, nbrs

In [None]:
%%writefile ./p/utils/faiss_gpu.py
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""
import gc
import os

import faiss
import numpy as np
from tqdm import tqdm

__all__ = ["faiss_search_approx_knn"]


class faiss_index_wrapper:
    def __init__(
        self,
        target,
        nprobe=128,
        index_factory_str=None,
        verbose=False,
        mode="proxy",
        using_gpu=True,
    ):
        self._res_list = []

        num_gpu = faiss.get_num_gpus()
        print("[faiss gpu] #GPU: {}".format(num_gpu))

        size, dim = target.shape
        assert size > 0, "size: {}".format(size)
        index_factory_str = (
            "IVF{},PQ{}".format(min(8192, 16 * round(np.sqrt(size))), 32)
            if index_factory_str is None
            else index_factory_str
        )
        cpu_index = faiss.index_factory(dim, index_factory_str)
        cpu_index.nprobe = nprobe

        if mode == "proxy":
            co = faiss.GpuClonerOptions()
            co.useFloat16 = True
            co.usePrecomputed = False

            index = faiss.IndexProxy()
            for i in range(num_gpu):
                res = faiss.StandardGpuResources()
                self._res_list.append(res)
                sub_index = (
                    faiss.index_cpu_to_gpu(res, i, cpu_index, co)
                    if using_gpu
                    else cpu_index
                )
                index.addIndex(sub_index)
        elif mode == "shard":
            co = faiss.GpuMultipleClonerOptions()
            co.useFloat16 = True
            co.usePrecomputed = False
            co.shard = True
            index = faiss.index_cpu_to_all_gpus(cpu_index, co, ngpu=num_gpu)
        else:
            raise KeyError("Unknown index mode")

        index = faiss.IndexIDMap(index)
        index.verbose = verbose

        # get nlist to decide how many samples used for training
        nlist = int(
            float(
                [
                    item
                    for item in index_factory_str.split(",")
                    if "IVF" in item
                ][0].replace("IVF", "")
            )
        )

        # training
        if not index.is_trained:
            indexes_sample_for_train = np.random.randint(0, size, nlist * 256)
            index.train(target[indexes_sample_for_train])

        # add with ids
        target_ids = np.arange(0, size)
        index.add_with_ids(target, target_ids)
        self.index = index

    def search(self, *args, **kargs):
        return self.index.search(*args, **kargs)

    def __del__(self):
        self.index.reset()
        del self.index
        for res in self._res_list:
            del res


def batch_search(index, query, k, bs, verbose=False):
    n = len(query)
    dists = np.zeros((n, k), dtype=np.float32)
    nbrs = np.zeros((n, k), dtype=np.int64)

    for sid in tqdm(
        range(0, n, bs), desc="faiss searching...", disable=not verbose
    ):
        eid = min(n, sid + bs)
        dists[sid:eid], nbrs[sid:eid] = index.search(query[sid:eid], k)
    return dists, nbrs


def faiss_search_approx_knn(
    query,
    target,
    k,
    nprobe=128,
    bs=int(1e6),
    index_factory_str=None,
    verbose=False,
):
    index = faiss_index_wrapper(
        target,
        nprobe=nprobe,
        index_factory_str=index_factory_str,
        verbose=verbose,
    )
    dists, nbrs = batch_search(index, query, k=k, bs=bs, verbose=verbose)

    del index
    gc.collect()
    return dists, nbrs

In [None]:
%%writefile ./p/utils/density.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""

from itertools import groupby

import numpy as np
import torch
from tqdm import tqdm

__all__ = [
    "density_estimation",
    "density_to_peaks",
    "density_to_peaks_vectorize",
]


def density_estimation(dists, nbrs, labels, **kwargs):
    """use supervised density defined on neigborhood"""
    num, k_knn = dists.shape
    conf = np.ones((num,), dtype=np.float32)
    ind_array = labels[nbrs] == np.expand_dims(labels, 1).repeat(k_knn, 1)
    pos = ((1 - dists[:, 1:]) * ind_array[:, 1:]).sum(1)
    neg = ((1 - dists[:, 1:]) * (1 - ind_array[:, 1:])).sum(1)
    conf = (pos - neg) * conf
    conf /= k_knn - 1
    return conf


def density_to_peaks_vectorize(dists, nbrs, density, max_conn=1, name=""):
    # just calculate 1 connectivity
    assert dists.shape[0] == density.shape[0]
    assert dists.shape == nbrs.shape

    num, k = dists.shape

    if name == "gcn_feat":
        include_mask = nbrs != np.arange(0, num).reshape(-1, 1)
        secondary_mask = (
            np.sum(include_mask, axis=1) == k
        )  # TODO: the condition == k should not happen as distance to the node self should be smallest, check for numerical stability; TODO: make top M instead of only supporting top 1
        include_mask[secondary_mask, -1] = False
        nbrs_exclude_self = nbrs[include_mask].reshape(-1, k - 1)  # (V, 79)
        dists_exclude_self = dists[include_mask].reshape(-1, k - 1)  # (V, 79)
    else:
        include_mask = nbrs != np.arange(0, num).reshape(-1, 1)
        nbrs_exclude_self = nbrs[include_mask].reshape(-1, k - 1)  # (V, 79)
        dists_exclude_self = dists[include_mask].reshape(-1, k - 1)  # (V, 79)

    compare_map = density[nbrs_exclude_self] > density.reshape(-1, 1)
    peak_index = np.argmax(np.where(compare_map, 1, 0), axis=1)  # (V,)
    compare_map_sum = np.sum(compare_map.cpu().data.numpy(), axis=1)  # (V,)

    dist2peak = {
        i: []
        if compare_map_sum[i] == 0
        else [dists_exclude_self[i, peak_index[i]]]
        for i in range(num)
    }
    peaks = {
        i: []
        if compare_map_sum[i] == 0
        else [nbrs_exclude_self[i, peak_index[i]]]
        for i in range(num)
    }

    return dist2peak, peaks


def density_to_peaks(dists, nbrs, density, max_conn=1, sort="dist"):
    # Note that dists has been sorted in ascending order
    assert dists.shape[0] == density.shape[0]
    assert dists.shape == nbrs.shape

    num, _ = dists.shape
    dist2peak = {i: [] for i in range(num)}
    peaks = {i: [] for i in range(num)}

    for i, nbr in tqdm(enumerate(nbrs)):
        nbr_conf = density[nbr]
        for j, c in enumerate(nbr_conf):
            nbr_idx = nbr[j]
            if i == nbr_idx or c <= density[i]:
                continue
            dist2peak[i].append(dists[i, j])
            peaks[i].append(nbr_idx)
            if len(dist2peak[i]) >= max_conn:
                break

    return dist2peak, peaks

In [None]:
%%writefile ./p/utils/deduce.py
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""
import dgl
import numpy as np
import torch
from sklearn import mixture

from .density import density_to_peaks, density_to_peaks_vectorize

__all__ = [
    "peaks_to_labels",
    "edge_to_connected_graph",
    "decode",
    "build_next_level",
    "build_next_level_mod6",
    "build_next_level_mod7",
]


def _find_parent(parent, u):
    idx = []
    # parent is a fixed point
    while u != parent[u]:
        idx.append(u)
        u = parent[u]
    for i in idx:
        parent[i] = u
    return u


def edge_to_connected_graph(edges, num):
    parent = list(range(num))
    for u, v in edges:
        p_u = _find_parent(parent, u)
        p_v = _find_parent(parent, v)
        parent[p_u] = p_v

    for i in range(num):
        parent[i] = _find_parent(parent, i)
    remap = {}
    uf = np.unique(np.array(parent))
    for i, f in enumerate(uf):
        remap[f] = i
    cluster_id = np.array([remap[f] for f in parent])
    return cluster_id


def peaks_to_edges(peaks, dist2peak, tau):
    edges = []
    for src in peaks:
        dsts = peaks[src]
        dists = dist2peak[src]
        for dst, dist in zip(dsts, dists):
            if src == dst or dist >= 1 - tau:
                continue
            edges.append([src, dst])
    return edges


def peaks_to_labels(peaks, dist2peak, tau, inst_num):
    edges = peaks_to_edges(peaks, dist2peak, tau)
    pred_labels = edge_to_connected_graph(edges, inst_num)
    return pred_labels, edges


def get_dists(g, nbrs, use_gt):
    k = nbrs.shape[1]
    src_id = nbrs[:, 1:].reshape(-1)
    dst_id = nbrs[:, 0].repeat(k - 1)
    eids = g.edge_ids(src_id, dst_id)
    if use_gt:
        new_dists = (
            (1 - g.edata["labels_edge"][eids]).reshape(-1, k - 1).float()
        )
    else:
        new_dists = g.edata["prob_conn"][eids, 0].reshape(-1, k - 1)
    ind = torch.argsort(new_dists, 1)
    offset = torch.LongTensor(
        (nbrs[:, 0] * (k - 1)).repeat(k - 1).reshape(-1, k - 1)
    ).to(g.device)
    ind = ind + offset
    nbrs = torch.LongTensor(nbrs).to(g.device)
    new_nbrs = torch.take(nbrs[:, 1:], ind)
    new_dists = torch.cat(
        [torch.zeros((new_dists.shape[0], 1)).to(g.device), new_dists], dim=1
    )
    new_nbrs = torch.cat(
        [torch.arange(new_nbrs.shape[0]).view(-1, 1).to(g.device), new_nbrs],
        dim=1,
    )
    return new_nbrs.cpu().detach().numpy(), new_dists.cpu().detach().numpy()


def get_edge_dist(g, threshold):
    if threshold == "prob":
        return g.edata["prob_conn"][:, 0]
    return 1 - g.edata["raw_affine"]


def tree_generation(ng):
    ng.ndata["keep_eid"] = torch.zeros(ng.num_nodes()).long() - 1

    def message_func(edges):
        return {"mval": edges.data["edge_dist"], "meid": edges.data[dgl.EID]}

    def reduce_func(nodes):
        ind = torch.min(nodes.mailbox["mval"], dim=1)[1]
        keep_eid = nodes.mailbox["meid"].gather(1, ind.view(-1, 1))
        return {"keep_eid": keep_eid[:, 0]}

    node_order = dgl.traversal.topological_nodes_generator(ng)
    ng.prop_nodes(node_order, message_func, reduce_func)
    eids = ng.ndata["keep_eid"]
    eids = eids[eids > -1]
    edges = ng.find_edges(eids)
    treeg = dgl.graph(edges, num_nodes=ng.num_nodes())
    return treeg


def peak_propogation(treeg):
    treeg.ndata["pred_labels"] = torch.zeros(treeg.num_nodes()).long() - 1
    peaks = torch.where(treeg.in_degrees() == 0)[0].cpu().numpy()
    treeg.ndata["pred_labels"][peaks] = torch.arange(peaks.shape[0])

    def message_func(edges):
        return {"mlb": edges.src["pred_labels"]}

    def reduce_func(nodes):
        return {"pred_labels": nodes.mailbox["mlb"][:, 0]}

    node_order = dgl.traversal.topological_nodes_generator(treeg)
    treeg.prop_nodes(node_order, message_func, reduce_func)
    pred_labels = treeg.ndata["pred_labels"].cpu().numpy()
    return peaks, pred_labels


def decode(
    g,
    tau,
    threshold,
    use_gt,
    ids=None,
    global_edges=None,
    global_num_nodes=None,
    global_peaks=None,
):
    # Edge filtering with tau and density
    den_key = "density" if use_gt else "pred_den"
    g = g.local_var()
    g.edata["edge_dist"] = get_edge_dist(g, threshold)
    g.apply_edges(
        lambda edges: {
            "keep": (edges.src[den_key] > edges.dst[den_key]).long()
            * (edges.data["edge_dist"] < 1 - tau).long()
        }
    )
    eids = torch.where(g.edata["keep"] == 0)[0]
    ng = dgl.remove_edges(g, eids)

    # Tree generation
    ng.edata[dgl.EID] = torch.arange(ng.num_edges())
    treeg = tree_generation(ng)
    # Label propogation
    peaks, pred_labels = peak_propogation(treeg)

    if ids is None:
        return pred_labels, peaks

    # Merge with previous layers
    src, dst = treeg.edges()
    new_global_edges = (
        global_edges[0] + ids[src.numpy()].tolist(),
        global_edges[1] + ids[dst.numpy()].tolist(),
    )
    global_treeg = dgl.graph(new_global_edges, num_nodes=global_num_nodes)
    global_peaks, global_pred_labels = peak_propogation(global_treeg)
    return (
        pred_labels,
        peaks,
        new_global_edges,
        global_pred_labels,
        global_peaks,
    )


def build_next_level(
    features, labels, peaks, global_features, global_pred_labels, global_peaks
):
    global_peak_to_label = global_pred_labels[global_peaks]
    global_label_to_peak = np.zeros_like(global_peak_to_label)
    for i, pl in enumerate(global_peak_to_label):
        global_label_to_peak[pl] = i
    cluster_ind = np.split(
        np.argsort(global_pred_labels),
        np.unique(np.sort(global_pred_labels), return_index=True)[1][1:],
    )
    cluster_features = np.zeros((len(peaks), global_features.shape[1]))
    for pi in range(len(peaks)):
        cluster_features[global_label_to_peak[pi], :] = np.mean(
            global_features[cluster_ind[pi], :], axis=0
        )
    features = features[peaks]
    labels = labels[peaks]
    return features, labels, cluster_features

################################ MODIFICA 6
def build_next_level_mod6(
    features, labels, peaks, global_features, global_pred_labels, global_peaks
):
    global_peak_to_label = global_pred_labels[global_peaks]
    global_label_to_peak = np.zeros_like(global_peak_to_label)
    for i, pl in enumerate(global_peak_to_label):
        global_label_to_peak[pl] = i
    
    # Invece di aggregare tramite media, usiamo direttamente le features dei peak nodes
    cluster_features = np.zeros((len(peaks), global_features.shape[1]))
    for pi in range(len(peaks)):
        peak_node_idx = global_peaks[global_label_to_peak[pi]]
        cluster_features[global_label_to_peak[pi], :] = global_features[peak_node_idx, :]
    
    features = features[peaks]
    labels = labels[peaks]
    return features, labels, cluster_features

############################### MODIFICA 7
def build_next_level_mod7(
    features, labels, peaks, global_features, global_pred_labels, global_peaks, size_threshold=10000
):
    global_peak_to_label = global_pred_labels[global_peaks]
    global_label_to_peak = np.zeros_like(global_peak_to_label)
    for i, pl in enumerate(global_peak_to_label):
        global_label_to_peak[pl] = i
    
    cluster_ind = np.split(
        np.argsort(global_pred_labels),
        np.unique(np.sort(global_pred_labels), return_index=True)[1][1:],
    )
    cluster_features = np.zeros((len(peaks), global_features.shape[1]))
    
    for pi in range(len(peaks)):
        cluster_indices = cluster_ind[pi]
        cluster_size = len(cluster_indices)
        peak_node_idx = global_peaks[global_label_to_peak[pi]]
        
        if cluster_size <= size_threshold:
            # Per cluster piccoli usiamo solo il peak
            cluster_features[global_label_to_peak[pi], :] = global_features[peak_node_idx, :]
        else:
            # Per cluster grandi uiamo la media
            cluster_features[global_label_to_peak[pi], :] = np.mean(
                global_features[cluster_indices, :], axis=0
            )
    
    features = features[peaks]
    labels = labels[peaks]
    return features, labels, cluster_features

In [None]:
%%writefile ./p/utils/adjacency.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""

import numpy as np
import scipy.sparse as sp
from scipy.sparse import coo_matrix


def row_normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    # if rowsum <= 0, keep its previous value
    rowsum[rowsum <= 0] = 1
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.0
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx, r_inv


def sparse_mx_to_indices_values(sparse_mx):
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
    values = sparse_mx.data
    shape = np.array(sparse_mx.shape)
    return indices, values, shape

In [None]:
%%writefile ./p/utils/evaluate.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import inspect

import sys
sys.path.append('/kaggle/working/p')  # root del progetto
sys.path.append('/kaggle/working/p/clustering-benchmark')

import numpy as np
from clustering_benchmark import ClusteringBenchmark
#from utils import metrics, TextColors, Timer

def _read_meta(fn):
    labels = list()
    lb_set = set()
    with open(fn) as f:
        for lb in f.readlines():
            lb = int(lb.strip())
            labels.append(lb)
            lb_set.add(lb)
    return np.array(labels), lb_set


def evaluate(gt_labels, pred_labels, metric="pairwise"):
    from utils import metrics, TextColors, Timer  
    if isinstance(gt_labels, str) and isinstance(pred_labels, str):
        print("[gt_labels] {}".format(gt_labels))
        print("[pred_labels] {}".format(pred_labels))
        gt_labels, gt_lb_set = _read_meta(gt_labels)
        pred_labels, pred_lb_set = _read_meta(pred_labels)

        print(
            "#inst: gt({}) vs pred({})".format(len(gt_labels), len(pred_labels))
        )
        print(
            "#cls: gt({}) vs pred({})".format(len(gt_lb_set), len(pred_lb_set))
        )

    metric_func = metrics.__dict__[metric]

    with Timer(
        "evaluate with {}{}{}".format(TextColors.FATAL, metric, TextColors.ENDC)
    ):
        result = metric_func(gt_labels, pred_labels)
    if isinstance(result, float):
        print(
            "{}{}: {:.4f}{}".format(
                TextColors.OKGREEN, metric, result, TextColors.ENDC
            )
        )
        return f"{metric},{result:.4f}"
    else:
        ave_pre, ave_rec, fscore = result
        print(
            "{}ave_pre: {:.4f}, ave_rec: {:.4f}, fscore: {:.4f}{}".format(
                TextColors.OKGREEN, ave_pre, ave_rec, fscore, TextColors.ENDC
            )
        )
        return f"{metric}_ave_pre,{ave_pre:.4f}\n{metric}_ave_rec,{ave_rec:.4f}\n{metric}_fscore,{fscore:.4f}"


def evaluation(pred_labels, labels, metrics,output_csv_path="evaluation_metrics.csv"):
    print("==> evaluation")
    # pred_labels = g.ndata['pred_labels'].cpu().numpy()
    max_cluster = np.max(pred_labels)
    # gt_labels_all = g.ndata['labels'].cpu().numpy()
    gt_labels_all = labels
    pred_labels_all = pred_labels
    metric_list = metrics.split(",")
    
    csv_lines = ["Metric,Value"]

    for metric in metric_list:
        metric_output  = evaluate(gt_labels_all, pred_labels_all, metric)
        if '\n' in metric_output:
            csv_lines.extend(metric_output.split('\n'))
        else:
            csv_lines.append(metric_output)
            
    # H and C-scores
    gt_dict = {}
    pred_dict = {}
    for i in range(len(gt_labels_all)):
        gt_dict[str(i)] = gt_labels_all[i]
        pred_dict[str(i)] = pred_labels_all[i]
    bm = ClusteringBenchmark(gt_dict)
    scores = bm.evaluate_vmeasure(pred_dict)
    fmi_scores = bm.evaluate_fowlkes_mallows_score(pred_dict)

    # Esempio di print
    #{'#gt clusters': 2452, '#pred clusters': 39732, 'h-score': 0.8615836636450002, 'c-score': 0.6993937747345866, 'v-meansure': 0.7720627293522723}
    csv_lines.append(f"scores_h_score,{scores['h-score']:.4f}")
    csv_lines.append(f"scores_c_score,{scores['c-score']:.4f}")
    csv_lines.append(f"scores_v_score,{scores['v-meansure']:.4f}")
    csv_lines.append(f"gt clusters,{scores['#gt clusters']:.4f}")
    csv_lines.append(f"pred clusters,{scores['#pred clusters']:.4f}")
    
    with open(output_csv_path, 'w', newline='') as csvfile:
        for line in csv_lines:
            csvfile.write(line + '\n')
        
    print(scores)

In [None]:
%%writefile ./p/utils/__init__.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from .adjacency import *
from .deduce import *
from .density import *
from .evaluate import *
from .faiss_gpu import faiss_search_approx_knn
from .faiss_search import faiss_search_knn
from .knn import *
from .metrics import *
from .misc import *

## CARTELLA MODELS

In [None]:
%%writefile ./p/models/lander.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import dgl
import dgl.function as fn
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .focal_loss import FocalLoss
from .graphconv import GraphConv


class LANDER(nn.Module):
    def __init__(
        self,
        feature_dim,
        nhid,
        num_conv=4,
        dropout=0,
        use_GAT=True,
        K=1,
        balance=False,
        use_cluster_feat=True,
        use_focal_loss=True,
        **kwargs
    ):
        super(LANDER, self).__init__()
        nhid_half = int(nhid / 2)
        self.use_cluster_feat = use_cluster_feat
        self.use_focal_loss = use_focal_loss

        if self.use_cluster_feat:
            self.feature_dim = feature_dim * 2
        else:
            self.feature_dim = feature_dim

        input_dim = (feature_dim, nhid, nhid, nhid_half)
        output_dim = (nhid, nhid, nhid_half, nhid_half)
        self.conv = nn.ModuleList()
        self.conv.append(GraphConv(self.feature_dim, nhid, dropout, use_GAT, K))
        for i in range(1, num_conv):
            self.conv.append(
                GraphConv(input_dim[i], output_dim[i], dropout, use_GAT, K)
            )

        self.src_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half)
        self.dst_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half)


        ###################################### OG
        self.classifier_conn = nn.Sequential(
            nn.PReLU(nhid_half),
            nn.Linear(nhid_half, nhid_half),
            nn.PReLU(nhid_half),
            nn.Linear(nhid_half, 2),
        )
        """
        ##################################### 4 MODIFICA
        self.classifier_conn_4 = nn.Sequential(
            nn.PReLU(nhid_half),
            nn.Linear(nhid_half, nhid_half),
            nn.BatchNorm1d(nhid_half),
            nn.PReLU(nhid_half),
            nn.Linear(nhid_half, nhid_half),
            nn.BatchNorm1d(nhid_half),
            nn.PReLU(nhid_half),             
            nn.Linear(nhid_half, 2),
        )

        ##################################### 5 MODIFICA
        self.classifier_conn_5 = nn.Sequential(
            nn.PReLU(2*nhid_half),
            nn.Linear(2*nhid_half, nhid_half),
            nn.BatchNorm1d(nhid_half),
            nn.PReLU(nhid_half),             
            nn.Linear(nhid_half, 2),
        )
        """
        
        if self.use_focal_loss:
            self.loss_conn = FocalLoss(2)
        else:
            self.loss_conn = nn.CrossEntropyLoss()
        self.loss_den = nn.MSELoss()

        self.balance = balance

    def pred_conn(self, edges):
        src_feat = self.src_mlp(edges.src["conv_features"])
        dst_feat = self.dst_mlp(edges.dst["conv_features"])
        ################  OG
        pred_conn = self.classifier_conn(src_feat + dst_feat)
        
        ################ 5 MODIFICA
        #combined_feat = torch.cat([src_feat, dst_feat], dim=-1) #
        #pred_conn = self.classifier_conn(combined_feat) #
        
        return {"pred_conn": pred_conn}

    def pred_den_msg(self, edges):
        prob = edges.data["prob_conn"]
        res = edges.data["raw_affine"] * (prob[:, 1] - prob[:, 0])
        return {"pred_den_msg": res}

    def forward(self, bipartites):
        if isinstance(bipartites, dgl.DGLGraph):
            bipartites = [bipartites] * len(self.conv)
            if self.use_cluster_feat:
                neighbor_x = torch.cat(
                    [
                        bipartites[0].ndata["features"],
                        bipartites[0].ndata["cluster_features"],
                    ],
                    axis=1,
                )
            else:
                neighbor_x = bipartites[0].ndata["features"]

            for i in range(len(self.conv)):
                neighbor_x = self.conv[i](bipartites[i], neighbor_x)

            output_bipartite = bipartites[-1]
            output_bipartite.ndata["conv_features"] = neighbor_x
        else:
            if self.use_cluster_feat:
                neighbor_x_src = torch.cat(
                    [
                        bipartites[0].srcdata["features"],
                        bipartites[0].srcdata["cluster_features"],
                    ],
                    axis=1,
                )
                center_x_src = torch.cat(
                    [
                        bipartites[1].srcdata["features"],
                        bipartites[1].srcdata["cluster_features"],
                    ],
                    axis=1,
                )
            else:
                neighbor_x_src = bipartites[0].srcdata["features"]
                center_x_src = bipartites[1].srcdata["features"]

            for i in range(len(self.conv)):
                neighbor_x_dst = neighbor_x_src[: bipartites[i].num_dst_nodes()]
                neighbor_x_src = self.conv[i](
                    bipartites[i], (neighbor_x_src, neighbor_x_dst)
                )
                center_x_dst = center_x_src[: bipartites[i + 1].num_dst_nodes()]
                center_x_src = self.conv[i](
                    bipartites[i + 1], (center_x_src, center_x_dst)
                )

            output_bipartite = bipartites[-1]
            output_bipartite.srcdata["conv_features"] = neighbor_x_src
            output_bipartite.dstdata["conv_features"] = center_x_src

        output_bipartite.apply_edges(self.pred_conn)
        output_bipartite.edata["prob_conn"] = F.softmax(
            output_bipartite.edata["pred_conn"], dim=1
        )
        output_bipartite.update_all(
            self.pred_den_msg, fn.mean("pred_den_msg", "pred_den")
        )
        return output_bipartite

    def compute_loss(self, bipartite):
        pred_den = bipartite.dstdata["pred_den"]
        loss_den = self.loss_den(pred_den, bipartite.dstdata["density"])

        labels_conn = bipartite.edata["labels_conn"]
        mask_conn = bipartite.edata["mask_conn"]

        if self.balance:
            labels_conn = bipartite.edata["labels_conn"]
            neg_check = torch.logical_and(
                bipartite.edata["labels_conn"] == 0, mask_conn
            )
            num_neg = torch.sum(neg_check).item()
            neg_indices = torch.where(neg_check)[0]
            pos_check = torch.logical_and(
                bipartite.edata["labels_conn"] == 1, mask_conn
            )
            num_pos = torch.sum(pos_check).item()
            pos_indices = torch.where(pos_check)[0]
            if num_pos > num_neg:
                mask_conn[
                    pos_indices[
                        np.random.choice(
                            num_pos, num_pos - num_neg, replace=False
                        )
                    ]
                ] = 0
            elif num_pos < num_neg:
                mask_conn[
                    neg_indices[
                        np.random.choice(
                            num_neg, num_neg - num_pos, replace=False
                        )
                    ]
                ] = 0

        # In subgraph training, it may happen that all edges are masked in a batch
        if mask_conn.sum() > 0:
            loss_conn = self.loss_conn(
                bipartite.edata["pred_conn"][mask_conn], labels_conn[mask_conn]
            )
            loss = loss_den + loss_conn
            loss_den_val = loss_den.item()
            loss_conn_val = loss_conn.item()
        else:
            loss = loss_den
            loss_den_val = loss_den.item()
            loss_conn_val = 0

        return loss, loss_den_val, loss_conn_val


In [None]:
%%writefile ./p/models/graphconv.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
from torch.nn import init


class GraphConvLayer(nn.Module):
    def __init__(self, in_feats, out_feats, bias=True):
        super(GraphConvLayer, self).__init__()
        self.mlp = nn.Linear(in_feats * 2, out_feats, bias=bias)

    def forward(self, bipartite, feat):
        if isinstance(feat, tuple):
            srcfeat, dstfeat = feat
        else:
            srcfeat = feat
            dstfeat = feat[: bipartite.num_dst_nodes()]
        graph = bipartite.local_var()

        graph.srcdata["h"] = srcfeat
        
        ############## OG
        graph.update_all(
            fn.u_mul_e("h", "affine", "m"), fn.sum(msg="m", out="h")
        )

        ############## 1MODIFICA
        #si puà testare solo con deepglint in quanto non usa il GAT-layer
        #graph.update_all(
        #    fn.u_mul_e("h", "raw_affine", "m"), fn.sum(msg="m", out="h")
        #)

        ########### 2MODIFICA
        #Anche questo si testa senza il modulo GAT, qui pesiamo il messaggio per la densità del nodo sorgente
        #graph.srcdata["h_density"] = graph.srcdata["h"] * graph.srcdata["density"].unsqueeze(1)
        #graph.update_all(fn.u_mul_e("h_density", "affine", "m"), fn.sum(msg="m", out="h"))
        
        gcn_feat = torch.cat([dstfeat, graph.dstdata["h"]], dim=-1)
        out = self.mlp(gcn_feat)
        return out


class GraphConv(nn.Module):
    def __init__(self, in_dim, out_dim, dropout=0, use_GAT=False, K=1):
        super(GraphConv, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

        if use_GAT:
            self.gcn_layer = GATConv(
                in_dim, out_dim, K, allow_zero_in_degree=True
            )
            self.bias = nn.Parameter(torch.Tensor(K, out_dim))
            init.constant_(self.bias, 0)
        else:
            self.gcn_layer = GraphConvLayer(in_dim, out_dim, bias=True)

        self.dropout = dropout
        self.use_GAT = use_GAT

    def forward(self, bipartite, features):
        out = self.gcn_layer(bipartite, features)

        if self.use_GAT:
            out = torch.mean(out + self.bias, dim=1)

        out = out.reshape(out.shape[0], -1)
        out = F.relu(out)
        if self.dropout > 0:
            out = F.dropout(out, self.dropout, training=self.training)

        return out

###################### 3 MODIFICA
#Per usare questa mpdifica rimuovere 1 dal nome della classe ed inserirlo sopra
class GraphConv1(nn.Module):
    def __init__(self, in_dim, out_dim, dropout=0, use_GAT=False, K=1):
        super(GraphConv, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

        if use_GAT:
            self.gcn_layer = GATConv(
                in_dim, out_dim, K, allow_zero_in_degree=True
            )
            self.bias = nn.Parameter(torch.Tensor(K, out_dim))
            self.head_weights = nn.Parameter(torch.Tensor(K, out_dim))
            init.constant_(self.head_weights, 1)
            init.constant_(self.bias, 0)
        else:
            self.gcn_layer = GraphConvLayer(in_dim, out_dim, bias=True)

        self.dropout = dropout
        self.use_GAT = use_GAT

    def forward(self, bipartite, features):
        out = self.gcn_layer(bipartite, features)

        if self.use_GAT:
            #out = torch.mean(out + self.bias, dim=1)
            weights = F.softmax(self.head_weights, dim=0)
            
            out_weighted = torch.zeros_like(out[:, 0, :])
            for i in range(out.shape[1]):  # Qui si itera sulle K teste
                out_weighted += weights[i] * (out[:, i, :] + self.bias[i])
            out = out_weighted

        out = out.reshape(out.shape[0], -1)
        out = F.relu(out)
        if self.dropout > 0:
            out = F.dropout(out, self.dropout, training=self.training)

        return out

In [None]:
%%writefile ./p/models/focal_loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# Below code are based on
# https://zhuanlan.zhihu.com/p/28527749


class FocalLoss(nn.Module):
    r"""
    This criterion is a implemenation of Focal Loss, which is proposed in
    Focal Loss for Dense Object Detection.

        Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

    The losses are averaged across observations for each minibatch.

    Args:
        alpha(1D Tensor, Variable) : the scalar factor for this criterion
        gamma(float, double) : gamma > 0; reduces the relative loss for well-classiﬁed examples (p > .5),
                               putting more focus on hard, misclassiﬁed examples
        size_average(bool): By default, the losses are averaged over observations for each minibatch.
                            However, if the field size_average is set to False, the losses are
                            instead summed for each minibatch.


    """

    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average

    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs)

        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.0)

        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]

        probs = (P * class_mask).sum(1).view(-1, 1)

        log_p = probs.log()

        batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p

        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

In [None]:
%%writefile ./p/models/__init__.py
from .graphconv import GraphConv
from .lander import LANDER

## CARTELLA DATASET

In [None]:
%%writefile ./p/datasetml/datasetml.py
import pickle

import numpy as np
import torch
from utils import (
    build_knns,
    build_next_level,
    decode,
    density_estimation,
    fast_knns2spmat,
    knns2ordered_nbrs,
    l2norm,
    row_normalize,
    sparse_mx_to_indices_values,
)

import dgl


class LanderDataset(object):
    def __init__(
        self,
        features,
        labels,
        cluster_features=None,
        k=10,
        levels=1,
        faiss_gpu=False,
    ):
        self.k = k
        self.gs = []
        self.nbrs = []
        self.dists = []
        self.levels = levels

        # Initialize features and labels
        features = l2norm(features.astype("float32"))
        global_features = features.copy()
        if cluster_features is None:
            cluster_features = features
        global_num_nodes = features.shape[0]
        global_edges = ([], [])
        global_peaks = np.array([], dtype=np.longlong)
        ids = np.arange(global_num_nodes)

        # Recursive graph construction
        for lvl in range(self.levels):
            if features.shape[0] <= self.k:
                self.levels = lvl
                break
            if faiss_gpu:
                knns = build_knns(features, self.k, "faiss_gpu")
            else:
                knns = build_knns(features, self.k, "faiss")
            dists, nbrs = knns2ordered_nbrs(knns)
            self.nbrs.append(nbrs)
            self.dists.append(dists)
            density = density_estimation(dists, nbrs, labels)

            g = self._build_graph(
                features, cluster_features, labels, density, knns
            )
            self.gs.append(g)

            if lvl >= self.levels - 1:
                break

            # Decode peak nodes
            (
                new_pred_labels,
                peaks,
                global_edges,
                global_pred_labels,
                global_peaks,
            ) = decode(
                g,
                0,
                "sim",
                True,
                ids,
                global_edges,
                global_num_nodes,
                global_peaks,
            )
            ids = ids[peaks]
            features, labels, cluster_features = build_next_level(
                features,
                labels,
                peaks,
                global_features,
                global_pred_labels,
                global_peaks,
            )

    def _build_graph(self, features, cluster_features, labels, density, knns):
        adj = fast_knns2spmat(knns, self.k)
        adj, adj_row_sum = row_normalize(adj)
        indices, values, shape = sparse_mx_to_indices_values(adj)

        g = dgl.graph((indices[1], indices[0]))
        g.ndata["features"] = torch.FloatTensor(features)
        g.ndata["cluster_features"] = torch.FloatTensor(cluster_features)
        g.ndata["labels"] = torch.LongTensor(labels)
        g.ndata["density"] = torch.FloatTensor(density)
        g.edata["affine"] = torch.FloatTensor(values)
        # A Bipartite from DGL sampler will not store global eid, so we explicitly save it here
        g.edata["global_eid"] = g.edges(form="eid")
        g.ndata["norm"] = torch.FloatTensor(adj_row_sum)
        g.apply_edges(
            lambda edges: {
                "raw_affine": edges.data["affine"] / edges.dst["norm"]
            }
        )
        g.apply_edges(
            lambda edges: {
                "labels_conn": (
                    edges.src["labels"] == edges.dst["labels"]
                ).long()
            }
        )
        g.apply_edges(
            lambda edges: {
                "mask_conn": (
                    edges.src["density"] > edges.dst["density"]
                ).bool()
            }
        )
        return g

    def __getitem__(self, index):
        assert index < len(self.gs)
        return self.gs[index]

    def __len__(self):
        return len(self.gs)



In [None]:
%%writefile ./p/datasetml/__init__.py
from .datasetml import LanderDataset

## TRAIN

In [None]:
%%writefile ./p/train_subg.py
import argparse

import os
os.environ["DGLBACKEND"] = "pytorch"

import pickle
import time

import matplotlib.pyplot as plt
import csv 

import dgl

import numpy as np
import torch
import torch.optim as optim

import sys
sys.path.append('/kaggle/working/p') 
sys.path.append('/kaggle/working/p/clustering-benchmark')


from datasetml import LanderDataset
from models import LANDER

###########
# ArgParser
parser = argparse.ArgumentParser()

# Dataset
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--levels", type=str, default="1")
parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument("--model_filename", type=str, default="lander.pth")

# KNN
parser.add_argument("--knn_k", type=str, default="10")
parser.add_argument("--num_workers", type=int, default=0)

# Model
parser.add_argument("--hidden", type=int, default=512)
parser.add_argument("--num_conv", type=int, default=1)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--gat", action="store_true")
parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument("--balance", action="store_true")
parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument("--use_focal_loss", action="store_true")

# Training
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=1024)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--weight_decay", type=float, default=1e-5)

#===========================================
# NUOVI PARAMETRI PER IL PLOT E SALVATAGGIO LOSS
parser.add_argument("--save_interval", type=int, default=50)  # Salva ogni x epoche
parser.add_argument("--loss_file", type=str, default="losses.csv")  # File per le losses
parser.add_argument("--plot_interval", type=int, default=50)  # Intervallo per il salvataggio del grafico
#===========================================

args = parser.parse_args()
print(args)

###########################
# Environment Configuration
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

##################
# Data Preparation
with open(args.data_path, "rb") as f:
    features, labels = pickle.load(f)

k_list = [int(k) for k in args.knn_k.split(",")]
lvl_list = [int(l) for l in args.levels.split(",")]
gs = []
nbrs = []
ks = []
for k, l in zip(k_list, lvl_list):
    dataset = LanderDataset(
        features=features,
        labels=labels,
        k=k,
        levels=l,
        faiss_gpu=args.faiss_gpu,
    )
    gs += [g for g in dataset.gs]
    ks += [k for g in dataset.gs]
    nbrs += [nbr for nbr in dataset.nbrs]

print("Dataset Prepared.")

def set_train_sampler_loader(g, k):
    fanouts = [k - 1 for i in range(args.num_conv + 1)]
    sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
    # fix the number of edges
    train_dataloader = dgl.dataloading.DataLoader(
        g,
        torch.arange(g.num_nodes()),
        sampler,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
        num_workers=args.num_workers,
    )
    return train_dataloader


train_loaders = []
for gidx, g in enumerate(gs):
    train_dataloader = set_train_sampler_loader(gs[gidx], ks[gidx])
    train_loaders.append(train_dataloader)

##################
# Model Definition
feature_dim = gs[0].ndata["features"].shape[1]
model = LANDER(
    feature_dim=feature_dim,
    nhid=args.hidden,
    num_conv=args.num_conv,
    dropout=args.dropout,
    use_GAT=args.gat,
    K=args.gat_k,
    balance=args.balance,
    use_cluster_feat=args.use_cluster_feat,
    use_focal_loss=args.use_focal_loss,
)
model = model.to(device)
model.train()

#################
# Hyperparameters
opt = optim.SGD(
    model.parameters(),
    lr=args.lr,
    momentum=args.momentum,
    weight_decay=args.weight_decay,
)

# keep num_batch_per_loader the same for every sub_dataloader
num_batch_per_loader = len(train_loaders[0])
train_loaders = [iter(train_loader) for train_loader in train_loaders]
num_loaders = len(train_loaders)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    opt, T_max=args.epochs * num_batch_per_loader * num_loaders, eta_min=1e-5
)

print("Start Training.")

#==============================================#
# CODICE DI TRAINING INVARIATO, 
# AGGIUNTO SOLO FUNZIONE DI PLOT
#==============================================#
losses = []

def plot_loss(losses, epoch):
    epochs = [loss['epoch'] for loss in losses]
    loss_vals = [loss['loss'] for loss in losses]
    loss_den_vals = [loss['loss_den'] for loss in losses]
    loss_conn_vals = [loss['loss_conn'] for loss in losses]

    plt.figure(figsize=(12, 8))
    plt.plot(epochs, loss_vals, label='Loss (D+C)', color='tab:blue', linewidth=2)
    plt.plot(epochs, loss_den_vals, label='Density Loss', color='tab:green', linewidth=2)
    plt.plot(epochs, loss_conn_vals, label='Connection Loss', color='tab:red', linewidth=2)
    
    plt.xlabel('Epoch', fontsize=14)
    plt.ylabel('Loss', fontsize=14)
    plt.title(f"Training Losses", fontsize=16)
    plt.legend(fontsize=12)
    plt.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.7)

    plt.savefig(f"/kaggle/working/loss_plot_epoch_{epoch}.png", bbox_inches='tight')
    plt.close()

# Training Loop
for epoch in range(args.epochs):
    loss_den_val_total = []
    loss_conn_val_total = []
    loss_val_total = []
    for batch in range(num_batch_per_loader):
        for loader_id in range(num_loaders):
            try:
                minibatch = next(train_loaders[loader_id])
            except:
                train_loaders[loader_id] = iter(
                    set_train_sampler_loader(gs[loader_id], ks[loader_id])
                )
                minibatch = next(train_loaders[loader_id])
            input_nodes, sub_g, bipartites = minibatch
            sub_g = sub_g.to(device)
            bipartites = [b.to(device) for b in bipartites]
            # get the feature for the input_nodes
            opt.zero_grad()
            output_bipartite = model(bipartites)
            loss, loss_den_val, loss_conn_val = model.compute_loss(
                output_bipartite
            )
            loss_den_val_total.append(loss_den_val)
            loss_conn_val_total.append(loss_conn_val)
            loss_val_total.append(loss.item())
            loss.backward()
            opt.step()
            if (batch + 1) % 10 == 0:
                print(
                    "epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
                    % (
                        epoch,
                        batch,
                        num_batch_per_loader,
                        loader_id,
                        num_loaders,
                        loss.item(),
                        loss_den_val,
                        loss_conn_val,
                    )
                )
            scheduler.step()

    losses.append({
        'epoch': epoch,
        'loss': np.array(loss_val_total).mean(),
        'loss_den': np.array(loss_den_val_total).mean(),
        'loss_conn': np.array(loss_conn_val_total).mean()
    })

    print(
        "epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
        % (
            epoch,
            np.array(loss_val_total).mean(),
            np.array(loss_den_val_total).mean(),
            np.array(loss_conn_val_total).mean(),
        )
    )

    if (epoch + 1) % args.save_interval == 0:
        torch.save(model.state_dict(), f"model_epoch_{epoch}.pth")
    
    # Salvataggio modello finale
    torch.save(model.state_dict(), args.model_filename)

    # Salvataggio del grafico
    if (epoch + 1) % args.plot_interval == 0:
        plot_loss(losses, epoch)

with open(args.loss_file, mode='w', newline='') as file:
    writer = csv.DictWriter(file, fieldnames=['epoch', 'loss', 'loss_den', 'loss_conn'])
    writer.writeheader()
    for loss in losses:
        writer.writerow(loss)

"""
###############
# Training Loop
for epoch in range(args.epochs):
    loss_den_val_total = []
    loss_conn_val_total = []
    loss_val_total = []
    for batch in range(num_batch_per_loader):
        for loader_id in range(num_loaders):
            try:
                minibatch = next(train_loaders[loader_id])
            except:
                train_loaders[loader_id] = iter(
                    set_train_sampler_loader(gs[loader_id], ks[loader_id])
                )
                minibatch = next(train_loaders[loader_id])
            input_nodes, sub_g, bipartites = minibatch
            sub_g = sub_g.to(device)
            bipartites = [b.to(device) for b in bipartites]
            # get the feature for the input_nodes
            opt.zero_grad()
            output_bipartite = model(bipartites)
            loss, loss_den_val, loss_conn_val = model.compute_loss(
                output_bipartite
            )
            loss_den_val_total.append(loss_den_val)
            loss_conn_val_total.append(loss_conn_val)
            loss_val_total.append(loss.item())
            loss.backward()
            opt.step()
            if (batch + 1) % 10 == 0:
                print(
                    "epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
                    % (
                        epoch,
                        batch,
                        num_batch_per_loader,
                        loader_id,
                        num_loaders,
                        loss.item(),
                        loss_den_val,
                        loss_conn_val,
                    )
                )
            scheduler.step()
    print(
        "epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f"
        % (
            epoch,
            np.array(loss_val_total).mean(),
            np.array(loss_den_val_total).mean(),
            np.array(loss_conn_val_total).mean(),
        )
    )
    torch.save(model.state_dict(), args.model_filename)

torch.save(model.state_dict(), args.model_filename)
"""

In [None]:
#========== TRAIN INAT2018
#%run ./p/train_subg.py --data_path /kaggle/input/dataset-ml/inat2018_train_dedup_inter_intra.pkl --knn_k 10,5,3 --levels 2,3,4 --faiss_gpu --hidden 512 --epochs 250 --lr 0.01 --batch_size 4096 --num_conv 1 --gat --gat_k 2 --balance

#========== TRAIN INAT PIC
#%run ./p/train_subg.py --data_path /kaggle/input/dataset-ml/inat2018_train_dedup_inter_intra_1_in_6_per_class.pkl --knn_k 10,5,3 --levels 2,3,4 --faiss_gpu --hidden 512 --epochs 250 --lr 0.01 --batch_size 4096 --num_conv 1 --gat --balance

#========== TRAIN DEEPLIGHT
#%run ./p/train_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_train_1_in_10_recreated.pkl --knn_k 10,5,3 --levels 2,3,4 --faiss_gpu --hidden 512 --epochs 250 --lr 0.01 --batch_size 4096 --num_conv 1 --balance --use_cluster_feat

In [None]:
%%writefile ./p/test_subg.py
import argparse

import os
os.environ["DGLBACKEND"] = "pytorch"
import pickle
import time

import sys
sys.path.append('/kaggle/working/p')  # root del progetto
sys.path.append('/kaggle/working/p/clustering-benchmark')

import dgl

import numpy as np
import torch
import torch.optim as optim
from datasetml import LanderDataset
from models import LANDER
from utils import build_next_level_mod7, decode, evaluation, stop_iterating

###########
# ArgParser
parser = argparse.ArgumentParser()

# Dataset
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--model_filename", type=str, default="lander.pth")
parser.add_argument("--faiss_gpu", action="store_true")
parser.add_argument("--num_workers", type=int, default=0)

# HyperParam
parser.add_argument("--knn_k", type=int, default=10)
parser.add_argument("--levels", type=int, default=1)
parser.add_argument("--tau", type=float, default=0.5)
parser.add_argument("--threshold", type=str, default="prob")
parser.add_argument("--metrics", type=str, default="pairwise,bcubed,nmi")
parser.add_argument("--early_stop", action="store_true")

# Model
parser.add_argument("--hidden", type=int, default=512)
parser.add_argument("--num_conv", type=int, default=4)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--gat", action="store_true")
parser.add_argument("--gat_k", type=int, default=1)
parser.add_argument("--balance", action="store_true")
parser.add_argument("--use_cluster_feat", action="store_true")
parser.add_argument("--use_focal_loss", action="store_true")
parser.add_argument("--use_gt", action="store_true")

# Subgraph
parser.add_argument("--batch_size", type=int, default=4096)

args = parser.parse_args()
print(args)

###########################
# Environment Configuration
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

##################
# Data Preparation
with open(args.data_path, "rb") as f:
    features, labels = pickle.load(f)
global_features = features.copy()
dataset = LanderDataset(
    features=features,
    labels=labels,
    k=args.knn_k,
    levels=1,
    faiss_gpu=args.faiss_gpu,
)
g = dataset.gs[0]
g.ndata["pred_den"] = torch.zeros((g.num_nodes()))
g.edata["prob_conn"] = torch.zeros((g.num_edges(), 2))
global_labels = labels.copy()
ids = np.arange(g.num_nodes())
global_edges = ([], [])
global_peaks = np.array([], dtype=np.longlong)
global_edges_len = len(global_edges[0])
global_num_nodes = g.num_nodes()

fanouts = [args.knn_k - 1 for i in range(args.num_conv + 1)]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges
test_loader = dgl.dataloading.DataLoader(
    g,
    torch.arange(g.num_nodes()),
    sampler,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=args.num_workers,
)

##################
# Model Definition
if not args.use_gt:
    feature_dim = g.ndata["features"].shape[1]
    model = LANDER(
        feature_dim=feature_dim,
        nhid=args.hidden,
        num_conv=args.num_conv,
        dropout=args.dropout,
        use_GAT=args.gat,
        K=args.gat_k,
        balance=args.balance,
        use_cluster_feat=args.use_cluster_feat,
        use_focal_loss=args.use_focal_loss,
    )
    model.load_state_dict(torch.load(args.model_filename, weights_only=True))
    model = model.to(device)
    model.eval()

# number of edges added is the indicator for early stopping
num_edges_add_last_level = np.Inf
##################################
# Predict connectivity and density
for level in range(args.levels):
    if not args.use_gt:
        total_batches = len(test_loader)
        for batch, minibatch in enumerate(test_loader):
            input_nodes, sub_g, bipartites = minibatch
            sub_g = sub_g.to(device)
            bipartites = [b.to(device) for b in bipartites]
            with torch.no_grad():
                output_bipartite = model(bipartites)
            global_nid = output_bipartite.dstdata[dgl.NID]
            global_eid = output_bipartite.edata["global_eid"]
            g.ndata["pred_den"][global_nid] = output_bipartite.dstdata[
                "pred_den"
            ].to("cpu")
            g.edata["prob_conn"][global_eid] = output_bipartite.edata[
                "prob_conn"
            ].to("cpu")
            torch.cuda.empty_cache()
            if (batch + 1) % 10 == 0:
                print("Batch %d / %d for inference" % (batch, total_batches))

    (
        new_pred_labels,
        peaks,
        global_edges,
        global_pred_labels,
        global_peaks,
    ) = decode(
        g,
        args.tau,
        args.threshold,
        args.use_gt,
        ids,
        global_edges,
        global_num_nodes,
        global_peaks,
    )
    ids = ids[peaks]
    new_global_edges_len = len(global_edges[0])
    num_edges_add_this_level = new_global_edges_len - global_edges_len
    if stop_iterating(
        level,
        args.levels,
        args.early_stop,
        num_edges_add_this_level,
        num_edges_add_last_level,
        args.knn_k,
    ):
        break
    global_edges_len = new_global_edges_len
    num_edges_add_last_level = num_edges_add_this_level

    # build new dataset
    features, labels, cluster_features = build_next_level_mod7(
        features,
        labels,
        peaks,
        global_features,
        global_pred_labels,
        global_peaks,
    )
    # After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.
    dataset = LanderDataset(
        features=features,
        labels=labels,
        k=args.knn_k,
        levels=1,
        faiss_gpu=False,
        cluster_features=cluster_features,
    )
    g = dataset.gs[0]
    g.ndata["pred_den"] = torch.zeros((g.num_nodes()))
    g.edata["prob_conn"] = torch.zeros((g.num_edges(), 2))
    test_loader = dgl.dataloading.DataLoader(
        g,
        torch.arange(g.num_nodes()),
        sampler,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.num_workers,
    )
evaluation(global_pred_labels, global_labels, args.metrics,)

In [None]:
# TEST INAT FULL
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/inat2018_test.pkl --model_filename /kaggle/input/mlreplica/og/inat2018_1/lander_inat.pth --knn_k 10 --tau 0.1 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop

#TEST DEEPGLINT
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_imdb_features_sampled_as_deepglint_1_in_10.pkl --model_filename /kaggle/input/mlreplica/og/deepglint/lander_deeplight.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

# TEST HANNAH
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_hannah.pkl --model_filename /kaggle/input/mlreplica/og/deepglint/lander_deeplight.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

#IMDB
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_imdb_features.pkl --model_filename /kaggle/input/mlreplica/og/deepglint/lander_deeplight.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

#inat_2018
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/inat2018_test.pkl --model_filename /kaggle/working/model_epoch_3.pth --knn_k 10 --tau 0.1 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop

In [None]:
######### MODIFICA 1
# TEST SI PUò FARE SOLO SU TEST DATa DI DEEPGLINT

#TEST DEEPGLINT SEENDATA
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_imdb_features_sampled_as_deepglint_1_in_10.pkl --model_filename /kaggle/input/mlmodifica1/1mod/lander_inat_2.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

#TEST HANNAH UNSEEN
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_hannah.pkl --model_filename /kaggle/input/mlmodifica1/1mod/lander_inat_2.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

#TEST IMBD unseeN
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_imdb_features.pkl --model_filename /kaggle/input/mlmodifica1/1mod/lander_inat_2.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

In [None]:
#################### MODIFICA 2
# TEST SI PUò FARE SOLO SU TEST DATa DI DEEPGLINT

#TEST DEEPGLINT SEENDATA
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_imdb_features_sampled_as_deepglint_1_in_10.pkl --model_filename /kaggle/input/mlmod2/2mod/lander_deepglint.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

#TEST HANNAH UNSEEN
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_hannah.pkl --model_filename /kaggle/input/mlmod2/2mod/lander_deepglint.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

#TEST IMBD unseeN
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_imdb_features.pkl --model_filename /kaggle/input/mlmod2/2mod/lander_deepglint.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

In [None]:
############## MODIFICA 3
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/inat2018_test.pkl --model_filename /kaggle/input/mod3gat/3mod/lander_inat_piccolo.pth --knn_k 10 --tau 0.1 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --gat_k 4 --batch_size 4096 --early_stop

In [None]:
###################### MODIFICA 4
# TEST INAT FULL
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/inat2018_test.pkl --model_filename /kaggle/input/mlmod4/4mod/inat_full/lander_inat.pth --knn_k 10 --tau 0.1 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop

#TEST DEEPGLINT
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_imdb_features_sampled_as_deepglint_1_in_10.pkl --model_filename /kaggle/input/mlmod4/4mod/deepglint/lander_deepglint.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

# TEST HANNAH
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_hannah.pkl --model_filename /kaggle/input/mlmod4/4mod/deepglint/lander_deepglint.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

#IMDB
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_imdb_features.pkl --model_filename /kaggle/input/mlmod4/4mod/deepglint/lander_deepglint.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

#inat_2018
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/inat2018_test.pkl --model_filename /kaggle/input/mlmod4/4mod/inat_piccolo/lander_inat_2.pth --knn_k 10 --tau 0.1 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop

In [None]:
############################# MODIFICA 5
# TEST INAT FULL
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/inat2018_test.pkl --model_filename /kaggle/input/mlmod6/5mod/inat_full/lander_inat.pth --knn_k 10 --tau 0.1 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop

#TEST DEEPGLINT
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_imdb_features_sampled_as_deepglint_1_in_10.pkl --model_filename /kaggle/input/mlmod6/5mod/deepglit/lander_deepglint.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

# TEST HANNAH
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_hannah.pkl --model_filename /kaggle/input/mlmod6/5mod/deepglit/lander_deepglint.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

#IMDB
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_imdb_features.pkl --model_filename /kaggle/input/mlmod6/5mod/deepglit/lander_deepglint.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

#inat_2018
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/inat2018_test.pkl --model_filename /kaggle/input/mlmod6/5mod/inat_piccolo/lander_inat_piccolo.pth --knn_k 10 --tau 0.1 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop

In [None]:
############################# MODIFICA 6
# TEST INAT FULL
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/inat2018_test.pkl --model_filename /kaggle/input/mlmod6-1/6mod/inat_full/lander_inat_pic.pth --knn_k 10 --tau 0.1 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop

#TEST DEEPGLINT
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_imdb_features_sampled_as_deepglint_1_in_10.pkl --model_filename /kaggle/input/mlmod6-1/6mod/deepglit/lander_deepglint.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

# TEST HANNAH
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_hannah.pkl --model_filename /kaggle/input/mlmod6-1/6mod/deepglit/lander_deepglint.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

#IMDB
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/subcenter_arcface_deepglint_imdb_features.pkl --model_filename /kaggle/input/mlmod6-1/6mod/deepglit/lander_deepglint.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat

#inat_2018
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/inat2018_test.pkl --model_filename /kaggle/input/mlmod6-1/6mod/inat_piccolo/lander_inat_2.pth --knn_k 10 --tau 0.1 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop

In [None]:
############################# MODIFICA 7
#inat_2018
#%run ./p/test_subg.py --data_path /kaggle/input/dataset-ml/inat2018_test.pkl --model_filename /kaggle/input/mlmod7/7mod/s10000/lander.pth --knn_k 10 --tau 0.1 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop