In [None]:
from pathlib import Path
import datetime

import pandas as pd

from datasets import load_dataset, Features, Value
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer

RAW_DATA_DIR = Path("data", "raw")
PROCESSED_DATA_DIR = Path("data", "processed")

In [None]:
# TODO: Use cleaned data + new labelled data for training
train_files = [str(p) for p in PROCESSED_DATA_DIR.iterdir()]
train_files

In [None]:
t_df = pd.read_csv(
    "tagged_articles.csv", 
    usecols=["Published", "Headline", "Summary", "Theme", "New Index", "New Sub Index"],
    na_values="-",
    parse_dates=["Published"],
).rename(
    lambda col_name: col_name.lower().replace(" ", "_"), axis="columns"
).assign(label = lambda df: df[["theme", "new_index"]].fillna("").agg(' > '.join,axis="columns"))

t_df.to_parquet("test.parquet")

In [None]:
# TODO: Replace with taxonomy

# Scaffold for trial
# df = pd.read_csv("tagged_articles.csv", usecols=["Headline", "Theme", "New Index", "New Sub Index", "label"]).rename(lambda col_name: col_name.lower().replace(" ", "_"), axis="columns")
min_labels_list = t_df["label"].value_counts()[lambda s: s>=2].index.to_list()

# TODO: Replace with duckdb schema
features = Features({
    'published': Value('timestamp[ns]'),
    'headline': Value('string'),
    'summary': Value('string'),
    'theme': Value('string'),
    'new_index': Value('string'),
    'new_sub_index': Value('string'),
    'label': Value('string'),
})

# TODO: Load real data when ready
# dataset = load_dataset("parquet", data_files={'train': train_files}, features=features)
dataset = load_dataset("parquet", data_files={'train': "test.parquet"}, features=features).filter(lambda row: row['label'] in min_labels_list)

# # Fast train for testing
# train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=4)
train_dataset = dataset["train"]

# Load a SetFit model
model = SetFitModel.from_pretrained(
    "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", 
    cache_dir="cached_models",
)

# Create trainer
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    loss_class=CosineSimilarityLoss,
    metric="accuracy",
    batch_size=16,
    num_iterations=20, # The number of text pairs to generate for contrastive learning
    num_epochs=1, # The number of epochs to use for contrastive learning
    column_mapping={"headline": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)

trainer.train()

trainer.model.save_pretrained(f"trained_models/{datetime.date.today().isoformat()}")