In [140]:
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import torch
import torch.nn as nn
import csv
import numpy as np
import typing as t
import csv
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from sklearn.metrics import precision_recall_fscore_support
import json
import pandas as pd

In [141]:
device = torch.device("cpu")

if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")

device

device(type='mps')

In [142]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',)

In [143]:
class WikiData(Dataset):
  def __init__(self, csv_path:str = "", limit: int = 0):
    self.csv_path = csv_path
    self.limit = limit
    self.data = self._load_csv()

  def _load_csv(self):
    rows = []
    i = 0
    with open(self.csv_path, "r", encoding="utf-8") as f:
      reader = csv.reader(f, delimiter=",")
      for row in reader:
        if self.limit and i >= self.limit:
          break
        _id, label, text = row
        rows.append((_id, label, text))
        i += 1
    return rows

  def __len__(self):
    return len(self.data)
  
  def __getitem__(self, idx: int):
    _id, label, text = self.data[idx]
    label = 0 if label == "standard" else 1
    return text, label

In [144]:
class GPT_Dataset(Dataset):
    def __init__(self, db_path: str) -> None:
        self.db_path = db_path
        self.data = self._load_data()
    
    def _load_data(self):
        
        with open(self.db_path, "r", encoding="utf-8") as f:
            d = json.load(f)
            data = []
            data.extend(d["gpt_generation"])
            # for item in d["wikipedia"]["standard"]:
            #     data.append({**item, "level": "standard"})
            # for item in d["wikipedia"]["simple"]:
            #     data.append({**item, "level": "simple"})
            return data
        
    def _unpack(self, row: dict):
        return row["topic"], row["level"], row["text"]
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        row = self.data[idx]
        topic, level, text = self._unpack(row)
        return topic, level, text
    

In [145]:
class ReadabilityClassifier(nn.Module):
  def __init__(self,
               hidden_size: int = 126,
               n_lstm_layers: int = 1
               ):
    super().__init__()

    self.bert = BertModel.from_pretrained("bert-base-uncased")

    for param in self.bert.parameters():
      param.requires_grad = False

    self.lstm = nn.LSTM(
        input_size = 769,
        hidden_size = hidden_size,
        num_layers = n_lstm_layers,
        batch_first = True
    )
    self.linear = nn.Linear(hidden_size, 2)
    self.softmax = nn.LogSoftmax(dim=1)
  
  def forward(self, tokens):
    attention = tokens.attention_mask
    embedded = self.bert(**tokens).last_hidden_state
    attention = attention.reshape(embedded.shape[0], -1, 1)
    embedded = torch.cat((embedded, attention), dim=2)
    output, _ = self.lstm(embedded)
    output = output[:, -1, :]
    output = self.linear(output)
    sm = self.softmax(output)
    return sm


In [146]:
def train(
    model: ReadabilityClassifier, 
    criterion: nn.Module,
    dataloader: DataLoader,
    optimizer,
    n_epochs: int = 1,
):
  for epoch in range(n_epochs):
    loop = tqdm(dataloader)
    losses = []
    for texts, labels in loop:
      optimizer.zero_grad()
      tokens = tokenizer(
          texts, 
          return_tensors="pt", 
          padding=True, 
          truncation=True
      ).to(device)
      labels = labels.to(device)
      output = model(tokens)
      loss = criterion(output, labels)
      losses.append(loss.item())
      loss.backward()
      optimizer.step()
    print(f"Loss at epoch {epoch}: {round(sum(losses) / len(losses), 4)}")
    
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": sum(losses) / len(losses)
    }, f"checkpoints/epoch_{epoch}.tar")

In [147]:
def evaluate(model: nn.Module, dataloader: DataLoader):
  y_true = []
  y_pred = []
  with torch.no_grad():
    loop = tqdm(dataloader)
    for texts, levels in loop:
      tokens = tokenizer(
          texts, 
          return_tensors="pt", 
          padding=True, 
          truncation=True
      ).to(device)
      levels = levels.to(device)
      output = model(tokens)
      output = output.argmax(dim=1).cpu().detach().numpy()
      levels = levels.cpu().detach().numpy()
      y_true.extend(levels)
      y_pred.extend(output)
    precision, recall, fscore, _ = precision_recall_fscore_support(
        y_true = levels, 
        y_pred = output,
        labels=[0, 1],
    )
    print()
    print("Precision", precision)
    print("Recall", recall)
    print("Fscore", fscore)
    return y_true, y_pred
      

In [148]:
def score_gpt_data(model: ReadabilityClassifier, dataloader: DataLoader):
    scores = {}
    for topics, levels, texts in dataloader:
        tokens = tokenizer(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(device)
        output = model(tokens)
        for topic, level, score in zip(topics, levels, output):
            if topic not in scores:
                scores[topic] = {}
            if level not in scores[topic]:
                scores[topic][level] = []
            p_standard = torch.exp(score).cpu().detach().numpy()[0]
            scores[topic][level] = round(p_standard, 3)
    return scores

In [149]:
dst = WikiData("./distrib/dataset_train.csv")
print(len(dst))

train_dataloader = DataLoader(
    WikiData("./distrib/dataset_train.csv"),
    batch_size=50,
    shuffle=True
)

284154


In [150]:
model = ReadabilityClassifier()
model.to(device)
criterion = nn.NLLLoss()
optimizer = Adam(model.parameters(), lr=2e-5)
checkpoint = torch.load("checkpoints/epoch_2.tar", map_location=device)
# train(model, criterion, train_dataloader, optimizer, n_epochs=5)
model.load_state_dict(checkpoint["model_state_dict"])


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [151]:
# test_dataloader = DataLoader(
#     WikiData("./distrib/dataset_test.csv"),
#     batch_size=50,
# )

# y_true, y_pred = evaluate(model, test_dataloader)

In [152]:
gpt_dataloader = DataLoader(
    GPT_Dataset("./distrib/dataset.json"),
    batch_size=10
)

scores = score_gpt_data(model, gpt_dataloader)


In [161]:
index = list(scores.keys()) + ["Avg"]
columns = list(scores[index[0]].keys()) + ["Avg"]
df = pd.DataFrame(index=index, columns=columns)

for i in index:
    for j in columns:
        if i != "Avg" and j != "Avg":
            df.loc[i, j] = scores[i][j]
        
for i in index:
    df.loc[i, "Avg"] = round(df.loc[i, :].mean(), 3)

for j in columns:
    df.loc["Avg", j] = round(df.loc[:, j].mean(), 3)

df = df.round(3)

ltx = df.style.to_latex()
df

Unnamed: 0,A1,A2,B1,B2,C1,C2,Avg
Color Blindness,0.02,0.022,0.056,0.095,0.015,0.03,0.04
The Great Depression,0.014,0.128,0.027,0.89,0.342,0.461,0.31
Butterflies,0.016,0.027,0.014,0.041,0.021,0.281,0.067
Dogs,0.028,0.069,0.013,0.208,0.088,0.041,0.075
Semantics,0.014,0.144,0.013,0.541,0.475,0.262,0.242
The Internet,0.012,0.027,0.291,0.043,0.031,0.041,0.074
The Moon,0.022,0.073,0.026,0.029,0.053,0.812,0.169
Dinosaurs,0.019,0.015,0.016,0.272,0.052,0.384,0.126
Economics,0.031,0.012,0.018,0.344,0.222,0.697,0.221
Quantum Mechanics,0.011,0.024,0.121,0.465,0.313,0.628,0.26
