In [None]:
import random
import pandas as pd

In [None]:
prompt = True
n = 1000 
batch_size = 4 

In [None]:
import numpy as np 
import re 
data_csv = f'./../../../toy-data/exp2/data_1.csv'

def remove_newlines(text):
    return re.sub(r'\n+', '', text)

df = pd.read_csv(data_csv)
df['Description'] = df['Description'].apply(lambda x: remove_newlines(x))


# Subsample Observations
indices = np.random.choice(df.index, size=32, replace=False)
df = df.loc[indices].reset_index(drop=True)
df['label'] = [random.randint(0, 1) for _ in range(df.shape[0])]
df.head()

In [None]:
conditions_not_covered = [
    "Short-term illnesses",
    "Temporary exhaustion",
    "Behavioral Traits and Conditions",
    "Having a quick temper without any underlying medical condition",
    "Certain Addictions",
    "Substance abuse disorders",
    "good health"
]

conditions_covered = [
    "Mobility impairments",
    "Visual impairments",
    "Hearing impairments",
    "Chronic illnesses",
    "Respiratory disorders",
    "Cardiovascular conditions",
    "Intellectual disabilities",
    "Learning disabilities",
    "Autism spectrum disorders",
    "Psychiatric disorders",
    "Traumatic brain injuries",
    "Alzheimer's disease and other dementias"
]

disabilities = conditions_covered + conditions_not_covered

# Create a list of (0,1)s
severity_indicator = [1 if disability in conditions_covered else 0 for disability in disabilities]

def fstage(var1, var2, var3, var4, var5, var6, var7, var8, var9):
    return 0.35*severity_indicator[var5] + 0.35*var9

In [None]:

def remove_newlines(text):
    return re.sub(r'\n+', '', text)

string_b = 'Based on the following information, predict (Yes/No) whether the tenant will be represented in court by a lawyer.\n\n Description:'
string1 = "\n\n Note: The tenant has access to a free lawyer."
string0 = "\n\n Note: The tenant does not have access to a free lawyer."
if prompt:
    string1 += " Access to a lawyer does not mean the tenant will be represented in court by a lawyer." \
        " If the tenant is given access to a lawyer, they must apply for representation." \
        " Because more tenants apply than can be represented, legal aid providers prioritize tenants with **vouchers** and **disabilities** when reviewing applications." \
        " Providers can differ over what they consider to be a dissability."\
        " If a tenant's application is selected, they must follow-up with the provider to arange for legal representation." \
        " Therefor it's possible that tenants with characteristics which sugges they they should be prioritized remain without representation."
    

df = pd.read_csv(data_csv)
df['Description'] = df['Description'].apply(lambda x: remove_newlines(x))


# Subsample Observations
indices = np.random.choice(df.index, size=n, replace=False)
df = df.loc[indices].reset_index(drop=True)
# Apply First Stage Function
df['FStage_Value'] = df.apply(lambda row: fstage(
                                                 row['Var1'], 
                                                 row['Var2'], 
                                                 row['Var3'], 
                                                 row['Var4'], 
                                                 row['Var5'],
                                                 row['Var6'],
                                                 row['Var7'],
                                                 row['Var8'],
                                                 row['Var9']), axis=1)
# Sample Instrumental Values  
df['Instrument'] = np.random.binomial(n=1, p=0.5, size=n)

# Text + Instrument
df['FullDescription'] = np.where(df['Instrument'] == 1,
                             string_b + df['Description'] + string1,
                             string_b + df['Description'] + string0)

df['label'] = np.random.binomial(n=1, p= df['FStage_Value'] * df['Instrument'], size=n)
df['label'] = df['label'].sample(frac=1).reset_index(drop=True)

df.head()

In [None]:
var = "FullDescription"
data = df

In [None]:
from transformers import AutoTokenizer, DataCollatorWithPadding
from datasets import Dataset, DatasetDict
from torch.utils.data import DataLoader


model_id = "microsoft/phi-3-mini-4k-instruct" #"microsoft/phi-3-mini-4k-instruct" # "meta-llama/Meta-Llama-3-8B-Instruct" #"meta-llama/Meta-Llama-3-8B" # #
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

def tokenizer_function(example):
  return tokenizer(example[var], truncation=True, max_length=512)

dataset = Dataset.from_dict(data[[var,'label']])
tokenized_dataset = dataset.map(tokenizer_function, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns([var])

data_loader = DataLoader(tokenized_dataset, batch_size=batch_size, collate_fn=DataCollatorWithPadding(tokenizer), shuffle=False)


In [None]:
import torch
from transformers import Phi3ForSequenceClassification, AdamW
model_id = "microsoft/phi-3-mini-4k-instruct"

# Load the model
model = Phi3ForSequenceClassification.from_pretrained(model_id,
                                                           device_map='auto',
                                                           num_labels=2, 
                                                           torch_dtype=torch.bfloat16,
                                                           attn_implementation="flash_attention_2")

model.gradient_checkpointing_enable()
model.config.use_cache = False # Doesn't work with gradient checkpointing.

# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

In [None]:
losses = []

model.train()
for epoch in range(20):  # Training for 3 epochs
    batch_loss = []
    for batch in data_loader:
        input_ids, attention_mask, labels = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)

        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        # Calculate the loss
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        batch_loss.append(loss.detach().item())

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    losses.append(sum(batch_loss)/len(data_loader))

    print(f"Epoch {epoch + 1}, Loss: {losses[-1]}")

In [None]:
import matplotlib.pyplot as plt 
plt.plot(losses)