In [106]:
import logging
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

import datasets
from datasets import load_dataset
import nibabel as nib
from nilearn.connectome import ConnectivityMeasure
from nilearn.maskers import NiftiLabelsMasker
import numpy as np
import torch
import transformers
from transformers import (
    AutoImageProcessor,
    AutoModel,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint

In [107]:
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    text_model_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained text model name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={
            "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether to trust the execution of code from datasets/models defined on the Hub."
                " This option should only be set to `True` for repositories you trust and in which you have read the"
                " code, as it will execute code present on the Hub on your local machine."
            )
        },
    )
    freeze_vision_model: bool = field(
        default=False, metadata={"help": "Whether to freeze the vision model parameters or not."}
    )
    freeze_text_model: bool = field(
        default=False, metadata={"help": "Whether to freeze the text model parameters or not."}
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    dataset_path: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    data_dir: Optional[str] = field(
        default=None, metadata={"help": "The data directory containing input files."})
    image_column: Optional[str] = field(
        default="image_path",
        metadata={
            "help": "The name of the column in the datasets containing the full image file paths."},
    )
    caption_column: Optional[str] = field(
        default="caption",
        metadata={
            "help": "The name of the column in the datasets containing the image captions."},
    )
    train_file: Optional[str] = field(
        default=None, metadata={"help": "The input training data file (a jsonlines file)."}
    )
    validation_file: Optional[str] = field(
        default=None,
        metadata={
            "help": "An optional input evaluation data file (a jsonlines file)."},
    )
    max_seq_length: Optional[int] = field(
        default=128,
        metadata={
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )


dataset_path_mapping = {
    "./dataset_loading_scripts/abide.py": ("image_path", "caption"),
}


def collate_fn(batch: list[dict]):
    past_values = torch.stack([sample['time_series'] for sample in batch])
    target_values = torch.stack([sample['label'] for sample in batch])
    return {
        'past_values': past_values,
        'target_values': target_values
    }

In [108]:
args_list = [
    '--tokenizer_name', '../pretrained_models/roberta-base',
    '--text_model_name', '../pretrained_models/roberta-base',
    '--trust_remote_code',
    '--freeze_text_model',
    '--dataset_path', './dataset_loading_scripts/abide.py',
    '--data_dir', '/bigdata/yanting/datasets/nilearn_data',
    '--output_dir', './outputs',
    '--overwrite_output_dir',
    '--do_train',
    '--do_eval',
    '--eval_strategy', 'epoch',
    '--per_device_train_batch_size', '64',
    '--per_device_eval_batch_size', '1',
    '--learning_rate', '1e-4',
    '--weight_decay', '1e-4',
    '--num_train_epochs', '20',
    '--lr_scheduler_type', 'cosine_with_restarts',
    '--logging_steps', '1',
    '--save_strategy', 'epoch',
    '--save_safetensors', 'False',
    '--dataloader_drop_last', 'True',
    '--dataloader_num_workers', '8',
    '--run_name', 'brainnettf_asd',
    '--remove_unused_columns', 'False',
    '--report_to', 'wandb'
]

parser = HfArgumentParser(
    (ModelArguments, DataTrainingArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses(
    args_list)

In [109]:
last_checkpoint = get_last_checkpoint(training_args.output_dir)

In [110]:
ds = load_dataset(
    path=data_args.dataset_path,
    data_dir=data_args.data_dir,
    split='train',
    trust_remote_code=model_args.trust_remote_code
).train_test_split(
    test_size=.2,
    stratify_by_column='label',
    seed=42,
)
ds_train_val = ds['train']
ds_test = ds['test']
ds_train_val = ds_train_val.train_test_split(
    test_size=.2,
    stratify_by_column='label',
    seed=42,
)
ds_train = ds_train_val['train']
ds_val = ds_train_val['test']
dataset = datasets.DatasetDict({
    'train': ds_train,
    'validation': ds_val,
    'test': ds_test
})

In [111]:
config = transformers.PatchTSTConfig(
        num_input_channels=200,
        num_targets=2,
        context_length=512,
        patch_length=12,
        stride=12,
        use_cls_token=True,
    )
model = transformers.PatchTSTForClassification.from_pretrained(
    last_checkpoint
)

In [112]:
def transform_images(batch):
    time_series_lst = [np.loadtxt(
        time_series_path, dtype=np.float32
    ) for time_series_path in batch['time_series_path']]  # bs x sequence_length x num_input_channels

    bs = len(time_series_lst)
    sequence_length = 512
    num_input_channels = time_series_lst[0].shape[-1]

    mask = np.zeros(
        (bs, sequence_length, num_input_channels), dtype=np.bool_
    )

    for i in range(len(time_series_lst)):
        time_series = time_series_lst[i]
        # truncate
        if time_series.shape[0] > sequence_length:
            time_series = time_series[:sequence_length]
        # mask
        mask[i, :time_series.shape[0]] = 1
        # pad
        time_series_lst[i] = np.pad(
            time_series,
            ((0, sequence_length - time_series.shape[0]), (0, 0))
        )
    time_series_lst = np.stack(time_series_lst, axis=0)

    batch['time_series'] = torch.from_numpy(time_series_lst)
    batch['mask'] = torch.from_numpy(mask)
    batch['label'] = torch.tensor(batch['label'])
    return batch

In [113]:
eval_dataset = dataset["test"]
if data_args.max_eval_samples is not None:
    max_eval_samples = min(
        len(eval_dataset), data_args.max_eval_samples)
    eval_dataset = eval_dataset.select(range(max_eval_samples))

# Transform images on the fly as doing it on the whole dataset takes too much time.
eval_dataset.set_transform(transform_images)

In [114]:
device = 'cuda:1'
model.eval()
model.to(device)

PatchTSTForClassification(
  (model): PatchTSTModel(
    (scaler): PatchTSTScaler(
      (scaler): PatchTSTStdScaler()
    )
    (patchifier): PatchTSTPatchify()
    (masking): Identity()
    (encoder): PatchTSTEncoder(
      (embedder): PatchTSTEmbedding(
        (input_embedding): Linear(in_features=36, out_features=512, bias=True)
      )
      (positional_encoder): PatchTSTPositionalEncoding(
        (positional_dropout): Identity()
      )
      (layers): ModuleList(
        (0): PatchTSTEncoderLayer(
          (self_attn): PatchTSTAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout_path1): Identity()
          (norm_sublayer1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (f

In [115]:
from tqdm import tqdm

In [116]:
y_true = []
y_pred = []
y_score = []
for sample in tqdm(eval_dataset):
    labels = sample['label']
    mask = sample['mask']
    pixel_values = sample['time_series']
    y_true.append(labels)
    with torch.no_grad():
        res = model(pixel_values.unsqueeze(0).to(device))
        logits = res['prediction_logits']
        y_score.append(logits.squeeze(0).cpu().numpy())
        y_pred.append(logits.argmax().item())

100%|██████████| 175/175 [00:06<00:00, 27.19it/s]


In [117]:
import sklearn
from sklearn.metrics import accuracy_score

In [118]:
np.unique(y_true, return_counts=True), np.unique(y_pred, return_counts=True)

((array([0, 1]), array([94, 81])), (array([0, 1]), array([168,   7])))

In [119]:
auroc = float(sklearn.metrics.roc_auc_score(y_true, np.array(y_score)[:,1]))
acc = accuracy_score(y_true, y_pred)
sen = sklearn.metrics.recall_score(y_true, y_pred)
spc = sklearn.metrics.recall_score(y_true, y_pred, pos_label=0)
auroc*100, acc*100, sen*100, spc*100

(54.62306277909115, 54.285714285714285, 4.938271604938271, 96.80851063829788)