In [12]:
# Cell 1 — Install and Setup
!pip install transformers accelerate -q
!pip install pandas
!pip install scikit-learn
!pip install numpy==1.26.4 -q
!pip install fastparquet

import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import pickle
import ast
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score
from torch.cuda.amp import GradScaler, autocast
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
scaler = GradScaler()
print(f"Device: {device}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"PyTorch version: {torch.__version__}")
print("Mixed precision enabled")

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To updat

In [13]:
# Cell 2 — Load Data and Tokenizer
train_df = pd.read_parquet('/workspace/train.parquet')
val_df = pd.read_parquet('/workspace/val.parquet')
test_df = pd.read_parquet('/workspace/test.parquet')

with open('/workspace/mlb.pkl', 'rb') as f:
    mlb = pickle.load(f)
with open('/workspace/top50_codes.pkl', 'rb') as f:
    top50_codes = pickle.load(f)

train_df['icd_codes'] = train_df['icd_codes'].apply(ast.literal_eval)
val_df['icd_codes'] = val_df['icd_codes'].apply(ast.literal_eval)
test_df['icd_codes'] = test_df['icd_codes'].apply(ast.literal_eval)

print("Loading Clinical-Longformer tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer")
print(f"Tokenizer loaded. Vocabulary size: {tokenizer.vocab_size:,}")
print(f"Train: {len(train_df):,}")
print(f"Val: {len(val_df):,}")
print(f"Test: {len(test_df):,}")

Loading Clinical-Longformer tokenizer...
Tokenizer loaded. Vocabulary size: 50,265
Train: 82,501
Val: 9,084
Test: 23,048


In [14]:
# Cell 3 — Dataset
class PLMICDDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        tokens = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        return {
            'input_ids': tokens['input_ids'].squeeze(),
            'attention_mask': tokens['attention_mask'].squeeze(),
            'labels': torch.FloatTensor(label)
        }

y_train = mlb.transform(train_df['icd_codes'])
y_val = mlb.transform(val_df['icd_codes'])
y_test = mlb.transform(test_df['icd_codes'])

train_dataset = PLMICDDataset(train_df['text_clean'].values, y_train, tokenizer, max_length=512)
val_dataset = PLMICDDataset(val_df['text_clean'].values, y_val, tokenizer, max_length=512)
test_dataset = PLMICDDataset(test_df['text_clean'].values, y_test, tokenizer, max_length=512)

print(f"Train: {len(train_dataset):,}")
print(f"Val: {len(val_dataset):,}")
print(f"Test: {len(test_dataset):,}")

Train: 82,501
Val: 9,084
Test: 23,048


In [15]:
# Cell 4 — DataLoaders
def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=4, collate_fn=collate_fn)

print(f"Train batches: {len(train_loader):,}")
print(f"Val batches: {len(val_loader):,}")
print(f"Test batches: {len(test_loader):,}")

Train batches: 10,313
Val batches: 1,136
Test batches: 2,881


In [16]:
# Cell 5 — PLM-ICD Model with Clinical-Longformer + Label Attention
class PLMICD(nn.Module):
    def __init__(self, num_labels=50, dropout=0.1):
        super(PLMICD, self).__init__()
        self.longformer = AutoModel.from_pretrained("yikuan8/Clinical-Longformer")
        self.dropout = nn.Dropout(dropout)
        self.label_attention = nn.Linear(768, num_labels)
        self.classifier = nn.Linear(768, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.longformer(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        token_output = outputs.last_hidden_state
        token_output = self.dropout(token_output)

        attention_scores = self.label_attention(token_output)
        attention_mask_expanded = attention_mask.unsqueeze(-1).float()
        attention_scores = attention_scores * attention_mask_expanded
        attention_scores = attention_scores - (1 - attention_mask_expanded) * 1e9
        attention_weights = torch.softmax(attention_scores, dim=1)

        label_representations = torch.bmm(
            attention_weights.transpose(1, 2),
            token_output
        )

        logits = self.classifier(label_representations)
        logits = torch.diagonal(logits, dim1=1, dim2=2)
        return logits

model = PLMICD(num_labels=50)
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
print("PLM-ICD Clinical-Longformer model loaded successfully")

Loading weights:   0%|          | 0/269 [00:00<?, ?it/s]

[1mLongformerModel LOAD REPORT[0m from: yikuan8/Clinical-Longformer
Key                                | Status     | 
-----------------------------------+------------+-
lm_head.bias                       | UNEXPECTED | 
lm_head.decoder.weight             | UNEXPECTED | 
longformer.embeddings.position_ids | UNEXPECTED | 
lm_head.dense.bias                 | UNEXPECTED | 
lm_head.layer_norm.bias            | UNEXPECTED | 
lm_head.layer_norm.weight          | UNEXPECTED | 
lm_head.dense.weight               | UNEXPECTED | 
lm_head.decoder.bias               | UNEXPECTED | 
pooler.dense.bias                  | MISSING    | 
pooler.dense.weight                | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


Total parameters: 148,736,356
PLM-ICD Clinical-Longformer model loaded successfully


In [17]:
# Cell 6 — Training Loop with Mixed Precision
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
num_epochs = 3
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=total_steps // 10,
    num_training_steps=total_steps
)

criterion = nn.BCEWithLogitsLoss()
best_val_f1 = 0

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch_idx, batch in enumerate(train_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        with autocast():
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        total_loss += loss.item()

        if batch_idx % 1000 == 0:
            print(f"Epoch {epoch+1} | Batch {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")

    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            with autocast():
                logits = model(input_ids, attention_mask)
            preds = (torch.sigmoid(logits) > 0.5).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    val_f1 = f1_score(all_labels, all_preds, average='micro')
    avg_loss = total_loss / len(train_loader)
    print(f"\nEpoch {epoch+1} complete | Avg Loss: {avg_loss:.4f} | Val Micro F1: {val_f1:.4f}")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), '/workspace/plmicd_best.pt')
        print(f"Best model saved with F1: {best_val_f1:.4f}\n")

print(f"\nTraining complete. Best Val F1: {best_val_f1:.4f}")

Epoch 1 | Batch 0/10313 | Loss: 0.7122
Epoch 1 | Batch 1000/10313 | Loss: 0.2991
Epoch 1 | Batch 2000/10313 | Loss: 0.2574
Epoch 1 | Batch 3000/10313 | Loss: 0.2364
Epoch 1 | Batch 4000/10313 | Loss: 0.2663
Epoch 1 | Batch 5000/10313 | Loss: 0.1605
Epoch 1 | Batch 6000/10313 | Loss: 0.3431
Epoch 1 | Batch 7000/10313 | Loss: 0.2587
Epoch 1 | Batch 8000/10313 | Loss: 0.2966
Epoch 1 | Batch 9000/10313 | Loss: 0.1876
Epoch 1 | Batch 10000/10313 | Loss: 0.2449

Epoch 1 complete | Avg Loss: 0.2561 | Val Micro F1: 0.4920
Best model saved with F1: 0.4920

Epoch 2 | Batch 0/10313 | Loss: 0.2077
Epoch 2 | Batch 1000/10313 | Loss: 0.2735
Epoch 2 | Batch 2000/10313 | Loss: 0.1982
Epoch 2 | Batch 3000/10313 | Loss: 0.1513
Epoch 2 | Batch 4000/10313 | Loss: 0.2073
Epoch 2 | Batch 5000/10313 | Loss: 0.1675
Epoch 2 | Batch 6000/10313 | Loss: 0.1637
Epoch 2 | Batch 7000/10313 | Loss: 0.2140
Epoch 2 | Batch 8000/10313 | Loss: 0.2471
Epoch 2 | Batch 9000/10313 | Loss: 0.2099
Epoch 2 | Batch 10000/10313 |