In [1]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence as PACK, pad_packed_sequence as PAD
from torch.nn.utils.rnn import pad_sequence
from abc import ABC, abstractmethod
import nltk

In [5]:
# !rm -rf ./text-segmentation
# !git clone https://github.com/koomri/text-segmentation.git
!rm -rf ./data
!mkdir -p ./data/choi
!cp -r ./text-segmentation/data/choi ./data/

In [2]:
from transformers import AutoTokenizer, AutoModel

# SMALL SBERT
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
nse_model = AutoModel.from_pretrained("cointegrated/rubert-tiny2")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def batch_calc_docs_embs(batch_docs):
        list_docs_embs = []
        for doc in batch_docs:
            tokenized_docs = tokenizer(
                doc,
                padding=True,
                truncation=True,
                return_tensors='pt',
                return_token_type_ids=False,
                return_attention_mask=False
                )

            with torch.no_grad():
                tokenized_docs = {k: v.to(nse_model.device) for k, v in tokenized_docs.items()}
                model_output = nse_model(**tokenized_docs)

            docs_embs = 0
            docs_embs = model_output.last_hidden_state[:, 0, :]
            docs_embs = torch.nn.functional.normalize(docs_embs)
            list_docs_embs.append(docs_embs)
        batch_docs_embs = pad_sequence(list_docs_embs, batch_first=True)
        return batch_docs_embs

sample_text = """We use the Pk metric as defined in Beeferman
et al. (1999) to evaluate the performance of our
model. Pk is the probability that when passing a
sliding window of size k over sentences, the sentences at the boundaries of the window will be incorrectly classified as belonging to the same segment (or vice versa). To match the setup of Chen
et al. (2009), we also provide the Pk metric for a
sliding window over words when evaluating on the
datasets from their paper"""
sent_detector = nltk.data.load('tokenizers/punkt/russian.pickle')
sample_sents = sent_detector.tokenize(sample_text)
sample_sents = [sample_sents]
sample_lengths = [len(s) for s in sample_sents]
sample_lengths = torch.LongTensor(sample_lengths)
sample_embs = batch_calc_docs_embs(sample_sents)
sample_targets = torch.zeros(1, len(sample_sents[0]))
sample_targets[:, sample_targets.shape[1]//2] = 1 # split into two segments in the middle

# create dummy batch of copies
batch_size = 2
sample_embs = sample_embs.expand(batch_size, -1, -1)
sample_targets = sample_targets.expand(batch_size, -1)
sample_lengths = sample_lengths.expand(batch_size)

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


In [4]:
%load_ext autoreload
%autoreload 2

In [4]:
filepath = './train_fit.py'
# ! chmod 755 {filepath}
exp_name = 'choi_test'
!rm -rf {exp_name}
# !cd {nse_topseg_path} && python train_fit.py --dataset choi -exp {exp_name}
! python {filepath} \
    --dataset choi -exp {exp_name} \
    --wandb --wandb_key aee284a72205e2d6787bd3ce266c5b9aefefa42c \
    --online_encoding --metric='F1' --verbose \
    --encoder="cointegrated/rubert-tiny2" \
    --hidden_units=256 --num_layers=2 -lr=0.001 -bs=8

540.63s - pydevd: Sending message related to process being replaced timed-out after 5 seconds
545.76s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


2024-06-09 12:55:10.122977: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2024-06-09 12:55:10.123039: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
[nltk_data] Downloading package punkt to /home/brazen/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to /home/brazen/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to /home/brazen/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to /home/brazen/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[34m[1mwandb[0m: Currently logged in as: [33mtony-pitchblack[0m ([33moverfit1010[0m). Use [1m`wandb login --relogin`[0m to force re