In [1]:
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import (
    BartForSequenceClassification,
    BartTokenizer,
    TrainingArguments,
    Trainer
)
from datasets import Dataset
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# 1. Load and clean data
df = pd.read_csv("new_updated_data.csv")
df['description'] = df['description'].fillna('').astype(str)
df['tags'] = df['tags'].apply(lambda x: x.split(', ') if isinstance(x, str) else [])

# 2. Filter out empty tags and prepare labels
df = df[df['tags'].apply(len) > 0]  # Remove samples with no tags
all_tags = sorted(list(set(tag for tags in df['tags'] for tag in tags)))
num_classes = len(all_tags)
mlb = MultiLabelBinarizer(classes=all_tags)

# 3. Train-test split
train_df, eval_df = train_test_split(df, test_size=0.2, random_state=42)

# 4. Initialize tokenizer and model
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForSequenceClassification.from_pretrained(
    'facebook/bart-base',
    num_labels=num_classes,
    problem_type="multi_label_classification"
)

# 5. Dataset preparation with proper label handling
def prepare_dataset(df, mlb, is_train=False):
    # Get binary labels
    labels = mlb.fit_transform(df['tags']) if is_train else mlb.transform(df['tags'])
    
    # Convert to float32 for PyTorch
    labels = labels.astype(np.float32)
    
    return Dataset.from_dict({
        'text': df['description'].tolist(),
        'labels': labels.tolist()
    })

# Prepare datasets
train_dataset = prepare_dataset(train_df, mlb, is_train=True)
eval_dataset = prepare_dataset(eval_df, mlb)

# 6. Tokenization with proper handling
def tokenize_function(examples):
    tokenized = tokenizer(
        examples['text'],
        padding='max_length',
        truncation=True,
        max_length=512,
        return_tensors="pt"  # Return PyTorch tensors directly
    )
    
    # Convert labels to tensor
    tokenized['labels'] = torch.tensor(examples['labels'], dtype=torch.float)
    return tokenized

# Apply tokenization with batched processing
tokenized_train = train_dataset.map(tokenize_function, batched=True, batch_size=32)
tokenized_eval = eval_dataset.map(tokenize_function, batched=True, batch_size=32)

# Set format for PyTorch
tokenized_train.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
tokenized_eval.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

# 7. Verify data
print("\nData Verification:")
print(f"Number of classes: {num_classes}")
print(f"Sample tags: {train_df.iloc[0]['tags']}")
print(f"Binarized labels: {tokenized_train[0]['labels'][:20]}...")  # First 20 labels
print(f"Number of active tags: {sum(tokenized_train[0]['labels'])}")
print(f"Input shape: {tokenized_train[0]['input_ids'].shape}")  # Should be (512,)

# 8. Training setup with adjusted parameters
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    learning_rate=3e-5,
    weight_decay=0.01,
    logging_steps=10,
    save_strategy="epoch",
    load_best_model_at_end=True,
    fp16=torch.cuda.is_available(),
    gradient_accumulation_steps=2,  # Helps with small batch sizes
    report_to="none"  # Disables wandb if not needed
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # Move to CPU if needed
    if isinstance(logits, tuple):
        logits = logits[0]
    
    # Convert to probabilities and then to binary predictions
    preds = torch.sigmoid(torch.tensor(logits)) > 0.5
    labels = torch.tensor(labels)
    
    # Calculate metrics
    tp = (labels * preds).sum().float()
    fp = preds.sum().float() - tp
    fn = labels.sum().float() - tp
    
    precision = tp / (tp + fp + 1e-10)
    recall = tp / (tp + fn + 1e-10)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
    accuracy = (preds == labels).float().mean()
    
    return {
        'accuracy': accuracy.item(),
        'precision': precision.item(),
        'recall': recall.item(),
        'f1': f1.item()
    }

# 9. Create trainer with error handling
try:
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_eval,
        compute_metrics=compute_metrics,
    )

    # 10. Train
    print("\nStarting training...")
    trainer.train()

    # 11. Save everything
    model.save_pretrained("./fine_tuned_bart_book_tags")
    tokenizer.save_pretrained("./fine_tuned_bart_book_tags")
    import joblib
    joblib.dump(mlb, 'label_binarizer.pkl')
    
except Exception as e:
    print(f"Error during training: {str(e)}")
    # Print more debug info if error occurs
    print("\nDebug Info:")
    print(f"Model device: {next(model.parameters()).device}")
    print(f"Sample input shape: {tokenized_train[0]['input_ids'].shape}")
    print(f"Sample labels shape: {tokenized_train[0]['labels'].shape}")
    print(f"Number of samples: {len(tokenized_train)}")

  from .autonotebook import tqdm as notebook_tqdm
Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|██████████| 3376/3376 [00:03<00:00, 876.76 examples/s]
Map: 100%|██████████| 844/844 [00:01<00:00, 767.11 examples/s]



Data Verification:
Number of classes: 776
Sample tags: ['Classics', 'Biography', 'Biography & Autobiography', 'Memoir', 'Language Arts & Disciplines', ' American']
Binarized labels: tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.])...
Number of active tags: 6.0
Input shape: torch.Size([512])

Starting training...


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.0348,0.035196,0.994319,0.552133,0.122342,0.200301
2,0.0265,0.026215,0.994319,0.552133,0.122342,0.200301
3,0.0233,0.02473,0.994319,0.552133,0.122342,0.200301


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight'].
