In [6]:
%cd /kaggle/working
import glob

import numpy as np
import polars as pl

/kaggle/working


In [65]:
path_list = sorted(glob.glob("input/ClimSim_low-res/train/*/*.npy"))
path_like_test = sorted(list(set((path_list[::24] + path_list[::30]))))
len(path_like_test)

3504

In [66]:
data_arrays = [
    np.load(path) for path in path_like_test[:100]
]  # Simulating the loaded arrays

concatenated_array = np.vstack(data_arrays)

# Create a DataFrame using Polars
df = pl.DataFrame(concatenated_array)

In [67]:
"""
以下のように8個ごとに1周期
[0,24,30,48,60,72,90,96],[120,...]
"""

df_time = (
    df[:, :1]
    .with_row_index()
    .with_columns(
        [
            (pl.col("index") % 384).alias("location"),
            ((pl.col("index") // 384) % 8).alias("time_mod"),
        ]
    )
    .with_columns(
        pl.when(pl.col("time_mod") == 0)
        .then(0)
        .when(pl.col("time_mod") == 1)
        .then(24)
        .when(pl.col("time_mod") == 2)
        .then(30)
        .when(pl.col("time_mod") == 3)
        .then(48)
        .when(pl.col("time_mod") == 4)
        .then(60)
        .when(pl.col("time_mod") == 5)
        .then(72)
        .when(pl.col("time_mod") == 6)
        .then(90)
        .otherwise(96)
        .alias("time")
    )
).drop(["column_0", "time_mod"])

df_time

index,location,time
u32,u32,i32
0,0,0
1,1,0
2,2,0
3,3,0
4,4,0
…,…,…
38395,379,48
38396,380,48
38397,381,48
38398,382,48


In [87]:
import pickle
from pathlib import Path

import torch
from torch.nn.functional import normalize
from torch.utils.data import DataLoader, TensorDataset
from tqdm.notebook import tqdm


def compute_cosine_similarity(tensor1, tensor2):
    return torch.mm(tensor1, tensor2.t())


def get_top_k_similar_rows(matrix, k=5, chunk_size=1000, device="cpu"):
    matrix = matrix.to(device)
    normalized_matrix = normalize(matrix, p=2, dim=1)

    top_k_indices = torch.empty((matrix.size(0), k), dtype=torch.long, device="cpu")

    dataset = TensorDataset(normalized_matrix)
    dataloader = DataLoader(dataset, batch_size=chunk_size)

    for ci, chunk in enumerate(tqdm(dataloader)):
        chunk_tensor = chunk[0]
        chunk_size_actual = chunk_tensor.size(0)

        cosine_similarities = compute_cosine_similarity(
            chunk_tensor, normalized_matrix
        ).to(device)

        top_k = torch.topk(
            cosine_similarities, k + 1, dim=1
        )  # k+1 because the most similar row will be itself
        top_k_indices[ci * chunk_size : ci * chunk_size + chunk_size_actual] = (
            top_k.indices[:, 1:].cpu()
        )  # Exclude the first one which is itself

    return top_k_indices


# Sample data
data = torch.tensor(concatenated_array[:, :556])

# scale
scale_dir = "output/preprocess/normalize_009_rate_feat/bolton"
feat_mean_dict = pickle.load(
    open(
        Path(scale_dir) / "x_mean_feat_dict.pkl",
        "rb",
    )
)
feat_std_dict = pickle.load(
    open(
        Path(scale_dir) / "x_std_feat_dict.pkl",
        "rb",
    )
)

data = (data - feat_mean_dict["base"]) / (feat_std_dict["base"] + 1e-60)


k = 5
chunk_size = 100

# Use CPU or GPU
device = "cpu"  # "cuda" if torch.cuda.is_available() else "cpu"

top_k_similar_rows = get_top_k_similar_rows(
    data, k=k, chunk_size=chunk_size, device=device
)

  0%|          | 0/384 [00:00<?, ?it/s]

In [88]:
top_k_similar_rows

tensor([[ 9216, 11200,  1920, 19981, 12352],
        [12676, 16516, 20032, 13060, 21953],
        [ 1922,  9218,  3458,   207,  1538],
        ...,
        [38013, 33021, 32637, 26109, 27261],
        [38014, 31102, 36093, 21886, 37630],
        [38015, 25727, 26111, 30719, 26878]])

In [89]:
df_similar = df_time.with_columns(
    pl.Series("similar", values=top_k_similar_rows.numpy())
).with_columns(
    pl.col("similar").list.contains(pl.col("index") + 384).alias("is_in_next")
)

In [90]:
df_similar.filter(pl.col("time") == 48)["is_in_next"].sum()

3718

In [91]:
df_similar.filter(pl.col("time") == 48)

index,location,time,similar,is_in_next
u32,u32,i32,list[i64],bool
1152,0,48,"[3073, 10368, … 1357]",true
1153,1,48,"[3076, 10369, … 20808]",false
1154,2,48,"[3074, 1538, … 27779]",true
1155,3,48,"[16967, 1539, … 24966]",true
1156,4,48,"[10372, 2693, … 20809]",false
…,…,…,…,…
38395,379,48,"[38011, 37627, … 32250]",false
38396,380,48,"[38012, 37244, … 36860]",false
38397,381,48,"[38013, 33021, … 27261]",false
38398,382,48,"[38014, 31102, … 37630]",false
