<a href="https://colab.research.google.com/github/yiw008/db-cleansing-xgen/blob/main/xgen_7b_data_cleansing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# XGen-7B Data Cleansing

## Set Up

In [None]:
# Check
!nvidia-smi

In [None]:
# Install packages
!pip install transformers accelerate bitsandbytes
!pip install tiktoken
!pip install github-clone
!ghclone https://github.com/deweydbb/data_clean_datasets/tree/main

In [None]:
import torch
import csv
from datetime import datetime
import ast
import random
import numpy as np
from pytz import timezone
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings('ignore')
current_timezone = 'America/Chicago'

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
checkpoint = "Salesforce/xgen-7b-8k-inst"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    checkpoint,
    trust_remote_code=True
)

tokenizer_config.json:   0%|          | 0.00/297 [00:00<?, ?B/s]

tokenization_xgen.py:   0%|          | 0.00/8.85k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/Salesforce/xgen-7b-8k-inst:
- tokenization_xgen.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


In [None]:
model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    torch_dtype=torch.bfloat16,
    load_in_8bit = True if torch.cuda.is_available() else False,
    device_map = "auto"
)

config.json:   0%|          | 0.00/510 [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


pytorch_model.bin.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

pytorch_model-00001-of-00003.bin:   0%|          | 0.00/9.95G [00:00<?, ?B/s]

pytorch_model-00002-of-00003.bin:   0%|          | 0.00/9.96G [00:00<?, ?B/s]

pytorch_model-00003-of-00003.bin:   0%|          | 0.00/7.68G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

## Helping functions

In [None]:
def transfer_csv_to_dict(csv_file_name):
  with open(csv_file_name, 'r') as file:
    csv_reader = csv.DictReader(file)
    data = [row for row in csv_reader]
  return data

In [None]:
def refine_output(output):
  # initializing substrings
  sub1 = "{"
  sub2 = "}"

  if (sub1 not in output) or (sub2 not in output):
    return False, output

  # getting index of substrings
  idx1 = output.index(sub1)
  idx2 = output.index(sub2)

  res = ''
  # getting elements in between (inclusive)
  for idx in range(idx1, idx2 + len(sub2)):
      res = res + output[idx]

  return True, res

In [None]:
def model_generation(prompt):
  inputs = tokenizer(prompt, return_tensors="pt").to(device)

  # https://huggingface.co/docs/transformers/main_classes/text_generation
  generated_ids = model.generate(
      **inputs,
      max_length=4096,
      pad_token_id=50256,
      do_sample=True,
      top_p=0.95,
      top_k=50,   #100
      temperature=0.7,
  )

  output = tokenizer.decode(generated_ids[0], skip_special_tokens=True).lstrip()

  # Output starts with ### Assistant: and ends with <|endoftext|>
  # Get just the answer
  output = output.split("### Assistant:")[-1]
  output = output.split("<|endoftext|>")[0]

  return output

In [None]:
def preprocessing(database_name, dirty_row, index):

  # Taken from Model card
  header = (
    "A chat between a curious human and an artificial intelligence assistant. "
    "The assistant gives helpful, detailed and polite answers to the human's questions.\n\n"
  )

  prompt = (
    "### Human: You are a database engineer. You need to cleanse a row of dirty data. \n\n"
    "The database name is: " + database_name + "\n\n"
    "Try to cleanse the following row of dirty data in a dictionary form: \n\n" + str(dirty_row) + "\n\n"
    "Return the updated dictionary ONLY with possible error fixed in one line! Your response should be in a valid dictionary form. All the values in the dictionary should be in a string form. "
    "Do not add or delete any attribute to the dictionary. There should be no lists in the dictionary. "
    "Your updated dictionary is:\n\n###"
  )

  output = model_generation(header + prompt)
  refine_status, output = refine_output(output)

  if refine_status == False:
    print("Fail to return in a dictionary form at index", index, "Output:", output)
    return dirty_row
  else:
    res = dirty_row
    # transfer dict str to dict
    try:
      res = ast.literal_eval(output)
    except:
      print("Fail to return in a dictionary form at index", index, "Output:", output)
    return res

In [None]:
def deduce_sub_dict(dictionary, selected_keys):
  sub_dict = {}
  for key in selected_keys:
    if key in dictionary:
      sub_dict[key] = dictionary[key]
  return sub_dict

In [None]:
def preprocessing_few_shot(database_name, dirty_row, selected_keys, attribute_name, index):

  # Taken from Model card
  header = (
    "A chat between a curious human and an artificial intelligence assistant. "
    "The assistant gives helpful, detailed and polite answers to the human's questions.\n\n"
  )

  dirty_row = deduce_sub_dict(dirty_row, selected_keys)

  prompt = (
    "### Human: You are a database engineer. You need to cleanse a row of dirty data. \n\n"
    "The database name is: " + database_name + "\n\n"
    "You will be given a row of dirty data in a dictionary form with attributes " + str(selected_keys) + ". "
    "According to the values of attributes except '" + attribute_name + "', only cleanse the value of '" + attribute_name + "'. \n\n"
    "The row of dirty data is: \n\n" + str(dirty_row) + "\n\n"
    "Return the updated dictionary ONLY with possible error at attribute '" + attribute_name + "' fixed in one line! Your response should be in a valid dictionary form. All the values in the dictionary should be in a string form. "
    "Do not add or delete any attribute to the dictionary. There should be no lists in the dictionary. "
    "Your updated dictionary is:\n\n###"
  )

  output = model_generation(header + prompt)
  refine_status, output = refine_output(output)

  if refine_status == False:
    print("Fail to return in a dictionary form at index", index, "Output:", output)
    return dirty_row
  else:
    res = dirty_row
    # transfer dict str to dict
    try:
      res = ast.literal_eval(output)
    except:
      print("Fail to return in a dictionary form at index", index, "Output:", output)
    return res

In [None]:
def evaluate(dirty_dict, clean_dict, preprocessed_dict):
  number_of_errors = 0
  detected_errors = 0
  repairs_performed = 0
  correct_repairs = 0
  correct_cells = 0
  wrong_repair_found = False
  not_performed_repair_found = False

  for i in range(len(dirty_dict)):
    for key in dirty_dict[i].keys():
      dirty_value = dirty_dict[i][key]
      clean_value = clean_dict[i][key]
      if dirty_value != clean_value:
        number_of_errors += 1

      if key in preprocessed_dict[i]:
        preprocessed_value = preprocessed_dict[i][key]
        if dirty_value != preprocessed_value:
          repairs_performed += 1
        if preprocessed_value == clean_value:
          correct_cells += 1
        if dirty_value != clean_value:
          if dirty_value != preprocessed_value:
            detected_errors += 1
            if preprocessed_value == clean_value:
              correct_repairs += 1
              if correct_repairs <= 75:
                print("Correct repair found at index", i)
                print("Key:", key)
                print("Dirty row:", dirty_dict[i])
                print("Clean row:", clean_dict[i])
                print("Preprocessed row:", preprocessed_dict[i])
            elif wrong_repair_found == False:
              print("Wrong repair found at index", i)
              print("Key:", key)
              print("Dirty row:", dirty_dict[i])
              print("Clean row:", clean_dict[i])
              print("Preprocessed row:", preprocessed_dict[i])
              wrong_repair_found = True
          elif not_performed_repair_found == False:
            print("Not performed repair found at index", i)
            print("Key:", key)
            print("Dirty row:", dirty_dict[i])
            print("Clean row:", clean_dict[i])
            print("Preprocessed row:", preprocessed_dict[i])
            not_performed_repair_found = True
        elif dirty_value != preprocessed_value and wrong_repair_found == False:
          print("Wrong repair found at index", i)
          print("Key:", key)
          print("Dirty row:", dirty_dict[i])
          print("Clean row:", clean_dict[i])
          print("Preprocessed row:", preprocessed_dict[i])
          wrong_repair_found = True

  precision = correct_repairs / repairs_performed
  recall = correct_repairs / number_of_errors
  percentage_of_errors_detected = detected_errors / number_of_errors

  print("Number of errors:", number_of_errors)
  print("Repairs performed:", repairs_performed)
  print("Correct repairs:", correct_repairs)
  print("Correct cells:", correct_cells)
  print("Detected errors:", detected_errors)
  print("Precision:", precision)
  print("Recall:", recall)
  print("Percentage of errors detected:", percentage_of_errors_detected)

  f1_score = 2 * (precision * recall) / (precision + recall)
  print("F1 Score:", f1_score)

## Hospital

In [None]:
print(datetime.now(timezone(current_timezone)))

2024-04-28 16:55:21.334445-05:00


In [None]:
dirty_dict = transfer_csv_to_dict('/content/data_clean_datasets/datasets/hospital/hospital_dirty.csv')
clean_dict = transfer_csv_to_dict('/content/data_clean_datasets/datasets/hospital/hospital_clean_rows.csv')
preprocessed_dict = []

In [None]:
for i in range(len(dirty_dict)):
  preprocessed_dict.append(preprocessing('Hospital', dirty_dict[i], i))
  if i % 50 == 0:
    print(i, datetime.now(timezone(current_timezone)))

0 2024-04-28 16:56:04.607982-05:00
50 2024-04-28 17:27:19.239893-05:00
100 2024-04-28 18:00:19.290709-05:00
150 2024-04-28 18:30:36.247458-05:00
200 2024-04-28 19:03:35.501336-05:00
250 2024-04-28 19:38:51.543606-05:00
Fail to return in a dictionary form at index 283 Output:  {'ProviderNumber': '10016', 'HospitalName': 'shelby baptist medical center', 'Address1': '1000 first street north', 'Address2': '', 'Address3': '', 'City': 'alabaster', 'State': 'al', 'ZipCode': '3x007', 'CountyName': 'shelby', 'PhoneNumber': '205.628.81

300 2024-04-28 20:11:19.767266-05:00
Fail to return in a dictionary form at index 350 Output: {'ProviderNumber': '10056', 'HospitalName': 'St Vincent's Hospital', 'Address1': '810 St Vincent's Drive', 'City': 'Birmingham', 'State': 'AL', 'ZipCode': '35205', 'CountyName': 'Jefferson', 'PhoneNumber': '205-937-7000', 'HospitalType': 'Acute Care Hospitals', 'HospitalOwner': 'Voluntary Non-Profit - Other', 'EmergencyService': 'Yes', 'Condition': 'Heart Attack', 'Measu

In [None]:
print(datetime.now(timezone(current_timezone)))

2024-04-29 03:52:16.128166-05:00


In [None]:
evaluate(dirty_dict, clean_dict, preprocessed_dict)

Wrong repair found at index 0
Key: PhoneNumber
Dirty row: {'ProviderNumber': '10018', 'HospitalName': 'callahan eye foundation hospital', 'Address1': '1720 university blvd', 'Address2': '', 'Address3': '', 'City': 'birmingham', 'State': 'al', 'ZipCode': '35233', 'CountyName': 'jefferson', 'PhoneNumber': '2053258100', 'HospitalType': 'acute care hospitals', 'HospitalOwner': 'voluntary non-profit - private', 'EmergencyService': 'yes', 'Condition': 'surgical infection prevention', 'MeasureCode': 'scip-card-2', 'MeasureName': 'surgery patients who were taking heart drugs caxxed beta bxockers before coming to the hospitax who were kept on the beta bxockers during the period just before and after their surgery', 'Score': '', 'Sample': '', 'Stateavg': 'al_scip-card-2'}
Clean row: {'ProviderNumber': '10018', 'HospitalName': 'callahan eye foundation hospital', 'Address1': '1720 university blvd', 'Address2': '', 'Address3': '', 'City': 'birmingham', 'State': 'al', 'ZipCode': '35233', 'CountyName

In [None]:
np.save('/content/hospital_preprocessed_dict.npy', preprocessed_dict)

## Hospital (Few-shot on City)

In [None]:
# Only use this code block if you are using Google Colab.
# If you are using Jupyter Notebook, please ignore this code block. You can directly upload the file to your Jupyter Notebook file systems.
from google.colab import files

## It will prompt you to select a local file. Click on “Choose Files” then select and upload the file.
## Wait for the file to be 100% uploaded. You should see the name of the file once Colab has uploaded it.

# Upload hospital_preprocessed_dict.npy if you need to.
uploaded = files.upload()

In [None]:
print(datetime.now(timezone(current_timezone)))

2024-04-29 04:10:57.766952-05:00


In [None]:
dirty_dict = transfer_csv_to_dict('/content/data_clean_datasets/datasets/hospital/hospital_dirty.csv')
clean_dict = transfer_csv_to_dict('/content/data_clean_datasets/datasets/hospital/hospital_clean_rows.csv')
preprocessed_dict_selected = []
selected_keys = ['HospitalName', 'Address1', 'City', 'State', 'ZipCode', 'CountyName']
attribute_name = 'City'

In [None]:
for i in range(len(dirty_dict)):
  preprocessed_dict_selected.append(preprocessing_few_shot('Hospital', dirty_dict[i], selected_keys, attribute_name, i))
  if i % 50 == 0:
    print(i, datetime.now(timezone(current_timezone)))

0 2024-04-29 04:11:16.270671-05:00
50 2024-04-29 04:25:14.275911-05:00
100 2024-04-29 04:38:47.655431-05:00
150 2024-04-29 04:50:37.647026-05:00
200 2024-04-29 05:03:58.340658-05:00
250 2024-04-29 05:17:59.213628-05:00
300 2024-04-29 05:31:53.756526-05:00
350 2024-04-29 05:43:38.269615-05:00
400 2024-04-29 05:57:33.012667-05:00
450 2024-04-29 06:11:24.825664-05:00
500 2024-04-29 06:23:30.671584-05:00
550 2024-04-29 06:36:20.159043-05:00
600 2024-04-29 06:48:50.288443-05:00
650 2024-04-29 07:01:00.729594-05:00
700 2024-04-29 07:13:19.464592-05:00
750 2024-04-29 07:25:19.760747-05:00
800 2024-04-29 07:38:39.895862-05:00
850 2024-04-29 07:50:52.507701-05:00
900 2024-04-29 08:03:16.993523-05:00
950 2024-04-29 08:16:12.745019-05:00


In [None]:
print(datetime.now(timezone(current_timezone)))

2024-04-29 08:28:18.495951-05:00


In [None]:
preprocessed_dict = np.load('/content/hospital_preprocessed_dict.npy', allow_pickle=True).tolist()
for i in range(len(preprocessed_dict)):
  if 'City' in preprocessed_dict_selected[i]:
    preprocessed_dict[i]['City'] = preprocessed_dict_selected[i]['City']

In [None]:
evaluate(dirty_dict, clean_dict, preprocessed_dict)

Wrong repair found at index 0
Key: PhoneNumber
Dirty row: {'ProviderNumber': '10018', 'HospitalName': 'callahan eye foundation hospital', 'Address1': '1720 university blvd', 'Address2': '', 'Address3': '', 'City': 'birmingham', 'State': 'al', 'ZipCode': '35233', 'CountyName': 'jefferson', 'PhoneNumber': '2053258100', 'HospitalType': 'acute care hospitals', 'HospitalOwner': 'voluntary non-profit - private', 'EmergencyService': 'yes', 'Condition': 'surgical infection prevention', 'MeasureCode': 'scip-card-2', 'MeasureName': 'surgery patients who were taking heart drugs caxxed beta bxockers before coming to the hospitax who were kept on the beta bxockers during the period just before and after their surgery', 'Score': '', 'Sample': '', 'Stateavg': 'al_scip-card-2'}
Clean row: {'ProviderNumber': '10018', 'HospitalName': 'callahan eye foundation hospital', 'Address1': '1720 university blvd', 'Address2': '', 'Address3': '', 'City': 'birmingham', 'State': 'al', 'ZipCode': '35233', 'CountyName

In [None]:
np.save('/content/hospital_preprocessed_dict_with_few_shot_on_city.npy', preprocessed_dict)

## Adults

In [None]:
print(datetime.now(timezone(current_timezone)))

2024-04-29 16:09:34.308221-05:00


In [None]:
dirty_dict = transfer_csv_to_dict('/content/data_clean_datasets/datasets/adults/adults_dirty.csv')
clean_dict = transfer_csv_to_dict('/content/data_clean_datasets/datasets/adults/adults_clean.csv')
preprocessed_dict = []

In [None]:
for i in range(len(dirty_dict)):
  preprocessed_dict.append(preprocessing('Adults', dirty_dict[i], i))
  if i % 50 == 0:
    print(i, datetime.now(timezone(current_timezone)))

0 2024-04-29 16:10:04.130871-05:00
50 2024-04-29 16:31:23.104498-05:00
100 2024-04-29 16:54:34.655082-05:00
150 2024-04-29 17:15:47.488474-05:00
200 2024-04-29 17:39:27.820875-05:00
250 2024-04-29 18:01:52.583849-05:00
300 2024-04-29 18:26:47.703096-05:00
350 2024-04-29 18:50:39.845744-05:00
400 2024-04-29 19:16:00.762622-05:00
450 2024-04-29 19:37:04.732521-05:00
500 2024-04-29 20:02:35.890639-05:00
550 2024-04-29 20:23:24.587574-05:00
600 2024-04-29 20:46:46.603774-05:00
650 2024-04-29 21:08:30.305839-05:00
700 2024-04-29 21:31:26.884791-05:00
750 2024-04-29 21:53:56.925844-05:00
800 2024-04-29 22:15:03.936593-05:00
850 2024-04-29 22:38:06.459982-05:00
900 2024-04-29 23:01:14.042935-05:00
950 2024-04-29 23:23:39.996718-05:00


In [None]:
print(datetime.now(timezone(current_timezone)))

2024-04-29 23:48:24.863047-05:00


In [None]:
evaluate(dirty_dict, clean_dict, preprocessed_dict)

Wrong repair found at index 1
Key: age
Dirty row: {'row_id': '1', 'age': '>50', 'workclass': 'Private', 'education': 'HS-grad', 'maritalstatus': 'Married-civ-spouse', 'occupation': 'Craft-repair', 'relationship': 'Husband', 'race': 'White', 'sex': 'Male', 'hoursperweek': '16', 'country': 'United-States', 'income': 'LessThan50K'}
Clean row: {'row_id': '1', 'age': '>50', 'workclass': 'Private', 'education': 'HS-grad', 'maritalstatus': 'Married-civ-spouse', 'occupation': 'Craft-repair', 'relationship': 'Husband', 'race': 'White', 'sex': 'Male', 'hoursperweek': '16', 'country': 'United-States', 'income': 'LessThan50K'}
Preprocessed row: {'row_id': '1', 'age': '50', 'workclass': 'Private', 'education': 'HS-grad', 'maritalstatus': 'Married-civ-spouse', 'occupation': 'Craft-repair', 'relationship': 'Husband', 'race': 'White', 'sex': 'Male', 'hoursperweek': '16', 'country': 'United-States', 'income': 'LessThan50K'}
Not performed repair found at index 2
Key: country
Dirty row: {'row_id': '2', '

In [None]:
np.save('/content/adults_preprocessed_dict.npy', preprocessed_dict)

## Adults (Few-shot on Occupation)

In [None]:
# Only use this code block if you are using Google Colab.
# If you are using Jupyter Notebook, please ignore this code block. You can directly upload the file to your Jupyter Notebook file systems.
from google.colab import files

## It will prompt you to select a local file. Click on “Choose Files” then select and upload the file.
## Wait for the file to be 100% uploaded. You should see the name of the file once Colab has uploaded it.

# Upload adults_preprocessed_dict.npy if you need to.
uploaded = files.upload()

In [None]:
print(datetime.now(timezone(current_timezone)))

2024-04-30 19:31:18.232240-05:00


In [None]:
dirty_dict = transfer_csv_to_dict('/content/data_clean_datasets/datasets/adults/adults_dirty.csv')
clean_dict = transfer_csv_to_dict('/content/data_clean_datasets/datasets/adults/adults_clean.csv')
preprocessed_dict_selected = []
selected_keys = ['age', 'workclass', 'education', 'occupation', 'hoursperweek', 'income']
attribute_name = 'occupation'

In [None]:
for i in range(len(dirty_dict)):
  preprocessed_dict_selected.append(preprocessing_few_shot('Adults', dirty_dict[i], selected_keys, attribute_name, i))
  if i % 50 == 0:
    print(i, datetime.now(timezone(current_timezone)))

0 2024-04-30 19:31:29.657368-05:00
50 2024-04-30 19:41:22.087311-05:00
100 2024-04-30 19:53:50.349483-05:00
150 2024-04-30 20:04:45.553638-05:00
200 2024-04-30 20:14:23.926975-05:00
250 2024-04-30 20:24:37.400673-05:00
300 2024-04-30 20:34:42.563272-05:00
350 2024-04-30 20:45:33.255312-05:00
400 2024-04-30 20:56:23.432926-05:00
450 2024-04-30 21:07:00.721603-05:00
500 2024-04-30 21:18:01.126005-05:00
550 2024-04-30 21:29:24.833670-05:00
600 2024-04-30 21:39:40.423125-05:00
650 2024-04-30 21:49:12.454696-05:00
700 2024-04-30 21:59:56.324557-05:00
750 2024-04-30 22:11:11.200042-05:00
800 2024-04-30 22:21:48.185994-05:00
850 2024-04-30 22:33:35.767496-05:00
900 2024-04-30 22:43:15.701518-05:00
950 2024-04-30 22:53:36.094628-05:00


In [None]:
print(datetime.now(timezone(current_timezone)))

2024-04-30 23:03:38.467623-05:00


In [None]:
preprocessed_dict = np.load('/content/adults_preprocessed_dict.npy', allow_pickle=True).tolist()
for i in range(len(preprocessed_dict)):
  if 'occupation' in preprocessed_dict_selected[i]:
    preprocessed_dict[i]['occupation'] = preprocessed_dict_selected[i]['occupation']

In [None]:
evaluate(dirty_dict, clean_dict, preprocessed_dict)

Wrong repair found at index 0
Key: occupation
Dirty row: {'row_id': '0', 'age': '31-50', 'workclass': 'Private', 'education': 'Prof-school', 'maritalstatus': 'Never-married', 'occupation': 'Prof-specialty', 'relationship': 'Not-in-family', 'race': 'White', 'sex': 'Female', 'hoursperweek': '40', 'country': 'United-States', 'income': 'MoreThan50K'}
Clean row: {'row_id': '0', 'age': '31-50', 'workclass': 'Private', 'education': 'Prof-school', 'maritalstatus': 'Never-married', 'occupation': 'Prof-specialty', 'relationship': 'Not-in-family', 'race': 'White', 'sex': 'Female', 'hoursperweek': '40', 'country': 'United-States', 'income': 'MoreThan50K'}
Preprocessed row: {'row_id': '0', 'age': '31-50', 'workclass': 'Private', 'education': 'Prof-school', 'maritalstatus': 'Never-married', 'occupation': 'Surgeon', 'relationship': 'Not-in-family', 'race': 'White', 'sex': 'Female', 'hoursperweek': '40', 'country': 'United-States', 'income': 'MoreThan50K'}
Not performed repair found at index 2
Key: co

In [None]:
np.save('/content/adults_preprocessed_dict_with_few_shot_on_occupation.npy', preprocessed_dict)