In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertModel, BertTokenizer
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, precision_recall_fscore_support
import pandas as pd
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score
from torch.utils.data import DataLoader, TensorDataset

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load and preprocess the dataset
df = pd.read_csv('mtsamples.csv')
df = df[['medical_specialty', 'transcription']].dropna()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Count occurrences of each category
category_counts = df['medical_specialty'].value_counts()
valid_categories = category_counts[category_counts > 300].index
df = df[df['medical_specialty'].isin(valid_categories)]

# Print the valid categories
print("Valid categories with more than 300 occurrences:")
print(valid_categories)
# Tokenization function
def preprocess_text(text):
    return tokenizer(text, padding='max_length', truncation=True, return_tensors='pt', max_length=128)

df['transcription'] = df['transcription'].apply(preprocess_text)
df['input_ids'] = df['transcription'].apply(lambda x: x['input_ids'][0])
df['attention_mask'] = df['transcription'].apply(lambda x: x['attention_mask'][0])

# Convert to tensors
X = torch.stack(df['input_ids'].tolist())
attention_masks = torch.stack(df['attention_mask'].tolist())
label_encoder = LabelEncoder()
y = torch.tensor(label_encoder.fit_transform(df['medical_specialty']))

# Split the data
X_train, X_test, y_train, y_test, train_masks, test_masks = train_test_split(
    X, y, attention_masks, test_size=0.1, random_state=42)

# Create DataLoaders for BERT and LSTM
train_dataset = TensorDataset(X_train, train_masks, y_train)
test_dataset = TensorDataset(X_test, test_masks, y_test)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Create DataLoaders for CNN
train_dataset_cnn = TensorDataset(X_train, y_train)
test_dataset_cnn = TensorDataset(X_test, y_test)
train_loader_cnn = DataLoader(train_dataset_cnn, batch_size=8, shuffle=True)
test_loader_cnn = DataLoader(test_dataset_cnn, batch_size=8, shuffle=False)

# Define BERT model
class BERTModel(nn.Module):
    def __init__(self):
        super(BERTModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(768, len(valid_categories))

    def forward(self, ids, mask):
        outputs = self.bert(ids, attention_mask=mask, return_dict=False)
        last_hidden_state = outputs[0]
        pooled_output = last_hidden_state[:, 0, :]
        return self.fc(pooled_output)

# Define CNN model
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.embedding = nn.Embedding(30522, 128)
        self.conv1 = nn.Conv1d(128, 64, kernel_size=5, padding=2)
        self.conv2 = nn.Conv1d(64, 32, kernel_size=5, padding=2)
        self.pool = nn.MaxPool1d(2)
        self.fc = nn.Linear(32 * 64, len(valid_categories))

    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(0, 2, 1)
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten
        return self.fc(x)

# Define LSTM model
class LSTMModel(nn.Module):
    def __init__(self):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(128, 128, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(128 * 2, len(valid_categories))

    def forward(self, ids, masks):
        lstm_output, _ = self.lstm(ids.float())
        return self.fc(lstm_output[:, -1, :]) if lstm_output.dim() == 3 else self.fc(lstm_output)

# Initialize models
model_bert = BERTModel().to(device)
model_cnn = CNNModel().to(device)
model_lstm = LSTMModel().to(device)

optimizer_bert = optim.Adam(model_bert.parameters(), lr=2e-5)
optimizer_cnn = optim.Adam(model_cnn.parameters(), lr=1e-3)
optimizer_lstm = optim.Adam(model_lstm.parameters(), lr=1e-3)

criterion = nn.CrossEntropyLoss()

# Training function for BERT and LSTM models
def train_model(model, optimizer, train_loader, model_name, epochs=1):
    model.train()
    total_loss = 0
    for epoch in range(epochs):
        for ids, masks, targets in train_loader:
            ids, masks, targets = ids.to(device), masks.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(ids, masks)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        average_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}, {model_name} trained with average loss: {average_loss:.4f}")

# Training function for CNN model
def train_model_cnn(model, optimizer, train_loader, model_name, epochs=1):
    model.train()
    total_loss = 0
    for epoch in range(epochs):
        for ids, targets in train_loader:
            ids, targets = ids.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(ids)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        average_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}, {model_name} trained with average loss: {average_loss:.4f}")

# Updated evaluation function for BERT and LSTM models
def evaluate_model(model, test_loader, model_name):
    model.eval()
    all_targets = []
    all_predictions = []
    with torch.no_grad():
        for ids, masks, targets in test_loader:
            ids, masks, targets = ids.to(device), masks.to(device), targets.to(device)
            outputs = model(ids, masks)
            _, predicted = torch.max(outputs, dim=1)
            all_targets.extend(targets.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    # Calculate metrics
    accuracy = np.mean(np.array(all_targets) == np.array(all_predictions))
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_targets, all_predictions, average='weighted'
    )

    print(f"{model_name} accuracy: {accuracy:.4f}")
    print(f"{model_name} precision: {precision:.4f}")
    print(f"{model_name} recall: {recall:.4f}")
    print(f"{model_name} F1 score: {f1:.4f}")

# Updated evaluation function for CNN model
def evaluate_model_cnn(model, test_loader, model_name):
    model.eval()
    all_targets = []
    all_predictions = []
    with torch.no_grad():
        for ids, targets in test_loader:
            ids, targets = ids.to(device), targets.to(device)
            outputs = model(ids)
            _, predicted = torch.max(outputs, dim=1)
            all_targets.extend(targets.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    # Calculate metrics
    accuracy = np.mean(np.array(all_targets) == np.array(all_predictions))
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_targets, all_predictions, average='weighted'
    )

    print(f"{model_name} accuracy: {accuracy:.4f}")
    print(f"{model_name} precision: {precision:.4f}")
    print(f"{model_name} recall: {recall:.4f}")
    print(f"{model_name} F1 score: {f1:.4f}")


# Train and evaluate the models
for epoch in range(1):
    train_model(model_bert, optimizer_bert, train_loader, "BERT Model", epochs=1)
    train_model(model_lstm, optimizer_lstm, train_loader, "LSTM Model", epochs=1)
    train_model_cnn(model_cnn, optimizer_cnn, train_loader_cnn, "CNN Model", epochs=1)

    evaluate_model(model_bert, test_loader, "BERT Model")
    evaluate_model(model_lstm, test_loader, "LSTM Model")
    evaluate_model_cnn(model_cnn, test_loader_cnn, "CNN Model")


Valid categories with more than 300 occurrences:
Index([' Surgery', ' Consult - History and Phy.',
       ' Cardiovascular / Pulmonary', ' Orthopedic'],
      dtype='object', name='medical_specialty')
Epoch 1, BERT Model trained with average loss: 0.7321
Epoch 1, LSTM Model trained with average loss: 1.2825
Epoch 1, CNN Model trained with average loss: 1.0036
BERT Model accuracy: 0.7296
BERT Model precision: 0.7524
BERT Model recall: 0.7296
BERT Model F1 score: 0.7126
LSTM Model accuracy: 0.4764
LSTM Model precision: 0.2270
LSTM Model recall: 0.4764
LSTM Model F1 score: 0.3074
CNN Model accuracy: 0.6824
CNN Model precision: 0.6021
CNN Model recall: 0.6824
CNN Model F1 score: 0.5980


  _warn_prf(average, modifier, msg_start, len(result))


In [None]:

# Define the StackingMLP model
class ImprovedStackingMLP(nn.Module):
    def __init__(self):
        super(ImprovedStackingMLP, self).__init__()
        self.fc1 = nn.Linear(3 * len(valid_categories), 100)
        self.dropout1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(100, 50)
        self.dropout2 = nn.Dropout(0.3)
        self.fc3 = nn.Linear(50, len(valid_categories))

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        return self.fc3(x)

# Function to normalize stacking inputs with standard scaling
def normalize_data(data):
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)
    return (data - mean) / std

# Ensure that all models are set to evaluation mode
def get_model_predictions(model, data_loader):
    model.eval()
    predictions = []

    with torch.no_grad():
        for batch in data_loader:
            if isinstance(model, BERTModel) or isinstance(model, LSTMModel):
                ids, masks = batch[:2]
                ids = ids.to(device)
                masks = masks.to(device)
                outputs = model(ids, masks)
            elif isinstance(model, CNNModel):
                ids = batch[0]
                ids = ids.to(device)
                outputs = model(ids)
            else:
                raise ValueError("Unknown model type")

            predictions.append(outputs.cpu().numpy())

    predictions = np.concatenate(predictions, axis=0)
    return predictions

# Generate predictions from BERT, CNN, and LSTM models
print("Generating predictions from BERT model...")
bert_preds = get_model_predictions(model_bert, test_loader)
print("Generating predictions from CNN model...")
cnn_preds = get_model_predictions(model_cnn, test_loader_cnn)
print("Generating predictions from LSTM model...")
lstm_preds = get_model_predictions(model_lstm, test_loader)

# Convert logits to probabilities
bert_probs = F.softmax(torch.tensor(bert_preds), dim=1).numpy()
cnn_probs = F.softmax(torch.tensor(cnn_preds), dim=1).numpy()
lstm_probs = F.softmax(torch.tensor(lstm_preds), dim=1).numpy()

# Prepare stacking inputs for the model
stacking_inputs = np.concatenate([bert_probs, cnn_probs, lstm_probs], axis=1)
stacking_inputs = normalize_data(stacking_inputs)

# Convert stacking inputs to tensor and create dataset
stacking_inputs_tensor = torch.tensor(stacking_inputs).float()
stacking_labels_tensor = torch.tensor(y_test.numpy()).long()
stacking_dataset = TensorDataset(stacking_inputs_tensor, stacking_labels_tensor)

# Split the data into training and test sets
train_inputs, test_inputs, train_labels, test_labels = train_test_split(
    stacking_inputs_tensor, stacking_labels_tensor, test_size=0.1, random_state=42
)

train_data = TensorDataset(train_inputs, train_labels)
test_data = TensorDataset(test_inputs, test_labels)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

def train_and_evaluate_improved_stacking_model(model, optimizer, train_loader, test_loader, epochs=100, patience=50, min_delta=0.01, accuracy_threshold=0.80):
    best_accuracy = 0
    epochs_no_improve = 0
    best_model_state = None

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0
        correct_train_preds = 0
        total_train_preds = 0

        for batch in train_loader:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_train_preds += (predicted == targets).sum().item()
            total_train_preds += targets.size(0)

        avg_train_loss = train_loss / len(train_loader)
        train_accuracy = correct_train_preds / total_train_preds

        # Evaluation phase (on test data)
        model.eval()
        test_loss = 0
        correct_test_preds = 0
        total_test_preds = 0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)

                outputs = model(inputs)
                loss = F.cross_entropy(outputs, targets)

                test_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct_test_preds += (predicted == targets).sum().item()
                total_test_preds += targets.size(0)

                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        avg_test_loss = test_loss / len(test_loader)
        test_accuracy = correct_test_preds / total_test_preds

        # Calculate F1 score, precision, and recall
        f1 = f1_score(all_targets, all_preds, average='weighted')
        precision = precision_score(all_targets, all_preds, average='weighted')
        recall = recall_score(all_targets, all_preds, average='weighted')

        # Print metrics for the current epoch
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Training loss: {avg_train_loss:.4f}, Training accuracy: {train_accuracy:.4f}")
        print(f"Test loss: {avg_test_loss:.4f}, Test accuracy: {test_accuracy:.4f}")
        print(f"F1 score: {f1:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print("-" * 100)

        # Check for early stopping based on accuracy threshold
        if test_accuracy > accuracy_threshold:
            print("Test accuracy exceeded threshold. Stopping training.")
            break

        # Check for early stopping based on improvement
        if test_accuracy > best_accuracy + min_delta:
            best_accuracy = test_accuracy
            epochs_no_improve = 0
            best_model_state = model.state_dict()
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print("Early stopping triggered.")
            break

    # Load the best model state
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

# Train and evaluate the improved stacking model
improved_stacking_model = ImprovedStackingMLP().to(device)
improved_stacking_model_optimizer = optim.Adam(improved_stacking_model.parameters(), lr=1e-2)
train_and_evaluate_improved_stacking_model(improved_stacking_model, improved_stacking_model_optimizer, train_loader, test_loader, epochs=100, patience=25, min_delta=0.01, accuracy_threshold=0.80)


Generating predictions from BERT model...
Generating predictions from CNN model...
Generating predictions from LSTM model...
Epoch 1/100
Training loss: 1.0658, Training accuracy: 0.5933
Test loss: 0.9712, Test accuracy: 0.6667
F1 score: 0.5339
Precision: 0.4456
Recall: 0.6667
----------------------------------------------------------------------------------------------------
Epoch 2/100
Training loss: 0.7225, Training accuracy: 0.7560
Test loss: 0.5539, Test accuracy: 0.7500
F1 score: 0.6822
Precision: 0.6975
Recall: 0.7500
----------------------------------------------------------------------------------------------------
Epoch 3/100
Training loss: 0.6681, Training accuracy: 0.7225
Test loss: 0.5643, Test accuracy: 0.7083
F1 score: 0.6097
Precision: 0.5894
Recall: 0.7083
----------------------------------------------------------------------------------------------------
Epoch 4/100
Training loss: 0.6346, Training accuracy: 0.7321
Test loss: 0.4900, Test accuracy: 0.7083
F1 score: 0.60

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 5/100
Training loss: 0.5611, Training accuracy: 0.7751
Test loss: 0.4585, Test accuracy: 0.7500
F1 score: 0.6626
Precision: 0.5938
Recall: 0.7500
----------------------------------------------------------------------------------------------------
Epoch 6/100
Training loss: 0.5462, Training accuracy: 0.7512
Test loss: 0.5492, Test accuracy: 0.7500
F1 score: 0.6614
Precision: 0.6100
Recall: 0.7500
----------------------------------------------------------------------------------------------------
Epoch 7/100
Training loss: 0.5702, Training accuracy: 0.7751
Test loss: 0.4721, Test accuracy: 0.7500
F1 score: 0.6614
Precision: 0.6100
Recall: 0.7500
----------------------------------------------------------------------------------------------------
Epoch 8/100
Training loss: 0.5135, Training accuracy: 0.7512
Test loss: 0.4581, Test accuracy: 0.7500
F1 score: 0.7221
Precision: 0.7639
Recall: 0.7500
-----------------------------------------------------------------------------------------

  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
#Prediction
def preprocess_input(transcript):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    encoding = tokenizer.encode_plus(
        transcript,
        add_special_tokens=True,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    input_ids = encoding['input_ids']
    attention_mask = encoding['attention_mask']
    bert_probs = model_bert(input_ids, attention_mask)
    cnn_probs = model_cnn(input_ids)
    lstm_probs = model_lstm(input_ids, attention_mask)
    stacking_inputs = torch.cat((bert_probs, cnn_probs, lstm_probs), dim=1)
    return stacking_inputs

In [None]:
input_transcript = "Left total knee arthroplasty performed due to severe osteoarthritis. Patient placed in supine position under general anesthesia. Tourniquet inflated. Tibial and femoral components implanted. Patellar component inserted. Wound closed in layers. Postoperative pain management initiated."
preprocessed_input = preprocess_input(input_transcript)
outputs = improved_stacking_model(preprocessed_input)
predicted_probs = F.softmax(outputs, dim=1)
predicted_label_idx = torch.argmax(predicted_probs, dim=1)
labels = ['Surgery', 'Consult - History and Phy.', 'Cardiovascular / Pulmonary', 'Orthopedic']
predicted_label = labels[predicted_label_idx.item()]

In [None]:
print(predicted_label)

Orthopedic
