In [9]:
# Import packages
from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from tqdm import tqdm

In [2]:
# Load BioBERT Model
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
model = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")



In [3]:
# Encode words
def encode_text(text):
    # Tokenize input text
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    
    # Forward pass to get hidden states
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Extract the [CLS] token representation (sentence embedding)
    cls_embedding = outputs.last_hidden_state[:, 0, :]
    
    return cls_embedding

# Example usage
text = "The patient's heart is fine."
vector = encode_text(text)
print(vector.shape)  # Should be (1, 768) for BioBERT-base

torch.Size([1, 768])


In [4]:
# Get embedding
def get_cls_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    cls_embedding = outputs.last_hidden_state[:, 0, :]  # Extract [CLS] embedding
    return cls_embedding  # Shape: (1, 768)

# Toy dataset
texts = ["Heart failure detected.", "No signs of cardiovascular issues.", "Possible arrhythmia found."]
labels = [1, 0, 1]  # 1: Disease, 0: No Disease

# Convert texts to embeddings
embeddings = torch.cat([get_cls_embedding(text) for text in texts]).numpy()

In [5]:
# MLP Classifier
class MLPClassifier(nn.Module):
    def __init__(self, input_dim = 768, hidden_dim = 256, num_classes = 2):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [6]:
# Train MLP
# Example medical sentences and labels (1 = disease, 0 = no disease)
texts = ["Patient shows signs of arrhythmia.", 
         "No signs of cardiovascular issues.", 
         "ECG indicates possible heart failure.", 
         "Heart rate appears normal."]

labels = torch.tensor([1, 0, 1, 0])  # Binary classification

# Convert texts to embeddings
embeddings = torch.cat([get_cls_embedding(text) for text in texts])

In [7]:
# Create a dataset
dataset = TensorDataset(embeddings, labels)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Initialize MLP
model = MLPClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)

In [8]:
# Training loop with mini-batches
epochs = 10
for epoch in range(epochs):
    with tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") as pbar:
        for batch in pbar:
            inputs, targets = batch
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            pbar.set_postfix(loss=loss.item())  # Update progress bar

Epoch 1/10: 100%|██████████| 2/2 [00:03<00:00,  1.57s/it, loss=0.697]
Epoch 2/10: 100%|██████████| 2/2 [00:03<00:00,  1.55s/it, loss=0.798]
Epoch 3/10: 100%|██████████| 2/2 [00:03<00:00,  1.55s/it, loss=0.63] 
Epoch 4/10: 100%|██████████| 2/2 [00:03<00:00,  1.55s/it, loss=0.459]
Epoch 5/10: 100%|██████████| 2/2 [00:03<00:00,  1.55s/it, loss=0.369]
Epoch 6/10: 100%|██████████| 2/2 [00:03<00:00,  1.65s/it, loss=0.382]
Epoch 7/10: 100%|██████████| 2/2 [00:03<00:00,  1.55s/it, loss=0.322]
Epoch 8/10: 100%|██████████| 2/2 [00:03<00:00,  1.55s/it, loss=0.208]
Epoch 9/10: 100%|██████████| 2/2 [00:03<00:00,  1.55s/it, loss=0.31] 
Epoch 10/10: 100%|██████████| 2/2 [00:03<00:00,  1.55s/it, loss=0.184]
