In [7]:
# Cell 1 — Imports and Setup
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import pickle
import ast
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

# Enable mixed precision for faster training
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
print("Mixed precision enabled")

Device: cuda
GPU: NVIDIA GeForce RTX 4070 Laptop GPU
Mixed precision enabled


In [8]:
# Cell 2 — Load Data and Tokenizer
train_df = pd.read_csv('../data/Processed/train.csv')
val_df = pd.read_csv('../data/Processed/val.csv')
test_df = pd.read_csv('../data/Processed/test.csv')

with open('../data/Processed/mlb.pkl', 'rb') as f:
    mlb = pickle.load(f)
with open('../data/Processed/top50_codes.pkl', 'rb') as f:
    top50_codes = pickle.load(f)

train_df['icd_codes'] = train_df['icd_codes'].apply(ast.literal_eval)
val_df['icd_codes'] = val_df['icd_codes'].apply(ast.literal_eval)
test_df['icd_codes'] = test_df['icd_codes'].apply(ast.literal_eval)

print("Loading Clinical-Longformer tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer")
print(f"Tokenizer loaded. Vocabulary size: {tokenizer.vocab_size:,}")

print(f"\nTrain: {len(train_df):,}")
print(f"Val: {len(val_df):,}")
print(f"Test: {len(test_df):,}")

Loading Clinical-Longformer tokenizer...
Tokenizer loaded. Vocabulary size: 50,265

Train: 82,501
Val: 9,084
Test: 23,048


In [9]:
# Cell 3 — PLM-ICD Dataset with Full Document Processing
class PLMICDDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        tokens = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )

        return {
            'input_ids': tokens['input_ids'].squeeze(),
            'attention_mask': tokens['attention_mask'].squeeze(),
            'labels': torch.FloatTensor(label)
        }

y_train = mlb.transform(train_df['icd_codes'])
y_val = mlb.transform(val_df['icd_codes'])
y_test = mlb.transform(test_df['icd_codes'])

train_dataset = PLMICDDataset(train_df['text_clean'].values, y_train, tokenizer, max_length=512)
val_dataset = PLMICDDataset(val_df['text_clean'].values, y_val, tokenizer, max_length=512)
test_dataset = PLMICDDataset(test_df['text_clean'].values, y_test, tokenizer, max_length=512)

print(f"Train: {len(train_dataset):,}")
print(f"Val: {len(val_dataset):,}")
print(f"Test: {len(test_dataset):,}")

Train: 82,501
Val: 9,084
Test: 23,048


In [10]:
# Cell 4 — Collate Function and DataLoaders
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn
)

print(f"Train batches: {len(train_loader):,}")
print(f"Val batches: {len(val_loader):,}")
print(f"Test batches: {len(test_loader):,}")

Train batches: 5,157
Val batches: 568
Test batches: 1,441


In [11]:
# Cell 5 — PLM-ICD with Clinical-Longformer Backbone
class PLMICD(nn.Module):
    def __init__(self, num_labels=50, dropout=0.1):
        super(PLMICD, self).__init__()
        self.longformer = AutoModel.from_pretrained("yikuan8/Clinical-Longformer")
        self.dropout = nn.Dropout(dropout)
        self.label_attention = nn.Linear(768, num_labels)
        self.classifier = nn.Linear(768, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.longformer(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        token_output = outputs.last_hidden_state
        token_output = self.dropout(token_output)

        attention_scores = self.label_attention(token_output)

        attention_mask_expanded = attention_mask.unsqueeze(-1).float()
        attention_scores = attention_scores * attention_mask_expanded
        attention_scores = attention_scores - (1 - attention_mask_expanded) * 1e9

        attention_weights = torch.softmax(attention_scores, dim=1)

        label_representations = torch.bmm(
            attention_weights.transpose(1, 2),
            token_output
        )

        logits = self.classifier(label_representations)
        logits = torch.diagonal(logits, dim1=1, dim2=2)

        return logits

model = PLMICD(num_labels=50)
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
print("Clinical-Longformer PLM-ICD model loaded successfully")

Some weights of LongformerModel were not initialized from the model checkpoint at yikuan8/Clinical-Longformer and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total parameters: 148,736,356
Clinical-Longformer PLM-ICD model loaded successfully


In [None]:
# Cell 6 — Training Loop with Mixed Precision
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
num_epochs = 3
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=total_steps // 10,
    num_training_steps=total_steps
)

criterion = nn.BCEWithLogitsLoss()
best_val_f1 = 0

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch_idx, batch in enumerate(train_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        with autocast():
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()

        if batch_idx % 5000 == 0:
            print(f"Epoch {epoch+1} | Batch {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")

    # Validation
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            with autocast():
                logits = model(input_ids, attention_mask)
            preds = (torch.sigmoid(logits) > 0.5).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    val_f1 = f1_score(all_labels, all_preds, average='micro')
    avg_loss = total_loss / len(train_loader)
    print(f"\nEpoch {epoch+1} complete | Avg Loss: {avg_loss:.4f} | Val Micro F1: {val_f1:.4f}")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), '../models/plmicd_longformer_best.pt')
        print(f"Best model saved with F1: {best_val_f1:.4f}\n")

print(f"\nTraining complete. Best Val F1: {best_val_f1:.4f}")

Epoch 1 | Batch 0/5157 | Loss: 0.7005
