In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [3]:
from importlib import reload  
import llm_programs
reload(llm_programs)

from llm_programs.programs.base import *
from llm_programs.utils import *

from llm_programs.tasks.redaction.templates import TEMPLATE_FILTER_ALL, REDACTION_TEMPLATES

  from .autonotebook import tqdm as notebook_tqdm


## Baseline: 1 prompt, 1 call, small window

In [4]:
root = Path('../..')
parent = root / 'data/contracts/confidential/'
docdir_clean = DocDir(parent / 'contracts_2_clean_v5')

path_sample = parent / 'sample_0_clean'

def fn(path_sample):
    return make_sample_docdir(docdir_clean, path_sample, seed=42)

docdir_sample = DocDir.find_or_make(path_sample, fn)

In [None]:
def gemma_parser(response):
    return lines_parser(response.replace('```', ''))

engine = LocalLM()

f = TemplatedFunction(TEMPLATE_FILTER_ALL, engine=engine, parser=gemma_parser)

window_size = 1024
window_stride = 128

for i_doc, doc in enumerate(docdir_sample.docs()):
    for i_win, window in enumerate(doc.windows(window_size, window_stride)):
        print(f"=== {doc} ===")
        print(f"--- {i_win} ---")
        print(wrap(window))
        window_kws = f(extract=window)
        print(f"---------")
        print(window_kws)
        print(f"---------")
        print()

## Prompt decomposition: 1 call, N prompts

In [None]:
window_size = 100
window_stride = 50

g = LMFunction(engine=engine, parser=gemma_parser)

for i_doc, doc in enumerate(docdir_sample.docs()):
    for i_win, window in enumerate(doc.windows(window_size, window_stride)):
        print(f"=== {doc} ===")
        print(f"--- {i_win} ---")
        print(wrap(window))
        print(f"---------")
        for key, template in REDACTION_TEMPLATES.items():
            print(key, "-->", end=" ")
            window_kws = g(template=template, extract=window)
            print(window_kws)
        print(f"---------")
        print()

## Double-prompting: filter the filtered