In [12]:
import pandas as pd
import torch
from transformers import BioGptForCausalLM, BioGptTokenizer, DataCollatorForLanguageModeling
from datasets import Dataset
import os

MODEL_NAME = "microsoft/biogpt"
MAX_LENGTH = 1024

class OperativeReportTrainer:
    def __init__(self, csv_path):
        self.csv_path = csv_path
        self.tokenizer = BioGptTokenizer.from_pretrained(MODEL_NAME)
        self.model = BioGptForCausalLM.from_pretrained(MODEL_NAME)
        
        self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        self.model.resize_token_embeddings(len(self.tokenizer))
        
    def load_and_clean_data(self):
        records = []
        
        try:
            with open(self.csv_path, 'r', encoding='utf-8') as f:
                content = f.read()
            
            lines = content.split('\n')
            
            for line in lines[1:]:  # Skip header
                if not line.strip() or not line.startswith('"') or 'appendectomy' not in line:
                    continue
                
                try:
                    line = line.strip()[1:-1]  # Remove outer quotes
                    
                    first_comma = line.find(',')
                    if first_comma == -1:
                        continue
                        
                    brief = line[:first_comma].strip()
                    remainder = line[first_comma+1:].strip()
                    
                    if remainder.startswith('""'):
                        remainder = remainder[2:]
                        last_quote_comma = remainder.rfind('""')
                        if last_quote_comma != -1:
                            full_report = remainder[:last_quote_comma].replace('""', '"').strip()
                            
                            if brief and full_report:
                                records.append({
                                    'brief_description': brief,
                                    'full_report': full_report,
                                    'procedure_type': 'appendectomy'
                                })
                                
                except Exception:
                    continue
            
            df = pd.DataFrame(records)
            print(f"Loaded {len(df)} operative reports")
            return df
            
        except Exception as e:
            print(f"Error reading file: {e}")
            return pd.DataFrame()
    
    def create_training_prompts(self, df):
        prompts = []
        for _, row in df.iterrows():
            prompt = f"PROCEDURE: Appendectomy\nINDICATION: {row['brief_description']}\nOPERATIVE REPORT: {row['full_report']}\n<|endoftext|>"
            prompts.append(prompt)
        return prompts
    
    def prepare_dataset(self, prompts):
        def tokenize_function(examples):
            tokenized = self.tokenizer(
                examples['text'],
                truncation=True,
                padding=False,
                max_length=MAX_LENGTH,
                return_tensors=None
            )
            tokenized['labels'] = tokenized['input_ids'].copy()
            return tokenized
        
        dataset = Dataset.from_dict({'text': prompts})
        return dataset.map(tokenize_function, batched=True, remove_columns=['text'])
    
    def train_model(self, dataset, output_dir='./operative-report-model', epochs=3, batch_size=2):
        os.makedirs(output_dir, exist_ok=True)
        
        self.model.train()
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
        
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False,
            pad_to_multiple_of=8
        )
        
        from torch.utils.data import DataLoader
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
        
        print(f"Training {epochs} epochs, {len(dataset)} samples, {len(dataloader)} batches")
        
        for epoch in range(epochs):
            total_loss = 0
            for batch_idx, batch in enumerate(dataloader):
                optimizer.zero_grad()
                outputs = self.model(**batch)
                loss = outputs.loss
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                
                if batch_idx % 2 == 0:
                    print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx+1}/{len(dataloader)}, Loss: {loss.item():.4f}")
            
            print(f"Epoch {epoch+1} avg loss: {total_loss / len(dataloader):.4f}")
        
        self.model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)
        print(f"Model saved to {output_dir}")
        return self.model
    
    def generate_report(self, brief_description, model_path='./operative-report-model'):
        model = BioGptForCausalLM.from_pretrained(model_path)
        tokenizer = BioGptTokenizer.from_pretrained(model_path)
        
        prompt = f"PROCEDURE: Appendectomy\nINDICATION: {brief_description}\nOPERATIVE REPORT:"
        inputs = tokenizer.encode(prompt, return_tensors='pt')
        attention_mask = torch.ones_like(inputs)
        
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                attention_mask=attention_mask,
                max_length=min(len(inputs[0]) + 400, 1024),
                temperature=0.3,
                do_sample=True,
                top_p=0.9,
                top_k=50,
                repetition_penalty=1.2,
                no_repeat_ngram_size=3,
                pad_token_id=tokenizer.eos_token_id,
                num_return_sequences=1,
                early_stopping=True
            )
        
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        report_start = generated_text.find("OPERATIVE REPORT:") + len("OPERATIVE REPORT:")
        generated_report = generated_text[report_start:].strip()
        
        return self._clean_generated_text(generated_report)
    
    def _clean_generated_text(self, text):
        lines = text.split('.')
        cleaned_lines = []
        
        for line in lines:
            line = line.strip()
            if line and len(line) > 10:
                is_repetitive = any(line in prev_line or prev_line in line for prev_line in cleaned_lines[-3:])
                if not is_repetitive:
                    cleaned_lines.append(line)
                else:
                    break
        
        return '. '.join(cleaned_lines) + '.' if cleaned_lines else text

def main():
    trainer = OperativeReportTrainer('or_reports.csv')
    
    df = trainer.load_and_clean_data()
    if len(df) == 0:
        print("No data loaded! Check CSV format.")
        return
    
    prompts = trainer.create_training_prompts(df)
    dataset = trainer.prepare_dataset(prompts)
    
    trainer.train_model(dataset, epochs=5, batch_size=1)
    
    test_brief = "Laparoscopic appendectomy for uncomplicated appendicitis in a 25-year-old patient"
    generated_report = trainer.generate_report(test_brief)
    
    print(f"\nTest Brief: {test_brief}")
    print(f"Generated Report: {generated_report}")

if __name__ == "__main__":
    main()

pytorch_model.bin:  66%|######5   | 1.03G/1.56G [00:00<?, ?B/s]

Loaded 13 operative reports


model.safetensors:   0%|          | 0.00/1.56G [00:00<?, ?B/s]

Map:   0%|          | 0/13 [00:00<?, ? examples/s]

Training 5 epochs, 13 samples, 13 batches
Epoch 1/5, Batch 1/13, Loss: 4.4639
Epoch 1/5, Batch 3/13, Loss: 3.6042
Epoch 1/5, Batch 5/13, Loss: 3.0792
Epoch 1/5, Batch 7/13, Loss: 3.3809
Epoch 1/5, Batch 9/13, Loss: 2.7540
Epoch 1/5, Batch 11/13, Loss: 3.5141
Epoch 1/5, Batch 13/13, Loss: 3.2400
Epoch 1 avg loss: 3.4005
Epoch 2/5, Batch 1/13, Loss: 2.4145
Epoch 2/5, Batch 3/13, Loss: 2.8441
Epoch 2/5, Batch 5/13, Loss: 1.7898
Epoch 2/5, Batch 7/13, Loss: 2.4411
Epoch 2/5, Batch 9/13, Loss: 3.3419
Epoch 2/5, Batch 11/13, Loss: 2.7281
Epoch 2/5, Batch 13/13, Loss: 2.0990
Epoch 2 avg loss: 2.6106
Epoch 3/5, Batch 1/13, Loss: 1.8863
Epoch 3/5, Batch 3/13, Loss: 1.8038
Epoch 3/5, Batch 5/13, Loss: 2.0136
Epoch 3/5, Batch 7/13, Loss: 2.9729
Epoch 3/5, Batch 9/13, Loss: 1.9374
Epoch 3/5, Batch 11/13, Loss: 2.6402
Epoch 3/5, Batch 13/13, Loss: 2.4974
Epoch 3 avg loss: 2.2257
Epoch 4/5, Batch 1/13, Loss: 2.8353
Epoch 4/5, Batch 3/13, Loss: 1.7007
Epoch 4/5, Batch 5/13, Loss: 2.2809
Epoch 4/5, Ba




Test Brief: Laparoscopic appendectomy for uncomplicated appendicitis in a 25-year-old patient
Generated Report: The patient was taken to the operating room and placed supine on the table. A Foley catheter was inserted, and anesthesia was induced with IV midazolam 0. 1 mg / kg followed by fentanyl 1 mcg / kg and propofol 2 mg / min. Anesthesia was maintained with N2O and O2 delivered via an LMA. After obtaining adequate surgical conditions, a 5 mm port site was created at the level of umbilicus and insufflated to 15 mmHg pressure. An EndoGIA stapler was used to create a 10 mm port in the right lower quadrant. Once pneumoperitoneum had been achieved, CO2 insufflation was performed up to 20 mmHg. Then, a Veress needle was introduced into the abdomen through the umbilical incision. Next, two graspers were passed across the appendix, and the base of the appendix was visualized. This was done while holding it within its normal position. Two other graspers then were brought out from this are

In [15]:
# Test generation without GUI
import torch
from transformers import BioGptForCausalLM, BioGptTokenizer

class SimpleReportGenerator:
    def __init__(self, model_path='./operative-report-model'):
        self.model = BioGptForCausalLM.from_pretrained(model_path)
        self.tokenizer = BioGptTokenizer.from_pretrained(model_path)
    
    def generate_report(self, brief_description, temperature=0.3, max_length=400):
        """Generate operative report with improved parameters"""
        prompt = f"PROCEDURE: Appendectomy\nINDICATION: {brief_description}\nOPERATIVE REPORT:"
        
        # Tokenize
        inputs = self.tokenizer.encode(prompt, return_tensors='pt')
        
        # Attention mask: tells model which tokens to pay attention to (1=real token, 0=padding)
        attention_mask = torch.ones_like(inputs)
        
        with torch.no_grad():  # Disable gradient computation for faster inference
            outputs = self.model.generate(
                inputs,
                attention_mask=attention_mask,
                max_length=min(len(inputs[0]) + max_length, 1024),
                
                # Sampling parameters:
                temperature=temperature,  # 0.3 = more focused/conservative, 1.0 = more random/creative
                do_sample=True,          # Enable sampling (vs greedy search)
                
                # Top-p (nucleus sampling): Only consider tokens that make up top 90% probability mass
                top_p=0.9,               
                
                # Top-k: Only consider the 50 most likely next tokens at each step
                top_k=50,                
                
                repetition_penalty=1.2,   # Penalize repeated tokens (1.0 = no penalty, >1.0 = discourage)
                no_repeat_ngram_size=3,   # Don't repeat any 3-word sequences
                
                # ✅ IMPROVED STOPPING CRITERIA:
                pad_token_id=self.tokenizer.eos_token_id,  # Use end-of-sequence as padding
                eos_token_id=self.tokenizer.eos_token_id,  # Force stop at EOS
                early_stopping=True,       # Stop when end-of-sequence token is generated
                
                # ✅ BLOCK HTML/XML GARBAGE TOKENS:
                bad_words_ids=[
                    self.tokenizer.encode("<", add_special_tokens=False),
                    self.tokenizer.encode(">", add_special_tokens=False),
                    self.tokenizer.encode("endoftext", add_special_tokens=False),
                    self.tokenizer.encode("AbstractText", add_special_tokens=False),
                    self.tokenizer.encode("NlmCategory", add_special_tokens=False),
                    self.tokenizer.encode("UNASSIGNED", add_special_tokens=False),
                    self.tokenizer.encode("ns0:", add_special_tokens=False),
                    self.tokenizer.encode("mml:", add_special_tokens=False)
                ]
            )
        
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # ✅ EARLY GARBAGE DETECTION: Cut at first sign of corruption
        garbage_tokens = ["<", "endoftext", "AbstractText", "NlmCategory", "UNASSIGNED", "ns0:", "mml:", "≥", "≤"]
        if any(token in generated_text for token in garbage_tokens):
            # Find first garbage token and cut there
            cut_points = []
            for token in garbage_tokens:
                if token in generated_text:
                    cut_points.append(generated_text.find(token))
            if cut_points:
                generated_text = generated_text[:min(cut_points)]
        
        report_start = generated_text.find("OPERATIVE REPORT:") + len("OPERATIVE REPORT:")
        generated_report = generated_text[report_start:].strip()
        
        return self._clean_repetitive_text(generated_report)
    
    def _clean_repetitive_text(self, text):
        """Remove repetitive sentences"""
        sentences = text.split('.')
        cleaned = []
        
        for sentence in sentences:
            sentence = sentence.strip()
            if sentence and len(sentence) > 10:
                is_repetitive = False
                for prev in cleaned[-2:]:
                    if sentence in prev or prev in sentence:
                        is_repetitive = True
                        break
                
                if not is_repetitive:
                    cleaned.append(sentence)
                else:
                    break
        
        return '. '.join(cleaned) + '.' if cleaned else text

# Test the generator
def test_generation():
    generator = SimpleReportGenerator()
    
    test_cases = [
        "Laparoscopic appendectomy for acute appendicitis in a 30-year-old male",
        "Open appendectomy for perforated appendicitis with peritonitis",
        "Emergency appendectomy for gangrenous appendicitis",
        "Laparoscopic appendectomy without perforation in young patient"
    ]
    
    for i, brief in enumerate(test_cases, 1):
        print(f"\n{'='*60}")
        print(f"TEST {i}: {brief}")
        print('='*60)
        
        report = generator.generate_report(brief, temperature=0.3)
        print(report)
        print()

if __name__ == "__main__":
    test_generation()


TEST 1: Laparoscopic appendectomy for acute appendicitis in a 30-year-old male
The patient was taken to the operating room, placed supine on his left side. A midline incision was made and performed under direct vision. After adequate visualization of the appendix, it was identified as an abnormal structure between the cecum and mesoappendix. It was then dissected out from its base with a vascular stapler. Once hemostasis had been obtained, the appendix was grasped with a hemostatic clamp and divided using a vascular bipolar device. Next, the umbilical fascia was incised longitudinally. The peritoneum was entered through this area and the appendix ligated with a Vascular Bipolar Device. All other ports were removed under direct visualization without difficulty. The abdomen was insufflated at 15 mmHg pressure until CO2 insufflation pressures reached 25 mmHg. Subsequently, the abdomen was closed with interrupted 0 Vicryl sutures across all layers. Steri-Strips were applied over the skin 