# Installs

In [71]:
%load_ext autoreload
%autoreload 2

In [72]:
%%capture
%%bash
rm -rf nse_topic_segmentation
git clone \
    --branch continuous-model-inference \
    https://github.com/tony-pitchblack/NSE-TopicSegmentation.git \
    nse_topic_segmentation

cd nse_topic_segmentation
pip install -r requirements.txt

# Imports

In [116]:
import pandas as pd
import numpy as np
import os
import nltk

os.environ["WANDB_API_KEY"] = "aee284a72205e2d6787bd3ce266c5b9aefefa42c"
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [117]:
import torch
import random
import numpy as np

In [118]:
# Enable deterministic behavior
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [119]:
# # Set seed
SEED = 2025

def reset_all_seeds():
    torch.manual_seed(SEED)
    random.seed(SEED)
    np.random.seed(SEED)

reset_all_seeds()

# Convert HF dataset

In [120]:
#@title HF config
import os
from pathlib import Path

# os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ['HF_TOKEN'] = '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'

from huggingface_hub import create_repo

REPO_NAME = 'news-segmentation-ntv'
REPO_ID = f"tony-pitchblack/{REPO_NAME}"

# REVISION = "dev"
REVISION = None

In [121]:
#@title download_json_dict

import json
from huggingface_hub import hf_hub_download
import os

FILE_NAME_SEGMENTS = "segments_breaks.json"
FILE_NAME_DOWNLOAD_URLS = "download_urls.json"
FILE_NAME_PLAYLIST_METADATA = "playlist_metadata.json"

def download_json_dict(file_name):
    try:
        hf_file_path = hf_hub_download(repo_id=REPO_ID, filename=file_name, repo_type='dataset')
        with open(hf_file_path, "r", encoding='utf-8') as json_file:
            json_dict = json.load(json_file)
    except Exception as e:
        print(f"File {file_name} does not exist.")
        json_dict = {}

    return json_dict

In [122]:
#@title upload_json_dict

def upload_json_dict(json_dict, file_name):
    with open(file_name, "w") as json_file:
        json.dump(json_dict, json_file)

    !huggingface-cli upload {REPO_NAME} {file_name} {file_name} --repo-type=dataset

In [123]:
#@title Check for playlist_metadata in segments

# playlist_metadata_dict = download_json_dict(FILE_NAME_PLAYLIST_METADATA)
# wrong_ids = {id for id, metadata in playlist_metadata_dict.items() if isinstance(metadata, list)}
# print(len(wrong_ids))

# for id in wrong_ids:
#     playlist_metadata_dict.pop(id)

# # upload_json_dict(playlist_metadata_dict, FILE_NAME_PLAYLIST_METADATA)

In [124]:
#@title Check for segments_breaks in playlist_metadata

# segments_breaks_dict = download_json_dict(FILE_NAME_SEGMENTS)
# wrong_ids = {id for id, seg in segments_breaks_dict.items() if isinstance(seg, dict)}
# print(len(wrong_ids))

# for id in wrong_ids:
#     segments_breaks_dict.pop(id)

# # upload_json_dict(segments_breaks_dict, FILE_NAME_SEGMENTS)

In [125]:
%%capture
from huggingface_hub import snapshot_download

# MODEL_NAME = "distil-whisper/distil-large-v2"
# MODEL_NAME = "openai/whisper-large-v2"
MODEL_NAME = "openai/whisper-large-v3"
# MODEL_NAME = "mitchelldehaven/whisper-large-v2-ru"

transcripts_dir = Path('transcripts') / Path(MODEL_NAME).stem
snapshot_download(
    repo_id=REPO_ID, repo_type='dataset',
    allow_patterns=str(transcripts_dir  / "*"),
    local_dir='.',
    revision=REVISION
)

In [126]:
import os
import glob

transcripts_dir = Path('transcripts') / Path(MODEL_NAME).stem
transcribed_files = glob.glob(str(transcripts_dir / '*'))
ids_to_files = {Path(file_name).stem: file_name for file_name in transcribed_files}

len(ids_to_files)

260

In [127]:
segments_breaks_dict = download_json_dict(FILE_NAME_SEGMENTS)

## Load transcripts

In [128]:
#@title Find transcripts with certain text
import re

matched_ids = []
pattern = 'В России установлен абсолютный рекорд по вводу в эксплуатацию нового жилья.'
for transcript_id, transcript_path in ids_to_files.items():
    with open(transcript_path, "r") as json_file:
        transcript = json.load(json_file)
        match = re.search(pattern, transcript['text'])
        if match:
            matched_ids.append(transcript_id)

matched_ids

['752958']

In [129]:
transcript_id = matched_ids[0]
transcript_path = ids_to_files[transcript_id]
with open(transcript_path, "r") as json_file:
    transcript = json.load(json_file)

In [130]:
file_idx = 1

transcript_id, transcript_path = list(ids_to_files.items())[file_idx]
with open(transcript_path, "r") as json_file:
    transcript = json.load(json_file)

low = 10
high = 40

# low = 0
# high = transcript['chunks'][-1]['timestamp'][1]

[
    chunk for chunk in transcript['chunks']
    if chunk['timestamp'][0] >= low and chunk['timestamp'][1] <= high
]

[{'timestamp': [29.14, 29.26],
  'text': ' Здравствуйте на НТВ новости в студии Егор Колыванов.'},
 {'timestamp': [33.94, 34.02],
  'text': ' Сегодня исполняется ровно год с того момента, как в состав России вошли 4 новых региона.'},
 {'timestamp': [39.06, 39.28],
  'text': ' Донецкая и Луганская народные республики, а также Запорожская и Херсонская области.'}]

In [131]:
segments_breaks = segments_breaks_dict[transcript_id]
segments_breaks

[{'start': 29,
  'summary': 'Владимир Путин обратился к россиянам в день принятия новых регионов в состав государства.'},
 {'start': 460,
  'summary': 'Евгению из Новосибирской области, пострадавшему во время украинского обстрела, нужна помощь для оплаты курсов реабилитации.'},
 {'start': 617,
  'summary': 'Госпиталь Пентагона в Германии начал принимать на лечение раненых служащих ВСУ.'},
 {'start': 663,
  'summary': 'США в очередной раз оказались на грани бюджетного шатдауна. Завтра в Штатах начинается очередной финансовый год, однако в Конгрессе не могут согласовать законопроект о финансировании правительства.'},
 {'start': 740,
  'summary': 'Первый центр реабилитации для нерп на Байкале готовятся выпустить подопечных в дикую природу.'}]

In [132]:
#@title generate_transcript_chunk_mask (three chunks)

def generate_transcript_chunk_mask(transcripts, segments):
    mask = [0] * len(transcripts['chunks'])

    for segment in segments:
        segment_start = segment['start']

        # Find the chunk directly spanning the boundary timestamp
        spanning_chunk_index = next(
            (i for i, chunk in enumerate(transcripts['chunks'])
             if chunk['timestamp'][0] <= segment_start <= chunk['timestamp'][1]),
            None
        )

        if spanning_chunk_index is not None:
            # Mark the chunk directly spanning the boundary
            mask[spanning_chunk_index] = 1

            # Mark the chunk before, if it exists
            if spanning_chunk_index > 0:
                mask[spanning_chunk_index - 1] = 1

            # Mark the chunk after, if it exists
            if spanning_chunk_index < len(transcripts['chunks']) - 1:
                mask[spanning_chunk_index + 1] = 1

    return mask

In [133]:
#@title generate_transcript_chunk_mask (two chunks)

def generate_transcript_chunk_mask(transcripts, segments):
    mask = [0] * len(transcripts['chunks'])

    for segment in segments:
        segment_start = segment['start']

        # Calculate the distances to both the start and end of each chunk
        distances = [
            min(
                abs(chunk['timestamp'][0] - segment_start),  # Distance to chunk start
                abs(chunk['timestamp'][1] - segment_start)   # Distance to chunk end
            )
            for chunk in transcripts['chunks']
        ]

        # Find the indices of the two closest chunks
        closest_indices = sorted(range(len(distances)), key=lambda i: distances[i])[:2]

        # Mark these chunks in the mask
        for index in closest_indices:
            mask[index] = 1

    return mask

In [134]:
#@title generate_transcript_chunk_mask

def generate_transcript_chunk_mask(transcripts, segments_breaks, exclude_last_none=False):
    # Exclude last timestamp with Nones from candidates
    if exclude_last_none and transcripts['chunks'][-1]['timestamp'] == [None, None]:
        transcripts['chunks'] = transcripts['chunks'][:-1]
        # print("Excluded last chunk:")
        # print(transcripts['chunks'][-1])

    transcript_timestamps = [chunk['timestamp'][0] for chunk in transcripts['chunks']]
    mask = [0] * len(transcripts['chunks'])

    for seg_idx, segment_break in enumerate(segments_breaks):
        segment_start = segment_break['start']
        closest_timestamp_index = min(
            range(len(transcript_timestamps)),
            key=lambda i: abs(transcript_timestamps[i] - segment_start)
        )

        mask[closest_timestamp_index] = 1

    return np.array(mask, dtype=bool)

In [135]:
seg_mask_chunk = generate_transcript_chunk_mask(transcript, segments_breaks)
len(transcript['chunks']), len(seg_mask_chunk), sum(seg_mask_chunk)

(182, 182, 5)

In [136]:
import numpy as np

np.array(transcript['chunks'])[seg_mask_chunk]

array([{'timestamp': [29.14, 29.26], 'text': ' Здравствуйте на НТВ новости в студии Егор Колыванов.'},
       {'timestamp': [462.54, 468.16], 'text': ' произошло, воссоединение продолжается. На ее передовых рубежах с первых дней был Евгений Корнеев.'},
       {'timestamp': [617.5, 623.38], 'text': ' Госпиталь Пентагона на территории Германии начал принимать на лечение военных, получивших ранения на Украине.'},
       {'timestamp': [667.66, 670.46], 'text': ' Завтра в Штатах начинается очередной финансовый год,'},
       {'timestamp': [742.5, 744.28], 'text': ' Первый реабилитационный'}],
      dtype=object)

In [137]:
#@title get_sentences_and_segments (multiple sentences per chunk start)

def get_sentences_and_segments(transcripts, boundary_mask):
    import re

    sentences = []
    segment_boundaries = []

    transcript_chunks = transcripts['chunks']
    current_sentence = ""
    current_boundary_flag = False
    prev_boundary_flag = False

    for i, chunk in enumerate(transcript_chunks):
        text = chunk['text']

        # Split the chunk into potential sentences
        chunk_sentences = re.split(r'(?<=[.!?])\s+', text.strip())

        for j, partial_sentence in enumerate(chunk_sentences):
            if current_sentence:
                current_sentence += " " + partial_sentence
            else:
                current_sentence = partial_sentence

            # If the chunk is marked as a boundary, set the flag
            if boundary_mask[i] == 1:
                current_boundary_flag = True

            # If the partial sentence ends a sentence
            if re.search(r'[.!?]$', partial_sentence.strip()):
                sentences.append(current_sentence.strip())

                # Mark all finished sentences within a boundary chunk as boundary sentences
                if current_boundary_flag or prev_boundary_flag:
                    boundary_mark = 1
                    prev_boundary_flag = False
                else:
                    boundary_mark = 0

                segment_boundaries.append(boundary_mark)
                current_sentence = ""

        if current_boundary_flag and current_sentence:
            prev_boundary_flag = True

        # Reset the boundary flag after processing a chunk
        current_boundary_flag = False

    return sentences, np.array(segment_boundaries)

In [138]:
#@title get_sentences_and_segments
import re

def get_sentences_and_segments(transcripts, boundary_mask):
    sentences = []
    segment_boundaries = []

    transcript_chunks = transcripts['chunks']
    current_sentence = ""
    current_boundary_flag = False
    prev_boundary_flag = False

    for i, chunk in enumerate(transcript_chunks):
        text = chunk['text']

        # Split the chunk into potential sentences
        chunk_sentences = re.split(r'(?<=[.!?])\s+', text.strip())

        # If the chunk is marked as a boundary, set the flag
        if boundary_mask[i] == 1:
            current_boundary_flag = True


        for j, partial_sentence in enumerate(chunk_sentences):
            if current_sentence:
                current_sentence += " " + partial_sentence
            else:
                current_sentence = partial_sentence

            # If the partial sentence ends a sentence
            if re.search(r'[.!?]$', partial_sentence.strip()):
                sentences.append(current_sentence.strip())
                if current_boundary_flag or prev_boundary_flag:
                    boundary_mark = 1
                    current_boundary_flag = False
                    prev_boundary_flag = False
                else:
                    boundary_mark = 0

                segment_boundaries.append(boundary_mark)
                current_sentence = ""

        if current_boundary_flag and current_sentence:
            prev_boundary_flag = True

    return sentences, np.array(segment_boundaries)

In [139]:
#@title get_sentences_and_segments (w/timestamps)
import re
import numpy as np

def get_sentences_and_segments(transcripts, boundary_mask):
    sentences = []
    segment_boundaries = []

    transcript_chunks = transcripts['chunks']
    current_sentence = ""
    current_boundary_flag = False
    prev_boundary_flag = False
    current_start_time = None

    for i, chunk in enumerate(transcript_chunks):
        text = chunk['text']
        chunk_start_time, chunk_end_time = chunk['timestamp']

        if current_start_time is None:
            current_start_time = chunk_start_time

        # Split the chunk into potential sentences
        chunk_sentences = re.split(r'(?<=[.!?])\s+', text.strip())

        if boundary_mask[i] == 1:
            current_boundary_flag = True

        for j, partial_sentence in enumerate(chunk_sentences):
            if current_sentence:
                current_sentence += " " + partial_sentence
            else:
                current_sentence = partial_sentence

            # If the partial sentence ends a sentence
            if re.search(r'[.!?]$', partial_sentence.strip()):
                sentences.append({
                    "timestamp": (current_start_time, chunk_end_time),
                    "text": current_sentence.strip(),
                })

                # Add boundary information
                if current_boundary_flag or prev_boundary_flag:
                    segment_boundaries.append(1)
                    current_boundary_flag = False
                    prev_boundary_flag = False
                else:
                    segment_boundaries.append(0)

                current_sentence = ""
                current_start_time = None

        if current_boundary_flag and current_sentence:
            prev_boundary_flag = True

    # Handle any remaining unfinished sentence
    if current_sentence and current_start_time is not None:
        sentences.append({
            "timestamp": (current_start_time, chunk_end_time),
            "text": current_sentence.strip(),
        })
        segment_boundaries.append(0)

    return sentences, np.array(segment_boundaries)


In [140]:
sentences, seg_mask_sentences = get_sentences_and_segments(transcript, seg_mask_chunk)
len(sentences), len(seg_mask_sentences), sum(seg_mask_sentences), len(segments_breaks)

(156, 156, 5, 5)

In [141]:
np.array(sentences)[seg_mask_sentences == 1]

array([{'timestamp': (29.14, 29.26), 'text': 'Здравствуйте на НТВ новости в студии Егор Колыванов.'},
       {'timestamp': (462.54, 468.16), 'text': 'Специальная военная операция, благодаря которой произошло, воссоединение продолжается.'},
       {'timestamp': (617.5, 623.38), 'text': 'Госпиталь Пентагона на территории Германии начал принимать на лечение военных, получивших ранения на Украине.'},
       {'timestamp': (667.66, 674.86), 'text': 'Завтра в Штатах начинается очередной финансовый год, однако в Конгрессе не могут согласовать законопроект о финансировании правительства.'},
       {'timestamp': (742.5, 749.06), 'text': 'Первый реабилитационный центр для НЕРП на Байкале готовится выпустить подопечных в дикую природу.'}],
      dtype=object)

## View excessive segments

In [142]:
#@title get_excessive_boundaries
import numpy as np

# Example segmentation mask
# segment_boundaries = np.array([0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0])

def get_excessive_boundaries(segment_boundaries):
    # Step 1: Get indices of boundary sentences
    boundary_indices = np.argwhere(segment_boundaries == 1).flatten()

    # Step 2: Identify and group continuous segments
    gaps = np.diff(boundary_indices) != 1
    group_boundaries = np.where(gaps)[0] + 1
    groups = np.split(boundary_indices, group_boundaries)

    # Step 3: Filter groups with length > 1
    excessive_boundaries = [group for group in groups if len(group) > 1]

    # Output
    return excessive_boundaries

excessive_boundaries = get_excessive_boundaries(seg_mask_sentences)

In [143]:
print(transcript_path)
excessive_boundaries

transcripts/whisper-large-v3/743678.json


[]

In [144]:
from pprint import pprint

for idx, excessive_bdr in enumerate(excessive_boundaries):
    print(f"Excessive boundary: {idx}")
    sentence_list = np.array(sentences)[excessive_bdr].tolist()
    for sentence in sentence_list:
        pprint(sentence)
    print()

In [145]:
import numpy as np

np.array(sentences)[np.array(seg_mask_sentences, dtype=bool)]

array([{'timestamp': (29.14, 29.26), 'text': 'Здравствуйте на НТВ новости в студии Егор Колыванов.'},
       {'timestamp': (462.54, 468.16), 'text': 'Специальная военная операция, благодаря которой произошло, воссоединение продолжается.'},
       {'timestamp': (617.5, 623.38), 'text': 'Госпиталь Пентагона на территории Германии начал принимать на лечение военных, получивших ранения на Украине.'},
       {'timestamp': (667.66, 674.86), 'text': 'Завтра в Штатах начинается очередной финансовый год, однако в Конгрессе не могут согласовать законопроект о финансировании правительства.'},
       {'timestamp': (742.5, 749.06), 'text': 'Первый реабилитационный центр для НЕРП на Байкале готовится выпустить подопечных в дикую природу.'}],
      dtype=object)

In [146]:
segments_breaks

[{'start': 29,
  'summary': 'Владимир Путин обратился к россиянам в день принятия новых регионов в состав государства.'},
 {'start': 460,
  'summary': 'Евгению из Новосибирской области, пострадавшему во время украинского обстрела, нужна помощь для оплаты курсов реабилитации.'},
 {'start': 617,
  'summary': 'Госпиталь Пентагона в Германии начал принимать на лечение раненых служащих ВСУ.'},
 {'start': 663,
  'summary': 'США в очередной раз оказались на грани бюджетного шатдауна. Завтра в Штатах начинается очередной финансовый год, однако в Конгрессе не могут согласовать законопроект о финансировании правительства.'},
 {'start': 740,
  'summary': 'Первый центр реабилитации для нерп на Байкале готовятся выпустить подопечных в дикую природу.'}]

# Inference

### Load dataloader

In [147]:
from scipy.ndimage import shift

BOUNDARY_STARTS_SEGMENT = False
if BOUNDARY_STARTS_SEGMENT:
    target = seg_mask_sentences
else:
    target = shift(seg_mask_sentences, -1, cval=0)

In [148]:
sentences_texts = [sentence['text'] for sentence in sentences]
docs_with_targets = [
    [sentences_texts, target]
]

len(docs_with_targets[0][0]), len(docs_with_targets[0][1])

(156, 156)

In [149]:
from nse_topic_segmentation.models.EncoderDataset import SentenceDataset
import nltk

tag_to_ix = {0:0, 1:1, '<START>':2, '<STOP>':3} # обязательно
dataset = SentenceDataset(docs_with_targets, tag_to_ix, encoder="cointegrated/rubert-tiny2")

Computing embeddings on cpu.


Computing embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.09it/s]

Embeddings computation | total processing time: 0:00:01





In [150]:
#@title gather_dataset
from nse_topic_segmentation.models.EncoderDataset import SentenceDataset
from tqdm.notebook import tqdm

def gather_dataset(transcripts_dir, boundary_starts_segment, max_files=None):
    assert isinstance(boundary_starts_segment, bool)

    segments_breaks_dict = download_json_dict(FILE_NAME_SEGMENTS)

    transcribed_files = glob.glob(str(transcripts_dir / '*'))
    transcribed_files = transcribed_files[:max_files]

    ids_to_files = {Path(file_name).stem: file_name for file_name in transcribed_files}

    docs_with_targets = []
    for transcript_id, transcript_path in tqdm(ids_to_files.items()):
        segments_breaks = segments_breaks_dict[transcript_id]

        with open(transcript_path, "r") as json_file:
            transcript = json.load(json_file)

        seg_mask_chunk = generate_transcript_chunk_mask(transcript, segments_breaks, exclude_last_none=True)
        sentences, seg_mask_sentences = get_sentences_and_segments(transcript, seg_mask_chunk)

        if boundary_starts_segment:
            target = seg_mask_sentences
        else:
            target = shift(seg_mask_sentences, -1, cval=0)

        docs_with_targets.append([sentences, target])

    tag_to_ix = {0:0, 1:1, '<START>':2, '<STOP>':3}
    dataset = SentenceDataset(docs_with_targets, tag_to_ix, encoder="cointegrated/rubert-tiny2")

    return dataset, transcribed_files

In [151]:
#@title [DEPRECATED] fix ids with None in last timestamp

# segments_breaks_dict = download_json_dict(FILE_NAME_SEGMENTS)

# transcribed_files = glob.glob(str(transcripts_dir / '*'))
# ids_to_files = {Path(file_name).stem: file_name for file_name in transcribed_files}

# ids_with_none_ts = list()
# for transcript_id, transcript_path in tqdm(ids_to_files.items()):
#     segments_breaks = segments_breaks_dict[transcript_id]
#     with open(transcript_path, "r") as json_file:
#         transcript = json.load(json_file)

#     if transcript['chunks'][-1]['timestamp'] == [None, None]:
#         ids_with_none_ts.append(transcript_id)

# len(ids_with_none_ts)

# # Replace None timestamp
# id = ids_with_none_ts[-1]
# transcript_path = ids_to_files[id]
# with open(transcript_path, "r") as json_file:
#     transcript = json.load(json_file)

# second_last_chunk = transcript['chunks'][-2]
# last_chunk = transcript['chunks'][-1]

# ts = second_last_chunk['timestamp'][-1]
# last_chunk['timestamp'] = [ts+1, ts+5]

In [152]:
transcripts_dir = Path('transcripts') / Path(MODEL_NAME).stem
_, transcribed_files = gather_dataset(transcripts_dir, boundary_starts_segment=False, max_files=10)

  0%|          | 0/10 [00:00<?, ?it/s]

Computing embeddings on cpu.


Computing embeddings: 100%|██████████| 10/10 [00:04<00:00,  2.06it/s]

Embeddings computation | total processing time: 0:00:05





In [153]:
DOC_IDX = 0

dataset[DOC_IDX].keys()
# dataset[DOC_IDX]['source']
# dataset[DOC_IDX]['embeddings'].shape
# dataset[DOC_IDX]['target']

dict_keys(['id', 'source', 'target', 'embeddings'])

In [154]:
from torch.utils.data import DataLoader

BATCH_SIZE = 8

test_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=dataset.collater)
batch = next(iter(test_dataloader))
batch

{'id': tensor([0]),
 'src_tokens': tensor([[[ 0.1258,  0.0432,  0.1385,  ..., -0.0014,  0.1130,  0.0137],
          [-0.0050,  0.0430, -0.0055,  ...,  0.0251,  0.0383, -0.0382],
          [ 0.1116, -0.0600,  0.0206,  ..., -0.0044,  0.0769, -0.0576],
          ...,
          [ 0.0049,  0.0539,  0.0302,  ...,  0.0671, -0.0066, -0.0681],
          [ 0.0283,  0.0074,  0.0125,  ..., -0.0189,  0.1123, -0.0819],
          [ 0.0248,  0.0352, -0.0146,  ...,  0.0402,  0.0204, -0.0698]]]),
 'src_lengths': tensor([156]),
 'tgt_tokens': tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.,

In [155]:
batch['tgt_tokens'].shape

torch.Size([1, 156])

### Load ria

In [None]:
from nse_topic_segmentation.utils.load_datasets import load_dataset

max_docs_count = 20
ds_corus = load_dataset('ria', corus=True, max_docs_cnt=max_docs_count, segments_per_doc=3)

Loading <utils.corus_dataset.CorusDataset object at 0x79821a0d94b0> dataset...
Restricted to 20 documents


Collecting docs:: 100%|██████████| 20/20 [00:00<00:00, 530.38it/s]

Collecting documents | total processing time: 0:00:00





In [None]:
# TODO: remove docs_with_targets and iterate over ds_corus in SentenceDataset

docs_with_targets = []
test_split = ds_corus[0][1]
for doc, target, _ in test_split:
    docs_with_targets.append([doc, target])

len(docs_with_targets)

4

In [None]:
doc_with_target = docs_with_targets[0]

len(doc_with_target[0]), len(doc_with_target[1])

(35, 35)

In [None]:
from nse_topic_segmentation.models.EncoderDataset import SentenceDataset

tag_to_ix = {0:0, 1:1, '<START>':2, '<STOP>':3}
dataset = SentenceDataset(docs_with_targets, tag_to_ix, encoder="cointegrated/rubert-tiny2")

Computing embeddings on cpu.


Computing embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.93it/s]

Embeddings computation | total processing time: 0:00:01





In [None]:
from torch.utils.data import DataLoader

test_dataloader = DataLoader(dataset, batch_size=8, collate_fn=dataset.collater)
batch = next(iter(test_dataloader))
batch.keys()

dict_keys(['id', 'src_tokens', 'src_lengths', 'tgt_tokens', 'src_sentences', 'src_segments'])

In [None]:
batch['tgt_tokens'].shape

torch.Size([4, 59])

### Load pretrained model

In [4]:
import os
import wandb

api = wandb.Api()
artifact = api.artifact('overfit1010/lenta_BiLSTM_F1/model-k4j7vuo7:v0', type='model')
art_dir = artifact.download()
ckpt_path = os.path.join(art_dir, 'model.ckpt')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [73]:
from nse_topic_segmentation.models.lightning_model import TextSegmenter

text_seg_model = TextSegmenter.load_from_checkpoint(ckpt_path).to('cpu')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [74]:
from nse_topic_segmentation.models.EncoderDataset import Predictor

predictor_model = Predictor(
    text_seg_model,
    sentence_encoder="cointegrated/rubert-tiny2",
    persistent=False
)

predictor_model

<nse_topic_segmentation.models.EncoderDataset.Predictor at 0x7c62a2dc8a90>

In [100]:
%%time
# Run inference

reset_all_seeds()

sentences_texts = [sentence['text'] for sentence in sentences]
doc = sentences_texts
predictions = predictor_model.predict([doc], pretokenized_sents=[doc])
predictions_single_pass = predictions[0]

for k, v in predictions_single_pass.items():
    print(k, len(v), type(v))

print()

init_hidden_states
segments 4 <class 'list'>
boundaries 1 <class 'list'>
scores 1 <class 'torch.Tensor'>
embeddings 3 <class 'list'>

CPU times: user 1.63 s, sys: 6.46 ms, total: 1.63 s
Wall time: 1.66 s


In [101]:
def format_boundary_list(boundary_list):
    boundary_mask = np.array(boundary_list)
    boundary_idx = np.where(boundary_mask)[0]
    return boundary_mask, boundary_idx

boundary_mask, boundary_idx = format_boundary_list(predictions_single_pass['boundaries'][0])
boundary_idx

array([ 59,  97, 107])

In [102]:
persistent_predictor_model = Predictor(
    text_seg_model,
    sentence_encoder="cointegrated/rubert-tiny2",
    persistent=True
)

persistent_predictor_model

<nse_topic_segmentation.models.EncoderDataset.Predictor at 0x7c629d868340>

In [108]:
%%time
from tqdm.notebook import tqdm

reset_all_seeds()

persistent_predictor_model.model.model.model.reset_hidden_states()

single_sentence_docs = [[sentence['text']] for sentence in sentences]
# single_sentence_docs = single_sentence_docs[:40]

predictions_one_by_one_list = []
for doc in tqdm(single_sentence_docs):
    predictions = persistent_predictor_model.predict([doc], pretokenized_sents=[doc])
    predictions_one_by_one_list.append(predictions[0])

reset_hidden_states


  0%|          | 0/156 [00:00<?, ?it/s]

init_hidden_states
CPU times: user 2.52 s, sys: 63.1 ms, total: 2.59 s
Wall time: 3.41 s


In [109]:
def merge_predictions(*dicts):
    merged_dict = {}
    for key in dicts[0]:
        values = [d[key] for d in dicts]
        if key == 'segments' or key == 'embeddings':
            merged_dict[key] = sum(values, [])
        elif key == 'boundaries':
            merged_dict[key] = [
                sum([value[0] for value in values], [])
            ]
        elif key == 'scores':
            merged_dict[key] = torch.cat(values, dim=1)
        else:
            raise KeyError

    return merged_dict

predictions_one_by_one = merge_predictions(*predictions_one_by_one_list)

for k, v in predictions_one_by_one.items():
    print(k, len(v), type(v))

print()

boundary_mask, boundary_idx = format_boundary_list(predictions_one_by_one['boundaries'][0])
boundary_idx

segments 157 <class 'list'>
boundaries 1 <class 'list'>
scores 1 <class 'torch.Tensor'>
embeddings 1 <class 'list'>



array([101])

In [50]:
from datetime import timedelta

for idx, (sentence, boundary_flag) in enumerate(zip(sentences, boundary_mask)):
    text = sentence['text']

    start = sentence['timestamp'][0]
    end = sentence['timestamp'][1]

    if start is None:
        start = prev_start + 1 # HACK (need to estimate)

    if end is None:
        end = start + 1 # HACK (need to estimate)

    prev_start = start

    print(
        f"[{timedelta(seconds=start)} - {timedelta(seconds=end)}] ",
        "BOUNDARY: " if boundary_flag else 'INNER: ',
        text,
        sep=''
    )

[0:00:00 - 0:00:26.100000] INNER: МУЗЫКАЛЬНАЯ ЗАСТАВКА НЕРП Влада Копыловская наблюдала за животными в бассейне и в дикой природе.
[0:00:29.140000 - 0:00:29.260000] INNER: Здравствуйте на НТВ новости в студии Егор Колыванов.
[0:00:33.940000 - 0:00:34.020000] INNER: Сегодня исполняется ровно год с того момента, как в состав России вошли 4 новых региона.
[0:00:39.060000 - 0:00:39.280000] INNER: Донецкая и Луганская народные республики, а также Запорожская и Херсонская области.
[0:00:45.960000 - 0:00:47.880000] INNER: Жители этих регионов высказали свое мнение относительно воссоединения с нашей страной на референдуме.
[0:00:47.880000 - 0:00:49.740000] INNER: Глава государства поздравил граждан России с этой исторической датой.
[0:00:49.740000 - 0:00:56.560000] INNER: По словам Владимира Путина, решение быть со своим отечеством осознанное, долгожданное, выстраданное и подлинно народное.
[0:00:57.560000 - 0:01:01.240000] INNER: Ничто и никто не в силах сломить волю миллионов людей.
[0:01:02

In [64]:
for (sentence, timestamp), boundary_flag in zip(sentences, boundary_mask):
    print(
        "[BOUNDARY]" if boundary_flag else '',
        f"Sentence: {sentence} (Start: {timestamp['start']:.2f}s, End: {timestamp['end']:.2f}s)",
        sep=''
    )

TypeError: string indices must be integers

In [None]:
first_n = 10

boundary_mask = predictions[0]['boundaries'][0]
# boundary_mask = boundary_mask[:first_n]
boundary_mask = np.array(boundary_mask)
boundary_idx = np.where(boundary_mask)[0]
boundary_idx

In [None]:
if len(boundary_idx) > 0:
    print(f"Boundaries detected: {boundary_idx}")

### Eval

In [None]:
from models.CRF import BiLSTM

import warnings
from sklearn.metrics import f1_score
from sklearn.exceptions import UndefinedMetricWarning

warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

In [None]:
from pytorch_lightning import Trainer

trainer = Trainer()
res = trainer.test(text_seg_model, test_dataloader)
res

## Visualize

### Funcs

In [None]:
def recover_segments(sentences, seg_bounds):
    segments = []
    last_bound = 0

    # append segments from 0 to n-1
    for sent_idx, (is_bound) in enumerate(seg_bounds):
        if is_bound:
            segment = sentences[last_bound:sent_idx+1]
            segments.append(segment)
            last_bound = sent_idx+1

    # append last segment
    segment = sentences[last_bound:]
    segments.append(segment)

    return segments

In [None]:
from pprint import pprint

def print_segments(sections, headings=None, scores=None):
    print(f"SECTION COUNT: {len(sections)}")
    for i, section in enumerate(sections):
        print(f'\n-- SECTION {i+1} START --')
        if headings is not None:
            print(f'-- HEADING: {headings[i]}')
        pprint(section)
        print('-- SECTION END --')

### Print sentences

In [None]:
batch_idx = 0

for _ in range(batch_idx + 1):
    batch = next(iter(test_dataloader))

scores, target = text_seg_model.predict_step(batch, 0)

# text_seg_model.threshold = 0.5
# text_seg_model.test_step(batch, 0)

In [None]:
doc_idx = 0

tgt_segments = recover_segments(batch['src_sentences'][doc_idx], batch['tgt_tokens'][doc_idx])
print("tgt_segments:", len(tgt_segments))

pred_segments = recover_segments(batch['src_sentences'][doc_idx], target[doc_idx])
print("pred_segments:", len(pred_segments))

tgt_segments: 11
pred_segments: 11


In [None]:
N = len(transcribed_files)
transcribed_files[batch_idx * BATCH_SIZE + doc_idx]

'transcripts/whisper-large-v3/745878.json'

In [None]:
print_segments(tgt_segments)

SECTION COUNT: 11

-- SECTION 1 START --
['МУЗЫКАЛЬНАЯ ЗАСТАВКА Ну а также незаменимые помощники в зоне СВО.',
 'Илья Ушенин на производстве дронов Камикадзе.',
 'Сторонники левых с огоньком отметили победу на выборах в парламент Франции.',
 'Ну а также это у нас семейные.',
 'Сегодня станут известны имена победителей всероссийского конкурса.',
 'Это программа сегодня в студии.',
 'Дмитрий Завойстый, здравствуйте.']
-- SECTION END --

-- SECTION 2 START --
['Сотрудники ФСБ пресекли попытку угона на Украину стратегического '
 'бомбардировщика Ту-22М3.',
 'По данным ведомства, киевская разведка попыталась завербовать российского '
 'военного летчика.',
 'Обещала ему 3 миллиона долларов и итальянское гражданство, если доставить '
 'самолет на Украину.',
 'Написал мне в телеграм как-то неизвестный.',
 'Ни морали, ни этики.',
 'Сразу начался угроза в адрес моих близких родственников.',
 'Требовал поджечь авиационную технику.',
 'Говорит, дай мне данные по самолетам, бортовые номера, техниче

In [None]:
print_segments(pred_segments)

SECTION COUNT: 15

-- SECTION 1 START --
['МУЗЫКАЛЬНАЯ ЗАСТАВКА Подтяжка самая широкая.',
 'Сегодня Владимир Путин посетил избирательный штаб.',
 'О чем ему рассказали соприседатели штаба?',
 'Репортаж Романа Соболя.',
 'Рекорд по строительству жилья, рост экономики и доходов граждан Владимир '
 'Путин провел первое в этом году совещание с правительством.',
 'Дорожный коллапс.',
 'О том, как метель парализовала движение в Поволжье.',
 'Михаил Чернов.',
 'Все орудия на передовой надежно укрыты маскировочными сетями.',
 'За что костромские артиллеристы-десантники,ны школьникам из города Стаханов '
 'узнал Алексей Ивлеев.',
 '7 миллиардов евро на оружие для Украины по звонку из Вашингтона.',
 'Это на фоне экономического кризиса и неутвержденного бюджета.',
 'Сергей Холошевский о том, как власти Германии собственными руками гробят '
 'свою страну.',
 'МУЗЫКАЛЬНАЯ ЗАСТАВКА Здравствуйте!',
 'Вас приветствует информационная служба телекомпании НТВ.',
 'Это программа «Сегодня», ее ведущие Эльм

# Test segeval metrics

In [None]:
%%capture
!pip install segeval

In [None]:
import segeval
dataset = segeval.HEARST_1997_STARGAZER
seg1 = dataset['stargazer']['1']
seg2 = dataset['stargazer']['2']

segeval.boundary_similarity(seg1, seg2)

Decimal('0.5')

In [None]:
#@title binary_mask_to_segeval_format
def to_segeval(binary_mask):
    """
    Convert a binary mask of boundaries into the segeval dataset format.

    Args:
        binary_mask (list): A binary mask list where 1 indicates a boundary and
                            0 indicates continuation of a segment.

    Returns:
        list: A list of integers representing segment lengths.
    """
    segment_lengths = []
    current_length = 1  # Start with the first segment having a length of at least 1.

    for i in range(1, len(binary_mask)):
        if binary_mask[i] == 1:  # Boundary found
            segment_lengths.append(current_length)
            current_length = 1  # Reset for the next segment
        else:
            current_length += 1  # Continue the current segment

    # Append the final segment's length
    segment_lengths.append(current_length)

    return segment_lengths

In [None]:
#@title generate_offset_binary_masks
import numpy as np

def generate_offset_binary_masks(length, segments):
    """
    Generate two binary masks of the same length with equally spaced segments,
    where the second mask's boundaries are offset by one position compared to the first.

    Args:
        length (int): Length of the binary masks.
        segments (int): Number of segments to divide the masks into.

    Returns:
        tuple: Two binary masks as lists of integers (0 or 1).
    """
    # Ensure there are at least enough positions for the number of segments
    if segments > length:
        raise ValueError("Number of segments cannot exceed the length of the binary mask.")

    # Calculate equally spaced boundaries for the first mask
    segment_size = length // segments
    boundaries1 = [0] * length
    for i in range(segment_size, length, segment_size):
        boundaries1[i] = 1

    # Offset the boundaries in the second mask by one position
    boundaries2 = [0] * length
    for i in range(segment_size - 1, length - 1, segment_size):
        boundaries2[i] = 1

    return boundaries1, boundaries2


In [None]:
length = 700
segments_breaks = 8
mask1, mask2 = generate_offset_binary_masks(length, segments_breaks)
print("Mask 1:", mask1)
print("Mask 2:", mask2)

Mask 1: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [None]:
# mask1 = [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]
# mask2 = [0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]

In [None]:
from segeval import convert_nltk_to_masses

seg1 = convert_nltk_to_masses(''.join([str(i) for i in mask1]))
seg2 = convert_nltk_to_masses(''.join([str(i) for i in mask2]))
seg1, seg2

((88, 87, 87, 87, 87, 87, 87, 87, 4), (87, 87, 87, 87, 87, 87, 87, 87, 5))

In [None]:
metrics = [
    segeval.pk,
    segeval.window_diff,
    segeval.boundary_similarity,
    segeval.segmentation_similarity
]

results = {}
for metric in metrics:
    results.update({metric.__name__: metric(seg1, seg2)})

results

{'pk': Decimal('0.02265861027190332326283987915'),
 'window_diff': Decimal('0.02265861027190332326283987915'),
 'boundary_similarity': Decimal('0.5'),
 'segmentation_similarity': Decimal('0.9942857142857142857142857143')}