In [None]:
import os
import json
import jsonlines
import time
import datetime
from tqdm import tqdm
from collections import defaultdict
import random
import statistics
import datasets
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import torch
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)

In [None]:
class Dataset:
    def __init__(self, dataset, lowercase=False, batch_size=16, max_len=64, instance_col_name="text"):
        self.dataset = dataset
        self.lowercase = lowercase
        self.batch_size = batch_size
        self.max_len = max_len
        self.instance_col_name = instance_col_name

    def _encode(self, example):
        return self.tokenizer(
            example[self.instance_col_name],
            truncation=True,
            max_length=self.max_len,
            padding="max_length",
        )

    def format(self, dataset):
        dataset = dataset.map(self._encode, batched=True)
        try:
            dataset.set_format(
                type="torch",
                columns=["input_ids", "token_type_ids", "attention_mask", "label"],
            )
        except:
            try:
                dataset.set_format(
                type="torch",
                columns=["input_ids", "attention_mask", "label"],
            )
            except:
                raise Exception("Unable to set columns.")
        return dataset

    def format_data(self, tokenizer, batch_size=None):
        print("Formatting data...")
        self.tokenizer = tokenizer
        if batch_size:
            self.batch_size = batch_size

        self.train_dataset = self.format(self.dataset["train"])
        # self.validation_dataset = self.format(self.dataset["validation"])
        self.test_dataset = self.format(self.dataset["test"])
        print("Done formatting.")


def make_output_dir(output_dir):
    timestamp = datetime.datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
    output_dir = os.path.join(output_dir, timestamp)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    return output_dir


def set_seed(seed):  # for reproducibility
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def compute_metrics(p, average="macro"):
    pred, true = p
    pred = np.argmax(pred, axis=1)

    accuracy = accuracy_score(y_true=true, y_pred=pred)
    recall = recall_score(y_true=true, y_pred=pred, average=average)
    precision = precision_score(y_true=true, y_pred=pred, average=average)
    f1 = f1_score(y_true=true, y_pred=pred, average=average)
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1-score": f1,
    }

In [None]:
class TextClassifier:
    def __init__(
        self,
        model_name=None,
        lowercase=False,
        max_len=64,
        num_labels=2,
        output_dir="",
        seed=None,  # for reproducibility
    ):
        self.output_dir = output_dir

        # Model parameters
        self.model_name = model_name
        self.do_lower_case = lowercase
        self.max_len = max_len
        self.num_labels = num_labels
        
        # Set seed for reproducibility
        if seed:  # for reproducibility
            set_seed(seed)

        print("Loading pre-trained tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)

        try:
            self.tokenizer.do_lower_case = self.do_lower_case
        except:
            try:
                self.tokenizer.do_lowercase_and_remove_accent = self.do_lower_case
            except:
                raise Exception("Unable to set value for 'do_lower_case' or 'do_lowercase_and_remove_accent'")
        self.tokenizer.model_max_length = self.max_len
        print("Done.")

    def model_init(self):
        print("Loading pre-trained model...")
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name, 
            num_labels = self.num_labels,
        )
        print(f"Loaded model from path: {self.model_name}")
        return self.model

    def save_model(self, save_path=None):
        save_dir = save_path if save_path else self.output_dir
        if not os.path.isdir(save_dir):
            save_dir.mkdir(parents=True)
        torch.save(self.model, save_dir)
        print(f"Saved model to path: {save_dir}")

    def load_model(self, path):
        if os.path.exists(path):
            print("Loading local model state dict...")
            self.model.load_state_dict(torch.load(path))
            print(f"Loaded model from path: {path}")
        else:
            print(f"Model path does not exist: {path}")
            raise Exception(f"The specified file path ({path}) does not exist!")


class TextClassifierTrainer(TextClassifier):
    def __init__(
        self,
        output_dir="",
        # bfloat16=False,
        epochs=5,
        batch_size=16,
        warmup_steps=500,
        weight_decay=0.9,
        random_weights=False,
        load_best_model_at_end=False,
        early_stop=False,
        eval_steps=500,
        seed=None,  # for reproducibility
        save_strategy="no",
        no_cuda=True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.output_dir = output_dir
        self.logging_dir = os.path.join(output_dir, "logs")

        # Model parameters
        self.epochs = epochs
        self.batch_size = batch_size
        self.warmup_steps = warmup_steps
        self.weight_decay = weight_decay
        self.load_best_model_at_end = load_best_model_at_end if save_strategy != "no" else False
        self.random_weights = random_weights
        self.early_stop = early_stop
        self.eval_steps = eval_steps
        self.save_strategy = save_strategy
        self.no_cuda = no_cuda

        self.seed = seed

    def train(self, train_dataset, eval_dataset):

        ### Training
        print("Initializing trainer...")

        training_args = TrainingArguments(
            output_dir=self.output_dir,
            num_train_epochs=self.epochs,
            per_device_train_batch_size=self.batch_size,
            per_device_eval_batch_size=self.batch_size * 4,
            warmup_steps=self.warmup_steps,
            weight_decay=self.weight_decay,
            # logging_dir=self.logging_dir,
            # logging_strategy="steps",
            # logging_steps=self.eval_steps,
            evaluation_strategy="steps",
            eval_steps=self.eval_steps,
            load_best_model_at_end=self.load_best_model_at_end,
            save_strategy=self.save_strategy,
            # report_to="wandb",
            seed=self.seed,
            no_cuda=self.no_cuda,
        )

        trainer = Trainer(
            model_init=self.model_init,
            args=training_args,
            compute_metrics=compute_metrics,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=self.tokenizer,
        )

        print("Trainer initialized.")
        print("Training...")
        trainer.train()
        print("Done training.")

        ### Validation
        print("Evaluating...")
        res = trainer.evaluate()
        preds = trainer.predict(eval_dataset)
        raw_preds = preds[0]
        labels = [np.argmax(l) for l in raw_preds]

        raw_res = {}
        assert len(labels) == len(eval_dataset)
        for uuid,label in zip(eval_dataset["uuid"],labels):
            raw_res[uuid] = str(label)
        print("Evaluation results:")
        print(
            f'Eval loss: \t{res["eval_loss"]}, Eval Acc: \t{res["eval_accuracy"]}, Eval P: \t{res["eval_precision"]}, Eval R: \t{res["eval_recall"]}, Eval F1: \t{res["eval_f1-score"]}'
        )
        print("Done evaluating.")

        if self.save_strategy != "no":
            trainer.save_model()
            trainer.save_state()
            self.trained_model_path = os.path.join(self.output_dir, "pytorch_model.bin")
            assert self.trained_model_path.is_file()
            print(f"Saved model to path: {self.output_dir}")

        return res, raw_res

In [None]:
def train(train_dataset, test_dataset, args):
    run_output_dir = args.experiment_output_dir
    if args.save_strategy != "no":
        run_output_dir = make_output_dir(args.experiment_output_dir)

    trainer = TextClassifierTrainer(
        model_name=args.model,
        lowercase=args.lowercase,
        epochs=args.epochs,
        batch_size=args.batch_size,
        output_dir=run_output_dir,
        warmup_steps=args.warmup_steps,
        weight_decay=args.weight_decay,
        load_best_model_at_end=args.load_best_model_at_end,
        random_weights=args.random_weights,
        early_stop=args.early_stop,
        eval_steps=args.eval_steps,
        save_strategy=args.save_strategy,
        seed=args.seed,
        no_cuda=args.no_cuda,
    )

    dataset = Dataset({"train": train_dataset, "test": test_dataset}, instance_col_name=args.instance_col_name)
    dataset.format_data(trainer.tokenizer)

    result, raw_result = trainer.train(dataset.train_dataset, dataset.test_dataset)

    return result, raw_result


def load_trainer(args):
    run_output_dir = args.experiment_output_dir
    if args.save_strategy != "no":
        run_output_dir = make_output_dir(args.experiment_output_dir)

    trainer = TextClassifierTrainer(
        model_name=args.model,
        lowercase=args.lowercase,
        epochs=args.epochs,
        batch_size=args.batch_size,
        output_dir=run_output_dir,
        warmup_steps=args.warmup_steps,
        weight_decay=args.weight_decay,
        load_best_model_at_end=args.load_best_model_at_end,
        random_weights=args.random_weights,
        early_stop=args.early_stop,
        eval_steps=args.eval_steps,
        save_strategy=args.save_strategy,
        seed=args.seed,
    )

    return trainer


def run(args):
    # wandb.init(project=args.wandb_name)

    if args.train_data:
        if isinstance(args.train_data, list):
            try:
                for path in args.train_data:
                    assert os.path.exists(path)
            except:
                raise Exception(f"Failed to load: {args.train_data}")

            train_dfs = []
            for path in args.train_data:
                _df = pd.read_csv(path, sep="\t")
                if args.lang:
                    _df = _df[_df["lang"] == args.lang]
                train_dfs.append(_df)

            train_df = pd.concat(train_dfs)
        elif isinstance(args.train_data, pd.DataFrame):
            train_df = args.train_data
        else:
            raise Exception(f"Invalid type for {type(args.train_data)}")

        train_df.rename(columns={"is_variable": "label"}, inplace=True)
    else:
        train_df = None
    
    if args.test_data:
        if isinstance(args.test_data, list):
            try:
                for path in args.test_data:
                    assert os.path.exists(path)
            except:
                raise Exception(f"Failed to load: {args.test_data}")

            test_dfs = []
            for path in args.test_data:
                _df = pd.read_csv(path, sep="\t")
                test_dfs.append(_df)

            test_df = pd.concat(test_dfs)
        elif isinstance(args.test_data, pd.DataFrame):
            test_df = args.test_data
        else:
            raise Exception(f"Invalid type for {type(args.test_data)}")
        
        test_df.rename(columns={"is_variable": "label"}, inplace=True)
    else:
        test_df = None

    if isinstance(train_df, pd.DataFrame):
        assert args.instance_col_name in train_df.columns.tolist()
    if isinstance(test_df, pd.DataFrame):
        assert args.instance_col_name in test_df.columns.tolist()

    results = defaultdict(list)
    raw_results = {}

    experiment_output_dir = make_output_dir(args.output_dir)
    args.experiment_output_dir = experiment_output_dir

    test_indices = []

    scores = defaultdict(list)

    if args.train:
        # Train model
        print(f"Training model...")

        if args.do_cross_validation:
            print("Running cross-validation...")

            combined_df = pd.concat([train_df, test_df]).copy()
            # combined_df.reset_index(inplace=True)
            X_idx = combined_df.index.to_numpy()
            y = combined_df.label.to_numpy()

            if args.balance_splits:
                print("Balacing splits...")
                kf = StratifiedKFold(n_splits=args.n_cv_splits, random_state=args.seed, shuffle=True)
                kf.get_n_splits(X_idx)
                splits = kf.split(X_idx, y)
            else:
                kf = KFold(n_splits=args.n_cv_splits, random_state=args.seed, shuffle=True)
                kf.get_n_splits(X_idx)
                splits = kf.split(X_idx)

            for j, (train_index, test_index) in enumerate(tqdm(splits)):
                train_dataset = datasets.Dataset.from_pandas(combined_df.iloc[train_index])
                test_dataset = datasets.Dataset.from_pandas(combined_df.iloc[test_index])
                test_indices.append(test_index)

                result, raw_result = train(train_dataset, test_dataset, args)

                for k, v in result.items():
                    results[k].append(v)
                    scores[k].append(v)
                
                raw_results[str(j)] = raw_result
        else:
            train_dataset = datasets.Dataset.from_pandas(train_df)
            test_dataset = datasets.Dataset.from_pandas(test_df)

            result, raw_result = train(train_dataset, test_dataset, args)

            for k, v in result.items():
                results[k].append(v)
                scores[k].append(v)
            
            raw_results["0"] = raw_result

        # Save hyperparameters
        results_file = os.path.join(experiment_output_dir, "hyperparameters.jsonl")
        with jsonlines.open(results_file, "a") as writer:
            writer.write(vars(args))

        # Compute mean and standard deviation
        print("***** Cross-Validation Results *****")
        for k, v in results.items():
            skip = True
            for m in ["accuracy", "precision", "recall", "f1-score"]:
                if m in k:
                    skip = False
            if skip:
                continue
            mean, std, pstd = (
                statistics.mean(v) if len(v) > 1 else v[0],
                statistics.stdev(v) if len(v) > 1 else 0,
                statistics.pstdev(v) if len(v) > 1 else 0,
            )
            print(
                k + ":\n",
                "Mean:",
                round(mean, 4),
                "\tStd.:",
                round(std, 4),
                "\tPStd:",
                round(pstd, 4),
            )
            
            with jsonlines.open(results_file, "a") as writer:
                writer.write({k: {"Mean": mean, "Std": std, "PStd": pstd}})
        
        with open(os.path.join(experiment_output_dir, "results_raw.jsonl"), "w") as fp:
            json.dump(scores, fp)
        
        # Save predictions
        raw_results_file = os.path.join(experiment_output_dir, "predictions.json")
        with open(raw_results_file, "w") as fp:
            json.dump(raw_results, fp)

        # Save the prediction as a submission file
        raw_results_file = os.path.join(experiment_output_dir, "submission.json")
        with open(raw_results_file, "w") as fp:
            json.dump(raw_results['0'], fp)

In [None]:
class Args:
    train=True
    test=False
    lang=None
    train_data=["./sv-ident/data/train/train.tsv", "./sv-ident/data/train/val.tsv"]
    test_data=[]
    instance_col_name="sentence"
    # wandb-name="variable-detection-baselines-FT"
    model="KM4STfulltext/SSCI-SciBERT-e2"
    output_dir="./runs/sv-ident/cross-val/en_sub/"
    max_len=64
    lowercase=False
    gradient_accumulation_steps=1
    learning_rate=5e-5
    weight_decay=0.0
    adam_epsilon=1e-8
    max_grad_norm=1.0
    random_weights=False
    batch_size=16
    epochs=1
    max_steps=-1
    warmup_steps=0
    early_stop=False
    eval_steps=100
    num_labels=2
    save_strategy="no"
    load_best_model_at_end=False
    seed=0
    balance_splits=False
    do_cross_validation=True
    n_cv_splits=10
    no_cuda=False

args = Args()
assert (args.train or args.test)
run(args)