In [None]:
import sys
sys.path.append("../")
from copy import deepcopy

import torch
import logging
from torch.utils.data import DataLoader

import transformers
from transformers import AutoTokenizer, AutoModel

from src.models.rnn_model import RNN
from src.models.learned_skip_masking import (
    DeepSkipAugmenter,
    WeightedMaskClassificationLoss,
    DeepSkipAugmenterTrainer
)
from src.data.dataio import DataFiles, Dataset, truncate_fn

transformers.logging.set_verbosity_error()

logger = logging.getLogger()

In [None]:
BERT_MODEL_NAME = "distilbert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=BERT_MODEL_NAME)

train_data_files = DataFiles(
    ["/Users/od/Documents/Mila/COMP-550/Final Project/comp-550-project/data/articles/augmentation/train.json"]
)

train_dataset = Dataset(train_data_files, data_format="json")
train_dataloader = DataLoader(train_dataset, batch_size=4)

valid_data_files = DataFiles(
    ["/Users/od/Documents/Mila/COMP-550/Final Project/comp-550-project/data/articles/augmentation/validation.json"]
)
valid_dataset = Dataset(valid_data_files, data_format="json")
valid_dataloader = DataLoader(valid_dataset, batch_size=4)


In [None]:
# data_loader = DataLoader(train_dataset, batch_size=10)

bert = AutoModel.from_pretrained(BERT_MODEL_NAME)

model = DeepSkipAugmenter(
    tokenizer=AutoTokenizer.from_pretrained(BERT_MODEL_NAME),
    masking_model=RNN(output_size=2, embeddings_layer=deepcopy(bert.embeddings.word_embeddings)),
    unmasking_model=bert,
    classifier=RNN(output_size=5, embeddings_layer=deepcopy(bert.embeddings.word_embeddings))
)

optimizer = torch.optim.Adam(
    params=[{"params": model.masking_model.parameters()},
            {"params": model.classifier.parameters()}],
    lr=0.001        
)
criterion = WeightedMaskClassificationLoss()

trainer = DeepSkipAugmenterTrainer(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    train_dataloader=train_dataloader,
    val_dataloader=valid_dataloader,
    logger=logger,
    num_epochs=2,
    log_interval=1
)

In [None]:
# F1 Score
trainer.best_f1_score

In [None]:
import os
from src.utils.json_utils import write_to_json

PATH = "/Users/od/Documents/Mila/COMP-550/Final Project/comp-550-project/data/articles/augmentation"

with open(os.path.join(PATH, "training_text.txt")) as f:
    training_text = f.readlines()
    training_text = [i.strip() for i in training_text]

with open(os.path.join(PATH, "training_labels.txt")) as f:
    training_labels = f.readlines()
    training_labels = [int(i) for i in training_labels]

max(training_labels)
