In [None]:
!pip install -q torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
!pip install -Uq bitsandbytes==0.46.0
!pip install -Uq peft==0.15.2
!pip install -Uq trl==0.18.2
!pip install -Uq accelerate==1.7.0
!pip install -Uq datasets==3.6.0
!pip install -Uq transformers==4.52.4

In [1]:
import os
import transformers
import torch
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig, GemmaTokenizer
from dotenv import load_dotenv

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
load_dotenv()
HF_TOKEN = os.environ["HF_TOKEN"]

In [3]:
model_id = "google/gemma-2b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id,
                                             quantization_config=bnb_config,
                                             device_map={"":0},
                                             token=os.environ['HF_TOKEN'])

<transformers.integrations.tensor_parallel.ParallelInterface object at 0x00000235CE683C10>
<transformers.integrations.tensor_parallel.ParallelInterface object at 0x00000235CE683C10>
<transformers.integrations.tensor_parallel.ParallelInterface object at 0x00000235CE683C10>
<transformers.integrations.tensor_parallel.ParallelInterface object at 0x00000235CE683C10>
<transformers.integrations.tensor_parallel.ParallelInterface object at 0x00000235CE683C10>
<transformers.integrations.tensor_parallel.ParallelInterface object at 0x00000235CE683C10>
<transformers.integrations.tensor_parallel.ParallelInterface object at 0x00000235CE683C10>
<transformers.integrations.tensor_parallel.ParallelInterface object at 0x00000235CE683C10>
<transformers.integrations.tensor_parallel.ParallelInterface object at 0x00000235CE683C10>
<transformers.integrations.tensor_parallel.ParallelInterface object at 0x00000235CE683C10>
<transformers.integrations.tensor_parallel.ParallelInterface object at 0x00000235CE683C10>

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 2/2 [00:05<00:00,  2.53s/it]


In [7]:
from datasets import load_dataset

dataset = load_dataset("mamachang/medical-reasoning")
dataset

DatasetDict({
    train: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 3702
    })
})

In [8]:
dataset

DatasetDict({
    train: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 3702
    })
})

In [9]:
df = dataset["train"].to_pandas()
df.head(10)

Unnamed: 0,input,instruction,output
0,Q:An 8-year-old boy is brought to the pediatri...,Please answer with one of the option in the br...,<analysis>\n\nThis is a clinical vignette desc...
1,Q:A 23-year-old man comes to the physician bec...,Please answer with one of the option in the br...,<analysis>\n\nThis is a clinical vignette desc...
2,Q:A 27-year-old man presents to the emergency ...,Please answer with one of the option in the br...,<analysis>\n\nThis is a question about a 27-ye...
3,Q:A 13-year-old girl presents with a 4-week hi...,Please answer with one of the option in the br...,<analysis>\n\nThis is a patient with signs and...
4,Q:A 53-year-old Asian woman comes to the physi...,Please answer with one of the option in the br...,<analysis>\n\nThis is a patient with symptoms ...
5,Q:A 7-year-old boy is brought to the physician...,Please answer with one of the option in the br...,<analysis>\n\nThis is a clinical vignette desc...
6,Q:A 21-year-old man comes to the military base...,Please answer with one of the option in the br...,<analysis>\n\nThis is a clinical case question...
7,Q:A 48-year-old woman presents to her primary ...,Please answer with one of the option in the br...,<analysis>\n\nThis is a question about determi...
8,Q:A 62-year-old man presents to the emergency ...,Please answer with one of the option in the br...,<analysis>\n\nThis is a patient with a history...
9,Q:A 34-year-old female presents to her primary...,Please answer with one of the option in the br...,<analysis>\n\nThis is a clinical vignette desc...


In [10]:
train_prompt_style = """
Please answer with one of the options in the bracket. Write reasoning in between <analysis></analysis>. Write the answer in between <answer></answer>.
### Question:
{}

### Response:
{}"""

In [11]:
EOS_TOKEN = tokenizer.eos_token  # Must add EOS_TOKEN

def formatting_prompts_func(examples):
    inputs = examples["input"]
    outputs = examples["output"]
    texts = []
    for question, response in zip(inputs, outputs):
        # Remove the "Q:" prefix from the question
        question = question.replace("Q:", "")
        
        # Append the EOS token to the response if it's not already there
        if not response.endswith(tokenizer.eos_token):
            response += tokenizer.eos_token
            
        text = train_prompt_style.format(question, response)
        texts.append(text)
    return {"text": texts}

In [12]:
from datasets import load_dataset

dataset = load_dataset(
    "mamachang/medical-reasoning",
    split="train",
    trust_remote_code=True,
)
dataset = dataset.map(
    formatting_prompts_func,
    batched=True,
)
print(dataset["text"][10])


Please answer with one of the options in the bracket. Write reasoning in between <analysis></analysis>. Write the answer in between <answer></answer>.
### Question:
A research group wants to assess the relationship between childhood diet and cardiovascular disease in adulthood. A prospective cohort study of 500 children between 10 to 15 years of age is conducted in which the participants' diets are recorded for 1 year and then the patients are assessed 20 years later for the presence of cardiovascular disease. A statistically significant association is found between childhood consumption of vegetables and decreased risk of hyperlipidemia and exercise tolerance. When these findings are submitted to a scientific journal, a peer reviewer comments that the researchers did not discuss the study's validity. Which of the following additional analyses would most likely address the concerns about this study's design?? 
{'A': 'Blinding', 'B': 'Crossover', 'C': 'Matching', 'D': 'Stratification',

In [None]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

In [13]:
inference_prompt_style = """
Please answer with one of the options in the bracket. Write reasoning in between <analysis></analysis>. Write the answer in between <answer></answer>.

### Question:
{}

### Response:
<analysis>
"""

In [14]:
question = dataset[10]['input']
question = question.replace("Q:", "")

inputs = tokenizer(
    [inference_prompt_style.format(question) + tokenizer.eos_token],
    return_tensors="pt"
).to("cuda")

outputs = model.generate(
    input_ids=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    max_new_tokens=512,
    eos_token_id=tokenizer.eos_token_id,
    use_cache=True,
)
response = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(response[0].split("### Response:")[1])


<analysis>


### Answer:
<answer>
A. Blinding
</answer>



In [15]:
[inference_prompt_style.format(question) + tokenizer.eos_token]

["\nPlease answer with one of the options in the bracket. Write reasoning in between <analysis></analysis>. Write the answer in between <answer></answer>.\n\n### Question:\nA research group wants to assess the relationship between childhood diet and cardiovascular disease in adulthood. A prospective cohort study of 500 children between 10 to 15 years of age is conducted in which the participants' diets are recorded for 1 year and then the patients are assessed 20 years later for the presence of cardiovascular disease. A statistically significant association is found between childhood consumption of vegetables and decreased risk of hyperlipidemia and exercise tolerance. When these findings are submitted to a scientific journal, a peer reviewer comments that the researchers did not discuss the study's validity. Which of the following additional analyses would most likely address the concerns about this study's design?? \n{'A': 'Blinding', 'B': 'Crossover', 'C': 'Matching', 'D': 'Stratifi

In [None]:
from peft import LoraConfig, get_peft_model

# LoRA config
peft_config = LoraConfig(
    lora_alpha=16,                           # Scaling factor for LoRA
    lora_dropout=0.05,                       # Add slight dropout for regularization
    r=64,                                    # Rank of the LoRA update matrices
    bias="none",                             # No bias reparameterization
    task_type="CAUSAL_LM",                   # Task type: Causal Language Modeling
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],  # Target modules for LoRA
)

model = get_peft_model(model, peft_config)

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments


# Training Arguments
training_arguments = TrainingArguments(
    output_dir="Magistral-Medical-Reasoning",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=5,
    logging_steps=0.2,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    report_to="none"
)

In [None]:
dataset

In [None]:
trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=dataset,
    peft_config=peft_config,
    data_collator=data_collator,
)

In [None]:
import gc, torch
gc.collect()
torch.cuda.empty_cache()
model.config.use_cache = False
trainer.train()

In [None]:
question = dataset[10]['input']
question = question.replace("Q:", "")

inputs = tokenizer(
    [inference_prompt_style.format(question,) + tokenizer.eos_token],
    return_tensors="pt"
).to("cuda")

outputs = model.generate(
    input_ids=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    max_new_tokens=512,
    eos_token_id=tokenizer.eos_token_id,
    use_cache=True,
)
response = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(response[0].split("### Response:")[1])


<analysis>
The question is asking about the validity concern raised in the peer reviewer's comment on the study. The study looked at the relationship between childhood diet and cardiovascular disease in adulthood. The peer reviewer says the researchers did not discuss the study's validity. 

To address the validity concern, we need to think of an additional analysis that would help assess whether the results are due to chance (selection bias) or an actual relationship between the two variables.

Choice A, blinding, refers to hiding treatment groups and assessing outcome blindly, which is not relevant here. 

Choice B, crossover, involves matching participants so they serve as their own control, which also does not apply. 

Choice C, matching, involves balancing covariates between groups, which does not necessarily address validity concerns. 

Choice D, stratification, involves grouping participants based on their exposure levels and then analyzing within each stratum, which helps cont

In [None]:
text = "Quote: Imagination is more,"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device) # return_tensors='pt' means torch tensor, 'tf' means tensorflow tensor

outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
text = "Quote: Imagination is more"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
os.environ["WANDB_DISABLED"] = "false"

In [None]:
lora_config = LoraConfig(
    r = 8, # rank
    target_modules = ["q_proj", "o_proj", "k_proj", "v_proj",
                      "gate_proj", "up_proj", "down_proj"],
    task_type = "CAUSAL_LM",
)

In [None]:
from datasets import load_dataset

data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

In [None]:
data['train']['quote']

In [None]:
def formatting_func(example):
    text = f"Quote: {example['quote'][0]}\nAuthor: {example['author'][0]}"
    return [text]

In [None]:
data['train']

In [None]:
data['train']['quote'][0]

In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=5,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
)