In [None]:
import string
import numpy as np
import os
import xml.etree.ElementTree as ET
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup
from transformers import DataCollatorWithPadding
from sklearn.metrics import classification_report, multilabel_confusion_matrix
from sklearn.model_selection import train_test_split, StratifiedKFold
import torch.optim as optim
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer

# Download NLTK data
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('punkt')

# Function to parse a single XML file
def parse_xml(file_path):
    tree = ET.parse(file_path)
    root = tree.getroot()
    text = root.find('TEXT').text.strip()
    tags = {tag.tag: tag.attrib['met'] for tag in root.find('TAGS')}
    return text, tags

# Directory containing the XML files
xml_dir = 'part1'

# List to store parsed data
data = []

# Parse all XML files
for file_name in os.listdir(xml_dir):
    if file_name.endswith('.xml'):
        file_path = os.path.join(xml_dir, file_name)
        text, tags = parse_xml(file_path)
        filtered_tags = {key: tags[key] for key in ['ABDOMINAL', 'CREATININE', 'MAJOR-DIABETES']}
        filtered_tags['text'] = text
        data.append(filtered_tags)

# Convert the list to a pandas DataFrame
df = pd.DataFrame(data)

# Text preprocessing function
def preprocess_text(text):
    # Lowercase
    text = text.lower()
    # Remove punctuation but keep numbers
    text = text.translate(str.maketrans('', '', string.punctuation.replace('-', '')))
    # Tokenize
    words = nltk.word_tokenize(text)
    # Remove stop words
    stop_words = set(stopwords.words('english'))
    words = [word for word in words if word not in stop_words]
    # Lemmatization
    lemmatizer = WordNetLemmatizer()
    words = [lemmatizer.lemmatize(word) for word in words]
    return ' '.join(words)

# Apply preprocessing to the text column
df['clean_text'] = df['text'].apply(preprocess_text)

# Encode the labels
df['ABDOMINAL'] = df['ABDOMINAL'].apply(lambda x: 1 if x == 'met' else 0)
df['CREATININE'] = df['CREATININE'].apply(lambda x: 1 if x == 'met' else 0)
df['MAJOR-DIABETES'] = df['MAJOR-DIABETES'].apply(lambda x: 1 if x == 'met' else 0)

# Convert encoded labels to numpy arrays
labels = df[['ABDOMINAL', 'CREATININE', 'MAJOR-DIABETES']].values
texts = df['clean_text'].tolist()

# Initialize the BioBERT tokenizer
tokenizer = BertTokenizer.from_pretrained('dmis-lab/biobert-v1.1')

# Tokenize the text data
def tokenize_texts(texts):
    return tokenizer(texts, padding=True, truncation=True, return_tensors='pt')

class MedicalDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.texts.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item

# Define the data collator for dynamic padding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Function to train and evaluate the model
def train_and_evaluate(train_loader, val_loader, model, optimizer, scheduler, device, epochs=15, accumulation_steps=4):
    model.to(device)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        optimizer.zero_grad()  # Move optimizer.zero_grad() outside the batch loop

        for step, batch in enumerate(train_loader):
            # Move the batch to the device
            labels = batch.pop('labels').to(device)
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(**batch, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()
            loss = loss / accumulation_steps

            loss.backward()

            if (step + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()

        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}, Loss: {avg_train_loss}")

    # Evaluate the model on validation set
    model.eval()
    predictions, true_labels = [], []

    with torch.no_grad():
        for batch in val_loader:
            labels = batch.pop('labels').to(device)
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(**batch, labels=labels)
            logits = outputs.logits

            predictions.append(logits.cpu().numpy())
            true_labels.append(labels.cpu().numpy())

    predictions = np.concatenate(predictions, axis=0)
    true_labels = np.concatenate(true_labels, axis=0)
    pred_labels = (predictions > 0.5).astype(int)

    return classification_report(true_labels, pred_labels, target_names=['ABDOMINAL', 'CREATININE', 'MAJOR-DIABETES']), multilabel_confusion_matrix(true_labels, pred_labels)

# Initialize the model, optimizer, and scheduler
model = BertForSequenceClassification.from_pretrained('dmis-lab/biobert-v1.1', num_labels=3)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, eps=1e-8)
device = torch.device('cuda')

# Stratified KFold cross-validation
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

# Prepare data for stratification
y = np.array([np.argmax(label) for label in labels])

# Perform cross-validation
fold_results = []

for fold, (train_val_index, test_index) in enumerate(kf.split(texts, y)):
    print(f"Fold {fold + 1}")

    train_val_texts = [texts[i] for i in train_val_index]
    test_texts = [texts[i] for i in test_index]
    train_val_labels = labels[train_val_index]
    test_labels = labels[test_index]

    # Split train_val into training and validation sets (80% training, 10% validation)
    train_texts, val_texts, train_labels, val_labels = train_test_split(train_val_texts, train_val_labels, test_size=0.1111, random_state=42)  # 0.1111 * 90% = 10%

    train_encoded_texts = tokenize_texts(train_texts)
    val_encoded_texts = tokenize_texts(val_texts)
    test_encoded_texts = tokenize_texts(test_texts)

    train_dataset = MedicalDataset(train_encoded_texts, train_labels)
    val_dataset = MedicalDataset(val_encoded_texts, val_labels)
    test_dataset = MedicalDataset(test_encoded_texts, test_labels)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=data_collator)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=data_collator)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=data_collator)

    total_steps = len(train_loader) * 15  # 15 epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=total_steps//10, num_training_steps=total_steps)

    # Train and evaluate
    fold_report, fold_confusion_matrices = train_and_evaluate(train_loader, val_loader, model, optimizer, scheduler, device)
    fold_results.append((fold_report, fold_confusion_matrices))

# Print results for each fold
for i, (report, conf_matrices) in enumerate(fold_results):
    print(f"Results for Fold {i+1}:")
    print("Classification Report:")
    print(report)
    print("Confusion Matrices:")
    for j, label in enumerate(['ABDOMINAL', 'CREATININE', 'MAJOR-DIABETES']):
        print(f"Confusion Matrix for {label}:")
        print(conf_matrices[j])

# Save the model after cross-validation
model_save_path = 'path_to_save_final_model_after_cv'
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)

# Final evaluation on the test set
model.eval()
predictions, true_labels = [], []

with torch.no_grad():
    for batch in test_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        logits = outputs.logits
        logits = logits.detach().cpu().numpy()
        label_ids = batch['labels'].cpu().numpy()
        predictions.append(logits)
        true_labels.append(label_ids)

predictions = np.concatenate(predictions, axis=0)
true_labels = np.concatenate(true_labels, axis=0)
pred_labels = (predictions > 0.5).astype(int)

conf_matrices = multilabel_confusion_matrix(true_labels, pred_labels)
print("Confusion Matrices:")
for i, label in enumerate(['ABDOMINAL', 'CREATININE', 'MAJOR-DIABETES']):
    print(f"Confusion Matrix for {label}:")
    print(conf_matrices[i])

class_report = classification_report(true_labels, pred_labels, target_names=['ABDOMINAL', 'CREATININE', 'MAJOR-DIABETES'])
print("Classification Report:")
print(class_report)
