In [4]:
from transformers import BartTokenizer, BartForConditionalGeneration
import torch

model = BartForConditionalGeneration.from_pretrained("./results_taskC_bart")
tokenizer = BartTokenizer.from_pretrained("./results_taskC_bart")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def generate_dialogue(note):
    prompt = "Note: " + note.strip() + "\n\nGenerate a doctor-patient conversation:"
    inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(device)
    output = model.generate(inputs["input_ids"], max_length=512, num_beams=4, early_stopping=True)
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Test example
note = """
The patient is a 45-year-old female presenting with shortness of breath, fatigue, and dizziness. Past medical history includes hypertension and mild asthma. She reports symptoms have worsened over the past 3 days.
"""
print(generate_dialogue(note))


  from .autonotebook import tqdm as notebook_tqdm
  attn_output = torch.nn.functional.scaled_dot_product_attention(


Note: The patient is a 45-year-old female presenting with shortness of breath, fatigue, and dizziness. Past medical history includes hypertension and mild asthma. She reports symptoms have worsened over the past 3 days. She's been having shortness and fatigue , fatigue , dizziness , and fatigue . . . i've been having some shortness in breath and fatigue and fatigue over the last 3 days , and she's been experiencing dizziness and fatigue for about three days . . ." , , , and , , i'm sorry , but i'm not sure if this is a new symptom or if it's just a new one , but it's something that's been going on for a while . , and it's been getting worse over the weekend , and i'm really concerned about it , so i'm going to go ahead and tell you a little bit more about it . . , so , what's going on with you ? ? . . ? ? ? , ? , , ? . , , . , ? ? ... , , ... , . ? , . . or , , or , ? ... ? , or ? , ... ? . ? . or ? ? or , or ... ? ... . ? ... or ? . ... , or . . .. ? , i mean , it's ... it's like , it

In [None]:
import re

def clean_and_format_dialogue(raw_output):
    # 1. Remove extra punctuation and artifacts
    text = raw_output.strip()

    # Remove sequences like , , , ... ? ? ? ... etc.
    text = re.sub(r'([,?.])(\s*\1)+', r'\1', text)
    text = re.sub(r'\.\.+', '.', text)
    text = re.sub(r',\s*,+', ',', text)
    text = re.sub(r'\s{2,}', ' ', text)

    # Remove filler words (optional)
    text = re.sub(r'\b(uh+|um+|like|you know)\b', '', text, flags=re.IGNORECASE)

    # Strip repeated filler at end
    text = re.sub(r'(and it gets worse\s*){2,}', 'and it gets worse.', text, flags=re.IGNORECASE)

    # Remove spacing and punctuation
    text = re.sub(r'\s*([?.!,])\s*', r'\1 ', text)
    text = text.strip()

    # 2. Split into turns
    lines = re.split(r'(?i)\b(doctor|patient):', text)
    dialogue = []
    current_speaker = None

    for chunk in lines:
        chunk = chunk.strip()
        if chunk.lower() in {"doctor", "patient"}:
            current_speaker = chunk.capitalize()
        elif current_speaker and chunk:
            dialogue.append(f"{current_speaker}: {chunk}")
            current_speaker = None  # reset speaker unless repeated

    # 3. Return formatted text
    return "\n".join(dialogue) if dialogue else text


In [None]:
raw = generate_dialogue(note)  
cleaned = clean_and_format_dialogue(raw)
print("Cleaned Dialogue:\n")
print(cleaned)


Cleaned Dialogue:

Note: The patient is a 45-year-old female presenting with shortness of breath, fatigue, and dizziness. Past medical history includes hypertension and mild asthma. She reports symptoms have worsened over the past 3 days. She's been having shortness and fatigue, fatigue, dizziness, and fatigue. i've been having some shortness in breath and fatigue and fatigue over the last 3 days, and she's been experiencing dizziness and fatigue for about three days. ", and, i'm sorry, but i'm not sure if this is a new symptom or if it's just a new one, but it's something that's been going on for a while. , and it's been getting worse over the weekend, and i'm really concerned about it, so i'm going to go ahead and tell you a little bit more about it. , so, what's going on with you? . ? , ? , ? . , . , ? . , . , . ? , . or, or, ? . ? , or? , . ? . ? . or? or, or. ? . ? . or? . , or. ? , i mean, it's. it's, it just keeps getting worse and worse, and then it just gets worse and it gets 

In [None]:
import pandas as pd
import re
from transformers import BartTokenizer, BartForConditionalGeneration
import torch

# Load test dataset
test_df = pd.read_csv("../dataset/task_b+c/data/challenge_data/clinicalnlp_taskC_test2.csv")

# Load fine-tuned BART model
model_dir = "./results_taskC_bart"
tokenizer = BartTokenizer.from_pretrained(model_dir)
model = BartForConditionalGeneration.from_pretrained(model_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Generate dialogue from note
def generate_dialogue(note):
    prompt = "Note: " + note.strip() + "\n\nGenerate a doctor-patient conversation:"
    inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(device)
    output = model.generate(inputs["input_ids"], max_length=512, num_beams=4, early_stopping=True)
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Clean and tag speaker turns with guaranteed newlines
def clean_and_tag_dialogue(raw_output):
    text = raw_output.strip()

    # Normalize repeated punctuation and whitespace
    text = re.sub(r'([,?.])(\s*\1)+', r'\1', text)
    text = re.sub(r'\.\.+', '.', text)
    text = re.sub(r',\s*,+', ',', text)
    text = re.sub(r'\s{2,}', ' ', text)
    text = re.sub(r'\b(uh+|um+|like|you know)\b', '', text, flags=re.IGNORECASE)
    text = re.sub(r'(and it gets worse\s*){2,}', 'and it gets worse.', text, flags=re.IGNORECASE)
    text = re.sub(r'\s*([?.!,])\s*', r'\1 ', text)
    text = text.strip()

    # Force newlines between speaker turns (even if jammed together)
    text = re.sub(r'(?i)(doctor|patient)\s*:', r'\n<\1> ', text)
    text = re.sub(r'\n+', '\n', text)  
    text = re.sub(r' +', ' ', text) 

    return text.strip()

# Run generation for all test notes
generated_dialogues = []
for note in test_df["note"]:
    raw = generate_dialogue(note)
    cleaned = clean_and_tag_dialogue(raw)
    generated_dialogues.append(cleaned)

# Save to CSV
test_df["generated_dialogue"] = generated_dialogues
test_df.to_csv("generated_test_dialogues_bart.csv", index=False)
print("Saved to generated_test_dialogues_bart.csv")


Saved to generated_test_dialogues_bart.csv
