In [None]:
!nvidia-smi

In [None]:
%cd /content/
!git clone https://github.com/westphal-jan/peer-data
%cd /content/peer-data
!git submodule update --init --recursive

In [None]:
!pip install pytorch-lightning wandb python-dotenv catalyst sentence-transformers numpy requests nlpaug sentencepiece nltk

In [None]:
import os
import torch
import json
import glob
from pathlib import Path
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments, TrainerCallback, TrainerState, TrainerControl, AdamW
import wandb
from datetime import datetime
import pickle
import numpy as np
import nlpaug.augmenter.word as naw
from torch.utils.data import DataLoader
from copy import copy
from datetime import datetime

In [None]:
class PaperDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

# class AugmentedPaperDataset(torch.utils.data.Dataset):
#     def __init__(self, texts, labels, tokenizer, augmenter, aug_p=0.0, seed=0):
#         """Augmenter needs to have the method 'augment'"""
#         self.texts = texts
#         self.labels = labels
#         self.tokenizer = tokenizer
#         self.augmenter = augmenter
#         self.encodings = tokenizer(train_texts, truncation=True, padding='max_length', max_length=512)
#         self.aug_p = aug_p
#         self.usages = [0] * len(texts)
#         self.rnd = np.random.RandomState(seed)
#         self.num_augmentations = 0

#     def __getitem__(self, idx):
#         if self.usages[idx] > 0 and self.rnd.random_sample() < self.aug_p:
#             self.texts[idx] = self.augmenter.augment(self.texts[idx])
#             new_encoding = tokenizer(self.texts[idx], truncation=True, padding='max_length', max_length=512)
#             for key, val in self.encodings.items():
#                 val[idx] = new_encoding[key]
#             self.usages[idx] = 0
#             self.num_augmentations += 1
        
#         item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
#         item['labels'] = torch.tensor(self.labels[idx])
#         self.usages[idx] += 1
#         return item

#     def __len__(self):
#         return len(self.labels)

def read_dataset(data_dir: Path, num_texts=None):
    file_paths = glob.glob(f"{data_dir}/*.json")
    if num_texts != None:
        file_paths = file_paths[:num_texts]
    texts = []
    labels = []
    for i, file_path in enumerate(tqdm(file_paths)):
        with open(file_path) as f:
            paper_json = json.load(f)
            accepted = paper_json["review"]["accepted"]
            abstract = paper_json["review"]["abstract"]
            
            texts.append(abstract)
            labels.append(int(accepted))
    return texts, labels

In [None]:
dataset = "data/original"

data_dir = Path(dataset)
train_texts, train_labels = read_dataset(data_dir)

num_accepted = len(list(filter(lambda x: x == 1, train_labels)))
num_not_accepted = len(list(filter(lambda x: x == 0, train_labels)))

print(num_accepted, num_not_accepted)
label_weight = num_not_accepted / np.array([num_not_accepted, num_accepted])

train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=0.2)

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)

train_dataset = PaperDataset(train_encodings, train_labels)
val_dataset = PaperDataset(val_encodings, val_labels)

In [None]:
# back_translation_aug = naw.BackTranslationAug(
#     from_model_name='facebook/wmt19-en-de', 
#     to_model_name='facebook/wmt19-de-en',
#     device="cuda",
#     max_length=512
# )

# tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
# train_dataset = AugmentedPaperDataset(train_texts, train_labels, tokenizer, back_translation_aug, 0.5)
# val_dataset = AugmentedPaperDataset(val_texts, val_labels, tokenizer, back_translation_aug)

In [None]:
# # Experiment with nlp augmentation
# import nlpaug.augmenter.word as naw
# import nlpaug.augmenter.sentence as nas

# text, label = train_texts[1], train_labels[1]
# print(text)
# print(len(text.split(" ")), label)

# # Synonym Augmenter
# import nltk
# nltk.download('averaged_perceptron_tagger')
# nltk.download('wordnet')

# aug = naw.SynonymAug(aug_src='wordnet', aug_min=10, aug_max=512, aug_p=0.25)
# augmented_text = aug.augment(text)
# print(augmented_text)

# # Back Translation
# back_translation_aug = naw.BackTranslationAug(
#     from_model_name='facebook/wmt19-en-de', 
#     to_model_name='facebook/wmt19-de-en',
#     device="cuda",
#     max_length=512
# )
# back_translated_text = back_translation_aug.augment(text) # Unique outputs only working partly
# print(back_translated_text)

# # Contextual Sentence Augmentation
# # https://nlpaug.readthedocs.io/en/latest/augmenter/sentence/context_word_embs_sentence.html
# # context_aug = nas.ContextualWordEmbsForSentenceAug(model_path='gpt2')
# # augmented_text = context_aug.augment(text)
# # print(augmented_text)

In [None]:
# back_translation_aug.augment(back_translated_text)

In [None]:
# synonym_aug = naw.SynonymAug(aug_src='ppdb', model_path='ppdb-2.0-tldr')
# augmented_text = synonym_aug.augment(text)
# print(augmented_text)

In [None]:
# class MyCallback(TrainerCallback):

#     def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
#         """
#         Event called after an evaluation phase.
#         """
#         print("Callback")
#         eval_dataloader = kwargs["eval_dataloader"]
#         eval_model = kwargs["model"]
#         step = state.global_step
#         eval_model.eval()
#         tp = fp = fn = tn = 0
#         for x in tqdm(eval_dataloader):
#             input_ids = x["input_ids"].to(device)
#             output = eval_model(input_ids=input_ids)

#             prediction = output.logits.argmax(dim=1).cpu()
#             actual = x["labels"]
#             tp += ((prediction == 1) & (actual == 1)).sum()
#             fp += ((prediction == 1) & (actual == 0)).sum()
#             fn += ((prediction == 0) & (actual == 1)).sum()
#             tn += ((prediction == 0) & (actual == 0)).sum()
        
#         precision = tp / (tp + fp)
#         recall = tp / (tp + fn)
#         f1 = 2 * precision * recall / (precision + recall)
#         accuracy = (tp + tn) / (tp + tn + fp + fn)

#         metrics = {"step": step, "precision": precision, "recall": recall, "f1": f1,
#                    "accuracy": accuracy, "tp": tp, "fp": fp, "fn": fn, "tn": tn}
#         print(step, metrics)
#         try:
#             wandb.log(metrics)
#         except:
#             pass


# class ImballancedLabelTrainer(Trainer):
#     def __init__(self, label_weight, **kwargs):
#         super().__init__(**kwargs)
#         self._loss = torch.nn.CrossEntropyLoss(label_weight, reduction="sum")

#     def compute_loss(self, model, inputs, return_outputs=False):
#         labels = inputs.pop("labels")
#         outputs = model(**inputs)
#         logits = outputs.logits
#         loss = self._loss(logits, labels)
#         return (loss, outputs) if return_outputs else loss


def compute_metrics(eval_pred, prefix="eval"):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)
    actual = np.array(labels)

    tp = ((predictions == 1) & (actual == 1)).sum()
    fp = ((predictions == 1) & (actual == 0)).sum()
    fn = ((predictions == 0) & (actual == 1)).sum()
    tn = ((predictions == 0) & (actual == 0)).sum()

    precision = tp / (tp + fp) if tp + fp > 0 else 0.0
    recall = tp / (tp + fn) if tp + fn > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    
    metrics = {"metric/precision": precision, "metric/recall": recall, "metric/f1": f1, "metric/accuracy": accuracy,
            "classification/tp": tp, "classification/fp": fp, "classification/fn": fn, "classification/tn": tn}
    metrics = {f"{prefix}/{key}": metric for key, metric in metrics.items()}
    # "augmentation/train": train_dataset.num_augmentations, "augmentation/val": val_dataset.num_augmentations
    return metrics

logits = [[1.0, 2.5], [1.3, 0.7]]
labels = [0, 1]

result = compute_metrics((logits, labels))
print(result)

# my_callback = MyCallback()

In [None]:
# training_args = TrainingArguments(
#     output_dir=f'test',          # output directory
#     num_train_epochs=10,              # total number of training epochs
#     per_device_train_batch_size=8,  # batch size per device during training
#     per_device_eval_batch_size=8,   # batch size for evaluation
#     warmup_steps=100,                # number of warmup steps for learning rate scheduler
#     weight_decay=0.01,               # strength of weight decay
#     logging_dir='./logs',            # directory for storing logs
#     logging_steps=20,
#     eval_steps=200,
#     evaluation_strategy="steps",
#     report_to="none",
#     logging_first_step=True,
# )

# trainer = ImballancedLabelTrainer(
#     model=model,                         # the instantiated 🤗 Transformers model to be trained
#     args=training_args,                  # training arguments, defined above
#     train_dataset=train_dataset,         # training dataset
#     eval_dataset=val_dataset,            # evaluation dataset
#     compute_metrics=compute_metrics,
#     label_weight=_label_weight         
# )

# loader = trainer.get_train_dataloader()
# for x in loader:
#     x = {key: t.to(device) for key, t in x.items()}
#     print()
#     break
# trainer.compute_loss(model, x)

In [None]:
# device = torch.device("cuda")

# model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased").to(device)

In [None]:
# wandb_logging = True
# wandb.login()

# run_name = datetime.now().strftime('%d-%m-%Y_%H_%M_%S')
# print("Run name:", run_name)
# if wandb_logging:
#     wandb.init(entity="paper-judging", project="huggingface", name=run_name)

# training_args = TrainingArguments(
#     output_dir=f'results/{run_name}',          # output directory
#     num_train_epochs=5,              # total number of training epochs
#     per_device_train_batch_size=16,  # batch size per device during training
#     per_device_eval_batch_size=16,   # batch size for evaluation
#     warmup_steps=100,                # number of warmup steps for learning rate scheduler
#     weight_decay=0.01,               # strength of weight decay
#     logging_dir='./logs',            # directory for storing logs
#     logging_steps=20,
#     eval_steps=200,
#     evaluation_strategy="steps",
#     report_to="wandb" if wandb_logging else "none",
#     logging_first_step=True,
# )

# _label_weight = torch.from_numpy(label_weight).float().to(device)
# trainer = ImballancedLabelTrainer(
#     model=model,                         # the instantiated 🤗 Transformers model to be trained
#     args=training_args,                  # training arguments, defined above
#     train_dataset=train_dataset,         # training dataset
#     eval_dataset=val_dataset,            # evaluation dataset
#     compute_metrics=compute_metrics,
#     label_weight=_label_weight         
# )

# trainer.evaluate()
# trainer.train()

In [None]:
def augment_texts(_texts, aug_p, augmenter, seed=0, verbose=True):
    # TODO: Maybe only encode the augmented texts
    texts = np.array(_texts)
    rnd = np.random.RandomState(seed)
    r = rnd.random_sample(len(texts))
    aug_idx = np.argwhere(r <= aug_p).flatten()
    if verbose:
        print("Augmenting:", len(aug_idx))
    texts_to_augment = texts[aug_idx]
    
    with torch.no_grad():
        augmented_texts = augmenter.augment(texts_to_augment.tolist())
    texts[aug_idx] = augmented_texts
    return texts.tolist(), aug_idx.tolist()

# back_translation_aug = naw.BackTranslationAug(
#     from_model_name='facebook/wmt19-en-de', 
#     to_model_name='facebook/wmt19-de-en',
#     device="cuda",
#     max_length=512
# )
# 
# augmented_texts = augment_texts(train_texts[:10], aug_p=0.5, augmenter=back_translation_aug)
# for text, augmented_text in zip(train_texts[:10], augmented_texts):
#     if text != augmented_text:
#         print(text)
#         print(augmented_text)
#         print()
# augmented_train_encodings = tokenizer(augmented_train_texts, truncation=True, padding=True)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
model.to(device)

train_batch_size, val_batch_size = 16, 16
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
_train_texts = copy(train_texts)

val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False)

optim = AdamW(model.parameters(), lr=5e-5)
_label_weight = torch.from_numpy(label_weight).float().to(device)
loss_func = torch.nn.CrossEntropyLoss(_label_weight, reduction="sum")

wandb_logging = False

run_name = datetime.now().strftime('%d-%m-%Y_%H_%M_%S')
print("Run name:", run_name)
if wandb_logging:
    wandb.login()
    wandb.init(entity="paper-judging", project="huggingface", name=run_name)

output_dir=f'results/{run_name}'
os.makedirs(output_dir)
num_train_epochs = 5
logging_steps, eval_steps = 20, 200
# logging_steps, eval_steps = 5, 10
aug_p = 0 # Set to 0 to turn of augmentation
_steps = 0

if aug_p > 0:
    ## Very slow ca. 30 texts / min
    # augmenter = naw.BackTranslationAug(
    #     from_model_name='facebook/wmt19-en-de', 
    #     to_model_name='facebook/wmt19-de-en',
    #     device="cuda",
    #     batch_size=train_batch_size,
    #     max_length=512
    # )

    import nltk
    nltk.download('averaged_perceptron_tagger')
    nltk.download('wordnet')

    augmenter = naw.SynonymAug(aug_src='wordnet', aug_min=10, aug_max=512, aug_p=0.2)

def run_model(_batch):
    inputs = {key: val.to(device) for key, val in _batch.items()}
    labels = inputs.pop("labels")
    outputs = model(**inputs)

    logits = outputs.logits
    loss = loss_func(logits, labels)
    return loss, logits

for epoch in range(num_train_epochs):
    print(f"Epoch: {epoch + 1}/{num_train_epochs}")
    for batch in tqdm(train_loader):
        model.train()
        optim.zero_grad()
        train_loss, _ = run_model(batch)
        
        train_loss.backward()
        optim.step()

        metrics = {}
        if _steps % logging_steps == 0:
            metrics.update({"train/loss": train_loss.item(), "steps": _steps})
        
        if _steps % eval_steps == 0:
            model.eval()
            val_losses = []
            logits = []
            with torch.no_grad():
                for val_batch in val_loader:
                    _val_loss, _logits = run_model(val_batch)
                    val_losses.append(_val_loss.item())
                    logits.extend(_logits.tolist())

            val_loss = sum(val_losses) / len(val_losses)
            _metrics = compute_metrics((logits, val_labels))
            metrics["eval/loss"] = val_loss
            metrics.update(_metrics)
        
        if metrics:
            print(metrics) 
            if wandb_logging:
                wandb.log(metrics)

        _steps += 1

    # TODO: Augmentation after epoch
    if aug_p > 0 and epoch + 1 != num_epochs:
        print("Augmenting with p:", aug_p, datetime.now())
        _train_texts, augmented_idx = augment_texts(_train_texts, aug_p=aug_p, augmenter=augmenter, seed=epoch, verbose=False)
        print("#Augmented texts:", len(augmented_idx), datetime.now())
        print()
        _train_encodings = tokenizer(_train_texts, truncation=True, padding=True)
        _train_dataset = PaperDataset(_train_encodings, train_labels)
        train_loader = DataLoader(_train_dataset, batch_size=train_batch_size, shuffle=True)
        if wandb_logging:
            wandb.log({"augmentation/aug_p": aug_p, "augmentation/num_texts": len(augmented_idx), "epoch": epoch})

snapshot_dir = f"results/{run_name}/network-snapshot-latest"
model.save_pretrained(snapshot_dir)

if wandb_logging:
    wandb.save(f"{snapshot_dir}/*", os.path.dirname(snapshot_dir))
    wandb.finish()

In [None]:
# wandb.finish()

In [None]:
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
# model.to(device)

In [None]:
# val_batch_size = 16
# val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False)
# _label_weight = torch.from_numpy(label_weight).float().to(device)
# loss_func = torch.nn.CrossEntropyLoss(_label_weight, reduction="sum")

# model.eval()
# val_losses = []
# logits = []
# with torch.no_grad():
#     for val_batch in tqdm(val_loader):
#         inputs = {key: val.to(device) for key, val in val_batch.items()}
#         _labels = inputs.pop("labels")
#         _outputs = model(**inputs)

#         _logits = _outputs.logits
#         _val_loss = loss_func(_logits, _labels)
#         val_losses.append(_val_loss.item())
#         logits.extend(_logits.tolist())

# val_loss = sum(val_losses) / len(val_losses)
# metrics = compute_metrics((logits, val_labels))

In [None]:
# snapshot_dir = f"results/{run_name}/network-snapshot-latest"
# model.save_pretrained(snapshot_dir)

# if wandb_logging:
#     wandb.save(f"{snapshot_dir}/*", os.path.dirname(snapshot_dir))
#     wandb.finish()

In [None]:
# !pip install wandb
# import wandb
# wandb.login()
# wandb.finish()

In [None]:
# from transformers import AutoTokenizer, AutoModelWithLMHead

# t5_tokenizer = AutoTokenizer.from_pretrained("t5-base")

# t5_model = AutoModelWithLMHead.from_pretrained("t5-base").to(device)

In [None]:
# input = "That's great"
# input_enc = t5_tokenizer(input, truncation=True, padding=True, return_tensors="pt")
# output = t5_model(**input_enc.to(device))
# print(output)
# print(torch.softmax(output.logits, dim=1))

In [None]:
paper_abstract = """Generative adversarial networks (GANs) have shown
outstanding performance on a wide range of problems in
computer vision, graphics, and machine learning, but often require numerous training data and heavy computational resources. To tackle this issue, several methods introduce a transfer learning technique in GAN training. They,
however, are either prone to overfitting or limited to learning small distribution shifts. In this paper, we show that
simple fine-tuning of GANs with frozen lower layers of
the discriminator performs surprisingly well. This simple
baseline, FreezeD, significantly outperforms previous techniques used in both unconditional and conditional GANs.
We demonstrate the consistent effect using StyleGAN and
SNGAN-projection architectures on several datasets of Animal Face, Anime Face, Oxford Flower, CUB-200-2011, and
Caltech-256 datasets. The code and results are available at
https://github.com/sangwoomo/FreezeD."""

In [None]:
input = ["It's incredibly bad", paper_abstract, "hello", "darkness"]
input_enc = tokenizer(input, truncation=True, padding=True, return_tensors="pt")
output = model(**input_enc.to(device))
print(output)
# print(torch.softmax(output.logits, dim=1))

In [None]:
# prediction = output.logits.argmax(dim=1)
# actual = prediction
# n = tp = fp = fn = tn = 0
# tp += (prediction == 1) & (actual == 1)
# fp += (prediction == 1) & (actual == 0)
# fn += (prediction == 0) & (actual == 1)
# tn += (prediction == 0) & (actual == 0)

In [None]:
# !python huggingface_train.py