Notebook for ripser features calculation.
Based on [the code](https://github.com/danchern97/tda4atd/blob/main/features_calculation_ripser_and_templates.ipynb). 

In [None]:
# *GPU is required for ripserplusplus
# !pip install transformers
# !pip install ripserplusplus

In [None]:
subset = "test_sub"

In [None]:
model_path = "./la-tda-models/bert-base-cased-en-cola_32_3e-05_lr_0.01_decay_balanced/"

In [None]:
from transformers import AutoModelForSequenceClassification,AutoTokenizer
from multiprocessing import Process, Queue, Pool
from collections import defaultdict
import itertools
import re
import os
from pathlib import Path
import subprocess
from math import ceil
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

import json
import gzip
import os.path

In [None]:
np.random.seed(42)

In [None]:
import ripserplusplus as rpp
# import numpy as np
# from tqdm import tqdm
import time
# from utils import cutoff_matrix

###################################
# RIPSER FEATURE CALCULATION FORMAT
###################################
# Format: "h{dim}\_{type}\_{args}"

# Dimension: 0, 1, etc.; homology dimension

# Types: 
    
#     1. s: sum of lengths; example: "h1_s".
#     2. m: mean of lengths; example: "h1_m"
#     3. v: variance of lengths; example "h1_v"
#     4. e: entropy of persistence diagram.
#     2. n: number of barcodes with time of birth/death more/less then threshold.
#         2.1. b/d: birth or death
#         2.2. m/l: more or less than threshold
#         2.2. t: threshold value
#        example: "h0_n_d_m_t0.5", "h1_n_b_l_t0.75"
#     3. t: time of birth/death of the longest barcode (not incl. inf).
#         3.1. b/d: birth of death
#             example: "h0_t_d", "h1_t_b"

####################################

def barcode_pop_inf(barcode):
    """Delete all infinite barcodes"""
    for dim in barcode:
        if len(barcode[dim]):
            barcode[dim] = barcode[dim][barcode[dim]['death'] != np.inf]
    return barcode

def barcode_number(barcode, dim=0, bd='death', ml='m', t=0.5):
    """Calculate number of barcodes in h{dim} with time of birth/death more/less then threshold"""
    if len(barcode[dim]):
        if ml == 'm':
            return np.sum(barcode[dim][bd] >= t)
        elif ml == 'l':
            return np.sum(barcode[dim][bd] <= t)
        else:
            raise Exception("Wrong more/less type in barcode_number calculation")
    else:
        return 0.0
        
def barcode_time(barcode, dim=0, bd='birth'):
    """Calculate time of birth/death in h{dim} of longest barcode"""
    if len(barcode[dim]):
        max_len_idx = np.argmax(barcode[dim]['death'] - barcode[dim]['birth'])
        return barcode[dim][bd][max_len_idx]
    else:
        return 0.0
    
def barcode_number_of_barcodes(barcode, dim=0):
    return len(barcode[dim])

def barcode_entropy(barcode, dim=0):
    if len(barcode[dim]):
        lengths = barcode[dim]['death'] - barcode[dim]['birth']
        lengths /= np.sum(lengths)
        return -np.sum(lengths*np.log(lengths))
    else:
        return 0.0
    

# def barcode_lengths(barcode, dim=0):
#     return barcode[dim]['death'] - barcode[dim]['birth']

def barcode_sum(barcode, dim=0):
    """Calculate sum of lengths of barcodes in h{dim}"""
    if len(barcode[dim]):
        return np.sum(barcode[dim]['death'] - barcode[dim]['birth'])
    else:
        return 0.0

def barcode_mean(barcode, dim=0):
    """Calculate mean of lengths of barcodes in h{dim}"""
    if len(barcode[dim]):
        return np.mean(barcode[dim]['death'] - barcode[dim]['birth'])
    else:
        return 0.0

def barcode_std(barcode, dim=0):
    """Calculate std of lengths of barcodes in h{dim}"""
    if len(barcode[dim]):
        return np.std(barcode[dim]['death'] - barcode[dim]['birth'])
    else:
        return 0.0

def count_ripser_features(barcodes, feature_list=['h0_m']):
    """Calculate all provided ripser features"""
    # first pop all infs from barcodes
    barcodes = [barcode_pop_inf(barcode) for barcode in barcodes]
    # calculate features
    features = []
    for feature in feature_list:
        feature = feature.split('_')
        # dimension, feature type and args
        dim, ftype, fargs = int(feature[0][1:]), feature[1], feature[2:]
        if ftype == 's':
            feat = [barcode_sum(barcode, dim) for barcode in barcodes]
        elif ftype == 'm':
            feat = [barcode_mean(barcode, dim) for barcode in barcodes]
        elif ftype == 'v':
            feat = [barcode_std(barcode, dim) for barcode in barcodes]
        elif ftype == 'n':
            bd, ml, t = fargs[0], fargs[1], float(fargs[2][1:])
            if bd == 'b':
                bd = 'birth'
            elif bd == 'd':
                bd = 'death'
            feat = [barcode_number(barcode, dim, bd, ml, t) for barcode in barcodes]
        elif ftype == 't':
            bd = fargs[0]
            if bd == 'b':
                bd = 'birth'
            elif bd == 'd':
                bd = 'death'
            feat = [barcode_time(barcode, dim, bd) for barcode in barcodes]
        elif ftype == 'nb':
            feat = [barcode_number_of_barcodes(barcode, dim) for barcode in barcodes]
        elif ftype == 'e':
            feat = [barcode_entropy(barcode, dim) for barcode in barcodes]
        features.append(feat) 
    return np.swapaxes(np.array(features), 0, 1) # samples X n_features

def matrix_to_ripser(matrix, ntokens, lower_bound=0.0):
    """Convert matrix to appropriate ripser++ format"""
    matrix = cutoff_matrix(matrix, ntokens)
    matrix = (matrix > lower_bound).astype(np.int) * matrix
    matrix = 1.0 - matrix
    matrix -= np.diag(np.diag(matrix)) # 0 on diagonal
    matrix = np.minimum(matrix.T, matrix) # symmetrical, edge emerges if at least one direction is working
    return matrix

def run_ripser_on_matrix(matrix, dim):
    barcode = rpp.run(f"--format distance --dim {dim}", data=matrix)
    return barcode

def get_barcodes(matricies, ntokens_array, dim=1, lower_bound=0.0, layer_head=(0, 0)):
    """Get barcodes from matrix"""
    barcodes = []
    layer, head = layer_head
    
    for i, matrix in enumerate(matricies):
#         with open("log.txt", 'w') as fp: # logging into file
#             fp.write(str(layer) + "_" + str(head) + "_" + str(i) + "\n")
        matrix = matrix_to_ripser(matrix, ntokens_array[i], lower_bound)
        barcode = run_ripser_on_matrix(matrix, dim)
        barcodes.append(barcode)
    return barcodes


In [None]:
def cutoff_matrix(matrix, ntokens):
    """Return normalized submatrix of first n_tokens"""
    matrix = matrix[:ntokens, :ntokens]
    matrix /= matrix.sum(axis=1, keepdims=True)
    return matrix

## Load Data

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
max_tokens_amount  = 64
MAX_LEN = max_tokens_amount

In [None]:
r_file = model_path + 'attentions/' + subset
barcodes_file = model_path + 'barcodes/' +subset
r_file, barcodes_file

('./la-tda-models/bert-base-cased-en-cola_32_3e-05_lr_0.01_decay_balanced/attentions/test_sub',
 './la-tda-models/bert-base-cased-en-cola_32_3e-05_lr_0.01_decay_balanced/barcodes/test_sub')

In [None]:
def get_token_length(batch_texts):
    inputs = tokenizer.batch_encode_plus(batch_texts,
       return_tensors='pt',
       add_special_tokens=True,
       max_length=64,             # Max length to truncate/pad
       pad_to_max_length=True,         # Pad sentence to max length
       truncation=True
    )
    inputs = inputs['input_ids'].numpy()
    n_tokens = []
    indexes = np.argwhere(inputs == tokenizer.pad_token_id)
    for i in range(inputs.shape[0]):
        ids = indexes[(indexes == i)[:, 0]]
        if not len(ids):
            n_tokens.append(MAX_LEN)
        else:
            n_tokens.append(ids[0, 1])
    return n_tokens

In [None]:
data = pd.read_csv("./data/en-cola/" + subset + '.csv')
data['tokenizer_length'] = get_token_length(list(data['sentence'].values))
sentences = data['sentence']

In [None]:
batch_size = 10 # batch size
number_of_batches = ceil(len(data['sentence']) / batch_size)
DUMP_SIZE = 100 # number of batches to be dumped
batched_sentences = np.array_split(data['sentence'].values, number_of_batches)
number_of_files = ceil(number_of_batches / DUMP_SIZE)
adj_matricies = []
adj_filenames = []
assert number_of_batches == len(batched_sentences) # sanity check

# Ripser features calculation

Format: "h{dim}\_{type}\_{args}"

Dimension: 0, 1, etc.; homology dimension

Types: 
    
    1. s: sum of lengths; example: "h1_s".
    2. m: mean of lengths; example: "h1_m"
    3. v: variance of lengths; example "h1_v"
    4. n: number of barcodes with time of birth/death more/less then threshold.
        4.1. b/d: birth or death
        4.2. m/l: more or less than threshold
        4.2. t: threshold value
       example: "h0_n_d_m_t0.5", "h1_n_b_l_t0.75"
    5. t: time of birth/death of the longest barcode (not incl. inf).
        3.1. b/d: birth of death
        example: "h0_t_d", "h1_t_b"
    6. nb: number of barcodes in dim
       example: h0_nb
    7. e: entropy; example: "h1_e"

In [None]:
def order_files(path, subset):
    files_path = Path(path)
    files = list(filter(lambda y: (y.is_file() and subset in str(y)), files_path.iterdir()))
    files = [str(_) for _ in files]
    files = sorted(files, key=lambda x: int(x.split('_')[-1].split('of')[0][4:].strip()))
    return files

In [None]:
output_dir=model_path
attn_dir = model_path + "/attentions/"
adj_filenames = order_files(path=attn_dir, subset=subset)

In [None]:
adj_filenames

['./la-tda-models/bert-base-cased-en-cola_32_3e-05_lr_0.01_decay_balanced/attentions/test_sub_part1of1.npy.gz']

In [None]:
dim = 1
lower_bound = 1e-3

## Barcodes calculation

In [None]:
def subprocess_wrap(queue, function, args):
    queue.put(function(*args))
    queue.close()
    exit()

In [None]:
def get_only_barcodes(adj_matricies, ntokens_array, dim, lower_bound, verbose=False):
    """Get barcodes from adj matricies for each layer, head"""
    barcodes = {}
    layers, heads = range(adj_matricies.shape[1]), range(adj_matricies.shape[2])
    iter = itertools.product(layers, heads)
    if verbose:
        iter = tqdm(iter, 'Layer, Head', leave=False)
    for (layer, head) in iter:
        matricies = adj_matricies[:, layer, head, :, :]
        barcodes[(layer, head)] = get_barcodes(matricies, ntokens_array, dim, lower_bound, (layer, head))
    return barcodes

def format_barcodes(barcodes):
    """Reformat barcodes to json-compatible format"""
    return [{d: b[d].tolist() for d in b} for b in barcodes]

def save_barcodes(barcodes, filename):
    """Save barcodes to file"""
    formatted_barcodes = defaultdict(dict)
    for layer, head in barcodes:
        formatted_barcodes[layer][head] = format_barcodes(barcodes[(layer, head)])
    json.dump(formatted_barcodes, open(filename, 'w'))
    
def unite_barcodes(barcodes, barcodes_part):
    """Unite 2 barcodes"""
    for (layer, head) in barcodes_part:
        barcodes[(layer, head)].extend(barcodes_part[(layer, head)])
    return barcodes

def split_matricies_and_lengths(adj_matricies, ntokens, number_of_splits):
    splitted_ids = np.array_split(np.arange(ntokens.shape[0]), number_of_splits) 
    splitted = [(adj_matricies[ids], ntokens[ids]) for ids in splitted_ids]
    return splitted

In [None]:
barcodes_dir = model_path + 'barcodes/'
!mkdir $barcodes_dir

In [None]:
queue = Queue()
number_of_splits = 4
run_in_parallel = False

for i, filename in enumerate(tqdm(adj_filenames, desc='Barcodes calculation')):
    part = filename.split('_')[-1].split('.')[0]
    if os.path.isfile(barcodes_file + '_' + part + '.json'):
        print("file already exists")
        print("passing", barcodes_file + '_' + part + '.json')
        continue

    barcodes = defaultdict(list)
    with gzip.GzipFile(filename, 'rb') as f:
        adj_matricies = np.load(f, allow_pickle=True)
        ntokens = ntokens_array[i*batch_size*DUMP_SIZE : (i+1)*batch_size*DUMP_SIZE]
    if not run_in_parallel:
        barcodes = get_only_barcodes(adj_matricies, ntokens, dim, lower_bound, verbose=True)
    else:
        splitted = split_matricies_and_lengths(adj_matricies, ntokens, number_of_splits)
        for matricies, ntokens in tqdm(splitted, leave=False):
            p = Process(
                target=subprocess_wrap,
                args=(
                    queue,
                    get_only_barcodes,
                    (matricies, ntokens, dim, lower_bound)
                )
            ) 
            p.start()
            barcodes_part = queue.get() # block until putted and get barcodes from the queue
            p.join() # release resources
            p.close() # releasing resources of ripser
            barcodes = unite_barcodes(barcodes, barcodes_part)

    
    save_barcodes(barcodes, barcodes_file + '_' + part + '.json')

In [None]:
barcodes_file=f"{model_path}/features/barcodes/{subset}"
barcodes_file

'./la-tda-models/bert-base-cased-en-cola_32_3e-05_lr_0.01_decay_balanced//features/barcodes/test_sub'

## Barcodes' ripser features

In [None]:
barcodes_file_dir = input_dir

In [None]:
ripser_features=[
    'h0_s', 
    'h0_e',
    'h0_t_d', 
    'h0_n_d_m_t0.75',
    'h0_n_d_m_t0.5',
    'h0_n_d_l_t0.25',
    'h1_t_b',
    'h1_n_b_m_t0.25',
    'h1_n_b_l_t0.95', 
    'h1_n_b_l_t0.70',  
    'h1_s',
    'h1_e',
    'h1_v',
    'h1_nb'
]

In [None]:
json_filenames = [
    output_dir + '/barcodes/' + filename 
    for filename in os.listdir(model_path + '/barcodes/') if r_file.split('/')[-1] in filename.split('_part')[0]

]
json_filenames = sorted(json_filenames, key = lambda x: int(x.split('_')[-1].split('of')[0][4:].strip())) 
json_filenames

['./la-tda-models/bert-base-cased-en-cola_32_3e-05_lr_0.01_decay_balanced//barcodes/test_sub_part1of1.json']

In [None]:
def reformat_barcodes(barcodes):
    """Return barcodes to their original format"""
    formatted_barcodes = []
    for barcode in barcodes:
        formatted_barcode = {}
        for dim in barcode:
            formatted_barcode[int(dim)] = np.asarray(
                [(b, d) for b,d in barcode[dim]], dtype=[('birth', '<f4'), ('death', '<f4')]
            )
        formatted_barcodes.append(formatted_barcode)
    return formatted_barcodes

In [None]:
features_array = []

for filename in tqdm(json_filenames, desc='Computing ripser++'):
    barcodes = json.load(open(filename))
    print(f"Barcodes loaded from: {filename}", flush=True)
    features_part = []
    for layer in barcodes:
        features_layer = []
        for head in barcodes[layer]:
            ref_barcodes = reformat_barcodes(barcodes[layer][head])
            features = count_ripser_features(ref_barcodes, ripser_features)
            features_layer.append(features)
        features_part.append(features_layer)
    features_array.append(np.asarray(features_part))

In [None]:
features = np.concatenate(features_array, axis=2)
features.shape

(12, 12, 533, 14)

In [None]:
ripser_file=f"{model_path}features/{subset}_ripser.npy"
ripser_file

'./la-tda-models/bert-base-cased-en-cola_32_3e-05_lr_0.01_decay_balanced/features/test_sub_ripser.npy'

In [None]:
%cd $model_path
!mkdir features

In [None]:
np.save(ripser_file, features)