#### **Import Libraries**

In [None]:
# Standard library imports
import os

# Deep learning libraries
import torch
from torch.utils.data import DataLoader
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorWithPadding

# Text processing libraries
import tiktoken
import csv
import pandas as pd

# Utility libraries
import numpy as np
import random
import math
from tqdm import tqdm
from itertools import chain
from IPython.display import display, Markdown
import textwrap

# Custom libraries  
from llmft.train import DecoderTrainer, EarlyStopping
from llmft.metrics import compute_recall
from llmft.losses import FocalLoss
from llmft.utils import predict

# Visualization libraries
import seaborn as sns  # Assuming seaborn is installed

# NLP utility (assuming trics is a library/module)
from trics.nlp.utils import to_markdown

# Configure GPU usage and tokenizer parallelism
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Suppress warnings
import warnings
warnings.filterwarnings("ignore")

# Dataset libraries (can be grouped together)
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict

#### **Parameters**

In [None]:
seed = 2
np.random.seed(seed)
noise = False 
verbose = True 
version = 5 
lr = 1e-4
sample_size = 300 
test_size = 0.25
warmup_ratio = 0.25
batch_size = 8
epochs = 30
patience = 3
gamma = 0.0

#### **Set Up Paths**

In [None]:
data_csv = f'./../../../toy-data/exp2/data_{version}.csv'
df = pd.read_csv(data_csv)
df.shape

#### **Set Up Plotting**

In [None]:
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams['image.interpolation'] = 'nearest'
rcParams['image.cmap'] = 'viridis'
rcParams['axes.grid'] = False
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
plt.style.use('seaborn-v0_8-dark-palette')

from matplotlib import font_manager 
locations = './../../../styles/Newsreader'
font_files = font_manager.findSystemFonts(fontpaths=locations)
print(locations)
print(font_files[0])
for f in font_files: 
    font_manager.fontManager.addfont(f)
plt.rcParams["font.family"] = "Newsreader"

#### **First Stage Function**

In [None]:
def fstage(var1, var2, var3, var4, var5, var6, var7, var8):
    return (1.0-var5)

#### **Read in Data**

In [None]:
# Read in Data Set
df = pd.read_csv(data_csv)

# Subsample Observations
indices = np.random.choice(df.index, size=sample_size, replace=False)
df = df.loc[indices].reset_index(drop=True)

# Apply First Stage Function
df['FStage_Value'] = df.apply(lambda row: fstage(
                                                 row['Var1'], 
                                                 row['Var2'], 
                                                 row['Var3'], 
                                                 row['Var4'], 
                                                 row['Var5'],
                                                 row['Var6'],
                                                 row['Var7'],
                                                 row['Var8']), axis=1)

# If Noise: Shuffle the first stage values
if noise: 
    df['FStage_Value'] = df['FStage_Value'].sample(frac=1).reset_index(drop=True)

# Sample Instrumental Values    
df['Instrument'] = np.random.binomial(n=1, p=0.5, size=sample_size)

# Text + Instrument
df['FullDescription'] = np.where(df['Instrument'] == 1,
                             df['Description'] + " The tenant has access to a free lawyer",
                             df['Description'] + " The tenant does not have access to a free lawyer")

# Text + Instrument == 1
df['Treated_FullDescription'] = df.apply(lambda row: row['Description'] + " The tenant has access to a free lawyer", axis=1)

# Text + Instrument == 0
df['Control_FullDescription'] = df.apply(lambda row: row['Description'] + " The tenant does not have access to a free lawyer", axis=1)

# Sample Treatment Values
df['Treatment'] = np.random.binomial(n=1, p= df['FStage_Value'] * df['Instrument'], size=sample_size)

# Sample Outcome Values
df['Outcome'] = df['Treatment'] + 0.1*np.random.normal(size=sample_size)

#### **Plot**

In [None]:
fig = plt.figure(dpi=300, tight_layout=True, figsize=(7, 4.5))
ax = plt.axes(facecolor=(.95, .96, .97))
for key in 'left', 'right', 'top':
    ax.spines[key].set_visible(False)
ax.text(0., 1.02, s='Count', transform=ax.transAxes, size=14)
ax.yaxis.set_tick_params(length=0)
ax.yaxis.grid(True, color='white', linewidth=2)
ax.set_axisbelow(True)
plt.hist(df['FStage_Value'], color='#36454F')
plt.xlim(0, 1)
plt.xlabel('Probability of Treatment Given Instrument', size=14)
plt.show()

#### **Set Up Device**

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

#### **QLora**

In [None]:
from peft import LoraConfig, get_peft_model 
from transformers import BitsAndBytesConfig

# ----- QUANTIZATION -------# 
# Activate 4-bit precision base model loading
use_4bit = True

# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = True

# Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

# Enable fp16/bf16 training (set bf16 to True with an A100)
fp16 = False
bf16 = True

bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

# Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and use_4bit:
    major, _ = torch.cuda.get_device_capability()
    if major >= 8:
        print("=" * 80)
        print("Your GPU supports bfloat16: accelerate training with bf16=True")
        print("=" * 80)

# ----- LORA -------# 

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

#### **Model**

In [None]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" #"microsoft/phi-3-mini-4k-instruct" #
model = AutoModelForCausalLM.from_pretrained(model_id, 
                                             device_map="auto", 
                                             quantization_config=bnb_config, 
                                             trust_remote_code=True)# So we can do gradient checkpointing
model.config.use_cache = True
model.config.pretraining_tp = 1
model.config.gradient_checkpointing = False
model.generation_config.temperature = 0.0 # Set temperature to 0
#model.enable_input_require_grads()
print(model.generation_config)

model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters())

#### **Tokenizer**

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

def tokenizer_function(example):
  return tokenizer.apply_chat_template(example["messages"], 
                                                          tokenize=True, 
                                                          add_generation_prompt=False, 
                                                          return_dict=True, 
                                                          truncation=True)

# yes_token_id = tokenizer.encode('Yes', add_special_tokens=False)[0]  # Ensure 'Yes' is tokenized correctly
# print("Token ID for 'Yes':", yes_token_id)
# no_token_id = tokenizer.encode('No', add_special_tokens=False)[0]  # Ensure 'Yes' is tokenized correctly
# print("Token ID for 'No':", no_token_id)

#### **Create Text Dataset**


In [None]:
messages = []
for i in range(df.shape[0]):
  user_content = f"Task: The following is a description of an eviction case. Predict whether the tenant has legal represenation (yes or no, and then explain your reasoning.)\n {df['FullDescription'].values[i]}\n {'Free legal representation is available.' if df['Instrument'].values[i] == 1 else 'Free legal representation is not available.'}"
  assistant_content = 'Yes' if df['Treatment'].values[i] == 1 else 'No'
  message = [{'role': 'system','content': 'You are a housing court clerk'},
             {'role': 'user','content': user_content},
             {'role': 'assistant', 'content':assistant_content}]
  messages.append(message)

dataset = Dataset.from_dict({'messages': messages})

def get_input_id(example):
    example['type_indicator'] = 1 if 'Free legal representation is available' in example['messages'][1]['content'] else 0
    return example

def get_target_id(example):
    example['label'] = 1 if 'Yes' == example['messages'][2]['content'] else 0
    return example 

dataset = dataset.map(get_input_id)
dataset = dataset.map(get_target_id)
print(f"Fraction treated who receive legal aid: {np.mean(np.array(dataset['label']), where=np.array(dataset['type_indicator'])==1):.3f}")
print(f"Fraction treated who don't receive legal aid: {np.mean(np.array(dataset['label']), where=np.array(dataset['type_indicator'])==0):.3f}")

#### **Prepare Datasets for Training**

In [None]:
split_dataset = dataset.train_test_split(test_size=test_size, seed=seed) 
tokenized_dataset = dataset.map(tokenizer_function, batched=True) 
tokenized_dataset = tokenized_dataset.remove_columns(['messages'])
tokenized_split_dataset = split_dataset.map(tokenizer_function, batched=True) 
tokenized_split_dataset = tokenized_split_dataset.remove_columns(['messages'])

#### **DataLoaders**

In [None]:
train_loader = DataLoader(tokenized_split_dataset['train'], batch_size=batch_size, collate_fn=DataCollatorWithPadding(tokenizer), shuffle=True)
test_loader = DataLoader(tokenized_split_dataset['test'], batch_size=batch_size, collate_fn=DataCollatorWithPadding(tokenizer), shuffle=False)
all_loader = DataLoader(tokenized_dataset, batch_size=batch_size, collate_fn=DataCollatorWithPadding(tokenizer), shuffle=False)

#### **Optimizer and Scheduler**

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler =transformers.optimization.get_linear_schedule_with_warmup(optimizer,int(warmup_ratio*len(train_loader)*epochs), len(train_loader)*epochs)

#### **Trainer**

In [None]:
yuri = DecoderTrainer(model,
                    tokenizer,
                    optimizer, 
                    scheduler,
                    torch.nn.CrossEntropyLoss(), 
                    device, 
                    verbose=False, 
                    threshold=10)

early_stopping = EarlyStopping(patience)

In [None]:
evaluation_losses = [yuri.evaluate(test_loader)]
training_losses = [yuri.evaluate(train_loader)]
lr_history = [yuri.optimizer.state_dict()['param_groups'][0]['lr']]
recall_history = [yuri.compute_recall(test_loader)]
pbar =  tqdm(range(epochs), desc=f'Epoch: 0, Train Loss: {training_losses[0]:.3f}, Val Loss: {evaluation_losses[0]:.3f}')

for epoch in pbar:
    train_loss = yuri.train(train_loader)
    training_losses.append(train_loss)
    val_loss = yuri.evaluate(test_loader)
    evaluation_losses.append(val_loss)
    recall = yuri.compute_recall(test_loader)
    recall_history.append(recall)

    pbar.set_description(f'Epoch: {epoch + 1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

In [None]:
plt.plot(evaluation_losses, label='Val')
plt.plot(training_losses, label='Train')
plt.legend()
plt.show()

In [None]:
Dhat , labels = predict(model, all_loader, device)

In [None]:
def generate_text(model, tokenizer, prompt_text, device, max_length=50):
    model.eval()  # Set the model to evaluation mode
    
    # Encode the prompt text to tensor
    text = prompt_text[:2]
    tokens = tokenizer.apply_chat_template(text, tokenize=True, 
                                                          add_generation_prompt=False, 
                                                          return_dict=True, 
                                                          truncation=True, 
                                                          return_tensors='pt')
    
    # Generate text using the model
    output_sequences = model.generate(
        input_ids=tokens['input_ids'],
        attention_mask=tokens['attention_mask'],
        max_new_tokens=5,
        temperature=0.,  # Controls the "randomness" of the generation
        top_k=50,  # Narrows down the likely next words by only considering the top k words
        top_p=0.95,  # Nucleus sampling: chooses from top p probability mass instead of top k
        do_sample=False,
        num_return_sequences=1
    )
    
    # Decode the output sequences to text
    generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
    
    return generated_text

In [None]:
for batch in all_loader:
    input_ids, attention_mask, labels = (batch['input_ids'].to(device), 
                                            batch['attention_mask'].to(device), 
                                            batch['labels'].to(device))
    logits = model(input_ids, attention_mask).logits
    print(logits.shape)
    break

In [None]:
for i in torch.argmax(logits, axis=-1)[:,-20:]:
    print(no_token_id in i)

In [None]:
tokenizer.decode(128009)

In [None]:
text = dataset['messages'][0]
tokens = tokenizer.apply_chat_template(text, tokenize=True, 
                                                          add_generation_prompt=False, 
                                                          truncation=True, 
                                                          return_dict=True,
                                                         return_tensors="pt")

output = model.generate(**tokens)

In [None]:
to_markdown(generate_text(model, tokenizer, dataset['messages'][0], device))

In [None]:
to_markdown(generate_text(model, tokenizer, dataset['messages'][1], device))

In [None]:
to_markdown(generate_text(model, tokenizer, dataset['messages'][2], device))

In [None]:
idx = [dataset[i]['label'] == 1 for i in range(1000)]

In [None]:
rep = dataset.filter(lambda x: x['label'] == 1)

In [None]:
to_markdown(generate_text(model, tokenizer, rep['messages'][0], device))

In [None]:
to_markdown(generate_text(model, tokenizer, rep['messages'][1], device))

In [None]:
x = dataset['messages'][0]