In [93]:
import os
import torch
from datasets import load_dataset, Dataset
import pandas as pd
from transformers import (
    AutoModelForCausalLM,
    AutoModel,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)

from sentence_transformers import SentenceTransformer, util
from rank_bm25 import BM25Okapi
import random
import numpy as np
from typing import Literal, Optional, TypedDict
from sklearn.model_selection import train_test_split

In [94]:
class URLCiteDataset(torch.utils.data.Dataset):
    '''
    create dataset
    - init
    - len
    - getitem
    '''
    def __init__(self, texts: list[str]):
        self.texts = texts

    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return self.texts[idx]

In [95]:
!nvidia-smi

Mon Dec 30 00:22:24 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-PCIE-32GB           On  | 00000000:18:00.0 Off |                    0 |
| N/A   31C    P0              23W / 250W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla V100-PCIE-32GB           On  | 00000000:AF:00.0 Off |  

In [96]:
csv_dataset = pd.read_csv("/data/group1/z40436a/ME/URL_Citation_Classification_Intermediate/data/all_data.csv", encoding="utf-8")

seed = 111 # fixed
train_df, eval_df = train_test_split(csv_dataset, test_size = 0.1, random_state=seed)
print("train_data_size:::", len(train_df))
print("test_data_size:::", len(eval_df))

train_data_size::: 2690
test_data_size::: 299


In [97]:
import nltk
import re

CITE_TOKEN = "[URL_CITE]"

def replace_tag(sentences: pd.Series) -> list[str]:
    # replace [Cite_****] to [Cite] token
    rule = re.compile(r'\[Cite[^\[\] ]*\]')
    sentences_replaced:list[str] = list()
    for sentence in sentences:
        sentences_replaced.append(rule.sub(CITE_TOKEN, sentence))

    return sentences_replaced

def get_3sent(paragraphs:list[str]) -> list[str]:
    ret:list[list[str]] = list()
    for paragraph in paragraphs:
        sentences: list[str] = nltk.sent_tokenize(paragraph)
        if not len(sentences):
            print('!!!')
        if len(sentences) < 4:
            ret.append(sentences)
            continue
        else:
            for i in range(len(sentences)):
                if CITE_TOKEN in sentences[i]:
                    if i == 0:
                        ret.append(sentences[i:i+2])
                    elif i == len(sentences)-1:
                        ret.append(sentences[i-1:i+1])
                    else:
                        ret.append(sentences[i-1:i+2])
                    break
                if i == len(sentences)-1:
                    # print(sentences)
                    pass
    cont_3sent = [" ".join(sent) for sent in ret]
    return cont_3sent

In [98]:
def create_icl(train_df:pd.DataFrame, test_df:pd.DataFrame, method:str, k:int=5) -> list[list[str]]:
    icl_idxs: list[list[str]] = []

    train_replaced_sentences = replace_tag(train_df['citation-paragraph'])
    test_replaced_sentences = replace_tag(test_df['citation-paragraph'])

    # check
    print(len(train_replaced_sentences))
    print(len(test_replaced_sentences))

    train_cont_3sent = get_3sent(train_replaced_sentences)
    test_cont_3sent = get_3sent(test_replaced_sentences)

    # check
    print(len(train_cont_3sent))
    print(len(test_cont_3sent))

    # random
    if method == "random":
        for cont, (i, row) in zip(test_cont_3sent, test_df.iterrows()):
            random.seed(i)
            icl_idxs.append(random.sample(range(len(train_cont_3sent)), k))

    #bm25
    elif method == "bm25":
        tokenized_corpus = [cont.split(" ") for cont in train_cont_3sent]
        bm25 = BM25Okapi(tokenized_corpus)

        for cont, (i, row) in zip(test_cont_3sent, test_df.iterrows()):
            bm25_scores = bm25.get_scores(cont.split(" "))
            icl_idxs.append(np.argsort(bm25_scores)[-k:][::-1].tolist())
    elif method == "encoder":
        model = SentenceTransformer("intfloat/multilingual-e5-base")

        tokenized_corpus = model.encode(train_cont_3sent, convert_to_tensor=True)
        for cont, (i, row) in zip(test_cont_3sent, test_df.iterrows()):
            tokenized_query = model.encode(test_cont_3sent)
            cos_scores = util.cos_sim(tokenized_query, tokenized_corpus)[0]

            icl_idxs.append(np.argsort(-cos_scores)[:k].tolist())
    else:
        print("select other method")
    
    return icl_idxs

In [99]:
ICL_METHOD = "encoder"
icls = create_icl(train_df, eval_df, method=ICL_METHOD)

2690
299
2690
299


modules.json:   0%|          | 0.00/387 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/179k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/57.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/418 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/200 [00:00<?, ?B/s]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)

In [92]:
print(icls)

[[2091, 2125, 973, 2553, 208], [2378, 1155, 1587, 2322, 2400], [818, 16, 716, 2124, 2150], [1485, 363, 1518, 924, 1400], [905, 1773, 712, 1968, 1067], [765, 498, 1508, 506, 375], [2434, 2020, 1375, 1130, 1064], [1009, 2538, 2380, 1741, 1976], [595, 2220, 1959, 2424, 112], [2333, 2520, 1854, 761, 674], [1537, 628, 719, 1905, 1563], [1506, 1211, 364, 1250, 2448], [1901, 614, 45, 1711, 650], [2011, 1102, 180, 1819, 686], [1441, 1498, 1321, 2375, 1831], [2074, 962, 2208, 1963, 1181], [1745, 1950, 4, 1393, 2605], [1540, 2666, 1326, 2290, 2003], [922, 875, 2261, 2473, 2421], [2060, 1664, 641, 570, 1215], [1272, 1002, 1702, 1574, 446], [703, 342, 368, 2368, 1784], [280, 2402, 56, 2294, 1035], [71, 1677, 837, 530, 2686], [1730, 1427, 2493, 958, 2520], [2398, 1852, 255, 975, 1959], [205, 2074, 2009, 801, 847], [232, 1797, 1087, 910, 1433], [833, 2478, 1954, 2060, 859], [582, 1305, 1947, 340, 2373], [491, 2200, 1667, 884, 1559], [1845, 2474, 1791, 464, 1938], [868, 449, 1493, 1901, 122], [1514, 

In [90]:
import json
output_dir = f"/data/group1/z40436a/ME/URL_Citation_Classification_Intermediate/icl/{ICL_METHOD}"
with open(f"{output_dir}/{str(seed)}.txt", "w") as jsonl_file:
    for icl in icls:
        json.dump(icl, jsonl_file)
        jsonl_file.write("\n")