In [5]:
# Import packages
from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import numpy as np
from tqdm import tqdm

In [12]:
# Define the dataset parameters
num_samples = 32  # Small subset for testing
embedding_dim = 768  # Dimension of ECG and text embeddings
num_classes = 3  # Tertiary classification

# Generate random ECG and text embeddings
np.random.seed(42)
ecg_embeddings = np.random.randn(num_samples, embedding_dim).astype(np.float32)
text_embeddings = np.random.randn(num_samples, embedding_dim).astype(np.float32)

# Generate synthetic class labels (random integers between 0 and 2)
labels = np.random.randint(0, num_classes, size=(num_samples,))

# Convert to PyTorch tensors
ecg_embeddings = torch.tensor(ecg_embeddings)
text_embeddings = torch.tensor(text_embeddings)
labels = torch.tensor(labels, dtype=torch.long)

# Create a simple dataset
class ECGTextDataset(Dataset):
    def __init__(self, ecg_data, text_data, labels):
        self.ecg_data = ecg_data
        self.text_data = text_data
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.ecg_data[idx], self.text_data[idx], self.labels[idx]

# Instantiate dataset and dataloader for the small batch
dataset = ECGTextDataset(ecg_embeddings, text_embeddings, labels)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [11]:
# MLP Classifier for text embeddings
class MLPText(nn.Module):
    def __init__(self, input_dim = 768, hidden_dim = 256, num_classes = 3):
        super(MLPText, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
# MLP Classifier for ECG embeddings
class MLPECG(nn.Module):
    def __init__(self, input_dim = 768, hidden_dim = 256, num_classes = 3):
        super(MLPECG, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [14]:
# Initialize models
ecg_mlp = MLPECG(input_dim = embedding_dim, hidden_dim = 256, num_classes = num_classes)
text_mlp = MLPText(input_dim = embedding_dim, hidden_dim = 256, num_classes = num_classes)

# Initialize optimizer and loss functions
optimizer = torch.optim.Adam(list(ecg_mlp.parameters()) + list(text_mlp.parameters()), lr=0.001)
classification_loss = nn.CrossEntropyLoss()
kl_loss = nn.KLDivLoss(reduction="batchmean")



In [19]:
# Training
epochs = 3
for epoch in range(epochs):
    ecg_mlp.train()
    text_mlp.train()
    
    with tqdm(dataloader, desc = f"Epoch {epoch + 1}", total = len(dataloader)) as pbar:
        for ecg_batch, text_batch, label_batch in pbar:
            optimizer.zero_grad()
            
            # Forward pass
            ecg_outputs = ecg_mlp(ecg_batch) # [batch_size, num_classes]
            log_ecg_outputs = F.log_softmax(ecg_outputs, dim = 1)
            text_outputs = text_mlp(text_batch) # [batch_size, num_classes]
            soft_text_outputs = F.softmax(text_outputs, dim = 1)
            
            # Compute losses
            prediction_loss = classification_loss(ecg_outputs, label_batch)
            kl_divergence_loss = kl_loss(log_ecg_outputs, soft_text_outputs)
            
            # Total loss
            alpha = 0.5 # Weight for KL divergence loss
            total_loss = prediction_loss + alpha * kl_divergence_loss
            
            # Backpropagation
            total_loss.backward()
            optimizer.step()
            
            # Update
            pbar.set_postfix(prediction_loss=prediction_loss.item(), kl_loss=kl_divergence_loss.item(), total_loss=total_loss.item())
        
        print(f"Epoch {epoch+1}: Prediction Loss={prediction_loss.item():.4f}, KL Loss={kl_divergence_loss.item():.4f}, Total Loss={total_loss.item():.4f}")
        

Epoch 1: 100%|██████████| 4/4 [00:12<00:00,  3.17s/it, kl_loss=0.0111, prediction_loss=0.0459, total_loss=0.0515]


Epoch 1: Prediction Loss=0.0459, KL Loss=0.0111, Total Loss=0.0515


Epoch 2: 100%|██████████| 4/4 [00:12<00:00,  3.20s/it, kl_loss=0.00519, prediction_loss=0.0427, total_loss=0.0453]


Epoch 2: Prediction Loss=0.0427, KL Loss=0.0052, Total Loss=0.0453


Epoch 3: 100%|██████████| 4/4 [00:12<00:00,  3.22s/it, kl_loss=0.00562, prediction_loss=0.0463, total_loss=0.0491]

Epoch 3: Prediction Loss=0.0463, KL Loss=0.0056, Total Loss=0.0491



