In [2]:
import os
import pathlib
import argparse
import collections

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as tF
import pytorch_lightning as pl
import tokenizers
import datasets


DEBUG_RUN = False

In [3]:
class HFDataset(torch.utils.data.Dataset):
    def __init__(self, hfdf):
        self.hfdf = hfdf

    def __getitem__(self, idx):
        return self.hfdf[idx]

    def __len__(self):
        return len(self.hfdf)

In [8]:
class LitSegmenterBaseline(pl.LightningModule):
    def __init__(
        self,
        hidden_size: int,
        tokenizer_uri: str,
        dataset_uri: str,
        batch_size: int,
        num_layers: int = 1,
        bidirectional: bool = True,
        num_classes: int = 4,
        pad_token: str = "[PAD]",
    ):
        super(LitSegmenterBaseline, self).__init__()

        self.tokenizer = tokenizers.Tokenizer.from_file(tokenizer_uri)

        self.batch_size = batch_size
        self.pad_id = self.tokenizer.get_vocab().get(pad_token, 0)

        def fn_pad_sequences(batch):
            X = [torch.tensor(x_i["input_ids"], dtype=torch.int) for x_i in batch]
            y = [torch.tensor(y_i["labels"]) for y_i in batch]

            X = nn.utils.rnn.pad_sequence(X, padding_value=self.pad_id, batch_first=True)
            y = nn.utils.rnn.pad_sequence(y, padding_value=-100, batch_first=True)

            return X, y

        self.fn_pad_sequences = fn_pad_sequences

        if isinstance(dataset_uri, str):
            self.hfdf = datasets.load_from_disk(dataset_uri)
            
        else:
            dfs = []
            for uri in dataset_uri:
                dfs.append(datasets.load_from_disk(uri))
            
            hfdf = {}
            for key in dfs[0].keys():
                hfdf[key] = datasets.concatenate_datasets([df[key] for df in dfs])
            
            self.hfdf = datasets.DatasetDict(hfdf)

        print(self.hfdf)
            
        self.embeddings = nn.Embedding(
            num_embeddings=self.tokenizer.get_vocab_size(),
            embedding_dim=768,
            padding_idx=self.pad_id,
        )

        self.lstm = nn.LSTM(
            input_size=768,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.0 if num_layers == 1 else 0.1,
            bidirectional=bidirectional,
            proj_size=0,
        )

        self.lin_out = nn.Linear(
            (1 + int(bidirectional)) * hidden_size,
            num_classes,
        )

    def forward(self, X):
        out = X

        if isinstance(out, str):
            out = self.tokenizer(out, return_tensors="pt")
            out = out["input_ids"]

        out = self.embeddings(out)
        out, *_ = self.lstm(out)
        out = self.lin_out(out)

        return out

    @staticmethod
    def _compute_pred_metrics(y_preds, y, phase: str) -> dict[str, float]:
        y_preds = y_preds.view(-1, y_preds.shape[-1])
        y = y.view(-1).squeeze()

        loss = F.cross_entropy(input=y_preds, target=y, ignore_index=-100)

        non_pad_inds = [i for i, cls_i in enumerate(y) if cls_i != -100]

        per_cls_recall = tF.recall(
            preds=y_preds[non_pad_inds, ...],
            target=y[non_pad_inds],
            num_classes=4,
            average=None,
        )

        per_cls_precision = tF.precision(
            preds=y_preds[non_pad_inds, ...],
            target=y[non_pad_inds],
            num_classes=4,
            average=None,
        )

        macro_precision = float(per_cls_precision.mean().item())
        macro_recall = float(per_cls_recall.mean().item())
        macro_f1_score = (
            2.0 * macro_precision * macro_recall / (1e-8 + macro_precision + macro_recall)
        )

        out = {
            f"{(phase + '_') if phase != 'train' else ''}loss": loss,
            **{f"{phase}_cls_{i}_precision": float(val) for i, val in enumerate(per_cls_precision)},
            **{f"{phase}_cls_{i}_recall": float(val) for i, val in enumerate(per_cls_recall)},
            f"{phase}_macro_precision": macro_precision,
            f"{phase}_macro_recall": macro_recall,
            f"{phase}_macro_f1_score": macro_f1_score,
        }

        return out

    @staticmethod
    def _agg_stats(step_outputs):
        out = {}
        agg_items = collections.defaultdict(list)

        for items in step_outputs:
            for key, val in items.items():
                if not isinstance(val, torch.Tensor):
                    val = torch.tensor(val)

                agg_items[key].append(val)

        for key, vals in agg_items.items():
            avg_vals = float(torch.stack(vals).mean().item())
            out[f"avg_{key}"] = avg_vals

        return out

    def training_step(self, batch, batch_idx: int):
        X, y = batch
        y_preds = self.forward(X)

        out = self._compute_pred_metrics(y_preds, y, phase="train")

        self.log_dict(
            out,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

        return out

    def training_epoch_end(self, training_step_outputs):
        out = self._agg_stats(training_step_outputs)

        self.log_dict(
            out,
            on_step=False,
            on_epoch=True,
            logger=True,
        )

#     def validation_step(self, batch, batch_idx: int):
#         X, y = batch
#         y_preds = self.forward(X)

#         out = self._compute_pred_metrics(y_preds, y, phase="val")

#         self.log_dict(
#             out,
#             on_step=False,
#             on_epoch=True,
#             logger=True,
#         )

#         return out

#     def validation_epoch_end(self, validation_step_outputs):
#         out = self._agg_stats(validation_step_outputs)

#         self.log_dict(
#             out,
#             on_step=False,
#             on_epoch=True,
#             logger=True,
#         )

    def test_step(self, batch, batch_idx: int):
        X, y = batch
        y_preds = self.forward(X)

        out = self._compute_pred_metrics(y_preds, y, phase="test")

        self.log_dict(
            out,
            on_step=False,
            on_epoch=True,
            logger=True,
        )

        return out

    def test_epoch_end(self, test_step_outputs):
        out = self._agg_stats(test_step_outputs)

        self.log_dict(
            out,
            on_step=False,
            on_epoch=True,
            logger=True,
        )

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=5e-4)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.75)
        return [optimizer], [lr_scheduler]

    def train_dataloader(self):
        df_train = HFDataset(self.hfdf["train"])

        train_dataloader = torch.utils.data.DataLoader(
            dataset=df_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=8,
            collate_fn=self.fn_pad_sequences,
        )

        return train_dataloader

#     def val_dataloader(self):
#         df_eval = HFDataset(self.hfdf["eval"])

#         eval_dataloader = torch.utils.data.DataLoader(
#             dataset=df_eval,
#             batch_size=self.batch_size,
#             shuffle=False,
#             num_workers=8,
#             collate_fn=self.fn_pad_sequences,
#         )

#         return eval_dataloader

    def test_dataloader(self):
        df_test = HFDataset(self.hfdf["test"])

        test_dataloader = torch.utils.data.DataLoader(
            dataset=df_test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=8,
            collate_fn=self.fn_pad_sequences,
        )

        return test_dataloader

In [12]:
def main(args):
    configs = [
#         (512, 32),
        (256, 32),
#         (128, 64),
#         (64, 64),
#         (32, 64),
    ]

    for hidden_size, batch_size in configs:
        accumulate_grad_batches = 128 // batch_size

        model = LitSegmenterBaseline(
            hidden_size=hidden_size,
            batch_size=batch_size,
            tokenizer_uri="../tokenizers/6000_subwords/tokenizer.json",
            dataset_uri=[
                "./final_curated_dataset_for_training",
                "../data/refined_datasets/ccjs_segmentados_train_test_splits/ccjs_segmentados_train_test_splits_curados",
                "../data/refined_datasets/emendas_variadas_segmentadas_train_test_splits/emendas_variadas_train_test_splits",
            ],
        )

        trainer = pl.Trainer.from_argparse_args(
            args,
            overfit_batches=0.001 if DEBUG_RUN else 0.0,
            accumulate_grad_batches=accumulate_grad_batches,
        )

        trainer.fit(model)

        if not DEBUG_RUN:
            trainer.test()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)

    args = parser.parse_args(
        """
        --gpu 1
        --max_epochs 10
        --log_every_n_steps 1000
        --precision 32
    """.split()
    )

    main(args)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type      | Params
-----------------------------------------
0 | embeddings | Embedding | 4.6 M 
1 | lstm       | LSTM      | 2.1 M 
2 | lin_out    | Linear    | 2.1 K 
-----------------------------------------
6.7 M     Trainable params
0         Non-trainable params
6.7 M     Total params
26.845    Total estimated model params size (MB)


DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 159808
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2602
    })
})


Training: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
Restoring states from the checkpoint path at /media/nvme/segmentador/notebooks/lightning_logs/version_15/checkpoints/epoch=9-step=12489.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /media/nvme/segmentador/notebooks/lightning_logs/version_15/checkpoints/epoch=9-step=12489.ckpt


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_cls_0_precision': 0.998159646987915,
 'avg_test_cls_0_recall': 0.999201774597168,
 'avg_test_cls_1_precision': 0.9796888828277588,
 'avg_test_cls_1_recall': 0.9497243165969849,
 'avg_test_cls_2_precision': 0.9082029461860657,
 'avg_test_cls_2_recall': 0.8206210136413574,
 'avg_test_cls_3_precision': 0.6968904733657837,
 'avg_test_cls_3_recall': 0.5787816643714905,
 'avg_test_loss': 0.01663857512176037,
 'avg_test_macro_f1_score': 0.8640515208244324,
 'avg_test_macro_precision': 0.8957353830337524,
 'avg_test_macro_recall': 0.8370822072029114,
 'test_cls_0_precision': 0.998175323009491,
 'test_cls_0_recall': 0.9992001056671143,
 'test_cls_1_precision': 0.9800203442573547,
 'test_cls_1_recall': 0.95093834400177,
 'test_cls_2_precision': 0.9074270725250244,
 'test_cls_2_recall': 0.8203122019767761,
 'test_cls_3_precision': 0.6964413523674011,
 'test_cls_3_recall': 0.580051

# Results

## 512 hidden units (lr_scheduler etc)
{'avg_test_cls_0_precision': 0.9982964396476746,
 'avg_test_cls_0_recall': 0.9993011951446533,
 'avg_test_cls_1_precision': 0.9845898747444153,
 'avg_test_cls_1_recall': 0.9579233527183533,
 'avg_test_cls_2_precision': 0.9040074944496155,
 'avg_test_cls_2_recall': 0.8204734921455383,
 'avg_test_cls_3_precision': 0.6785829663276672,
 'avg_test_cls_3_recall': 0.5864138007164001,
 'avg_test_loss': 0.01806369051337242,
 'avg_test_macro_f1_score': 0.8640120029449463,
 'avg_test_macro_precision': 0.8913691639900208,
 'avg_test_macro_recall': 0.8410279750823975,
 'test_cls_0_precision': 0.9982959032058716,
 'test_cls_0_recall': 0.9993012547492981,
 'test_cls_1_precision': 0.9845767617225647,
 'test_cls_1_recall': 0.9579482078552246,
 'test_cls_2_precision': 0.903980016708374,
 'test_cls_2_recall': 0.820344090461731,
 'test_cls_3_precision': 0.6786746978759766,
 'test_cls_3_recall': 0.5857861638069153,
 'test_loss': 0.018070800229907036,
 'test_macro_f1_score': 0.8639229536056519,
 'test_macro_precision': 0.8913817405700684,
 'test_macro_recall': 0.8408451080322266}

## 512 hidden units (regular setup)
{'avg_test_cls_0_precision': 0.9982461929321289,
 'avg_test_cls_0_recall': 0.9994046688079834,
 'avg_test_cls_1_precision': 0.986571729183197,
 'avg_test_cls_1_recall': 0.9582132697105408,
 'avg_test_cls_2_precision': 0.9132260084152222,
 'avg_test_cls_2_recall': 0.8215959072113037,
 'avg_test_cls_3_precision': 0.7250827550888062,
 'avg_test_cls_3_recall': 0.535641610622406,
 'avg_test_loss': 0.010228903032839298,
 'avg_test_macro_f1_score': 0.8643761277198792,
 'avg_test_macro_precision': 0.905781626701355,
 'avg_test_macro_recall': 0.8287138938903809,
 'test_cls_0_precision': 0.9982464909553528,
 'test_cls_0_recall': 0.9994055032730103,
 'test_cls_1_precision': 0.9865930676460266,
 'test_cls_1_recall': 0.9582584500312805,
 'test_cls_2_precision': 0.9132982492446899,
 'test_cls_2_recall': 0.8214690089225769,
 'test_cls_3_precision': 0.7250284552574158,
 'test_cls_3_recall': 0.5351755023002625,
 'test_loss': 0.010232209227979183,
 'test_macro_f1_score': 0.8643065094947815,
 'test_macro_precision': 0.9057917594909668,
 'test_macro_recall': 0.8285772800445557}

## 256 hidden units (lr_scheduler)
{'avg_test_cls_0_precision': 0.9983095526695251,
 'avg_test_cls_0_recall': 0.9992923736572266,
 'avg_test_cls_1_precision': 0.9841660857200623,
 'avg_test_cls_1_recall': 0.9603431820869446,
 'avg_test_cls_2_precision': 0.906660258769989,
 'avg_test_cls_2_recall': 0.8225072026252747,
 'avg_test_cls_3_precision': 0.6957836151123047,
 'avg_test_cls_3_recall': 0.5835158228874207,
 'avg_test_loss': 0.01692836359143257,
 'avg_test_macro_f1_score': 0.8669429421424866,
 'avg_test_macro_precision': 0.8962297439575195,
 'avg_test_macro_recall': 0.8414146900177002,
 'test_cls_0_precision': 0.9983090758323669,
 'test_cls_0_recall': 0.9992936849594116,
 'test_cls_1_precision': 0.9841693043708801,
 'test_cls_1_recall': 0.9603489637374878,
 'test_cls_2_precision': 0.9068716764450073,
 'test_cls_2_recall': 0.822350025177002,
 'test_cls_3_precision': 0.6960282921791077,
 'test_cls_3_recall': 0.58326256275177,
 'test_loss': 0.016935868188738823,
 'test_macro_f1_score': 0.8669469952583313,
 'test_macro_precision': 0.8963446021080017,
 'test_macro_recall': 0.8413137197494507}

## 256 hidden units (regular setup)
{'avg_test_cls_0_precision': 0.9983312487602234,
 'avg_test_cls_0_recall': 0.9993693232536316,
 'avg_test_cls_1_precision': 0.9855096340179443,
 'avg_test_cls_1_recall': 0.9620329737663269,
 'avg_test_cls_2_precision': 0.9168782830238342,
 'avg_test_cls_2_recall': 0.8202478289604187,
 'avg_test_cls_3_precision': 0.7135547995567322,
 'avg_test_cls_3_recall': 0.5649248957633972,
 'avg_test_loss': 0.01007597055286169,
 'avg_test_macro_f1_score': 0.8680000305175781,
 'avg_test_macro_precision': 0.9035684466362,
 'avg_test_macro_recall': 0.8366436958312988,
 'test_cls_0_precision': 0.998330295085907,
 'test_cls_0_recall': 0.999370276927948,
 'test_cls_1_precision': 0.9855092763900757,
 'test_cls_1_recall': 0.9620238542556763,
 'test_cls_2_precision': 0.9171323776245117,
 'test_cls_2_recall': 0.8202000856399536,
 'test_cls_3_precision': 0.7137227058410645,
 'test_cls_3_recall': 0.5646311640739441,
 'test_loss': 0.010080181993544102,
 'test_macro_f1_score': 0.8680046200752258,
 'test_macro_precision': 0.9036736488342285,
 'test_macro_recall': 0.8365563750267029}


## 128 hidden units
{'avg_test_cls_0_precision': 0.99830561876297,
 'avg_test_cls_0_recall': 0.999434769153595,
 'avg_test_cls_1_precision': 0.9863963723182678,
 'avg_test_cls_1_recall': 0.960133969783783,
 'avg_test_cls_2_precision': 0.9214738011360168,
 'avg_test_cls_2_recall': 0.8192890286445618,
 'avg_test_cls_3_precision': 0.7188460826873779,
 'avg_test_cls_3_recall': 0.5316470861434937,
 'avg_test_loss': 0.008982779458165169,
 'avg_test_macro_f1_score': 0.8646918535232544,
 'avg_test_macro_precision': 0.9062554240226746,
 'avg_test_macro_recall': 0.8276262879371643,
 'test_cls_0_precision': 0.9983052611351013,
 'test_cls_0_recall': 0.9994353652000427,
 'test_cls_1_precision': 0.9863983988761902,
 'test_cls_1_recall': 0.9601314067840576,
 'test_cls_2_precision': 0.9216869473457336,
 'test_cls_2_recall': 0.8193143010139465,
 'test_cls_3_precision': 0.7187364101409912,
 'test_cls_3_recall': 0.5313534736633301,
 'test_loss': 0.008986305445432663,
 'test_macro_f1_score': 0.8646671175956726,
 'test_macro_precision': 0.9062817692756653,
 'test_macro_recall': 0.8275586366653442}
 
 ## 64 hidden units
 {'avg_test_cls_0_precision': 0.9981099367141724,
 'avg_test_cls_0_recall': 0.9995588660240173,
 'avg_test_cls_1_precision': 0.9872504472732544,
 'avg_test_cls_1_recall': 0.9569665193557739,
 'avg_test_cls_2_precision': 0.938145637512207,
 'avg_test_cls_2_recall': 0.7984744310379028,
 'avg_test_cls_3_precision': 0.8033229112625122,
 'avg_test_cls_3_recall': 0.4425460398197174,
 'avg_test_loss': 0.009166688658297062,
 'avg_test_macro_f1_score': 0.8598015904426575,
 'avg_test_macro_precision': 0.9317071437835693,
 'avg_test_macro_recall': 0.7993864417076111,
 'test_cls_0_precision': 0.9981100559234619,
 'test_cls_0_recall': 0.9995586276054382,
 'test_cls_1_precision': 0.9872363209724426,
 'test_cls_1_recall': 0.9569570422172546,
 'test_cls_2_precision': 0.938208818435669,
 'test_cls_2_recall': 0.7985380291938782,
 'test_cls_3_precision': 0.8028951287269592,
 'test_cls_3_recall': 0.4424935579299927,
 'test_loss': 0.00917029194533825,
 'test_macro_f1_score': 0.8597609400749207,
 'test_macro_precision': 0.931612491607666,
 'test_macro_recall': 0.799386739730835}
 
 ## 32 hidden units
 {'avg_test_cls_0_precision': 0.997940182685852,
 'avg_test_cls_0_recall': 0.9996035099029541,
 'avg_test_cls_1_precision': 0.9890697002410889,
 'avg_test_cls_1_recall': 0.9552433490753174,
 'avg_test_cls_2_precision': 0.9365561604499817,
 'avg_test_cls_2_recall': 0.785433292388916,
 'avg_test_cls_3_precision': 0.7959164381027222,
 'avg_test_cls_3_recall': 0.31288468837738037,
 'avg_test_loss': 0.009863470681011677,
 'avg_test_macro_f1_score': 0.8375892043113708,
 'avg_test_macro_precision': 0.92987060546875,
 'avg_test_macro_recall': 0.7632912397384644,
 'test_cls_0_precision': 0.9979405999183655,
 'test_cls_0_recall': 0.9996036887168884,
 'test_cls_1_precision': 0.9890552759170532,
 'test_cls_1_recall': 0.9552522897720337,
 'test_cls_2_precision': 0.9367883801460266,
 'test_cls_2_recall': 0.7855103015899658,
 'test_cls_3_precision': 0.7954726219177246,
 'test_cls_3_recall': 0.31284019351005554,
 'test_loss': 0.009864856489002705,
 'test_macro_f1_score': 0.8375713229179382,
 'test_macro_precision': 0.9298143982887268,
 'test_macro_recall': 0.7633016109466553}