1. Setup Environment

Install required packages.

In [None]:
%pip install -q transformers==4.20.1 datasets==2.10.0 pandas==1.4.2 numpy==1.22.4 scikit-learn==1.1.1 torch==1.11.0 nltk==3.7 imbalanced-learn==0.9.1


In [None]:
import torch
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import accuracy_score, f1_score, classification_report
import nltk
from nltk.corpus import stopwords
from imblearn.over_sampling import SMOTE


In [None]:
nltk.download('stopwords')


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


2. Create and Preprocess drug_use_data.csv

Load SetFit/ade_corpus_v2_classification train split, create CSV, and preprocess.

In [None]:
import re
import urllib.request

# Define splits
splits = {'train': 'train.jsonl', 'test': 'test.jsonl'}

# Load via hf:// protocol
try:
    df = pd.read_json("hf://datasets/SetFit/ade_corpus_v2_classification/" + splits["train"], lines=True)
except Exception as e:
    print(f"hf:// loading failed: {e}")
    print("Falling back to direct URL...")
    url = "https://huggingface.co/datasets/SetFit/ade_corpus_v2_classification/resolve/main/train.jsonl"
    urllib.request.urlretrieve(url, "train.jsonl")
    df = pd.read_json("train.jsonl", lines=True)

# Expanded substance and symptom lists
substance_map = {
    'morphine': 'opioid', 'oxycodone': 'opioid', 'fentanyl': 'opioid', 'hydrocodone': 'opioid',
    'heroin': 'opioid', 'codeine': 'opioid', 'tramadol': 'opioid',
    'cocaine': 'stimulant', 'methamphetamine': 'stimulant', 'amphetamine': 'stimulant',
    'placebo': 'none', 'heparin': 'none'
}
symptom_list = ['nausea', 'confusion', 'drowsiness', 'overdose', 'dizziness', 'vomiting',
                'fatigue', 'headache', 'anxiety', 'seizure', 'hematoma', 'rash', 'pain',
                'constipation', 'dyspnea', 'pruritus']

def assign_labels(text, original_label=None):
    substance = 'none'
    symptoms = []
    text_lower = str(text).lower()
    
    # Check for substances
    for drug, subst in substance_map.items():
        if drug in text_lower:
            substance = subst
            break
    
    # Check for symptoms
    for symp in symptom_list:
        if symp in text_lower:
            symptoms.append(symp)
    
    # Use original ADE label if available and no symptoms found
    if original_label == 1 and not symptoms:
        symptoms = ['adverse_event']
    
    return substance, symptoms if symptoms else ['none']

# Apply labels with original label information
if 'label' in df.columns:
    df['substance_label'], df['symptom_labels'] = zip(*[
        assign_labels(text, label) for text, label in zip(df['text'], df['label'])
    ])
else:
    df['substance_label'], df['symptom_labels'] = zip(*df['text'].apply(lambda x: assign_labels(x)))

# Save to CSV BEFORE any processing that might duplicate data
df[['text', 'substance_label', 'symptom_labels']].to_csv('drug_use_data.csv', index=False)
print('Dataset saved as drug_use_data.csv')

# Preprocess text
def preprocess_text(text):
    text = str(text).lower()
    text = re.sub(r'http\S+|www\S+|https\S+', '', text)
    text = re.sub(r'@\w+', '', text)
    text = re.sub(r'#\w+', '', text)
    text = re.sub(r'[^\x00-\x7F]+', '', text)
    text = re.sub(r'[^\w\s]', '', text)
    stop_words = set(stopwords.words('english'))
    text = ' '.join(word for word in text.split() if word not in stop_words)
    return text

df['text'] = df['text'].apply(preprocess_text)

# Encode labels
substance_classes = df['substance_label'].unique()
substance2id = {label: idx for idx, label in enumerate(substance_classes)}
df['substance_label'] = df['substance_label'].map(substance2id)

mlb = MultiLabelBinarizer()
symptom_encoded = mlb.fit_transform(df['symptom_labels'])
symptom_df = pd.DataFrame(symptom_encoded, columns=mlb.classes_)
symptom_columns = mlb.classes_

# Combine dataframes
df = pd.concat([df[['text', 'substance_label']], symptom_df], axis=1)

# Apply SMOTE for balanced training data
from imblearn.over_sampling import SMOTE
from collections import Counter

print("Original distribution:", Counter(df['substance_label']))

# Use TF-IDF features for SMOTE
temp_vectorizer = TfidfVectorizer(max_features=1000, stop_words='english')
X_tfidf_temp = temp_vectorizer.fit_transform(df['text']).toarray()

try:
    smote = SMOTE(random_state=42, k_neighbors=min(3, Counter(df['substance_label']).most_common()[-1][1] - 1))
    X_balanced, y_balanced = smote.fit_resample(X_tfidf_temp, df['substance_label'])
    
    # Create balanced dataframe by finding closest matches
    balanced_indices = []
    for x_sample in X_balanced:
        similarities = np.dot(X_tfidf_temp, x_sample)
        closest_idx = np.argmax(similarities)
        balanced_indices.append(closest_idx)
    
    df_balanced = df.iloc[balanced_indices].copy()
    df_balanced['substance_label'] = y_balanced
    df = df_balanced
    
    print("Balanced distribution:", Counter(df['substance_label']))
except ValueError as e:
    print(f"SMOTE failed: {e}, using original data with manual balancing")
    # Fallback: simple oversampling for minority classes
    minority_threshold = len(df) * 0.1  # 10% threshold
    minority_data = []
    for label in df['substance_label'].unique():
        label_data = df[df['substance_label'] == label]
        if len(label_data) < minority_threshold:
            # Duplicate minority class samples
            multiplier = int(minority_threshold / len(label_data)) + 1
            minority_data.append(pd.concat([label_data] * multiplier, ignore_index=True))
    
    if minority_data:
        df = pd.concat([df] + minority_data, ignore_index=True)
        print("Manual balancing applied")

# Split data
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['substance_label'])

print(f'Training samples: {len(train_df)}, Test samples: {len(test_df)}')

Dataset saved as drug_use_data.csv
Original distribution: Counter({0: 17510, 1: 102, 2: 25})


3. Create TF-IDF Features and Datasets

Use TF-IDF features and create custom dataset.

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer
import torch
import numpy as np
from scipy import sparse

# Create TF-IDF features with reduced memory footprint
print("Creating TF-IDF features...")
vectorizer = TfidfVectorizer(
    max_features=2000,  # Reduced from 5000 to save memory
    stop_words='english', 
    ngram_range=(1, 2),  # Reduced from (1,3) to save memory
    dtype=np.float32,
    min_df=3,  # Increased to reduce vocabulary size
    max_df=0.90  # More aggressive filtering
)

# Keep matrices in sparse format - DON'T convert to dense arrays
X_train_tfidf_sparse = vectorizer.fit_transform(train_df['text'])
X_test_tfidf_sparse = vectorizer.transform(test_df['text'])

print(f"TF-IDF sparse matrix shape: {X_train_tfidf_sparse.shape}")
print(f"Memory usage (sparse): ~{X_train_tfidf_sparse.data.nbytes / 1024**2:.1f} MB")
print(f"Vocabulary size: {len(vectorizer.vocabulary_)}")

# Convert sparse matrices to dense in smaller batches to avoid memory issues
def sparse_to_dense_batched(sparse_matrix, batch_size=1000):
    """Convert sparse matrix to dense in batches to manage memory"""
    n_samples = sparse_matrix.shape[0]
    n_features = sparse_matrix.shape[1]
    
    # Pre-allocate dense array
    dense_array = np.zeros((n_samples, n_features), dtype=np.float32)
    
    # Process in batches
    for i in range(0, n_samples, batch_size):
        end_idx = min(i + batch_size, n_samples)
        batch_sparse = sparse_matrix[i:end_idx]
        dense_array[i:end_idx] = batch_sparse.toarray()
        
        if i % (batch_size * 10) == 0:  # Progress update every 10 batches
            print(f"Processed {i}/{n_samples} samples...")
    
    return dense_array

print("Converting sparse matrices to dense (this may take a moment)...")
try:
    X_train_tfidf = sparse_to_dense_batched(X_train_tfidf_sparse, batch_size=500)
    X_test_tfidf = sparse_to_dense_batched(X_test_tfidf_sparse, batch_size=500)
    print("Conversion completed successfully!")
except MemoryError:
    print("Still not enough memory. Using even smaller batch size...")
    try:
        X_train_tfidf = sparse_to_dense_batched(X_train_tfidf_sparse, batch_size=100)
        X_test_tfidf = sparse_to_dense_batched(X_test_tfidf_sparse, batch_size=100)
        print("Conversion completed with smaller batches!")
    except MemoryError:
        print("Memory still insufficient. Switching to sparse-compatible approach...")
        # Alternative: Work directly with sparse matrices (requires model modification)
        raise MemoryError("Consider using a machine with more RAM or further reducing max_features")

print(f"Final TF-IDF feature shape: {X_train_tfidf.shape}")

# Memory-efficient dataset class
class TFIDFDataset(torch.utils.data.Dataset):
    def __init__(self, features, substance_labels, symptom_labels):
        # Store as numpy arrays to save memory compared to tensors
        self.features = features.astype(np.float32)
        self.substance_labels = substance_labels.astype(np.int64)
        self.symptom_labels = symptom_labels.astype(np.float32)
        
        print(f"Dataset created with {len(self.features)} samples")
        print(f"Feature shape: {self.features.shape}")
        print(f"Memory usage: ~{self.features.nbytes / 1024**2:.1f} MB")
    
    def __getitem__(self, idx):
        # Convert to tensors only when needed (lazy loading)
        return {
            'x': torch.from_numpy(self.features[idx]).float(),
            'substance_labels': torch.from_numpy(np.array(self.substance_labels[idx])).long(),
            'symptom_labels': torch.from_numpy(self.symptom_labels[idx]).float()
        }
    
    def __len__(self):
        return len(self.features)

# Ensure symptom columns exist in both dataframes
missing_train_cols = [col for col in symptom_columns if col not in train_df.columns]
missing_test_cols = [col for col in symptom_columns if col not in test_df.columns]

if missing_train_cols:
    print(f"Adding missing columns to train_df: {missing_train_cols}")
    for col in missing_train_cols:
        train_df[col] = 0

if missing_test_cols:
    print(f"Adding missing columns to test_df: {missing_test_cols}")
    for col in missing_test_cols:
        test_df[col] = 0

# Get symptom data
train_symptom_data = train_df[symptom_columns].values
test_symptom_data = test_df[symptom_columns].values

print(f"Train symptom data shape: {train_symptom_data.shape}")
print(f"Test symptom data shape: {test_symptom_data.shape}")

# Create datasets with memory management
import gc

# Clear any unnecessary variables
if 'X_train_tfidf_sparse' in locals():
    del X_train_tfidf_sparse
if 'X_test_tfidf_sparse' in locals():
    del X_test_tfidf_sparse
gc.collect()

try:
    print("Creating training dataset...")
    train_dataset = TFIDFDataset(
        X_train_tfidf,
        train_df['substance_label'].values,
        train_symptom_data
    )
    
    print("Creating test dataset...")
    test_dataset = TFIDFDataset(
        X_test_tfidf,
        test_df['substance_label'].values,
        test_symptom_data
    )
    
    print("Datasets created successfully!")
    
    # Verify dataset integrity
    sample = train_dataset[0]
    print(f"Sample data shapes - Features: {sample['x'].shape}, "
          f"Substance: {sample['substance_labels'].shape}, "
          f"Symptoms: {sample['symptom_labels'].shape}")
    
    # Clean up large arrays to free memory
    del X_train_tfidf, X_test_tfidf
    gc.collect()
    print("Memory cleanup completed!")
    
except Exception as e:
    print(f"Error creating datasets: {e}")
    print("Debugging information:")
    print(f"Available memory info:")
    import psutil
    memory = psutil.virtual_memory()
    print(f"Total RAM: {memory.total / 1024**3:.1f} GB")
    print(f"Available RAM: {memory.available / 1024**3:.1f} GB")
    print(f"Used RAM: {memory.percent}%")
    raise

print("TF-IDF processing completed successfully!")

4. Define Custom Model

BioBERT for multi-task classification.

In [None]:
class EnhancedMultiTaskModel(torch.nn.Module):
    def __init__(self, input_size, num_substance_classes, num_symptom_labels):
        super(EnhancedMultiTaskModel, self).__init__()
        
        # Input normalization
        self.input_norm = torch.nn.BatchNorm1d(input_size)
        
        # Enhanced architecture with residual connections
        self.hidden1 = torch.nn.Linear(input_size, 512)
        self.norm1 = torch.nn.BatchNorm1d(512)
        self.hidden2 = torch.nn.Linear(512, 256)
        self.norm2 = torch.nn.BatchNorm1d(256)
        self.hidden3 = torch.nn.Linear(256, 128)
        self.norm3 = torch.nn.BatchNorm1d(128)
        
        # Residual connection layer
        self.residual = torch.nn.Linear(input_size, 128)
        
        # Dropout with different rates
        self.dropout1 = torch.nn.Dropout(0.2)
        self.dropout2 = torch.nn.Dropout(0.3)
        self.dropout3 = torch.nn.Dropout(0.2)
        
        # Task-specific layers with attention
        self.substance_attention = torch.nn.Sequential(
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.Sigmoid()
        )
        
        self.symptom_attention = torch.nn.Sequential(
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.Sigmoid()
        )
        
        self.substance_classifier = torch.nn.Sequential(
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(64, num_substance_classes)
        )
        
        self.symptom_classifier = torch.nn.Sequential(
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(64, num_symptom_labels)
        )
        
        self.num_substance_classes = num_substance_classes
        self.num_symptom_labels = num_symptom_labels
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)
            elif isinstance(m, torch.nn.BatchNorm1d):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0)
    
    def forward(self, x, substance_labels=None, symptom_labels=None):
        # Input normalization
        x_norm = self.input_norm(x)
        
        # Forward pass through hidden layers
        hidden = torch.relu(self.hidden1(x_norm))
        hidden = self.norm1(hidden)
        hidden = self.dropout1(hidden)
        
        hidden = torch.relu(self.hidden2(hidden))
        hidden = self.norm2(hidden)
        hidden = self.dropout2(hidden)
        
        hidden = torch.relu(self.hidden3(hidden))
        hidden = self.norm3(hidden)
        
        # Residual connection
        residual = torch.relu(self.residual(x_norm))
        hidden = hidden + residual  # Add residual connection
        hidden = self.dropout3(hidden)
        
        # Task-specific attention
        substance_att = self.substance_attention(hidden)
        symptom_att = self.symptom_attention(hidden)
        
        # Apply attention
        substance_features = hidden * substance_att
        symptom_features = hidden * symptom_att
        
        # Generate logits
        substance_logits = self.substance_classifier(substance_features)
        symptom_logits = self.symptom_classifier(symptom_features)
        
        loss = None
        if substance_labels is not None and symptom_labels is not None:
            # Improved loss calculation
            
            # Focal loss for substance classification (better for imbalanced classes)
            alpha = 0.25
            gamma = 2.0
            
            # Standard cross entropy
            ce_loss = torch.nn.functional.cross_entropy(substance_logits, substance_labels, reduction='none')
            pt = torch.exp(-ce_loss)
            focal_loss = alpha * (1 - pt) ** gamma * ce_loss
            substance_loss = focal_loss.mean()
            
            # Class-balanced BCE loss for symptoms
            pos_counts = substance_labels.bincount(minlength=self.num_substance_classes).float()
            total_count = len(substance_labels)
            pos_weights = total_count / (2.0 * pos_counts + 1e-6)
            
            # For symptoms, use adaptive positive weights
            symptom_pos_counts = symptom_labels.sum(dim=0) + 1e-6
            symptom_neg_counts = (1 - symptom_labels).sum(dim=0) + 1e-6
            symptom_pos_weights = symptom_neg_counts / symptom_pos_counts
            symptom_pos_weights = torch.clamp(symptom_pos_weights, min=0.1, max=10.0)
            
            symptom_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                symptom_logits, 
                symptom_labels, 
                pos_weight=symptom_pos_weights
            )
            
            # Combine losses with adaptive weighting
            substance_weight = 0.7  # Higher weight for substance classification
            symptom_weight = 0.3
            
            loss = substance_weight * substance_loss + symptom_weight * symptom_loss
        
        return {
            'loss': loss, 
            'substance_logits': substance_logits, 
            'symptom_logits': symptom_logits,
            'substance_probs': torch.softmax(substance_logits, dim=-1),
            'symptom_probs': torch.sigmoid(symptom_logits)
        }

# Create model with proper input size
actual_input_size = X_train_tfidf.shape[1]  # Use actual TF-IDF feature size
print(f"Using input size: {actual_input_size}")

model = EnhancedMultiTaskModel(
    input_size=actual_input_size,
    num_substance_classes=len(substance_classes),
    num_symptom_labels=len(symptom_columns)
)

model.to(device)

# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model created with {total_params:,} total parameters")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model device: {next(model.parameters()).device}")

# Model summary
print("\nModel Architecture:")
print(f"Input size: {actual_input_size}")
print(f"Substance classes: {len(substance_classes)}")
print(f"Symptom labels: {len(symptom_columns)}")
print(f"Hidden layers: 512 -> 256 -> 128")
print("Features: Batch normalization, residual connections, attention mechanisms, focal loss")

5. Train Model


In [None]:
from sklearn.metrics import precision_score, recall_score
import gc
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
import matplotlib.pyplot as plt

# Custom training function since Transformers Trainer expects specific model structure
def train_model(model, train_dataset, test_dataset, num_epochs=10, batch_size=16, learning_rate=5e-4):
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=False)
    
    # Optimizer and scheduler
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.001)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
    
    # Training history
    history = {
        'train_loss': [], 'val_loss': [],
        'substance_acc': [], 'symptom_f1': [],
        'symptom_precision': [], 'symptom_recall': []
    }
    
    best_substance_acc = 0.0
    best_model_state = None
    
    print("Starting training...")
    print(f"Total epochs: {num_epochs}")
    print(f"Batch size: {batch_size}")
    print(f"Learning rate: {learning_rate}")
    print(f"Total training batches: {len(train_loader)}")
    print(f"Total validation batches: {len(test_loader)}")
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_batches = 0
        
        for batch_idx, batch in enumerate(train_loader):
            optimizer.zero_grad()
            
            # Move data to device
            x = batch['x'].to(device)
            substance_labels = batch['substance_labels'].to(device)
            symptom_labels = batch['symptom_labels'].to(device)
            
            # Forward pass
            outputs = model(x, substance_labels=substance_labels, symptom_labels=symptom_labels)
            loss = outputs['loss']
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item()
            train_batches += 1
            
            # Log progress every 50 batches
            if batch_idx % 50 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, "
                      f"Loss: {loss.item():.4f}")
        
        avg_train_loss = train_loss / train_batches
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_batches = 0
        all_substance_preds = []
        all_substance_labels = []
        all_symptom_preds = []
        all_symptom_labels = []
        
        with torch.no_grad():
            for batch in test_loader:
                x = batch['x'].to(device)
                substance_labels = batch['substance_labels'].to(device)
                symptom_labels = batch['symptom_labels'].to(device)
                
                outputs = model(x, substance_labels=substance_labels, symptom_labels=symptom_labels)
                loss = outputs['loss']
                
                val_loss += loss.item()
                val_batches += 1
                
                # Collect predictions
                substance_preds = torch.argmax(outputs['substance_logits'], dim=1)
                symptom_preds = (torch.sigmoid(outputs['symptom_logits']) > 0.5).float()
                
                all_substance_preds.extend(substance_preds.cpu().numpy())
                all_substance_labels.extend(substance_labels.cpu().numpy())
                all_symptom_preds.extend(symptom_preds.cpu().numpy())
                all_symptom_labels.extend(symptom_labels.cpu().numpy())
        
        avg_val_loss = val_loss / val_batches
        
        # Calculate metrics
        all_substance_preds = np.array(all_substance_preds)
        all_substance_labels = np.array(all_substance_labels)
        all_symptom_preds = np.array(all_symptom_preds)
        all_symptom_labels = np.array(all_symptom_labels)
        
        substance_accuracy = accuracy_score(all_substance_labels, all_substance_preds)
        symptom_f1 = f1_score(all_symptom_labels, all_symptom_preds, average='micro', zero_division=0)
        symptom_precision = precision_score(all_symptom_labels, all_symptom_preds, average='micro', zero_division=0)
        symptom_recall = recall_score(all_symptom_labels, all_symptom_preds, average='micro', zero_division=0)
        
        # Update learning rate
        scheduler.step()
        
        # Save best model
        if substance_accuracy > best_substance_acc:
            best_substance_acc = substance_accuracy
            best_model_state = model.state_dict().copy()
            print(f"New best model saved! Substance accuracy: {best_substance_acc:.4f}")
        
        # Store history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['substance_acc'].append(substance_accuracy)
        history['symptom_f1'].append(symptom_f1)
        history['symptom_precision'].append(symptom_precision)
        history['symptom_recall'].append(symptom_recall)
        
        # Print epoch results
        print(f"\nEpoch {epoch+1}/{num_epochs} Results:")
        print(f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        print(f"Substance Accuracy: {substance_accuracy:.4f}")
        print(f"Symptom F1: {symptom_f1:.4f}, Precision: {symptom_precision:.4f}, Recall: {symptom_recall:.4f}")
        print(f"Learning Rate: {scheduler.get_last_lr()[0]:.6f}")
        print("-" * 60)
        
        # Clear cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"\nLoaded best model with substance accuracy: {best_substance_acc:.4f}")
    
    return model, history

# Enhanced evaluation function
def evaluate_model(model, test_dataset, batch_size=16):
    model.eval()
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    all_substance_preds = []
    all_substance_labels = []
    all_symptom_preds = []
    all_symptom_labels = []
    all_substance_probs = []
    all_symptom_probs = []
    
    with torch.no_grad():
        for batch in test_loader:
            x = batch['x'].to(device)
            substance_labels = batch['substance_labels'].to(device)
            symptom_labels = batch['symptom_labels'].to(device)
            
            outputs = model(x)
            
            substance_preds = torch.argmax(outputs['substance_logits'], dim=1)
            symptom_preds = (torch.sigmoid(outputs['symptom_logits']) > 0.5).float()
            
            all_substance_preds.extend(substance_preds.cpu().numpy())
            all_substance_labels.extend(substance_labels.cpu().numpy())
            all_symptom_preds.extend(symptom_preds.cpu().numpy())
            all_symptom_labels.extend(symptom_labels.cpu().numpy())
            all_substance_probs.extend(torch.softmax(outputs['substance_logits'], dim=1).cpu().numpy())
            all_symptom_probs.extend(torch.sigmoid(outputs['symptom_logits']).cpu().numpy())
    
    return {
        'substance_preds': np.array(all_substance_preds),
        'substance_labels': np.array(all_substance_labels),
        'symptom_preds': np.array(all_symptom_preds),
        'symptom_labels': np.array(all_symptom_labels),
        'substance_probs': np.array(all_substance_probs),
        'symptom_probs': np.array(all_symptom_probs)
    }

# Clear memory before training
torch.cuda.empty_cache() if torch.cuda.is_available() else None
gc.collect()

# Train the model
print("Starting model training...")
trained_model, training_history = train_model(
    model=model,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    num_epochs=10,  # Increased epochs for better convergence
    batch_size=16,  # Larger batch size for stability
    learning_rate=5e-4  # Optimized learning rate
)

print("\nTraining completed!")

6. Evaluate Model

Evaluate and print results.

In [None]:
# Replace the trainer evaluation code with this:

# Evaluate the trained model
print("\nEvaluating trained model...")
eval_results = evaluate_model(trained_model, test_dataset, batch_size=16)

# Print evaluation metrics
substance_accuracy = accuracy_score(eval_results['substance_labels'], eval_results['substance_preds'])
symptom_f1 = f1_score(eval_results['symptom_labels'], eval_results['symptom_preds'], average='micro', zero_division=0)
symptom_precision = precision_score(eval_results['symptom_labels'], eval_results['symptom_preds'], average='micro', zero_division=0)
symptom_recall = recall_score(eval_results['symptom_labels'], eval_results['symptom_preds'], average='micro', zero_division=0)

print(f'Final Evaluation Results:')
print(f'Substance Accuracy: {substance_accuracy:.4f}')
print(f'Symptom F1 Score: {symptom_f1:.4f}')
print(f'Symptom Precision: {symptom_precision:.4f}')
print(f'Symptom Recall: {symptom_recall:.4f}')

# Get predictions for classification reports
substance_preds = eval_results['substance_preds']
symptom_preds = eval_results['symptom_preds']

# Make sure you have these variables defined (they should be from your data preprocessing)
# If not, you'll need to extract them from your datasets
print('\nSubstance Classification Report:')
print(classification_report(eval_results['substance_labels'], substance_preds, 
                          target_names=substance_classes, zero_division=0))

print('\nSymptom Classification Report:')
print(classification_report(eval_results['symptom_labels'], symptom_preds, 
                          target_names=symptom_columns, zero_division=0))

# Optional: Plot training history
def plot_training_history(history):
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss plots
    axes[0, 0].plot(history['train_loss'], label='Training Loss')
    axes[0, 0].plot(history['val_loss'], label='Validation Loss')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Substance accuracy
    axes[0, 1].plot(history['substance_acc'], label='Substance Accuracy', color='green')
    axes[0, 1].set_title('Substance Classification Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Symptom F1 score
    axes[1, 0].plot(history['symptom_f1'], label='Symptom F1', color='orange')
    axes[1, 0].set_title('Symptom Classification F1 Score')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('F1 Score')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    # Symptom precision and recall
    axes[1, 1].plot(history['symptom_precision'], label='Precision', color='red')
    axes[1, 1].plot(history['symptom_recall'], label='Recall', color='blue')
    axes[1, 1].set_title('Symptom Precision and Recall')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Score')
    axes[1, 1].legend()
    axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.show()

# Plot the training history
plot_training_history(training_history)

# Optional: Save the trained model
# torch.save(trained_model.state_dict(), 'best_model.pth')
# print("Model saved as 'best_model.pth'")

7. Save Model

In [None]:
import pickle
import torch

# Save the trained model (use the trained_model from your custom training loop)
torch.save(trained_model.state_dict(), './tfidf_drug_use_model.pt')
print('Model state dict saved to ./tfidf_drug_use_model.pt')

# Save the TF-IDF vectorizer
with open('./tfidf_vectorizer.pkl', 'wb') as f:
    pickle.dump(vectorizer, f)
print('TF-IDF vectorizer saved to ./tfidf_vectorizer.pkl')

# Optional: Save additional model information for easier loading later
model_info = {
    'model_state_dict': trained_model.state_dict(),
    'model_config': {
        'tfidf_dim': trained_model.tfidf_dim,
        'hidden_dim': trained_model.hidden_dim,
        'num_substances': trained_model.num_substances,
        'num_symptoms': trained_model.num_symptoms,
        'dropout_rate': trained_model.dropout_rate
    },
    'substance_classes': substance_classes,  # Make sure this variable exists
    'symptom_columns': symptom_columns,     # Make sure this variable exists
    'training_history': training_history
}

torch.save(model_info, './complete_model_info.pt')
print('Complete model information saved to ./complete_model_info.pt')

print('All files saved successfully!')

# Example of how to load the model later:
def load_trained_model(model_path, vectorizer_path, device='cpu'):
    """
    Function to load the saved model and vectorizer
    """
    # Load vectorizer
    with open(vectorizer_path, 'rb') as f:
        loaded_vectorizer = pickle.load(f)
    
    # Load complete model info
    model_info = torch.load(model_path, map_location=device)
    
    # Recreate model (you'll need to import your DrugUseClassifier class)
    # loaded_model = DrugUseClassifier(
    #     tfidf_dim=model_info['model_config']['tfidf_dim'],
    #     hidden_dim=model_info['model_config']['hidden_dim'],
    #     num_substances=model_info['model_config']['num_substances'],
    #     num_symptoms=model_info['model_config']['num_symptoms'],
    #     dropout_rate=model_info['model_config']['dropout_rate']
    # )
    
    # Load the trained weights
    # loaded_model.load_state_dict(model_info['model_state_dict'])
    # loaded_model.to(device)
    # loaded_model.eval()
    
    return loaded_vectorizer, model_info

# Uncomment and use this to test loading:
# loaded_vectorizer, loaded_model_info = load_trained_model('./complete_model_info.pt', './tfidf_vectorizer.pkl')