In [None]:
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
import hydra
from hydra import compose, initialize
from omegaconf import OmegaConf, ListConfig

import numpy as np
import pandas as pd
import wandb
import os
import seaborn as sns

from typing import Optional, List, Tuple

from huggingface_hub import snapshot_download
from transformers import AutoTokenizer

from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
from vllm.lora.request import LoRARequest

from spacy.tokens import Span, Doc
from spacy import displacy
from spacy.lang.en import English

import numpy as np
from IPython.core.display import display, HTML

from utils import *

In [None]:
# init hydra config

initialize(version_base=None, config_path="../conf/generation", job_name="test_app")
cfg = compose(config_name="generation_conf_quantized")

In [None]:
wandb_cfg = cfg.wandb
os.environ["WANDB_PROJECT"] = wandb_cfg.project
os.environ["WANDB_ENTITY"] = wandb_cfg.entity
os.environ["WANDB_JOB_TYPE"] = wandb_cfg.job_type
os.environ["WANDB_LOG_MODEL"] = "false"
os.environ["WANDB_WATCH"] = "all"

os.environ["HYDRA_FULL_ERROR"] = "1"

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

# Data generation

Below you can see the prompt we will be using for rewriting essays

In [None]:
# Read promt_format
with open(cfg.prompt_path, "r") as file:
    prompt_format = file.read()
    print(prompt_format.format("\'original essay\'", "\'PII entities\'"))

Create requests for the engine with variable **entity types** and **sampling parameters**

In [None]:
from itertools import combinations
import random

PII_ENTS = [
  ("name", "NAME_STUDENT", "James Brown"), # обрати внимание ФИО это одна сущность или несколько !!!!
  ("email", "EMAIL", "example@email.com"),
  ("personal_url", "URL_PERSONAL", "https://example.com"),
  ("username", "USERNAME", "john42"),
  ("address", "STREET_ADDRESS", "221B, Baker Street, London"),
  ("phone_num", "PHONE_NUM", "+1 212 555 0188"),
  ("userid", "ID_NUM", "123456789")
]

ENT_COMBINATIONS = [
   *[[ent] for ent in PII_ENTS],
   *[[PII_ENTS[0], ent] for ent in PII_ENTS[1:]],
   *[list(comb) for comb in combinations(PII_ENTS[:4], 3)]
]

tokenizer = AutoTokenizer.from_pretrained(cfg.engine.model)

def sample_ent_combination():
   return random.choice(ENT_COMBINATIONS)

def dict2str(d):
  return "\n".join([f"{k}={v}" for k, v in d.items()])

def build_request(prompt_format, ent_combination, sampling_params, essay=None):
  true_ents_dict = {ent_type: [ent_text] for ent_description, ent_type, ent_text in ent_combination}
  pii_str = "\n".join([f"{ent_description}={ent_text}" for ent_description, ent_type, ent_text in ent_combination]) 
  prompt = prompt_format.format(pii_str) if essay is None else prompt_format.format(essay, pii_str)

  chat = [
    {"role": "user", "content": prompt}
  ]
  prompt_with_chat_template = tokenizer.apply_chat_template(chat, tokenize=False)

  request = {
    "prompt": prompt_with_chat_template,
    "true_ents_dict": true_ents_dict,
    "sampling_params": sampling_params,
    "lora_params": None
  }
  return request

def create_requests(essays, cfg) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
    essays = [None] if essays is None else essays
    generation_requests = []
    for essay in essays:
      ent_combination = sample_ent_combination()
      request = build_request(prompt_format, 
                              ent_combination, 
                              sampling_params=OmegaConf.to_container(cfg.sampling_params), 
                              essay=essay)
      generation_requests.append(request)

    return generation_requests

In [None]:
# Read essays for rewriting
orig_essays_df = pd.read_json(cfg.original_essays_path)
orig_essays_df = orig_essays_df[~orig_essays_df["has_ents"]]
essays = orig_essays_df["full_text"].tolist()

generation_requests = create_requests(essays, cfg)
generation_requests = generation_requests[:5]

In [None]:
len(generation_requests)

Feed requests to the engine

In [None]:
def initialize_engine(cfg) -> LLMEngine:
    engine_args = EngineArgs(**cfg.engine)
    return LLMEngine.from_engine_args(engine_args)

def process_requests(engine, generation_requests):
    """Continuously pro cess a list of prompts and handle the outputs."""

    generated_examples = []
    for request_id, request_data in enumerate(generation_requests):

        prompt = request_data["prompt"]
        sampling_params = SamplingParams(**request_data["sampling_params"])
        lora_request = None if not request_data["lora_params"] else LoRARequest(**request_data["lora_params"])

        engine.add_request(str(request_id), prompt, sampling_params, lora_request)


    while engine.has_unfinished_requests():
        request_outputs = engine.step()
        for request_output in request_outputs:
            if request_output.finished:
                for output in request_output.outputs:
                    generated_text = output.text
                    request_data = generation_requests[int(request_output.request_id)]
                    generated_examples.append({"generated_text": generated_text, **request_data})
    return generated_examples

In [None]:
engine = initialize_engine(cfg)    
generated_examples = process_requests(engine, generation_requests)

In [None]:
generation_df = pd.DataFrame().from_records(generated_examples)

# Entity detection

In [None]:
from utils import *

In [None]:
# Replace generated entities with their accrording labels
generation_df = generation_df.agg(replace_ents_with_labels, axis=1)

# Tokenize texts (with labels instead of entities)
generation_df = generation_df.agg(tokenize_df_with_spacy, axis=1)

# Find entities and mark their positions
generation_df = generation_df.agg(mark_ent_label_tokens, axis=1)

# Replace labels with new entities
# generation_df = generation_df.agg(replace_labels_with_ents, axis=1)

generation_df.head(2)

### Sanity checks

In [None]:
len_df = generation_df[["tokens", "trailing_whitespace", "labels"]].applymap(len)
mask = (len_df["tokens"] == len_df["trailing_whitespace"]) & (len_df["tokens"] == len_df["labels"])
assert mask.astype(int).agg("prod") == 1

### Analysys of generated data 

In [None]:
row = generation_df.iloc[0]
print(set(row["labels"]))
html = visualize_ents(row["tokens"], row["trailing_whitespace"], row["labels"])
display(HTML(html))

In [None]:
print("Number of entities present in generated text")

total_ents = generation_df["ents_present_in_generated_text"].apply(lambda x : [k for k, v in x.items()]).explode().value_counts().to_dict()
missing_ents_dict = generation_df["ents_present_in_generated_text"].apply(lambda x : [k for k, v in x.items() if v == True]).explode().value_counts().to_dict()

for k in total_ents.keys():
    print(f"{k} -> {missing_ents_dict[k] if k in missing_ents_dict else 0} out of {total_ents[k]} present")

In [None]:
print("Most popular enity combinations")
generation_df["ents_present_in_generated_text"].apply(lambda x : [k for k, v in x.items() if v == True]).value_counts()

In [None]:
all_ent_pos = []
for label2pos in generation_df["label2position"]:
    for label, positions in label2pos.items():
        for pos in positions:
            all_ent_pos.append({"ent": label, "pos": pos})

ents_pos_distr_df = pd.DataFrame().from_records(all_ent_pos)

# sns.displot(ents_pos_distr_df, x="pos", col="ent")
sns.displot(ents_pos_distr_df, x="pos", hue="ent", multiple="stack")

# Log generated data

In [None]:
def add_visualization(row):
    html = visualize_ents(row["tokens"], row["trailing_whitespace"], row["labels"])
    row["vizualization"] = wandb.Html(html)
    return row

log_df = generation_df.apply(add_visualization, axis=1)
log_df = log_df.drop(columns=["ents_present_in_generated_text"])

ents_present_in_generated_text = pd.json_normalize(generation_df["ents_present_in_generated_text"])
log_df = pd.concat([log_df, ents_present_in_generated_text], axis=1)
log_df.head(2)

In [None]:
with wandb.init(name=cfg.wandb.run_name, job_type=cfg.wandb.job_type) as run:
    table = wandb.Table(dataframe=log_df)
    run.summary["generated_texts"] = table