In [2]:
import json
import pandas as pd
from collections import Counter
import os

# Step 1: Load the arXiv dataset
input_path = '/kaggle/input/arxiv/arxiv-metadata-oai-snapshot.json'
data = []
print("Loading arXiv JSON lines dataset...")
with open(input_path, 'r') as f:
    for line in f:
        data.append(json.loads(line))
print(f"Total papers loaded: {len(data)}")

# Step 2: Convert to DataFrame
df = pd.DataFrame(data)
print("Available columns:")
print(df.columns.tolist())

# Step 3: Count papers per category
category_counts = Counter()
for cats in df['categories']:
    for cat in cats.split():
        category_counts[cat] += 1

category_df = pd.DataFrame(category_counts.items(), columns=['Category', 'Count'])
category_df = category_df.sort_values('Count', ascending=False)

# Step 4: Filter categories with >5,000 papers
popular_categories = category_df[category_df['Count'] > 5000]['Category'].tolist()
print(f"\nCategories with > 5,000 papers: {len(popular_categories)}")
print(category_df[category_df['Category'].isin(popular_categories)].head(20))

# Step 5: Filter papers that have at least one popular category
def has_popular_category(category_str):
    return any(cat in popular_categories for cat in category_str.split())

filtered_df = df[df['categories'].apply(has_popular_category)]
print(f"\nFiltered papers: {len(filtered_df)}")

# Step 6: Keep only relevant columns and enrich text
filtered_df = filtered_df[['title', 'abstract', 'categories']]
filtered_df = filtered_df.dropna()
filtered_df['title_abstract'] = filtered_df['title'] + " " + filtered_df['abstract']

# Step 7: Save outputs
output_dir = '/kaggle/working'
os.makedirs(output_dir, exist_ok=True)

filtered_csv_path = os.path.join(output_dir, 'filtered_arxiv_papers.csv')
stats_csv_path = os.path.join(output_dir, 'category_statistics.csv')

filtered_df.to_csv(filtered_csv_path, index=False)
category_df.to_csv(stats_csv_path, index=False)

# Step 8: Print summary
print("\n Processing Complete")
print(f"Original total papers: {len(df)}")
print(f"Filtered papers: {len(filtered_df)}")
print(f"Filtered dataset saved to: {filtered_csv_path}")
print(f"Category statistics saved to: {stats_csv_path}")
print("\nSample rows:")
print(filtered_df.head())

Loading arXiv JSON lines dataset...
Total papers loaded: 2730173
Available columns:
['id', 'submitter', 'authors', 'title', 'comments', 'journal-ref', 'doi', 'report-no', 'categories', 'license', 'abstract', 'versions', 'update_date', 'authors_parsed']

Categories with > 5,000 papers: 128
               Category   Count
96                cs.LG  215824
0                hep-ph  187339
13               hep-th  173525
27             quant-ph  161515
114               cs.CV  154368
42                cs.AI  124844
7                 gr-qc  113475
9              astro-ph  105380
8     cond-mat.mtrl-sci   99501
6     cond-mat.mes-hall   95273
34              math.MP   83934
33              math-ph   83934
126               cs.CL   83033
20      cond-mat.str-el   77684
21   cond-mat.stat-mech   76808
136         astro-ph.CO   71512
1               math.CO   71159
110             stat.ML   70661
144         astro-ph.GA   69844
66              math.AP   67149

Filtered papers: 2702634

 Processing

In [3]:
import pandas as pd
import os
import csv

# File paths
file_path = '/kaggle/working/filtered_arxiv_papers.csv'
stats_path = '/kaggle/working/category_statistics.csv'

# Check file
print(f"Checking for {file_path}...")
if not os.path.exists(file_path):
    raise FileNotFoundError(f"File not found: {file_path}")

# Load dataset
print("Loading filtered_arxiv_papers.csv...")
df = pd.read_csv(file_path)
print(f"Loaded {len(df)} rows")
print(f"Columns: {df.columns.tolist()}")

# Enrich text
if 'title_abstract' not in df.columns:
    print("Enriching text column with title + abstract...")
    df['title_abstract'] = df['title'].astype(str) + " " + df['abstract'].astype(str)

# Load and filter category stats for math.* only
print("Loading category statistics...")
stats_df = pd.read_csv(stats_path)
stats_df = stats_df[stats_df['Category'].str.startswith('math.')]
stats_df = stats_df.sort_values('Count', ascending=False)

top_categories = stats_df['Category'].tolist()
print(f"\nUsing ALL {len(top_categories)} math categories:")
print(top_categories)

Checking for /kaggle/working/filtered_arxiv_papers.csv...
Loading filtered_arxiv_papers.csv...
Loaded 2702634 rows
Columns: ['title', 'abstract', 'categories', 'title_abstract']
Loading category statistics...

Using ALL 32 math categories:
['math.MP', 'math.CO', 'math.AP', 'math.PR', 'math.AG', 'math.OC', 'math.IT', 'math.NT', 'math.DG', 'math.NA', 'math.DS', 'math.FA', 'math.RT', 'math.ST', 'math.GT', 'math.GR', 'math.CA', 'math.QA', 'math.RA', 'math.CV', 'math.AT', 'math.LO', 'math.AC', 'math.OA', 'math.MG', 'math.SP', 'math.SG', 'math.CT', 'math.KT', 'math.GN', 'math.GM', 'math.HO']


In [4]:
# Map labels
label_map = {cat: idx for idx, cat in enumerate(top_categories)}

# Extract primary math category
def extract_primary_category(cat_str):
    if pd.isna(cat_str):
        return None
    for cat in cat_str.split():
        if cat.startswith('math.') and cat in label_map:
            return cat
    return None

# Apply extraction
print("Mapping categories to labels...")
df['category'] = df['categories'].apply(extract_primary_category)
df['label'] = df['category'].map(label_map)

# Filter valid rows
df = df[df['label'].notnull()].copy()
print(df)

Mapping categories to labels...
                                                     title  \
1                 Sparsity-certifying Graph Decompositions   
3        A determinant of Stirling cycle numbers counts...   
4        From dyadic $\Lambda_{\alpha}$ to $\Lambda_{\a...   
9        Partial cubes: structures, characterizations, ...   
10       Computing genus 2 Hilbert-Siegel modular forms...   
...                                                    ...   
2702530  Yang-Baxter Algebra for the n-Harmonic Oscilla...   
2702543  Integrable deformations of oscillator chains f...   
2702544  A note on real forms of the complex N=4 supers...   
2702558  Real forms of the complex twisted N=2 supersym...   
2702563  Vector NLS hierarchy solitons revisited: dress...   

                                                  abstract  \
1          We describe a new algorithm, the $(k,\ell)$-...   
3          We show that a determinant of Stirling cycle...   
4          In this paper we show how 

In [5]:
import pandas as pd
import os
import csv

# File paths
file_path = '/kaggle/working/filtered_arxiv_papers.csv'
stats_path = '/kaggle/working/category_statistics.csv'
output_path = '/kaggle/working/math_dataset_labeled.csv'
label_map_path = '/kaggle/working/label_map.csv'

# Check file existence
print(f"Checking for {file_path}...")
if not os.path.exists(file_path):
    raise FileNotFoundError(f"File not found: {file_path}")

# Load the dataset
print("Loading filtered_arxiv_papers.csv...")
df = pd.read_csv(file_path)
print(f"Loaded {len(df)} rows")
print(f"Columns: {df.columns.tolist()}")

# Enrich title_abstract if not present
if 'title_abstract' not in df.columns:
    print("Creating 'title_abstract' from title + abstract...")
    df['title'] = df['title'].fillna('')
    df['abstract'] = df['abstract'].fillna('')
    df['title_abstract'] = df['title'] + " " + df['abstract']

# Load category stats and filter math.* categories
print("Loading and filtering category statistics...")
stats_df = pd.read_csv(stats_path)
stats_df = stats_df[stats_df['Category'].str.startswith('math.')]
stats_df = stats_df.sort_values('Count', ascending=False)

top_categories = stats_df['Category'].tolist()
print(f"\nUsing ALL {len(top_categories)} math categories:")
print(top_categories)

# Create label mapping
label_map = {cat: idx for idx, cat in enumerate(top_categories)}

# Function to extract primary math category
def extract_primary_category(cat_str):
    if pd.isna(cat_str):
        return None
    for cat in cat_str.split():
        if cat.startswith('math.') and cat in label_map:
            return cat
    return None

# Apply category extraction and label mapping
print("Mapping categories to labels...")
df['category'] = df['categories'].apply(extract_primary_category)
df['label'] = df['category'].map(label_map)

# Filter out rows without math category
df = df[df['label'].notnull()].copy()

# Save label map to CSV
pd.Series(label_map).to_csv(label_map_path, header=['Label'], index_label='Category')
print(f"Saved label mapping to {label_map_path}")

# Drop unneeded columns
df = df.drop(columns=['title','abstract'])

# Save final labeled dataset
df.to_csv(output_path, index=False)
print(f"Saved cleaned dataset to {output_path}")

# Reload and print category distribution
print("\n--- Category Distribution ---")
df_loaded = pd.read_csv(output_path)
category_counts = df_loaded['category'].value_counts()

for category, count in category_counts.items():
    label = df_loaded[df_loaded['category'] == category]['label'].iloc[0]
    print(f"Category: {category} | Label: {label} | Papers: {count}")

print("\nTotal unique categories:", df_loaded['category'].nunique())
print("Total number of papers:", len(df_loaded))

Checking for /kaggle/working/filtered_arxiv_papers.csv...
Loading filtered_arxiv_papers.csv...
Loaded 2702634 rows
Columns: ['title', 'abstract', 'categories', 'title_abstract']
Loading and filtering category statistics...

Using ALL 32 math categories:
['math.MP', 'math.CO', 'math.AP', 'math.PR', 'math.AG', 'math.OC', 'math.IT', 'math.NT', 'math.DG', 'math.NA', 'math.DS', 'math.FA', 'math.RT', 'math.ST', 'math.GT', 'math.GR', 'math.CA', 'math.QA', 'math.RA', 'math.CV', 'math.AT', 'math.LO', 'math.AC', 'math.OA', 'math.MG', 'math.SP', 'math.SG', 'math.CT', 'math.KT', 'math.GN', 'math.GM', 'math.HO']
Mapping categories to labels...
Saved label mapping to /kaggle/working/label_map.csv
Saved cleaned dataset to /kaggle/working/math_dataset_labeled.csv

--- Category Distribution ---
Category: math.AP | Label: 2.0 | Papers: 56123
Category: math.CO | Label: 1.0 | Papers: 55201
Category: math.MP | Label: 0.0 | Papers: 51329
Category: math.OC | Label: 5.0 | Papers: 46835
Category: math.IT | Lab

In [12]:
print("\nTrain label distribution:")
print(train_df['label'].value_counts().sort_index())

print("\nVal label distribution:")
print(val_df['label'].value_counts().sort_index())

print("\nTest label distribution:")
print(test_df['label'].value_counts().sort_index())



Train label distribution:
label
0     2400
1     2400
2     2400
3     2400
4     2400
5     2400
6     2400
7     2400
8     2400
9     2400
10    2400
11    2400
12    2400
13    2400
14    2400
15    2400
16    2400
17    2400
18    2400
19    2400
20    2400
21    2400
22    2400
23    2400
24    2400
25    2400
26    2400
27    2400
28    2400
Name: count, dtype: int64

Val label distribution:
label
0     300
1     300
2     300
3     300
4     300
5     300
6     300
7     300
8     300
9     300
10    300
11    300
12    300
13    300
14    300
15    300
16    300
17    300
18    300
19    300
20    300
21    300
22    300
23    300
24    300
25    300
26    300
27    300
28    300
Name: count, dtype: int64

Test label distribution:
label
0     300
1     300
2     300
3     300
4     300
5     300
6     300
7     300
8     300
9     300
10    300
11    300
12    300
13    300
14    300
15    300
16    300
17    300
18    300
19    300
20    300
21    300
22    300
23    300
24 

In [13]:
import pandas as pd
from sklearn.model_selection import train_test_split
import os

# Load the dataset
df = pd.read_csv('/kaggle/working/math_dataset_labeled.csv')

# Check the columns
print(df.columns)

# Balance parameters
rows_per_category = 3000
top_k = 10
print(f"\nBalancing: {rows_per_category} rows per category for top {top_k} categories...")

# Get top K categories
top_categories = df['category'].value_counts().head(top_k).index.tolist()

# Collect balanced data
filtered_list = []
skipped = []

for cat in top_categories:
    cat_df = df[df['category'] == cat]
    available = len(cat_df)
    if available >= rows_per_category:
        filtered_list.append(cat_df.sample(n=rows_per_category, random_state=42))
    else:
        skipped.append((cat, available))
        print(f"⚠ Skipping {cat}: only {available} papers (needs ≥ {rows_per_category})")

# Combine into a new DataFrame
if not filtered_list:
    raise ValueError("No categories had enough papers to sample from!")

balanced_df = pd.concat(filtered_list).reset_index(drop=True)

# Create label map
used_categories = sorted(balanced_df['category'].unique())
label_map = {cat: i for i, cat in enumerate(used_categories)}
balanced_df['label'] = balanced_df['category'].map(label_map)

# Print label mapping
print("\nLabel mapping:")
for cat, idx in label_map.items():
    count = (balanced_df['category'] == cat).sum()
    print(f" Category: {cat:<20} | Label: {idx:<2} | Papers: {count}")

# Split into train/val/test (80/10/10)
print("\nSplitting dataset...")
train_df, temp_df = train_test_split(
    balanced_df, test_size=0.2, stratify=balanced_df['label'], random_state=42
)
val_df, test_df = train_test_split(
    temp_df, test_size=0.5, stratify=temp_df['label'], random_state=42
)

# Save CSVs
output_dir = '/kaggle/working/'
train_df.to_csv(os.path.join(output_dir, 'train.csv'), index=False)
val_df.to_csv(os.path.join(output_dir, 'val.csv'), index=False)
test_df.to_csv(os.path.join(output_dir, 'test.csv'), index=False)

# Summary
print(f"\nSaved train/val/test splits:")
print(f" Train size: {len(train_df)}")
print(f" Val size:   {len(val_df)}")
print(f" Test size:  {len(test_df)}")

print("\nTrain label distribution:")
print(train_df['label'].value_counts().sort_index())

print("\nVal label distribution:")
print(val_df['label'].value_counts().sort_index())

print("\nTest label distribution:")
print(test_df['label'].value_counts().sort_index())


Index(['categories', 'title_abstract', 'category', 'label'], dtype='object')

Balancing: 3000 rows per category for top 10 categories...

Label mapping:
 Category: math.AG              | Label: 0  | Papers: 3000
 Category: math.AP              | Label: 1  | Papers: 3000
 Category: math.CO              | Label: 2  | Papers: 3000
 Category: math.DG              | Label: 3  | Papers: 3000
 Category: math.IT              | Label: 4  | Papers: 3000
 Category: math.MP              | Label: 5  | Papers: 3000
 Category: math.NA              | Label: 6  | Papers: 3000
 Category: math.NT              | Label: 7  | Papers: 3000
 Category: math.OC              | Label: 8  | Papers: 3000
 Category: math.PR              | Label: 9  | Papers: 3000

Splitting dataset...

Saved train/val/test splits:
 Train size: 24000
 Val size:   3000
 Test size:  3000

Train label distribution:
label
0    2400
1    2400
2    2400
3    2400
4    2400
5    2400
6    2400
7    2400
8    2400
9    2400
Name: count, dtyp

In [22]:
import pandas as pd
from transformers import LongformerTokenizerFast
import torch
import pickle
import logging
import os

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('/kaggle/working/step2_log.txt'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)  # Fixed _name_ to __name__

# Define input/output directories
input_dir = '/kaggle/working/'
output_dir = '/kaggle/working/'

# Verify train.csv and val.csv exist
train_path = os.path.join(input_dir, 'train.csv')
val_path = os.path.join(input_dir, 'val.csv')

print(f"Checking for train.csv at {train_path}...")
print(f"Checking for val.csv at {val_path}...")
if not os.path.exists(train_path) or not os.path.exists(val_path):
    logger.error("train.csv or val.csv not found in /kaggle/working/.")
    raise FileNotFoundError("train.csv or val.csv not found. Please run Step 1 first.")

# Load datasets
print("Loading datasets...")
logger.info("Loading datasets...")
train_df = pd.read_csv(train_path)
val_df = pd.read_csv(val_path)
print(f"Loaded {len(train_df)} rows in train set")
print(f"Loaded {len(val_df)} rows in val set")

# Load tokenizer
print("Loading Longformer tokenizer...")
logger.info("Loading Longformer tokenizer...")
try:
    tokenizer = LongformerTokenizerFast.from_pretrained('allenai/longformer-base-4096')
    tokenizer.model_max_length = 1024  # Adjust to match your previous setup
except Exception as e:
    logger.error(f"Error loading tokenizer: {e}")
    raise

# Tokenization function
def tokenize_data(df, max_length=1024):
    texts = df['title_abstract'].tolist()
    labels = df['label'].tolist()
    encodings = tokenizer(
        texts,
        truncation=True,
        padding=True,
        max_length=max_length,
        return_tensors='pt'
    )
    return {
        'input_ids': encodings['input_ids'],
        'attention_mask': encodings['attention_mask'],
        'labels': torch.tensor(labels)
    }

# Tokenize in batches
batch_size = 100
train_tokenized = []
val_tokenized = []

print("Tokenizing training data...")
logger.info("Tokenizing training data...")
for i in range(0, len(train_df), batch_size):
    batch_df = train_df[i:i + batch_size]
    tokenized_batch = tokenize_data(batch_df)
    train_tokenized.append(tokenized_batch)
    print(f"Tokenized train batch {i//batch_size + 1}/{len(train_df)//batch_size + 1}")
    logger.info(f"Tokenized train batch {i//batch_size + 1}/{len(train_df)//batch_size + 1}")

print("Tokenizing validation data...")
logger.info("Tokenizing validation data...")
for i in range(0, len(val_df), batch_size):
    batch_df = val_df[i:i + batch_size]
    tokenized_batch = tokenize_data(batch_df)
    val_tokenized.append(tokenized_batch)
    print(f"Tokenized val batch {i//batch_size + 1}/{len(val_df)//batch_size + 1}")
    logger.info(f"Tokenized val batch {i//batch_size + 1}/{len(val_df)//batch_size + 1}")

# Save tokenized datasets
print(f"Saving tokenized datasets to {output_dir}...")
logger.info(f"Saving tokenized datasets to {output_dir}...")
with open(os.path.join(output_dir, 'train_tokenized.pkl'), 'wb') as f:
    pickle.dump(train_tokenized, f)
with open(os.path.join(output_dir, 'val_tokenized.pkl'), 'wb') as f:
    pickle.dump(val_tokenized, f)

print(f"Train tokenized: {len(train_tokenized)} batches, Val tokenized: {len(val_tokenized)} batches")
logger.info(f"Train tokenized: {len(train_tokenized)} batches, Val tokenized: {len(val_tokenized)} batches")
print("Step 2 complete.")
logger.info("Step 2 complete.")

Checking for train.csv at /kaggle/working/train.csv...
Checking for val.csv at /kaggle/working/val.csv...
Loading datasets...


loading file vocab.json from cache at /root/.cache/huggingface/hub/models--allenai--longformer-base-4096/snapshots/301e6a42cb0d9976a6d6a26a079fef81c18aa895/vocab.json
loading file merges.txt from cache at /root/.cache/huggingface/hub/models--allenai--longformer-base-4096/snapshots/301e6a42cb0d9976a6d6a26a079fef81c18aa895/merges.txt
loading file tokenizer.json from cache at /root/.cache/huggingface/hub/models--allenai--longformer-base-4096/snapshots/301e6a42cb0d9976a6d6a26a079fef81c18aa895/tokenizer.json
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at None
loading file chat_template.jinja from cache at None
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--allenai--longformer-base-4096/snapshots/301e6a42cb0d9976a6d6a26a079fef81c18aa895/config.json
Model config LongformerConfig {
  "attention_mode": "longformer",
  "attention_probs_dropout_pro

Loaded 24000 rows in train set
Loaded 3000 rows in val set
Loading Longformer tokenizer...
Tokenizing training data...
Tokenized train batch 1/241
Tokenized train batch 2/241
Tokenized train batch 3/241
Tokenized train batch 4/241
Tokenized train batch 5/241
Tokenized train batch 6/241
Tokenized train batch 7/241
Tokenized train batch 8/241
Tokenized train batch 9/241
Tokenized train batch 10/241
Tokenized train batch 11/241
Tokenized train batch 12/241
Tokenized train batch 13/241
Tokenized train batch 14/241
Tokenized train batch 15/241
Tokenized train batch 16/241
Tokenized train batch 17/241
Tokenized train batch 18/241
Tokenized train batch 19/241
Tokenized train batch 20/241
Tokenized train batch 21/241
Tokenized train batch 22/241
Tokenized train batch 23/241
Tokenized train batch 24/241
Tokenized train batch 25/241
Tokenized train batch 26/241
Tokenized train batch 27/241
Tokenized train batch 28/241
Tokenized train batch 29/241
Tokenized train batch 30/241
Tokenized train batc

In [27]:
import os
import glob
import re
import torch
import numpy as np
import pickle
import gc
import logging
import time
import pandas as pd
from datasets import Dataset
from sklearn.metrics import precision_recall_fscore_support
from transformers import (
    LongformerForSequenceClassification,
    LongformerTokenizerFast,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding
)

# ----------------------------
# Setup
# ----------------------------
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.backends.cuda.matmul.allow_tf32 = True

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('/kaggle/working/training_log.txt'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

input_dir = '/kaggle/working/'
train_tokenized_path = os.path.join(input_dir, 'train_tokenized.pkl')
val_tokenized_path = os.path.join(input_dir, 'val_tokenized.pkl')
results_dir = os.path.join(input_dir, 'results')
os.makedirs(results_dir, exist_ok=True)

# ----------------------------
# Load Checkpoint
# ----------------------------
def get_latest_checkpoint(results_dir):
    checkpoint_dirs = glob.glob(os.path.join(results_dir, 'checkpoint-*'))
    if not checkpoint_dirs:
        return None
    checkpoint_nums = [int(re.search(r'checkpoint-(\d+)', d).group(1)) for d in checkpoint_dirs]
    return os.path.join(results_dir, f'checkpoint-{max(checkpoint_nums)}')

checkpoint_path = get_latest_checkpoint(results_dir)
print(f"Checkpoint: {checkpoint_path}" if checkpoint_path else "No checkpoints found.")

# ----------------------------
# Load Tokenized Data
# ----------------------------
def flatten_batches(batched_data):
    flat_data = []
    for batch in batched_data:
        for i in range(len(batch['input_ids'])):
            item = {
                'input_ids': batch['input_ids'][i][:1024],
                'attention_mask': batch['attention_mask'][i][:1024],
                'labels': int(batch['labels'][i])
            }
            flat_data.append(item)
    return flat_data

with open(train_tokenized_path, 'rb') as f:
    train_tokenized = pickle.load(f)
with open(val_tokenized_path, 'rb') as f:
    val_tokenized = pickle.load(f)

train_dataset = Dataset.from_list(flatten_batches(train_tokenized))
val_dataset = Dataset.from_list(flatten_batches(val_tokenized))

# ----------------------------
# Cast label to int64
# ----------------------------
from datasets import Value

train_dataset = train_dataset.cast_column("labels", Value("int64"))
val_dataset = val_dataset.cast_column("labels", Value("int64"))


# Confirm label range
train_labels = [x['labels'] for x in train_dataset]
val_labels = [x['labels'] for x in val_dataset]
print("Train label min/max:", min(train_labels), max(train_labels))
print("Val label min/max:", min(val_labels), max(val_labels))

# ----------------------------
# Tokenizer & Model
# ----------------------------
tokenizer = LongformerTokenizerFast.from_pretrained('allenai/longformer-base-4096')
tokenizer.model_max_length = 1024

# ✅ Always load from base model to avoid mismatch errors
model = LongformerForSequenceClassification.from_pretrained(
    'allenai/longformer-base-4096',
    num_labels=10
)
checkpoint_path = None  # ✅ prevent resuming from mismatched checkpoints


# Freeze all layers except classifier head
for name, param in model.named_parameters():
    if not name.startswith("classifier"):
        param.requires_grad = False
trainable_params = [name for name, param in model.named_parameters() if param.requires_grad]
logger.info(f"Trainable parameters: {trainable_params}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# ----------------------------
# Compute Metrics
# ----------------------------
def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(pred.predictions, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    return {
        'accuracy': (preds == labels).mean(),
        'precision_weighted': precision,
        'recall_weighted': recall,
        'f1_weighted': f1
    }

# ----------------------------
# Training Arguments
# ----------------------------
training_args = TrainingArguments(
    output_dir=results_dir,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=200,
    learning_rate=5e-5,
    weight_decay=0.01,
    max_grad_norm=1.0,
    logging_dir=os.path.join(input_dir, 'logs'),
    logging_steps=10,
    logging_first_step=True,
    eval_strategy='steps',
    eval_steps=500,
    save_strategy='epoch',
    load_best_model_at_end=False,
    fp16=True,
    report_to='none',
    log_level="info",
    disable_tqdm=False
)

# ----------------------------
# Trainer
# ----------------------------
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    data_collator=DataCollatorWithPadding(tokenizer, padding=True),
)

# ----------------------------
# Train
# ----------------------------
print("Starting training...")
start_time = time.time()
trainer.train(resume_from_checkpoint=checkpoint_path)
end_time = time.time()

# ----------------------------
# Save Model & Evaluation
# ----------------------------
final_model_path = os.path.join(input_dir, 'final_model')
trainer.save_model(final_model_path)
print(f"Model saved to {final_model_path}")
print(f"Training completed in {(end_time - start_time)/60:.2f} minutes.")

metrics = trainer.evaluate()
pd.DataFrame([metrics]).to_csv(os.path.join(input_dir, "final_eval_metrics.csv"), index=False)
print("Metrics saved to final_eval_metrics.csv")

# ----------------------------
# Cleanup
# ----------------------------
del model, trainer
torch.cuda.empty_cache()
gc.collect()


No checkpoints found.


Casting the dataset:   0%|          | 0/24000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/3000 [00:00<?, ? examples/s]

Train label min/max: 0 9
Val label min/max: 0 9


loading file vocab.json from cache at /root/.cache/huggingface/hub/models--allenai--longformer-base-4096/snapshots/301e6a42cb0d9976a6d6a26a079fef81c18aa895/vocab.json
loading file merges.txt from cache at /root/.cache/huggingface/hub/models--allenai--longformer-base-4096/snapshots/301e6a42cb0d9976a6d6a26a079fef81c18aa895/merges.txt
loading file tokenizer.json from cache at /root/.cache/huggingface/hub/models--allenai--longformer-base-4096/snapshots/301e6a42cb0d9976a6d6a26a079fef81c18aa895/tokenizer.json
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at None
loading file chat_template.jinja from cache at None
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--allenai--longformer-base-4096/snapshots/301e6a42cb0d9976a6d6a26a079fef81c18aa895/config.json
Model config LongformerConfig {
  "attention_mode": "longformer",
  "attention_probs_dropout_pro

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
