# Utils
Hyperparams from [code](https://github.com/fajri91/discourse_probing/tree/3f8c89c18a4eb217820667116ec17f6cec9b7e12/segment). Code mostly incorrect. Fetch hyperparams and params from [paper](https://aclanthology.org/2021.naacl-main.301.pdf) 

GUM only first

In [None]:
import os, json
import pandas as pd, numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl

from transformers import AutoTokenizer, AutoModel
from transformers import AdamW, get_linear_schedule_with_warmup

from datasets import Dataset, DatasetDict

from tqdm import tqdm
tqdm.pandas()

In [None]:
class DiscourseDataModule(pl.LightningDataModule):
    def __init__(self, mname, train_fl, dev_fl, test_fl):
        super().__init__()
        self.mname = mname
        self.tokenizer = AutoTokenizer.from_pretrained(mname, add_prefix_space=True)
        self.train_fl, self.validation_fl, self.test_fl = train_fl, dev_fl, test_fl
        self.train_batch_size = 8
        self.eval_batch_size = 32
        
    def load_dataset(self, fl):
        all_edus = []
        with open(fl) as f:
            dat = json.load(f)
            for k in range(len(dat)):
                all_edus.append(dat[str(k)])
        return Dataset.from_dict({"words": all_edus})
        
    def setup(self, stage: str):
        self.dataset = DatasetDict({"train": self.load_dataset(self.train_fl), \
                                    "validation": self.load_dataset(self.validation_fl), \
                                    "test": self.load_dataset(self.test_fl)
                                   })

        for split in self.dataset.keys():
            self.dataset[split] = self.dataset[split].map(
                self.convert_to_features,
                batched=True,
                remove_columns=['words']
            )
            self.dataset[split].set_format(type="torch")
        self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]

    def prepare_data(self):
        AutoTokenizer.from_pretrained(self.mname, add_prefix_space=True)

    def train_dataloader(self):
        return DataLoader(self.dataset["train"], batch_size=self.train_batch_size)

    def val_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset["validation"], batch_size=self.eval_batch_size)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

    def test_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset["test"], batch_size=self.eval_batch_size)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

    def convert_to_features(self, example_batch):
        features = self.tokenizer(example_batch["words"], is_split_into_words=True, \
                                  padding=True, truncation=False, add_special_tokens=False)
        tot_feats, tot_width = len(features["attention_mask"]), len(features["attention_mask"][0])
        labels = [[0 for _ in range(tot_width)] for _ in range(tot_feats)]
        
        for ft_ix in range(tot_feats):
            word_ids = features.word_ids(batch_index=ft_ix)
            #Last tokens of each discourse
            word_ids = word_ids[::-1]
            tot_words = len(example_batch["words"][ft_ix])
            
            for w_ix in range(tot_words-1): #Ignore last word. We have no business segmenting it.
                ix = len(word_ids)-word_ids.index(w_ix)-1
                labels[ft_ix][ix] = 1
        features["labels"] = labels
        return features

```python
import json
from tqdm import tqdm
with open("../input/edu-disrpt-2021-datasets-gum/gum_train_edus.json") as f:
    dat = json.load(f)
tot_edus = 0
tot_words = 0
for key in tqdm(range(len(dat))):
    tot_edus += len(dat[str(key)])
    tot_words += sum([True for edu in dat[str(key)] for word in edu.split()])
tot_edus/tot_words
```

0 wt: 0.14986802621410866
1 wt: 1 - 0 wt

In [None]:
BATCH_SIZE, N_EPOCHS = 32, 20
WEIGHT_DECAY, LR, ADAM_EPS = 0., 1e-5, 1e-8
pred_labels = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_weights = torch.Tensor([0.15, 0.85]).to(device)

class ShallowSegmentor(pl.LightningModule):
    def __init__(self, mname, num_layers=6, is_bart=False):
        super().__init__()
        model = AutoModel.from_pretrained(mname)
        if is_bart: #Bart-base ENCODER ONLY FRAMEWORK
            model.encoder.layers = model.encoder.layers[:num_layers]
            self.model = model.encoder
        else: #Roberta-base
            model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]])
            self.model = model
        self.linear = nn.Linear(self.model.config.hidden_size, 1)
        self.dropout = nn.Dropout(0.2)
        self.sigmoid = nn.Sigmoid()
        self.epoch_tracker = 0

    def forward(self, ip, mask_src):
        #Why is this frozen?!
        #with torch.no_grad():
        top_vec = self.model(ip, attention_mask=mask_src)
        top_vec = top_vec[0]
        vec = self.dropout(top_vec)
        vec = self.linear(vec).squeeze()
        return self.sigmoid(vec) * mask_src

    def training_step(self, batch, batch_idx):
        src, label, mask_src = batch["input_ids"], batch["labels"], batch["attention_mask"]
        output = self.forward(src, mask_src)
        
        batch_weights = torch.gather(loss_weights.unsqueeze(0).repeat_interleave(len(label), dim=0), -1, label)
        loss_fn = torch.nn.BCELoss(weight=batch_weights, reduction='sum')
        loss = loss_fn(output, label.float())
        self.log("train/loss", loss.item())
        return loss
    
    def _transform_predictions(self, tensor, mask_src):
        results = []
        array = tensor.data.cpu().numpy()
        m = mask_src.sum(dim=-1).data.cpu().numpy()
        for idx in range(len(array)):
            cur_arr = array[idx][:m[idx]]
            now=0
            result=[]
            for idy in range(len(cur_arr)):
                if cur_arr[idy]==1:
                    result.append((now, idy))
                    now=idy
                if idy == len(cur_arr)-1 and now!=idy:
                    result.append((now, idy))
            results.append(result)
        return results
        
    def validation_step(self, batch, batch_idx):
        src, label, mask_src = batch["input_ids"], batch["labels"], batch["attention_mask"]
        output = self.forward(src, mask_src)

        batch_weights = torch.gather(loss_weights.unsqueeze(0).repeat_interleave(len(label), dim=0), -1, label)
        loss_fn = torch.nn.BCELoss(weight=batch_weights, reduction='sum')
        loss = loss_fn(output, label.float())
        self.log("val/loss", loss.item())
        
        output = (output > 0.5) + 0
        predictions = self._transform_predictions(output, mask_src)
        answers = self._transform_predictions(label, mask_src)
        f1s = []
        for prediction, answer in zip(predictions, answers):
            count = 0
            for item in prediction:
                if item in answer:
                    count+=1
            #F1 is either overlap fraction
            #or 1 if both pred and answers are empty lists
            #or 0 if answer is  empty, but preds is not.
            f1 = float(2.0*count/(len(prediction)+len(answer))) if  bool(prediction or answer) else float(not bool(prediction))
            f1s.append(f1)
        return answers, predictions, f1s

    def validation_epoch_end(self, validation_step_outputs):
        global pred_labels
        mean_f1s = 0.
        tot_f1s = 0
        for _, _, f1s in validation_step_outputs:
            mean_f1s += sum(f1s)
            tot_f1s += len(f1s)
        mean_f1 = mean_f1s/tot_f1s
        self.log("val/f1", mean_f1)
        
        #Show first 10 elements of first batch only
        
        answers, preds, _ = validation_step_outputs[0]
        answers, preds = answers[:10], preds[:10]
        pred_labels.extend([(self.epoch_tracker, str(gt), str(pr)) for gt, pr in zip(answers, preds)])
        self.epoch_tracker += 1
    
    def configure_optimizers(self):
        no_decay = ["bias", "LayerNorm.weight"]
        t_total = N_EPOCHS * len(gdm_data.dataset["train"]) // BATCH_SIZE
        warmup_steps = int(0.1 * t_total)
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": WEIGHT_DECAY,
            },
            {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=LR, eps=ADAM_EPS)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
        )
        return [optimizer], [scheduler]

In [None]:
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from kaggle_secrets import UserSecretsClient
import wandb
user_secrets = UserSecretsClient()
wandbkey = user_secrets.get_secret("wandbkey")
wandb.login(key=wandbkey)

model_names = {"bartb": "facebook/bart-base", "robb": "roberta-base"}
AVAIL_GPUS = min(1, torch.cuda.device_count())

In [None]:
#TODO: Save every model Done
#TODO: Grad clip norm Done
#TODO: Wandb runs per train Done
#TODO: Eval scores Done (Sorta?!)
for mkey, mname in model_names.items():
    for num_layers in range(1, 7):
        local_name = f"{mkey}_{num_layers}layers"
        wandb_logger = WandbLogger(name=local_name, project="kaggle_feedback", log_model=True, \
                                   notes="jdoesv/edu-segmentation-models-on-gum; V11", 
                                  tags=["gum", "disc_segment", "smaller_layers"])
        
        checkpoint_callback = ModelCheckpoint(monitor='val/loss', dirpath=local_name, \
                                              filename='epoch{epoch:02d}-val_loss{val/loss:.2f}',
                                              auto_insert_metric_name=False,
                                              #every_n_epochs=1,
                                              #save_top_k=N_EPOCHS
                                             )
        early_stop_callback = EarlyStopping(monitor="val/f1", patience=5, verbose=False, mode="max")

        pl.seed_everything(208)
        
        gdm_data = DiscourseDataModule(mname, \
                          "../input/edu-disrpt-2021-datasets-gum/gum_train_edus.json", \
                         "../input/edu-disrpt-2021-datasets-gum/gum_dev_edus.json", \
                         "../input/edu-disrpt-2021-datasets-gum/gum_test_edus.json")
        gdm_data.prepare_data()
        gdm_data.setup("fit")
        
        
        segmentor = ShallowSegmentor(mname, num_layers=num_layers, is_bart="bart" in mkey) 
        wandb_logger.watch(segmentor)
        trainer = pl.Trainer(logger=wandb_logger, \
                             max_epochs=N_EPOCHS, \
                             gpus=AVAIL_GPUS, \
                             callbacks=[checkpoint_callback, early_stop_callback], \
                             gradient_clip_val=1.)
        trainer.fit(segmentor, gdm_data)
        
        columns = ["epoch_num", "gt", "pred"]
        wandb_logger.log_text(key="val/samples", columns=columns, data=pred_labels)
        pred_labels = []
        wandb.finish()    
