In [None]:
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

tqdm.pandas()

MODEL_NAME = "/kaggle/working/all-MiniLM-L6-v2"
K = 5
ID_LABELS = ["dataset_id", "article_id", "id", "DOI", "url"]
TRAIN_LABELS = [
    'title', 'segments', 'extension', 'abstract', 'publisher',
    'copyright', 'issued_year', 'all_authors', 'categories'
]
ALL_FIELDS = [
    'article_id','text','extension','source','dataset_id','dataset_id_cited','type',
    'id','categories','abstract','DOI','publisher','title','URL','copyright',
    'issued_year','all_authors','y'
]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class Basement(SentenceTransformer):
    def __init__(self, initial_data: pd.DataFrame):
        super().__init__(MODEL_NAME, device=DEVICE)
        self.x = self.preprocess_data(initial_data)
        self.to(DEVICE)
        self.eval()
        for p in self.parameters():
            p.requires_grad = False

    def _meta_conditioner(self, row):
        combined_prompt = f"""
        Title: {row.get('title', '')}
        Abstract: {row.get('abstract', '')}
        Author: {row.get('all_authors', '')}
        Extension File Type: {row.get('extension', '')}
        Publisher: {row.get('publisher', '')}
        Category: {row.get('categories', '')}
        Issued Year: {row.get('issued_year', '')}
        Copyright: {row.get('copyright', '')}
        """
        meta_embedding = self.encode(
            [combined_prompt],
            convert_to_numpy=True,
            output_value="sentence_embedding",
            show_progress_bar=False
        )
        text_embedding = row.get('segments', None)

        if isinstance(text_embedding, np.ndarray):
            if text_embedding.ndim == 1:
                text_embedding = np.expand_dims(text_embedding, axis=0)
            return np.vstack((meta_embedding, text_embedding))

        return meta_embedding  # fallback to meta only

    def _segment_text(self, s: str, k: int = K) -> np.ndarray:
        if not isinstance(s, str) or not s:
            return np.zeros((k, self.get_sentence_embedding_dimension()), dtype=np.float32)
        n = len(s)
        idx = np.linspace(0, n, k + 1, dtype=int)
        chunks = [s[idx[i]:idx[i+1]] for i in range(k)]
        emb = self.encode(chunks, convert_to_numpy=True, show_progress_bar=False)
        return np.asarray(emb, dtype=np.float32).reshape(k, -1)

    def preprocess_data(
        self,
        data: pd.DataFrame,
        k: int = K,
        train_labels: list[str] | None = None
    ) -> pd.DataFrame:
        train_labels = train_labels or TRAIN_LABELS
        df = data.copy()

        # 1) create 'segments'
        df["segments"] = df["text"].apply(lambda x: self._segment_text(x, k))

        # 2) explode multi-valued columns
        for col in ("all_authors", "segments", "categories"):
            if col in df.columns:
                df = df.explode(col, ignore_index=True)

        # 3) drop duplicated rows
        subset_cols = [c for c in train_labels if c in df.columns and c != "segments"]
        if subset_cols:
            df = df.drop_duplicates(subset=subset_cols, keep="first")

        # 4) fills / cleaning
        if "issued_year" in df.columns:
            mode_series = df["issued_year"].dropna().mode()
            if not mode_series.empty:
                df["issued_year"] = df["issued_year"].fillna(mode_series.iloc[0])

        if "copyright" in df.columns:
            df["copyright"] = (
                df["copyright"].astype("string").str.strip().replace({"": "Unknown"})
            )

        if "categories" in df.columns:
            df["categories"] = df["categories"].astype("string").fillna("Unknown")

        df["inputs"] = df.apply(self._meta_conditioner, axis=1)
        return df

    def predict(self, inputs: dict, threshold: float = 0.5):
        text = inputs.get("text")
        if not text:
            return None

        # Query embedding
        segment = self._segment_text(text)
        inputs = dict(inputs, segments=segment)
        input_emb = self._meta_conditioner(inputs)
        if input_emb.ndim == 1:
            input_emb = input_emb.reshape(1, -1)
        q = input_emb.mean(axis=0, keepdims=True)

        q_tensor = torch.tensor(q, device=DEVICE, dtype=torch.float32)

        max_pred = None
        for grp, row in self.x.groupby(['article_id', 'dataset_id', 'source']):
            valid_inputs = [
                x for x in row['inputs'].values
                if isinstance(x, np.ndarray) and x.size > 0
            ]
            if not valid_inputs:
                continue

            train_sample = np.vstack(valid_inputs)
            train_tensor = torch.tensor(train_sample, device=DEVICE, dtype=torch.float32)

            avg = float(cos_sim(train_tensor, q_tensor).mean())
            prev = max_pred['score'] if max_pred else None
            if (max_pred is None) or (avg > prev):
                source_type = grp[2] if avg >= threshold else "Unknown"
                max_pred = {
                    'article_id': grp[0],
                    'dataset_id': grp[1],
                    'type': source_type,
                    'score': avg,
                }
        return max_pred



In [None]:
# ======================
# Run pipeline
# ======================
train_data = pd.read_parquet('/kaggle/input/make-data-count-data-preparation/train_dataset.parquet')
test_data = pd.read_parquet('/kaggle/input/make-data-count-data-preparation/test_dataset.parquet')

base = Basement(train_data)

In [None]:
prediction = []

for _, row in test_data.iterrows():
    pred = base.predict(row.to_dict())
    if pred:
        prediction.append(pred)

submission = pd.DataFrame.from_records(prediction)
submission = submission.sort_values(["article_id", "dataset_id", "type"]).reset_index(drop=True)
submission['row_index'] = submission.index

submission[['row_index', 'article_id', 'dataset_id', 'type']].to_csv('/kaggle/working/submission.csv', index=False)
print('data submitted!', submission.head(5))