# Multi Author Writing Style Analysis
by: Noah Syrkis

In [None]:
import torch
from torch import nn, optim
import wandb
import os
import json
from collections import Counter, defaultdict
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
import warnings; warnings.simplefilter('ignore')

In [None]:
data_path = 'data/pan22/'
PAD = '<PAD>'  # id 0
UNK = '<UNK>'  # id 1

## Define functions

In [None]:
def get_files(dataset, split='train', data_path=data_path):
    folder_path = os.path.join(data_path, dataset, split)
    files = os.listdir(folder_path)
    files = make_pairs(files, folder_path)
    return files


def make_pairs(files, folder_path):
    # there are two files for each sample problem-id.txt and truth-problem-id.json
    # we want to pair them up
    pairs = []
    for f in files:
        if f.endswith('.txt'):
            truth_file = 'truth-' + f.replace('.txt', '.json')
            pairs.append((os.path.join(folder_path, f), os.path.join(folder_path, truth_file)))
    return pairs


def get_vocab(dataset, n_vocab=10000, split='train', data_path=data_path):
    files = get_files(dataset, split, data_path)
    freqs = Counter()
    for f in files:
        with open(f[0], 'r') as f:
            for line in f:
                freqs.update(line.strip().split())
    vocab = dict(freqs.most_common(n_vocab)).keys()
    vocab = [PAD, UNK] + list(vocab)
    idx2word = defaultdict(lambda: '<unk>', {i: w for i, w in enumerate(vocab)})
    word2idx = defaultdict(lambda: 1, {w: i for i, w in enumerate(vocab)})
    return word2idx, idx2word


def plot_seqs(seq1, seq2, title):
    # plot loss and log loss
    # black background
    plt.style.use('dark_background')
    _, axes = plt.subplots(1, 2, figsize=(15, 5))
    axes[0].plot(seq1, label='train', color='white')
    axes[0].plot(seq2, label='val', color='grey')
    axes[0].set_title('Loss', color='white')
    axes[0].legend()
    axes[1].plot(np.log(seq1), label='train', color='white')
    axes[1].plot(np.log(seq2), label='val', color='grey')
    axes[1].set_title('Log Loss', color='white')
    axes[1].legend()
    plt.suptitle(title)
    plt.show()

## Define dataset class

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, split='train', data_path=data_path, max_len=32, word2idx=None, idx2word=None):
        self.files = get_files(dataset, split, data_path)
        if word2idx is None or idx2word is None:
            self.word2idx, self.idx2word = get_vocab(dataset, split=split)
        else:
            self.word2idx = word2idx
            self.idx2word = idx2word
        self.pad = torch.nn.utils.rnn.pad_sequence
        self.max_len = max_len

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        text_file, truth_file = self.files[idx]
        with open(text_file, 'r') as f:
            text = f.readlines()
            text = self.preprocess(text)
        with open(truth_file, 'r') as f:
            truth = json.load(f)
            authors = torch.tensor(truth['authors']).reshape(1)
            changes = torch.tensor(truth['changes'])
            paragraph_authors = torch.tensor(truth['paragraph-authors'] )
        return text, authors.float(), changes.float(), paragraph_authors.float()

    # method for opening and processing the data
    def preprocess(self, text):
        ids_seq_list = [list(map(lambda x: self.word2idx[x], line.strip().split())) for line in text]
        # pad and truncate
        ids_seq_list = [line[:self.max_len] for line in ids_seq_list]
        # pad
        ids_seq_list = [line + [0] * (self.max_len - len(line)) for line in ids_seq_list]
        return torch.tensor(ids_seq_list)

## Define model class

In [None]:
class Model(torch.nn.Module):
    # samples are lists of lists of integers
    # target for now is just author count (int)
    # model first determines the number of authors
    # then determines at which paragraphs the authors change
    # then determines which authors are at which paragraphs
    # the model does not use batches of samples, but rather a single sample
    def __init__(self, params):
        super().__init__()
        self.embedding = nn.Embedding(params['vocab_size'], params['embedding_dim'])
        self.lstm = nn.LSTM(params['embedding_dim'], params['hidden_dims'][0])
        self.linear = nn.Linear(params['hidden_dims'][0], params['hidden_dims'][1])

        self.a_lstm = nn.LSTM(params['hidden_dims'][1], params['hidden_dims'][2])
        self.a_linear = nn.Linear(params['hidden_dims'][2], 1)

        self.c_lstm = nn.LSTM(params['hidden_dims'][1], params['hidden_dims'][2])
        self.c_linear = nn.Linear(params['hidden_dims'][2], 1)

        self.pa_lstm = nn.LSTM(params['hidden_dims'][1], params['hidden_dims'][2])
        self.pa_linear = nn.Linear(params['hidden_dims'][2], 1)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.linear(x)
        x = x[:, -1, :]
        a = self.authors(x)
        c = self.changes(x)
        pa = self.paragraph_authors(x, int(a))
        return a, c, pa

    def authors(self, x):
        x, _ = self.a_lstm(x)
        x = self.a_linear(x)
        return x[-1]
    
    def changes(self, x):
        x, _ = self.c_lstm(x)
        x = self.c_linear(x)
        return x[1:].reshape(-1)
    
    def paragraph_authors(self, x, n_authors):
        x, _ = self.pa_lstm(x)
        x = self.pa_linear(x)
        x = torch.softmax(x, dim=1)
        return x

## Define training loop

In [None]:
def train(model, ds_train, ds_valid, criterion, optimizer, epochs=10):
    wandb.watch(model)
    for epoch in range(epochs):
        for x, a, c, pa in (pbar := tqdm(ds_train)):
            optimizer.zero_grad()
            a_hat, c_hat, pa_hat = model(x)
            a_loss = criterion(a, a_hat)
            c_loss = criterion(c, c_hat)
            pa_loss = criterion(pa, pa_hat)
            loss = a_loss + c_loss + pa_loss
            loss.backward()
            optimizer.step()
            pbar.set_description(f'Epoch {epoch + 1}, Train Loss: {loss.item():.4f}')
            wandb.log({'train_loss': loss.item()})
        with torch.no_grad():
            for x, a, c, pa in (pbar := tqdm(ds_valid)):
                a_hat, c_hat, pa_hat = model(x)
                a_loss = criterion(pa, a_hat)
                c_loss = criterion(c, c_hat)
                pa_loss = criterion(pa, pa_hat)
                loss = a_loss + c_loss + pa_loss
                pbar.set_description(f'Epoch {epoch + 1}, Valid Loss: {loss.item():.4f}')
                wandb.log({'valid_loss': loss.item()})

## Instantiate dataset, model, and train model

In [None]:
ds_train = Dataset('dataset1', split='train')
ds_valid = Dataset('dataset1', split='validation', word2idx=ds_train.word2idx, idx2word=ds_train.idx2word)
params = {'embedding_dim': 50, 'vocab_size': len(ds_train.word2idx), 'epochs': 10, 'hidden_dims': [100, 50, 25]}
model = Model(params)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
wandb.init(project='mawsa', config=params)
train(model, ds_train, ds_valid, criterion, optimizer, epochs=1)
wandb.finish()