## Import the Segmented Corpus

In [None]:
# If these packages are not installed:

# ! pip install git+https://github.com/iinemo/isanlp.git
# ! pip install isanlp_rst
# ! pip install hf_xet

In [None]:
# === Import
# import pandas as pd
import sys
import json
from pathlib import Path

import warnings
warnings.filterwarnings("ignore", message="`encoder_attention_mask` is deprecated")

# === Define the path to the auxiliary modules ===
ROOT = Path.cwd().parent
SRC = (ROOT / "src").resolve()

if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))

# === import the module for rst work === <--- this is not ready yet
import importlib
import discourse.rst as rst

In [None]:
# === Define the path to the data and the pattern for retrieval ==
HOME = Path.home()
DATA_DIR = (HOME / "My Drive" / "_VectorData" / "projects" / "identifying_depression_with_rst" / "data").resolve(strict=True)

corpus_path = DATA_DIR / "processed"
corpus_file = corpus_path / "segmented_corpora_320_token_split.json"

with open(corpus_file, "r") as file:
    corpora = json.load(file)

In [None]:
# === This is just in case for possible debugging (with verbose output) ===
# import transformers
# transformers.utils.logging.set_verbosity_info()
# transformers.utils.logging.enable_explicit_format()

## Visual Inspection

In [None]:
# === A note on the structure of the resulting segmented corpora: ===
# Each separate corpus in the corpora is the value for the key indicating the name of this sepcific corpus (like 'ked' in this case)

corpora.keys()

In [None]:
# === Further down, the tree the structure is as follows: ===
# The value of the key is a list of 2 items
# Where each item is also a list
# The first list is the original text either as a single list item if it has not been split
# Or as several itmes, which are the resulting chunks of the splitting pipeline upstream
# The second list is made up of the sentences returned by the sentence tokenizer as list items

# So, pulling the text (or the resulting chunks) for the "ked" coprus looks something like this:

corpora["ked"][0][0]

In [None]:
len(corpora["ked"])

## Prep Everything for Running the RST Parser

In [None]:
# In case we need to reload the module
rst = importlib.reload(rst)

In [None]:
# === Initialize the Parser
# model = 'tchewik/isanlp_rst_v3'
# version = 'gumrrg'  # Choose from {'gumrrg', 'rstdt', 'rstreebank'}

rst.init_parser()

In [None]:
# Put the target corpus as texts/chunks into a separate variable for easier navigation/iteration logic downstream

CORPUS_NAME = "ked"

corpus = [item[0] for item in corpora[CORPUS_NAME]] # grab only the texts/segments, not the texts as sentences

In [None]:
# Double-check the strucutre of the corpus is what the parser would expect

corpus[:5]

In [None]:
# Check that the nuber of segmented texts matches what we got in the previous stage (segmentation / 01_segemnt_documents.ipynb)

find_split_texts = [i for i in corpus if len(i) > 1]
len(find_split_texts)

## Parsing the Selected Corpus

In [None]:
# === This works on the specific corpus from the corpora: ===

# 
# The code expects the corpus to be a list of items, where each item is also a list
# containing the whole document as one list item or the doucment in two or more chunks (as list items)
# e.g. [['doc_1'], ['seg1_of_doc_2', 'seg2_of_doc_2'], ['doc_3'] ... ]
# There is also a safeguard to normalize each item in corpus to a list[str]
# for cases where the structure of the corpus may be like
# ["doc_1", ["seg1_of_doc_2", "seg2_of_doc_2"], "doc_3", ... ]

parsed_coprus = rst.parse_corpus(corpus)

## A Bit of Visual Inspection

In [None]:
len(parsed_coprus)

In [None]:
# === What does the parser return and how is it structured in our 'segmented' corpus?

# First, each item is a list of at least one item -- a dictionary returned by the parser:
parsed_corpus = parsed_coprus[0]

In [None]:
# Further down the tree, each such dictionary has the key "rst", which stores the results of parsing as a list of one item (in our case)
# This item is the RST object/tree proper
# It can be explored using the 'vars' function:

parsed_corpus[0]

## Extract the Features

In [None]:
all_features = rst.extract_all_rst_features(parsed_corpus)

In [None]:
all_features

## Follow-up (in the Notebook for Now, in the Module Later)

### Get a List of All Relations --> Module

In [None]:
all_relations = set(k for item in all_features for k in item["relation_counts"])

In [None]:
all_relations

### Transform the Diagnosis into Labels

In [None]:
corpora[CORPUS_NAME]

In [None]:
# === Transform the diagnoses (now stored as real labels in the 'diagnoses' variable) into ML appropriate labels

from sklearn.preprocessing import LabelEncoder

encoder = LabelEncoder()
y_encoded = encoder.fit_transform(diagnoses)

rst_features_dep = []
rst_features_ok = []

for i in range(len(rst_features)):
    if y_encoded[i] == 0:
        rst_features_dep.append(rst_features[i])
    else:
        rst_features_ok.append(rst_features[i])

relation_counts_dep = count_relations(rst_features_dep)
relation_counts_ok = count_relations(rst_features_ok)

def convert_abs_rel_counts(relation_counts):
    total = 0
    
    relative_rels = []
    for item in relation_counts:
        total += int(item[1])
    
    for item in relation_counts:
        interim_rels = (item[0], item[1]/total)
        relative_rels.append(interim_rels)
    return relative_rels

relative_relation_counts_dep = convert_abs_rel_counts(relation_counts_dep)
relative_relation_counts_ok = convert_abs_rel_counts(relation_counts_ok)

causal_rel_props_dep = [item['relation_proportions'].get("causal", 0.0) for item in rst_features_dep]
causal_rel_props_ok = [item['relation_proportions'].get("causal", 0.0) for item in rst_features_ok]

def _count_relations(rst_features_segment):
    relation_counter = Counter()
    for item in rst_features_segment:
        relation_counter.update(item["relation_counts"])

    return relation_counter.most_common()

## Investigate Why A Segment is Reported to be over 512

In [None]:
import logging, transformers
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_explicit_format()

parsed_corpus, offenders = [], []
for di, doc in enumerate(corpus):
    for si, seg in enumerate(_as_segments(doc)):
        print(f"processing doc:{di}, seg:{si}", flush=True)
        try:
            _ = parser(seg)                 # just probe once
            parsed_corpus.append(None)      # we're not keeping results here
        except Exception as e:
            offenders.append((di, si, str(e)))


In [None]:
from transformers import AutoTokenizer

# model_name = "DeepPavlov/rubert-base-cased"
# model_name = "ai-forever/ruBert-base"
model_name = "sberbank-ai/ruBert-large"


tok = AutoTokenizer.from_pretrained(model_name)
for item in corpus:
    for segment in item:
        print(len(tok(normalize_for_parser(segment), add_special_tokens=True)["input_ids"]))

In [None]:
res_temp = parser(normalize_for_parser(corpus[106][2]))

In [None]:
import re, unicodedata
from transformers import AutoTokenizer

# 1) close-enough tokenizer to what the RST parser likely uses
xlmr = AutoTokenizer.from_pretrained("xlm-roberta-base")

def single_len(text: str) -> int:
    return len(xlmr(text, add_special_tokens=True)["input_ids"])

def pair_len(a: str, b: str) -> int:
    # this mimics how RoBERTa encodes pairs: <s>a</s></s>b</s>
    return len(xlmr(a, b, add_special_tokens=True)["input_ids"])

def max_pair_len(text: str) -> int:
    # super-naive sentence split is fine for triage; plug in your splitter if you want
    sents = [t for t in re.split(r'(?<=[.!?])\s+', text) if t.strip()]
    m = 0
    for i in range(len(sents)-1):
        m = max(m, pair_len(sents[i], sents[i+1]))
        if m >= 512:  # early exit
            break
    return m

offenders = []
for di, doc in enumerate(corpus):
    for si, raw in enumerate(_as_segments(doc)):
        seg = normalize_for_parser(raw)
        s_len = single_len(seg)
        p_len = max_pair_len(seg)
        if s_len >= 512 or p_len >= 512 or p_len >= 508:   # 508â‰ˆpractical tripwire
            offenders.append((di, si, s_len, p_len))

print(f"Found {len(offenders)} risky segments (doc_idx, seg_idx, single_len, max_pair_len):")
for row in offenders[:20]:
    print(row)

## NEXT:
* Walk the RST Trees and Extract the Data (this part is WIP)