In [169]:
import pytorch_lightning as pl
import os
import sys
from datetime import datetime
import gc
import pathlib
import torch
from torch import nn
from torch.utils.data import DataLoader
import optuna
import numpy as np
import pandas as pd
import sklearn
from transformers import AutoTokenizer, AdamW, AutoConfig, AutoModelForQuestionAnswering
from tqdm import tqdm
from typing import List, Dict, Union, Tuple
os.environ["LOGGING_INI"] = "../logging.ini"
import questionanswering as qa
from questionanswering import squad

# Configuration

In [170]:
gpus: List[int] = [0]
epochs = 20
batch_size = 8  # electra-small=64, electra-large=8
patience = 5
lr = (1e-5, 1e-5)
swa_anneal_epochs = (5, 5)
swa_start_epoch = (3, 3)
ca_T_max = (4, 4)
cawr_T_0 = (-1, -1)
cawr_T_mult = (1.0, 1.0)
model_name = "xlm_roberta"
PRETRAINED_DIR = "../pretrained/"
pretrained_map = {
    "albert": f"{PRETRAINED_DIR}albert-base-v2",
    "electra_small": f"{PRETRAINED_DIR}google/electra-small-discriminator",
    "electra": f"{PRETRAINED_DIR}google/electra-base-discriminator",
    "electra_large": f"{PRETRAINED_DIR}google/electra-large-discriminator",
    "distilroberta": f"{PRETRAINED_DIR}distilroberta-base",
    "mpnet": f"{PRETRAINED_DIR}microsoft/mpnet-base",
    "roberta": f"{PRETRAINED_DIR}roberta-base",
    "xlm_roberta": f"{PRETRAINED_DIR}xlm-roberta-base",
    "xlm_roberta_large": f"{PRETRAINED_DIR}xlm-roberta-large",
}
wp_models = {"electra_small", "electra", "electra_large", "mpnet"}
sp_models = {"albert", "xlm_roberta", "xlm_roberta_large"}
bpe_models = {"roberta", "distilroberta"}
pretrained_dir = pretrained_map[model_name]
model_max_length = 512
stride = 128
n_trials = 1
n_splits = 3
min_fold_score = 0.5
gradient_checkpointing = False
num_workers = 0  # multi-process data loading
accelerator = None
pd.set_option("use_inf_as_na", True)
pd.set_option("max_info_columns", 9999)
pd.set_option("display.max_columns", 9999)
pd.set_option("display.max_rows", 9999)
pd.set_option('max_colwidth', 9999)
tqdm.pandas()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
pl.seed_everything(31)

31

In [171]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    for i in range(torch.cuda.device_count()):
        print(f"{i}: {torch.cuda.get_device_name(i)}")
        print('Memory Allocated:\t', round(torch.cuda.memory_allocated(i)/1024**3,1), 'GB')
        print('Memory Cached:\t\t', round(torch.cuda.memory_reserved(i)/1024**3,1), 'GB')
print(f"device={device}")

0: NVIDIA GeForce GTX 1060 6GB
Memory Allocated:	 0.0 GB
Memory Cached:		 0.0 GB
device=cuda


In [172]:
tokenizer = AutoTokenizer.from_pretrained(pretrained_dir, model_max_length=model_max_length)
is_right_padding = tokenizer.padding_side == "right"
print(f"{repr(tokenizer)}\ninput_keys={tokenizer.model_input_names}")

PreTrainedTokenizerFast(name_or_path='../pretrained/xlm-roberta-base', vocab_size=250002, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})
input_keys=['input_ids', 'attention_mask']


In [173]:
config = AutoConfig.from_pretrained(pretrained_dir)
config.gradient_checkpointing = gradient_checkpointing
model = AutoModelForQuestionAnswering.from_pretrained(pretrained_dir, config=config)
print(repr(model.config))
del config, model

XLMRobertaConfig {
  "_name_or_path": "../pretrained/xlm-roberta-base",
  "architectures": [
    "XLMRobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "xlm-roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_past": true,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.10.0",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 250002
}



In [174]:
%%time
df = pd.read_parquet(f"input/train.parquet")
#df = df.sample(frac=0.01)
contexts = df["context"].tolist()
questions = df["question"].tolist()
answer_start = df["answer_start"].tolist()
answer_length = df["a_length"].tolist()
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 130319 entries, 0 to 130318
Data columns (total 10 columns):
 #   Column          Non-Null Count   Dtype 
---  ------          --------------   ----- 
 0   id              130319 non-null  object
 1   title           130319 non-null  object
 2   question        130319 non-null  object
 3   answer_text     130319 non-null  object
 4   answer_start    130319 non-null  int16 
 5   context         130319 non-null  object
 6   qc_length       130319 non-null  int32 
 7   a_length        130319 non-null  int32 
 8   qc_word_length  130319 non-null  int32 
 9   a_word_length   130319 non-null  int32 
dtypes: int16(1), int32(4), object(5)
memory usage: 7.2+ MB
Wall time: 608 ms


In [175]:
%%time
s1, s2 = contexts, questions
truncation = "only_first"
if is_right_padding:
    s1, s2 = questions, contexts
    truncation = "only_second"
x = tokenizer(
    s1, 
    s2, 
    truncation=truncation, 
    padding="max_length",
    stride=stride,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    return_special_tokens_mask=False,
)
print(f"{repr(x.keys())}")
# all only supports torch.uint8 and torch.bool dtypes
#special_tokens_mask = torch.tensor(x.pop("special_tokens_mask"), dtype=torch.uint8)
overflow_to_sample_mapping = x.pop("overflow_to_sample_mapping")
print(f"len(overflow_to_sample_mapping)={len(overflow_to_sample_mapping)}")
offset_mapping = x.pop("offset_mapping")
print(f"len(offset_mapping)={len(offset_mapping)}")

dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])
len(overflow_to_sample_mapping)=131167
len(offset_mapping)=131167
Wall time: 3min 16s


In [176]:
%%time
start_positions, end_positions = squad.position_labels(
    offset_mapping=offset_mapping,
    overflow_to_sample_mapping=overflow_to_sample_mapping,
    answer_start=answer_start,
    answer_length=answer_length,
)
x["start_positions"] = start_positions
x["end_positions"] = end_positions
print(f"len(start_positions)={len(start_positions)}, len(end_positions)={len(end_positions)}")
assert len(x["input_ids"]) == len(start_positions)

ValueError: answer span cannot be found!
prev=15928, i=60, j=61, start=0, end=27, offsets=[(0, 0), (0, 2), (3, 7), (8, 12), (13, 16), (17, 20), (21, 26), (27, 37), (38, 43), (44, 47), (48, 57), (58, 62), (62, 63), (0, 0), (0, 0), (0, 3), (4, 11), (12, 19), (19, 20), (21, 23), (24, 27), (28, 34), (35, 40), (41, 48), (49, 51), (52, 55), (56, 60), (61, 68), (69, 74), (75, 80), (81, 88), (89, 91), (92, 99), (100, 105), (106, 117), (118, 123), (123, 124), (125, 126), (125, 126), (127, 136), (137, 140), (141, 146), (147, 150), (151, 156), (156, 160), (161, 164), (165, 168), (169, 174), (175, 182), (183, 190), (190, 192), (193, 195), (196, 197), (198, 202), (203, 205), (206, 207), (206, 211), (212, 213), (214, 215), (215, 217), (217, 220), (220, 221), (222, 223), (222, 223), (224, 226), (226, 230), (231, 235), (236, 237), (238, 242), (243, 249), (250, 254), (255, 262), (263, 265), (266, 276), (277, 280), (281, 286), (287, 293), (294, 295), (294, 295), (296, 297), (296, 304), (305, 308), (309, 315), (315, 316), (317, 319), (320, 327), (328, 335), (336, 341), (342, 343), (342, 343), (344, 350), (350, 353), (354, 357), (357, 358), (358, 361), (362, 363), (364, 366), (367, 371), (371, 374), (375, 377), (378, 379), (378, 379), (380, 384), (385, 386), (385, 386), (387, 390), (391, 396), (397, 402), (402, 407), (408, 412), (413, 414), (415, 420), (421, 422), (423, 428), (428, 430), (430, 432), (433, 437), (438, 439), (438, 439), (440, 443), (444, 446), (446, 447), (447, 448), (449, 453), (453, 455), (456, 463), (464, 473), (473, 477), (478, 485), (485, 489), (490, 496), (497, 501), (502, 509), (510, 511), (512, 514), (515, 518), (519, 522), (523, 528), (529, 534), (535, 541), (542, 548), (549, 554), (555, 560), (560, 562), (562, 565), (566, 573), (573, 575), (576, 578), (579, 580), (581, 582), (581, 586), (587, 588), (589, 590), (590, 592), (592, 595), (596, 597), (598, 609), (610, 611), (610, 611), (612, 619), (620, 621), (620, 621), (622, 625), (625, 628), (629, 636), (637, 642), (642, 646), (646, 647), (648, 651), (652, 655), (656, 659), (660, 662), (663, 669), (670, 681), (681, 682), (683, 685), (686, 689), (690, 696), (697, 702), (702, 707), (707, 708), (708, 709), (710, 723), (724, 731), (732, 733), (734, 740), (741, 745), (746, 756), (757, 760), (761, 771), (771, 773), (774, 781), (782, 785), (786, 791), (791, 792), (792, 793), (794, 801), (802, 804), (804, 807), (808, 813), (814, 817), (818, 821), (822, 827), (828, 831), (832, 837), (837, 841), (841, 842), (843, 844), (843, 844), (845, 848), (848, 851), (852, 859), (860, 865), (865, 869), (869, 870), (871, 877), (877, 878), (879, 882), (883, 888), (888, 890), (890, 893), (894, 896), (897, 908), (909, 913), (914, 916), (917, 920), (921, 926), (927, 929), (930, 939), (939, 941), (942, 947), (948, 950), (950, 953), (954, 956), (956, 961), (961, 963), (963, 970), (971, 972), (971, 972), (973, 977), (978, 979), (978, 982), (983, 988), (989, 992), (993, 995), (996, 1002), (1002, 1004), (1005, 1010), (1011, 1014), (1015, 1017), (1018, 1021), (1021, 1024), (1025, 1032), (1033, 1039), (1040, 1043), (1043, 1044), (1045, 1047), (1048, 1052), (1053, 1054), (1055, 1059), (1060, 1065), (1066, 1073), (1074, 1077), (1078, 1086), (1087, 1091), (1092, 1095), (1096, 1099), (1099, 1102), (1103, 1105), (1106, 1110), (1110, 1112), (1113, 1122), (1123, 1128), (1129, 1138), (1138, 1139), (1140, 1142), (1143, 1144), (1145, 1146), (1145, 1150), (1151, 1152), (1153, 1154), (1154, 1156), (1156, 1159), (1160, 1171), (1172, 1177), (1178, 1180), (1181, 1186), (1187, 1188), (1187, 1188), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]

In [163]:
y = tokenizer([
    "2003", "(2003)", "2003,", "2003.", "2003!", "2003?", "$2003", "-2003", "buddha", "buddha's"
], add_special_tokens=False)
y["input_ids"]

[[6052],
 [119845],
 [88847],
 [46730],
 [6052, 38],
 [6052, 32],
 [3650, 34498],
 [20, 34498],
 [177155, 11],
 [177155, 11, 25, 7]]

In [168]:
tmp = "in 1958 , korolev upgraded the r - 7 to be able to launch a 400 - kilogram 880 lb payload to the moon . three"
tmp.rfind("three")

104

In [177]:
df.iloc[15928]

id                                                                                                                                                                                                                                                                                                                                                                                                                             56e0f816231d4119001ac511
title                                                                                                                                                                                                                                                                                                                                                                                                                                        Space_Race
question                                                                                                                

In [112]:
ds = squad.Dataset(x)
tmp = ds[0]
shape = (len(ds), tmp['input_ids'].size()[0])
print(f"ds.shape={shape}, device={tmp['input_ids'].device}\n{tmp}")

ds.shape=(131165, 512), device=cpu
{'input_ids': tensor([     0,   3229,   6777,    186,   9480,    329,   4034, 141753,   5700,
            32,      2,      2,    186,   9480,  11676,    706,  35788,   3714,
          1577,     20,  17734,     42,    248,    333, 233844, 160548,    170,
        247085,  29319, 243547,    248,    186,     13,     20, 134181,     20,
          5154, 103122,  11691,    201,      6,      4,  26771,     83,    142,
         60680,   5367,     56,      6,      4,  11531,  70035,      6,      4,
         17164, 108558,    136, 215542,      6,      5, 103122,    136, 165249,
            23,  19660,  19386,      6,      4,    120,  53499,      6,      4,
          2412,  51339,    297,     23,  67842,   5367,    214,    136,    123,
         21896, 130412,      7,    237,     10,  29041,      6,      4,    136,
         49175,     47,  65536,     23,     70,  72399,  11704,      7,    237,
         37105,   5367,     56,    111,   1690,   1230,    275,  23040,

In [None]:
def _callbacks(patience: int, monitor: str = "val_loss", verbose: bool = True):
    return [
        pl.callbacks.EarlyStopping(
            monitor=monitor, patience=patience, verbose=verbose
        ),
        pl.callbacks.ModelCheckpoint(
            monitor=monitor, verbose=verbose, save_top_k=1
        ),
        pl.callbacks.LearningRateMonitor(logging_interval="epoch")
    ]

In [None]:
class MyObjective:
    def __init__(
        self,
        ds,
        splitter,
        epochs: int,
        batch_size: int,
        patience: int,
        job_dir: str,
        pretrained_dir: str,
        lr: Tuple[float, float],
        do_swa: bool,
        swa_anneal_epochs: Tuple[int, int],
        swa_start_epoch: Tuple[int, int],
        gradient_checkpointing: bool,
        score_threshold: float,
        gpus: List[int],
        accelerator: str,
        stratify_labels: List[int],
    ):
        self.ds = ds
        self.splitter = splitter
        self.epochs = epochs
        self.batch_size = batch_size
        self.patience = patience
        self.job_dir = job_dir
        self.pretrained_dir = pretrained_dir
        self.lr = lr
        self.do_swa = do_swa
        self.swa_anneal_epochs = swa_anneal_epochs
        self.swa_start_epoch = swa_start_epoch
        self.gradient_checkpointing = gradient_checkpointing
        self.gpus = gpus
        self.accelerator = accelerator
        self.history: List[Dict[str, Union[str, int, float]]] = []
        self.score_threshold = score_threshold
        self.stratify_labels = stratify_labels

    def __call__(self, trial):
        hist = {
            "trial_id": trial.number,
            "lr": trial.suggest_loguniform(
                "lr", self.lr[0], self.lr[1]
            ),
        }
        if self.do_swa:
            hist["swa_anneal_epochs"] = trial.suggest_int(
                "swa_anneal_epochs", self.swa_anneal_epochs[0], self.swa_anneal_epochs[1]
            )
            hist["swa_start_epoch"] = trial.suggest_int(
                "swa_start_epoch", self.swa_start_epoch[0], self.swa_start_epoch[1]
            )
        trial_id = hist['trial_id']
        scores = []
        max_epochs = 0
        dummy = np.zeros(len(self.ds))
        for fold, (ti, vi) in enumerate(self.splitter.split(dummy, self.stratify_labels)):
            directory = f"{self.job_dir}/trial_{trial_id}/fold_{fold}"
            train_ds = torch.utils.data.Subset(self.ds, ti)
            val_ds = torch.utils.data.Subset(self.ds, vi)
            swa_scheduler_params = None
            swa_start_epoch = -1
            if self.do_swa:
                swa_start_epoch = hist["swa_start_epoch"]
                swa_scheduler_params = {
                    "swa_lr": hist["lr"],
                    "anneal_epochs": hist["swa_anneal_epochs"],
                    "anneal_strategy": "cos"
                }
            model = squad.Model(
                pretrained_dir=self.pretrained_dir,
                gradient_checkpointing=self.gradient_checkpointing,
                lr=hist["lr"],
                scheduler_params={"T_max": self.epochs},
                swa_scheduler_params=swa_scheduler_params,
                swa_start_epoch=swa_start_epoch,
            )
            trainer = qa.HfTrainer(
                default_root_dir=directory,
                gpus=self.gpus,
                auto_select_gpus=True,
                accelerator=self.accelerator,
                max_epochs=self.epochs,
                callbacks=_callbacks(patience=self.patience),
                deterministic=True
            )
            trainer.fit(
                model,
                train_dataloader=DataLoader(
                    train_ds, batch_size=self.batch_size, shuffle=True, num_workers=num_workers),
                val_dataloaders=DataLoader(
                    val_ds, batch_size=self.batch_size, shuffle=False, num_workers=num_workers),
            )
            max_epochs = max(max_epochs, trainer.current_epoch + 1)
            start_logits = model.start_logits.cpu().numpy()
            end_logits = model.end_logits.cpu().numpy()
            print(f"start_logits.shape={start_logits.shape}, end_logits.shape={end_logits.shape}")
            y_pred_start = np.argmax(start_logits, axis=1)
            y_pred_end = np.argmax(end_logits, axis=1)
            tmp = []
            for i, j in enumerate(vi):
                score = qa.dice_coefficient(
                    true_start=self.ds[j]["start_positions"].item(), 
                    true_end=self.ds[j]["end_positions"].item(), 
                    pred_start=y_pred_start[i], 
                    pred_end=y_pred_end[i],
                )
                tmp.append(score)
            score = np.mean(tmp)
            print(f"score={score:.4f}, fold={fold}, trial={trial_id}")
            hist[f"fold_{fold}_score"] = score
            scores.append(score)
            del model, trainer, train_ds, val_ds
            gc.collect()
            if score <= self.score_threshold:
                break
        hist["max_epochs"] = max_epochs
        hist["score_mean"] = np.mean(scores)
        hist["score_std"] = np.std(scores)
        hist["score_worst"] = min(scores)
        self.history.append(hist)
        return hist["score_worst"]

In [None]:
ts = datetime.now().strftime('%Y%m%d_%H%M%S')
job_dir = f"../models/{model_name}/{ts}"
#job_dir = "tmp"
pathlib.Path(job_dir).mkdir(parents=True, exist_ok=True)
print(f"job_dir={job_dir}")

In [None]:
%%time
obj = MyObjective(
    ds=ds,
    splitter=sklearn.model_selection.StratifiedKFold(n_splits=n_splits, shuffle=True),
    epochs=epochs,
    batch_size=batch_size,
    patience=patience,
    job_dir=job_dir,
    pretrained_dir=pretrained_dir,
    lr=lr,
    do_swa=do_swa,
    swa_anneal_epochs=swa_anneal_epochs,
    swa_start_epoch=swa_start_epoch,
    gradient_checkpointing=gradient_checkpointing,
    score_threshold=min_fold_score,
    accelerator=accelerator,
    gpus=gpus,
    stratify_labels=is_impossible,
)
study = optuna.create_study(direction="maximize")  # f1 score
study.optimize(obj, n_trials=n_trials)

In [None]:
history = pd.DataFrame.from_records(obj.history)
history.sort_values("score_worst", ascending=False, inplace=True, ignore_index=True)
history.to_csv(f"{job_dir}/cv.csv", index=False)
history.head()

# Train final model on best Hps

In [None]:
best = history.iloc[0]
scheduler_params = {"T_max": epochs}
swa_scheduler_params = None
swa_start_epoch = -1
if do_swa:
    swa_start_epoch = int(best["swa_start_epoch"])
    swa_scheduler_params = {
        "swa_lr": best["lr"],
        "anneal_epochs": int(best["swa_anneal_epochs"]),
        "anneal_strategy": "cos"
    }
    scheduler_params = {"T_max": swa_start_epoch}
model = squad.Model(
    pretrained_dir=pretrained_dir,
    gradient_checkpointing=gradient_checkpointing,
    lr=best["lr"],
    scheduler_params=scheduler_params,
    swa_scheduler_params=swa_scheduler_params,
    swa_start_epoch=swa_start_epoch,
)
print(repr(model))

In [None]:
splitter = sklearn.model_selection.StratifiedKFold(n_splits=25, shuffle=True)
dummy = np.zeros(len(ds))
for ti, vi in splitter.split(dummy, is_impossible):
    train_ds = torch.utils.data.Subset(ds, ti)
    val_ds = torch.utils.data.Subset(ds, vi)
    break
print(f"len(train_ds)={len(train_ds)}, len(val_ds)={len(val_ds)}")

In [None]:
%%time
trainer = qa.HfTrainer(
    default_root_dir=job_dir,
    gpus=gpus,
    auto_select_gpus=True,
    accelerator=accelerator,
    max_epochs=epochs,
    callbacks=_callbacks(patience=patience),
    deterministic=True
)
trainer.fit(
    model,
    train_dataloader=DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers),
    val_dataloaders=DataLoader(
        val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers),
)