In [None]:
from collections import Counter, defaultdict
import re
import json
import pickle
from io import StringIO
from tqdm import tqdm
from pathlib import Path

import numpy as np
import pandas as pd

import seaborn as sns
from matplotlib import pyplot as plt

In [None]:
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader, random_split

# torch.manual_seed(42)


In [None]:
torch.set_default_dtype(torch.float64)

### model

In [None]:
class MyReg(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 1),
            nn.ReLU(),
            # nn.LogSoftmax(),
        )

    def forward(self, x):
        return self.layers(x)


In [None]:
class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=.001, max_beta=.1):

        self.tolerance = tolerance
        self.min_delta = min_delta
        self.max_beta = max_beta

        self.counter = 0
        self.early_stop = False

    def __call__(self, val_losses, train_loss=None):

        if len(val_losses) < 2:
            return

        curr, prev = val_losses[-1], val_losses[-2]
        is_delta = (
            (prev - curr) < self.min_delta
        )
        is_beta = (
            (curr - train_loss) > self.max_beta
        )

        if is_beta or is_delta:
            self.counter += 1

        if self.counter >= self.tolerance:
            self.early_stop = True


In [None]:
class MyTrainer:
    def __init__(self, dataset, input_size, batch_size=32) -> None:
        self.train_losses = list()
        self.val_losses = list()
        self.batch_size = batch_size
        self.load_data(dataset)

        self.model = MyReg(input_size=input_size)
        # self.lossf = nn.CrossEntropyLoss()
        self.lossf = nn.L1Loss()
        self.optif = AdamW(
            self.model.parameters(),
            lr=1e-3,
        )
        # self.optif = torch.optim.SGD(
        #     self.model.parameters(),
        #     lr=1e-3,
        #     momentum=0.9,
        # )

    def load_data(self, dataset):
        train_dt, val_dt, test_dt = random_split(dataset, [0.6, 0.2, 0.2])
        self._train_dl = DataLoader(
            train_dt,
            batch_size=self.batch_size,
            shuffle=True,
        )

        self._val_dl = DataLoader(val_dt, shuffle=True)
        self._test_dl = DataLoader(test_dt, shuffle=True)

    def _train_one(self):
        epoch_loss = 0.0
        for _, data in enumerate(self._train_dl):
            inputs, targets = data
            outputs = self.model(inputs)

            self.optif.zero_grad()
            batch_loss = self.lossf(outputs[:, 0], targets)
            batch_loss.backward()
            self.optif.step()

            epoch_loss += batch_loss.item()
        return epoch_loss / len(self._train_dl)

    def _val_one(self):
        epoch_loss = 0.0
        for _, data in enumerate(self._val_dl):
            inputs, target = data
            outputs = self.model(inputs)

            loss = self.lossf(outputs[:, 0], target)
            epoch_loss += loss.item()
        return epoch_loss / len(self._val_dl)

    def fit(self, epochs=1, early_stopping=None):
        for epoch in range(epochs):
            t_loss = self._train_one()
            self.train_losses.append(t_loss)

            self.model.eval()

            v_loss = self._val_one()
            self.val_losses.append(v_loss)
            print(f"validation loss {v_loss:.3f} at epoch {epoch}")

            # early stopping
            if early_stopping is not None:
                early_stopping(self.val_losses, t_loss)
                if early_stopping.early_stop:
                    print(f"early stopping...")
                    break

    def gasp(self):
        return [(self.model(inputs).item(), target.item()) for inputs, target in self._test_dl]

In [None]:
def metrics(predictions, labels):

    yy = list(zip(predictions, labels))

    P_TRUE = sum(labels)
    P_HAT = sum(predictions)

    TP = yy.count((1, 1))
    TN = yy.count((0, 0))
    N = len(yy)

    precision = TP / P_HAT
    recall = TP / P_TRUE
    F1 = 2 * (precision * recall) / (precision + recall)
    accuracy = (TP + TN) / N

    print(f"F1: {F1:.2f}")
    print(f"Precision: {precision:.2f}")
    print(f"Recall: {recall:.2f}")
    print(f"Accuracy: {accuracy:.2f}")


### data

Loading and mergin dataset

In [None]:
data_text = ""

for filename in sorted(["ds", "ts_hs", "ts_ht"]):
    with open(
        Path("datasets", "task_0", f"{filename}.tsv"), "rt", encoding="utf8"
    ) as f:
        data_text += f.read()

df = pd.read_csv(StringIO(data_text), sep="\t")
df = df.drop_duplicates().reset_index(names="old_idx").reset_index(names="new_idx")


In [None]:
df.label = df.label.astype(int)

In [None]:
import liwc
to_liwc, categories = liwc.load_token_parser('dic/LIWC2007_English080730.dic')
K = len(categories)

kat_lookup = dict(zip(categories, range(K)))

A custom Dataset class must implement three functions: __init__, __len__, and __getitem__. 

### lex

In [None]:
class LexDataset(Dataset):
    def __init__(self, df, x_col="pp_text", y_col="label"):
        self.df = df
        self.x_col = x_col
        self.y_col = y_col

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        counter = {k: 0 for k in categories}
        for word in self.df.loc[idx, self.x_col].split():
            for k in list(to_liwc(word)):
                counter[k] += 1
        X = (
            np.array([counter[k] for k in categories]) / sum(counter.values())
            if sum(counter.values())
            else np.zeros(K)
        )
        # y = np.zeros(2)
        # y[self.df.loc[idx, self.y_col]] = 1
        y = self.df.loc[idx, self.y_col]

        return X, y


In [None]:
ds = LexDataset(df)

In [None]:
trainer = MyTrainer(ds, 64)
trainer.fit(
    epochs=100,
    early_stopping=EarlyStopping(
        tolerance=4,
        min_delta=0.0,
        max_beta=0.1,
    ),
)


In [None]:
sns.lineplot(trainer.train_losses)
sns.lineplot(trainer.val_losses)

In [None]:
weights = trainer.model.layers[0].weight[0].detach().numpy()

In [None]:
y_hat, y_true = list(zip(*trainer.gasp()))
y_hat = [1 if e>0.5 else 0 for e in y_hat]

In [None]:
metrics(y_hat, y_true)

### RoBERTa

In [None]:
import torch
torch.cuda.get_device_name(0)

In [None]:
from transformers import (
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    DataCollatorWithPadding,
    AutoModelForSequenceClassification,
)
from datasets import load_dataset
import datasets

In [None]:
tokenizer_hf = AutoTokenizer.from_pretrained("ShreyaR/finetuned-roberta-depression")
model = AutoModelForSequenceClassification.from_pretrained("ranieri-unimi/test-trainer").to("cuda")

In [None]:
TH = tokenizer_hf.batch_encode_plus(df.pp_text.to_list(), return_tensors='pt', padding=True, truncation=True).to("cuda")
y = torch.tensor(df.label).to("cuda")

In [None]:
X = None
BATCH_SIZE = 16

with torch.no_grad():
    for i in range(0, len(y), BATCH_SIZE):
        input_ids = TH.input_ids[i : i + BATCH_SIZE, :]
        result = model(input_ids, output_hidden_states=True)
        cls_batch = result.hidden_states[-1][0, 0, :]
        try:
            X = torch.cat((X, cls_batch), 0)
        except:
            X = cls_batch


In [None]:
class RobDataset(Dataset):
    def __init__(self, X_list, y_list):
        self.X_list = X_list
        self.y_list = y_list

    def __len__(self):
        return len(self.y_list)

    def __getitem__(self, idx):
        return self.X_list[idx], self.y_list[idx]

In [None]:
ds = RobDataset(X, y)

In [None]:
trainer = MyTrainer(ds, 768)
trainer.fit(
    epochs=100,
    early_stopping=EarlyStopping(
        tolerance=4,
        min_delta=0.0,
        max_beta=0.1,
    ),
)

In [None]:
sns.lineplot(trainer.train_losses)
sns.lineplot(trainer.val_losses)

In [None]:
rob_weights = trainer.model.layers[0].weight[0].detach().numpy()

In [None]:
y_hat, y_true = list(zip(*trainer.gasp()))
y_hat = [1 if e>0.5 else 0 for e in y_hat]

In [None]:
metrics(y_hat, y_true)

In [None]:
rob_weights.shape, weights.shape