In [35]:
import os
import re
import numpy as np
import torch
import json
import pprint
import random

from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel

In [2]:
MAX_LEN = 512  # @param {type:"integer"}
TRAIN_BATCH_SIZE = 64  # @param {type:"integer"}
VALID_BATCH_SIZE = 32  # @param {type:"integer"}
EPOCHS = 1  # @param {type:"integer"}
LEARNING_RATE = 1e-5  # @param {type:"number"}
tokenizer = AutoTokenizer.from_pretrained("roberta-base")

Downloading config.json: 100%|██████████| 481/481 [00:00<00:00, 945kB/s]
Downloading vocab.json: 100%|██████████| 899k/899k [00:00<00:00, 3.00MB/s]
Downloading merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 2.24MB/s]
Downloading tokenizer.json: 100%|██████████| 1.36M/1.36M [00:00<00:00, 3.28MB/s]


In [16]:
texts = [
    ("Hello, my dog is cute", "positive"),
    ("Hello, my dog is cute", "positive"),
]
pprint.pprint(tokenizer.batch_encode_plus(
    texts,
    padding="max_length",
    max_length=MAX_LEN,
    add_special_tokens=True,
    truncation=True,
    return_attention_mask=True,
))

{'attention_mask': [[1,
                     1,
                     1,
                     1,
                     1,
                     1,
                     1,
                     1,
                     1,
                     1,
                     1,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                     0,
                

## PAN23

In [4]:
class PAN23Dataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.len = len(os.listdir(path))

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        file_path = os.path.join(self.path, f"{index}.json")
        with open(file_path, "r", encoding="utf-8") as f:
            return json.load(f)

In [17]:
class PAN23CollatorFn:
    def __init__(self, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __call__(self, batch):
        texts = [(item["text1"], item["text2"]) for item in batch]
        labels = [item["label"] for item in batch]

        encoding = self.tokenizer.batch_encode_plus(
            texts,
            padding="max_length",
            max_length=self.max_len,
            add_special_tokens=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt",
        )

        return dict(
            input_ids=encoding["input_ids"],
            attention_mask=encoding["attention_mask"],
            labels=torch.tensor(labels),
        )

In [9]:
train_dataset = PAN23Dataset("../data/pan23/transformed/pan23-task1-train")

In [18]:
train_data_loader = DataLoader(
    train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    collate_fn=PAN23CollatorFN(tokenizer, MAX_LEN),
)

In [20]:
for batch in train_data_loader:
    pprint.pprint(batch)
    break

{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]),
 'input_ids': tensor([[    0,   713,    16,  ...,     1,     1,     1],
        [    0,   170,   308,  ...,     1,     1,     1],
        [    0,  8346,     4,  ...,     1,     1,     1],
        ...,
        [    0,   894, 23079,  ...,     1,     1,     1],
        [    0,   100,  1266,  ...,     1,     1,     1],
        [    0,  5975,     6,  ...,     1,     1,     1]]),
 'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0,
        1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1])}


## Blogposts

In [26]:
path = "../data/blogposts"
post1, post2 = random.sample(os.listdir(path), 2)
print(post1, post2)

4942 2797


In [28]:
class BlogDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.len = len(os.listdir(path))

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        author_path = os.path.join(self.path, str(index))
        author_files = os.listdir(author_path)

        post1_ind, post2_ind = random.sample(author_files, 2)
        post1_path = os.path.join(author_path, post1_ind)
        post2_path = os.path.join(author_path, post2_ind)

        with open(post1_path, "r", encoding="utf-8") as f:
            post1 = f.read()
        with open(post2_path, "r", encoding="utf-8") as f:
            post2 = f.read()

        return {
            "post1": post1,
            "post2": post2,
        }

In [32]:
dataset = BlogDataset("../data/blogposts")
print(len(dataset))
print(dataset[0])

18536
{'post1': "             If you click on my profile you'll make a not-so-startling discovery...I was born in Year of the Pig, as they say in the Korean/Chinese calendar.  But blogger.com figured it would be more appropriate to call it Year of the Boar/bore...thanks guys.  Anyways, you may be wondering how a fat, lazy, smelly Canadian guy born in a little town waaaaaay up north finds himself in a place like Seoul...and Yeouido, no less, where only a handful of foreigners visit, let alone live and work.  The culprit is my wife.  She is Korean, as you may know, and when I was doing financial consulting in Canada we came across an interesting client.  He found us through his wife's (see a pattern here?  Korean women rule the men) reading my column in the Vancouver Chosun (Chosun is the name of the last dynasty to rule Korea--it was ended by the 1910-1945 Japanese occupation) which you'll find  urlLink here . It's all in Korean, my email and webpage have changed, though.  Anyways, this

In [37]:
class BlogCollatorFn:
    def __init__(self, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __call__(self, batch):
        posts = []
        for item in batch:
            posts.append(self.__clean_text(item["post1"]))
            posts.append(self.__clean_text(item["post2"]))

        encoding = self.tokenizer.batch_encode_plus(
            posts,
            padding="max_length",
            max_length=self.max_len,
            truncation=True,
            return_tensors="pt",
        )

        return dict(
            input_ids=encoding["input_ids"],
            attention_mask=encoding["attention_mask"],
        )

    def __clean_text(self, text):
        text = re.sub(r"\s+", " ", text)
        words = text.split()
        if len(words) > self.max_len:
            start = random.randint(0, len(words) - self.max_len)
            words = words[start:start + self.max_len]
        return " ".join(words)


In [38]:
data_loader = DataLoader(
    dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    collate_fn=BlogCollatorFn(tokenizer, MAX_LEN),
)

In [39]:
for batch in data_loader:
    pprint.pprint(batch)
    break

{'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]),
 'input_ids': tensor([[    0,   463, 13351,  ...,    20, 22780,     2],
        [    0,   100,   437,  ...,     1,     1,     1],
        [    0, 31414,     6,  ...,    78,   633,     2],
        ...,
        [    0, 21518, 15305,  ...,     1,     1,     1],
        [    0,   100,   437,  ...,     1,     1,     1],
        [    0,  5625,    21,  ...,     1,     1,     1]])}
