In [31]:
import torch
import re
from datasets import load_dataset
from pathlib import Path

assert torch.cuda.is_available(), "CUDA not available"
device = torch.device("cuda")
seed = 42

dataset_id = "sujet-ai/Sujet-Finance-Instruct-177k"

topics = ['Analyst Update', 'Fed | Central Banks', 'Company | Product News', 'Treasuries | Corporate Debt', 'Dividend', 'Earnings', 'Energy | Oil', 'Financials', 'Currencies', 'General News | Opinion', 'Gold | Metals | Materials', 'IPO', 'Legal | Regulation', 'M&A | Investments', 'Macro', 'Markets', 'Politics', 'Personnel Change', 'Stock Commentary', 'Stock Movement']
topics_label2id = {label: str(i) for i, label in enumerate(topics)}

In [32]:
def topic_class_filter(example):
    return example["task_type"] == "topic_classification"

def topic_mapping(example):
    example["answer"] = topics_label2id[example["answer"]]
    return example

def text_cleaning(example):
    example["user_prompt"] = re.sub(r"https://t.co/.+", "", example["user_prompt"]).strip(" ")
    return example

rmv_cols = ["Unnamed: 0", "inputs", "system_prompt", "task_type", "dataset", "index_level", "conversation_id"]
rnm_cols = {"answer": "label", "user_prompt": "text"}

dataset = (load_dataset(dataset_id, split="train")
            .filter(topic_class_filter)
            .map(topic_mapping)
            .map(text_cleaning)
            .remove_columns(rmv_cols)
            .rename_columns(rnm_cols))
dataset = dataset.train_test_split(test_size=0.05, seed=seed)
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 16140
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 850
    })
})

In [34]:
hub_basepath = Path(r"C:\Users\samba\.cache\huggingface\hub")
dataset["train"].to_csv(hub_basepath / "datasets--Sujet--TopicClassification" / "train.csv")
dataset["test"].to_csv(hub_basepath / "datasets--Sujet--TopicClassification" / "test.csv")

Creating CSV from Arrow format:   0%|          | 0/17 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

88000