In [61]:
from torch.backends.cuda import allow_fp16_bf16_reduction_math_sdp

from core.PC_NET import PCNet
import torch
from torch.utils.data import DataLoader
from core.dataset import collate_fn
import yaml
from sklearn.metrics import f1_score, classification_report
from core.config import punct_label2id

model = PCNet.load_from_checkpoint('../logs_all/final_model_2_layer_Stable_adamw_lr_5e-5/PCNet_lr5e-5_layers2/best-checkpoint.ckpt')

with open('../scripts/always_predictable.yml', 'r') as file:
    yml_file = yaml.safe_load(file)

always_capital = yml_file["always_capitalized_tokens"]
always_period = yml_file["always_period_abbreviations"]

dataset = torch.load('../data/pt_datasets/test.pt')
dataloader = DataLoader(dataset, batch_size=512, shuffle=False, collate_fn=collate_fn)

  dataset = torch.load('../data/pt_datasets/test.pt')


In [57]:
def post_processing(cap_label_pred, subword_tokens, punct_labels_pred, input_ids):
    for batch_num in range(len(subword_tokens)):
        for word_num in range(1, len(subword_tokens[batch_num] ) - 1):
            if input_ids[batch_num][word_num] != input_ids[batch_num][word_num -1] and input_ids[batch_num][word_num] != input_ids[batch_num][word_num +1]:
                if subword_tokens[batch_num][word_num] in always_capital:
                    cap_label_pred[batch_num][word_num] = 1
                if subword_tokens[batch_num][word_num] in always_period:
                    punct_labels_pred[batch_num][word_num] = punct_label2id["."]

    return cap_label_pred, punct_labels_pred


In [58]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cap_preds = []
cap_preds_processed = []
cap_labels_all = []

punct_labels_all = []
punct_preds_ = []
punct_preds_processed = []

for batch in dataloader:
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    punct_labels = batch["punct_labels"].to(device)
    cap_labels = batch["cap_labels"].to(device)
    subword_tokens = batch["subword_tokens"]

    punct_logits, cap_logits = model(input_ids, attention_mask)
    cap_labels_pred = torch.argmax(cap_logits, dim=-1)
    punct_labels_pred = torch.argmax(punct_logits, dim=-1)
    valid_mask = (attention_mask.view(-1) == 1) & (cap_labels.view(-1) != -100)
    cap_preds.extend(cap_labels_pred.view(-1)[valid_mask].tolist())
    cap_labels_all.extend(cap_labels.view(-1)[valid_mask].tolist())
    cap_labels_pred_processed,punct_labels_processed = post_processing(cap_labels_pred, subword_tokens,punct_labels_pred, input_ids)
    cap_preds_processed.extend(cap_labels_pred_processed.view(-1)[valid_mask].tolist())

    punct_labels_all.extend(punct_labels.view(-1)[valid_mask].tolist())
    punct_preds_.extend(punct_labels_pred.view(-1)[valid_mask].tolist())
    punct_preds_processed.extend(punct_labels_processed.view(-1)[valid_mask].tolist())




In [59]:
pre_processing_f1 = f1_score(cap_labels_all, cap_preds)
post_processing_f1 = f1_score(cap_labels_all, cap_preds_processed)

print(f"Pre-processing F1: {pre_processing_f1}")
print(f"Post-processing F1: {post_processing_f1}")



Pre-processing F1: 0.7068517585016868
Post-processing F1: 0.7110258132616831


In [63]:
preprocessed_report = classification_report(y_true=punct_labels_all,
                                            y_pred=punct_preds_,
                                            target_names=list(punct_label2id.keys()))
postprocessed_report = classification_report(y_true=punct_labels_all,
                                            y_pred=punct_preds_processed,
                                            target_names=list(punct_label2id.keys()))


In [66]:
print(preprocessed_report)


              precision    recall  f1-score   support

           O       0.95      0.98      0.97    265028
           ,       0.55      0.34      0.42     17004
           .       0.57      0.59      0.58     14518
           ?       0.53      0.27      0.36      1317

    accuracy                           0.92    297867
   macro avg       0.65      0.55      0.58    297867
weighted avg       0.91      0.92      0.91    297867



In [67]:
print(postprocessed_report)

              precision    recall  f1-score   support

           O       0.95      0.98      0.97    265028
           ,       0.55      0.34      0.42     17004
           .       0.57      0.59      0.58     14518
           ?       0.53      0.27      0.36      1317

    accuracy                           0.92    297867
   macro avg       0.65      0.55      0.58    297867
weighted avg       0.91      0.92      0.91    297867



In [28]:
def post_processing(cap_label_pred, subword_tokens, input_ids):
    for batch_num in range(len(subword_tokens)):
        for word_num in range(1, len(subword_tokens[batch_num] - 1)):
            if input_ids[batch_num][word_num] != input_ids[batch_num][word_num -1] and input_ids[batch_num][word_num] != input_ids[batch_num][word_num +1]:
                if subword_tokens[batch_num][word_num] in always_capital:
                    cap_label_pred[batch_num][word_num] = 1

    return cap_label_pred


In [45]:
punct_logits.shape

torch.Size([481, 63, 4])

In [46]:
torch.argmax(punct_logits, dim=-1).shape

torch.Size([481, 63])

In [47]:
def chunk(arr, window_size, stride):
    chunks = [arr[i:i + window_size] for i in range(0, len(arr), stride)]
    # return chunks if len(chunks[-1]) == window_size else chunks[:-1]
    return chunks

def mask_centrals(window_size, stride, include=None):
    mask = [0] * window_size
    mid = window_size // 2
    start = mid - (stride // 2)
    end = start + stride
    if include == 'left':
        mask[:end] = [1] * end
    elif include == 'right':
        mask[start:] = [1] * (window_size - start)
    else:
        mask[start:end] = [1] * stride
    return mask

In [36]:
chunks = chunk(list(range(23)), 7, 3)

In [49]:
def chunk(arr, window_size, stride):
    chunks = [arr[i:i + window_size] for i in range(0, len(arr), stride)]
    last = chunks.pop()
    for it in last:
        if it not in chunks[-1]:
            chunks[-1].append(it)
    return chunks

In [50]:
print(chunk(list(range(23)), 7, 3))

[[0, 1, 2, 3, 4, 5, 6], [3, 4, 5, 6, 7, 8, 9], [6, 7, 8, 9, 10, 11, 12], [9, 10, 11, 12, 13, 14, 15], [12, 13, 14, 15, 16, 17, 18], [15, 16, 17, 18, 19, 20, 21], [18, 19, 20, 21, 22]]


In [51]:
mask_centrals

<function __main__.mask_centrals(window_size, stride, include=None)>

In [52]:
mask_centrals(7, 3, include='right')

[0, 0, 1, 1, 1, 1, 1]

In [164]:
def mask_centrals(window_size, stride, include=None):
    mask = [0] * window_size
    mid = window_size // 2
    start = mid - (stride // 2)
    end = start + stride
    if include == 'left':
        mask[:end] = [1] * end
    elif include == 'right':
        mask[start:] = [1] * (window_size - start)
    else:
        mask[start:end] = [1] * stride
    return mask

def make_chunk(arr, window_size, stride):
    chunks = [arr[i:i + window_size] for i in range(0, len(arr), stride)]
    last = chunks.pop()
    for it in last:
        if it not in chunks[-1]:
            chunks[-1].append(it)
    return chunks

def create_pairs(len_seq, window_size,stride):
    chunks = make_chunk(list(range(len_seq)), window_size, stride)
    masks = [mask_centrals(window_size, stride,include='left')]
    masks.extend([mask_centrals(window_size, stride) for _ in range(len(chunks) - 2)])


    last = chunks.pop()
    chunks[-1] = chunks[-1] + last
    masks.append(mask_centrals(len(chunks[-1]), stride, include='right'))
    return chunks, masks

In [165]:
create_pairs(25, 9, 3)


([[0, 1, 2, 3, 4, 5, 6, 7, 8],
  [3, 4, 5, 6, 7, 8, 9, 10, 11],
  [6, 7, 8, 9, 10, 11, 12, 13, 14],
  [9, 10, 11, 12, 13, 14, 15, 16, 17],
  [12, 13, 14, 15, 16, 17, 18, 19, 20],
  [15, 16, 17, 18, 19, 20, 21, 22, 23],
  [18, 19, 20, 21, 22, 23, 24, 21, 22, 23, 24]],
 [[1, 1, 1, 1, 1, 1, 0, 0, 0],
  [0, 0, 0, 1, 1, 1, 0, 0, 0],
  [0, 0, 0, 1, 1, 1, 0, 0, 0],
  [0, 0, 0, 1, 1, 1, 0, 0, 0],
  [0, 0, 0, 1, 1, 1, 0, 0, 0],
  [0, 0, 0, 1, 1, 1, 0, 0, 0],
  [0, 0, 0, 1, 1, 1, 0, 0, 0],
  [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]])

In [161]:
mask_centrals(9,3, include='right')

[0, 0, 0, 1, 1, 1, 1, 1, 1]

In [135]:
-i - diff

-7

In [111]:
l1 = [1,2,3,4,5,6,7,8,9,10]
l1[-3:-1]

[8, 9]

49


In [154]:
def make_chunks(data, window_size, stride):
    """
    Breaks `data` (a list or array) into overlapping (or potentially partially overlapping) chunks.
    The step between consecutive chunks is given by `stride`.

    Example:
        data = [0, 1, 2, 3, 4, 5, 6]
        window_size = 3
        stride = 2
        -> chunks = [
             [0, 1, 2],
             [2, 3, 4],
             [4, 5, 6]    # Last chunk
           ]
    Then the last chunk is merged with the penultimate chunk if there are
    distinct elements that are not in the penultimate chunk.
    """
    if stride <= 0:
        raise ValueError("`stride` must be a positive integer.")
    if window_size <= 0:
        raise ValueError("`window_size` must be a positive integer.")

    # Create list of chunks
    chunks = [data[i : i + window_size] for i in range(0, len(data), stride)]

    # If there's more than one chunk, merge the last chunk into the penultimate
    # so that duplicates are not repeated.
    if len(chunks) > 1:
        last_chunk = chunks.pop()
        for item in last_chunk:
            if item not in chunks[-1]:
                chunks[-1].append(item)
    return chunks


def create_pairs(sequence_length, window_size, stride):
    """
    Returns masks and chunk indices for a sequence of length `sequence_length`.
    Internally uses `make_chunks` to chunk the range [0, 1, ..., sequence_length - 1].
    It then constructs masks using `mask_centrals`, and handles a short last chunk case
    by taking items from the penultimate chunk.

    Parameters:
        sequence_length (int) : The length of the sequence (e.g. len_seq).
        window_size (int)     : Size of each window or chunk.
        stride (int)          : Step size between consecutive chunks.

    Returns:
        masks (list)  : A list of masks created by `mask_centrals`.
        chunks (list) : A list of chunked indices.
    """
    # Create chunks from 0..sequence_length-1
    chunks = make_chunks(list(range(sequence_length)), window_size, stride)

    # Build masks: first mask with include='left', then middle masks, etc.
    # (Assuming `mask_centrals` is your own function available in the namespace.)
    masks = [mask_centrals(window_size, stride, include='left')]
    masks.extend([mask_centrals(window_size, stride) for _ in range(len(chunks) - 2)])

    # If the last chunk is smaller than window_size, attempt to prepend
    # missing items from the penultimate chunk.
    if len(chunks) >= 2 and len(chunks[-1]) < window_size:
        i = 1
        while chunks[-2][-i] in chunks[-1]:
            i += 1
        diff = window_size - len(chunks[-1])
        # Merge the needed slice from the penultimate chunk with the last chunk
        chunks[-1] = chunks[-2][-i : -i + diff] + chunks[-1]

    return masks, chunks


In [155]:
create_pairs(29, 10, 3)

([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
  [0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
  [0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
  [0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
  [0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
  [0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
  [0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
  [0, 0, 0, 0, 1, 1, 1, 0, 0, 0]],
 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
  [3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
  [6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
  [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
  [12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
  [15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
  [18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
  [21, 22, 23, 24, 25, 26, 27, 28],
  [23, 24, 25, 26, 27, 24, 25, 26, 27, 28]])

In [152]:
def make_chunks(data, window_size, stride):
    """
    Create chunks of length `window_size`, stepping by `stride`.
    If the last chunk is smaller than `window_size`, it will be "topped up"
    by elements from the penultimate chunk (without creating duplicates).

    Example:
        data = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
        window_size = 5
        stride = 3

        # Normal chunking
        #   [10, 11, 12, 13, 14]
        #   [13, 14, 15, 16, 17]
        #   [16, 17, 18, 19, 20]  <-- last chunk is already length 5
        # If last chunk was smaller, we fill it up from the penultimate chunk.
    """
    if window_size <= 0:
        raise ValueError("`window_size` must be a positive integer.")
    if stride <= 0:
        raise ValueError("`stride` must be a positive integer.")

    # Make the initial chunks
    chunks = [data[i : i + window_size] for i in range(0, len(data), stride)]

    # If there's only one chunk or none, nothing more to do
    if len(chunks) < 2:
        return chunks

    # If the last chunk is too short, top it up from the penultimate chunk
    if len(chunks[-1]) < window_size:
        needed = window_size - len(chunks[-1])
        # We'll take items from the end of the penultimate chunk
        penult_chunk = chunks[-2]

        # Gather items from the end of penult_chunk (in reverse),
        # skipping items already in the last chunk, until we have 'needed' items.
        reversed_fill = []
        for item in reversed(penult_chunk):
            # Only add if not already in the last chunk
            if item not in chunks[-1]:
                reversed_fill.append(item)
            # Stop if we have enough
            if len(reversed_fill) == needed:
                break

        # Reverse them to restore the original order
        reversed_fill.reverse()

        # Prepend to the last chunk
        chunks[-1] = reversed_fill + chunks[-1]

    return chunks


def create_pairs(sequence_length, window_size, stride):
    """
    Demonstration wrapper that shows how you might call `make_chunks`
    and do additional operations (like building masks).

    Example usage:
        masks, chunks = create_pairs(25, 5, 3)
    """
    # Chunk up the sequence of indices
    chunks = make_chunks(list(range(sequence_length)), window_size, stride)

    # Build your masks in whatever way `mask_centrals` is defined.
    # Example placeholders:
    masks = []
    if hasattr(globals(), 'mask_centrals'):
        # First mask with include='left' if that’s how you like it
        masks.append(mask_centrals(window_size, stride, include='left'))
        # For each extra chunk (beyond the first and last), do standard masking
        num_extra = max(0, len(chunks) - 2)
        masks.extend(mask_centrals(window_size, stride) for _ in range(num_extra))

    return masks, chunks

Chunks:
[15, 16, 17, 18, 19, 20, 21, 22, 23]
[17, 18, 19, 20, 21, 22, 23]
[19, 20, 21, 22, 23]
[21, 22, 23]
[21, 22, 23]
