<a href="https://colab.research.google.com/github/safaabuzaid/segmentation-prompt-generator/blob/main/Prompt_driven.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Prompt Generator for Radiology Segmentation tasks from Synthetic Clinical Notes**

**Note:** This dataset is synthetically generated using ChatGPT for educational and demonstration purposes only. It does not represent real patient data and should not be used for clinical decision-making or real-world applications.  
The goal is to create a prompt generator that can turn clinical notes into precise prompt that can be used later for segmentation tasks.

In [None]:
import pandas as pd

df = pd.read_csv('/content/clinical_notes.csv')
df.info()
df.head()

In [None]:
df.isnull().sum()

In [None]:
#format the dataset for the model
input_text = "Clinical Note: [note]"
target_text = "Prompt: [prompt]"

In [None]:
from datasets import Dataset

#create dictionary of note,prpompt
data_dict = {'note': df['note'], 'prompt': df['prompt']}

dataset = Dataset.from_dict(data_dict)

dataset = dataset.train_test_split(test_size=0.2)
dataset


In [None]:
print (df['note'][0])
print (df['prompt'][0])

Output:

CT scan reveals a 3.2 cm irregular mass in the upper lobe of the left lung; biopsy confirms stage II adenocarcinoma.
Segment tumor in left lung based on stage II adenocarcinoma


# Preprocessing the data



In [None]:
from huggingface_hub import notebook_login
notebook_login()

from transformers import pipeline
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

max_input_length = 512
max_target_length = 128

def preprocess_function(examples):
    inputs = ["Generate a segmentation prompt from the following Clinical Note: " + note for note in examples["note"]]
    targets = ["Prompt: " + prompt for prompt in examples["prompt"]]

    model_inputs = tokenizer(inputs, padding = "max_length", truncation=True, max_length=max_input_length)
    with tokenizer.as_target_tokenizer():
      labels = tokenizer(targets, padding = "max_length", truncation=True, max_length=max_target_length)

    #targets= tokenizer(examples["prompt"], padding = "max_length", truncation=True, max_length=max_target_length)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True)
tokenized_dataset


Output:


```
DatasetDict({
    train: Dataset({
        features: ['note', 'prompt', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 12
    })
    test: Dataset({
        features: ['note', 'prompt', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 3
    })
})
```



# Load The moodel

In [None]:
from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

# Set training Arguments

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    report_to=None,
    output_dir="./finetuned-flan-t5",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=1,
    num_train_epochs=5,
    predict_with_generate=True,
    fp16=False,
)

# Fine Tuning The model

In [None]:


from transformers import Seq2SeqTrainer , DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

In [None]:
#test it to the dataset
trainer.evaluate()

In [None]:
input_text = "Generate a segmentation prompt from the following Clinical Note: " + df['note'][0]
inputs = tokenizer(input_text, return_tensors="pt", truncation=True,padding = "max_length", max_length = 512).to(model.device)

generated_ids= model.generate(**inputs, max_new_tokens=50, num_beams = 4, early_stopping = True)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

if generated_text.lower().startswith("prompt:"):
  generated_text = generated_text[7:].strip()

print (input_text)
print(generated_text)

Output:

Generate a segmentation prompt from the following Clinical Note: CT scan reveals a 3.2 cm irregular mass in the upper lobe of the left lung; biopsy confirms stage II adenocarcinoma.
adenocarcinoma