## Data preparation and settings

### Use in Colab to resolve environment (otherwise ignore)

In [None]:
%%capture
!pip install pytorch-lightning
!pip install transformers
!pip install adapter-transformers
!pip install scikit-learn 

In [None]:
!nvidia-smi

### Data Inspection

In [None]:
import json
import pandas as pd
import torch
from sklearn.preprocessing import OneHotEncoder
from torch.utils.data import Dataset

In [None]:
# toy dataReader for exploration 
class DataReader:
    def __init__(self, json_name, shuffle=False):
        with open(json_name, 'r') as json_file:
            raw_json = list(json_file)
        self.raw = raw_json
        self.raw_objects = []
        for item in self.raw:
            self.raw_objects.append(json.loads(item))
        self.df = pd.DataFrame(self.raw_objects)


    def get_stats(self):   
        return self.df.head()

    def get_data(self):
        # import IPython; IPython.embed(); exit(1)
        return self.df['text'], self.df['intent']

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
train_path = "./drive/MyDrive/DL project/project/acl-arc/train.jsonl"
val_path = "./drive/MyDrive/DL project/project/acl-arc/test.jsonl"
train_acl = DataReader(json_name=train_path).df
test_acl = DataReader(json_name=val_path).df

In [None]:
train_acl.head()

In [None]:
# Simple label view
labels = list(set(train_acl['intent']))
train_acl.groupby('intent').count()['text'].plot.bar()

### Dataset

In [None]:
class ACL_Dataset(Dataset):
    def __init__(self, data_path, tokenizer, max_token_len=128):
        self.data_path = data_path
        self.tokenizer = tokenizer
        self.attribute = None
        self.max_token_len = max_token_len
        self._prepare_data()

    def _prepare_data(self):
        '''
        Place to add other data preparations (sampling / train&test separation)

        '''
        if self.data_path[-3:] != "csv":
            with open(self.data_path, 'r') as json_file:
                raw_json = list(json_file)
            raw_objects = []
            for item in raw_json:
                raw_objects.append(json.loads(item))
            self.data = pd.DataFrame(raw_objects)

            # Turn into one-hot encoding
            encoder = OneHotEncoder(handle_unknown='ignore')
            encoder_df = pd.DataFrame(encoder.fit_transform(self.data[['intent']]).toarray())
            self.attribute = list(set(train_acl['intent']))
            encoder_df.columns = self.attribute
            self.data = self.data.join(encoder_df)

        else:
            print("Not yet implemented for csv")

    def __len__ (self):
        return len(self.data)

    def __getitem__(self, index):
        # TODO get desired items by index
        item = self.data.iloc[index]
        labels = torch.Tensor(item[self.attribute])
        text = str(item.cleaned_cite_text)
        tokens = self.tokenizer.encode_plus(text, add_special_tokens=True, 
                    return_tensors='pt', truncation=True, max_length = self.max_token_len, 
                    padding="max_length", return_attention_mask=True)
    
        return {"input_ids": tokens.input_ids.flatten(), "attention_mask": tokens.attention_mask.flatten(), "labels": labels}

In [None]:
from transformers import AutoTokenizer
model_name = "roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
train = ACL_Dataset(train_path, tokenizer)


In [None]:
# train.__getitem__(0)

### Data module

In [None]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader

In [None]:
class ACL_DataLoader(pl.LightningDataModule):
    def __init__(self, train_path, val_path, batch_size:int = 32, max_token_length: int = 128, model_name = "roberta-base"):
        super().__init__()
        self.train_path = train_path
        self.val_path = val_path
        self.batch_size = batch_size
        self.max_token_length = max_token_length
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)


    def setup(self, stage=None):
        if stage in (None, "fit"):
            self.train_dataset = ACL_Dataset(train_path, self.tokenizer)
            self.val_dataset = ACL_Dataset(val_path, self.tokenizer)
            self.attributes = self.val_dataset.attribute

        if stage == "predict":
            self.val_dataset = ACL_Dataset(val_path, self.tokenizer)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers = 4, shuffle=False)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers = 4, shuffle=False)

In [None]:
acl_datamodule = ACL_DataLoader(train_path=train_path, val_path=val_path)
acl_datamodule.setup()
acl_dataloader = acl_datamodule.train_dataloader()

In [None]:
# number of batches 
len(acl_dataloader)

### Model

In [None]:
from transformers import AutoModel, AdamW, get_cosine_schedule_with_warmup
import torch.nn as nn
import math
from torchmetrics.functional.classification import f1_score, auroc

'''
TODO change model accordingly
'''
class ACL_Classifier(pl.LightningModule):

    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        self.pre_trained_model = AutoModel.from_pretrained(config['model_name'], return_dict = True)
        # TODO Adapters to be added when base is trained 
        self.pre_trained_adapter = None
        self.hidden = nn.Linear(self.pre_trained_model.config.hidden_size, self.pre_trained_model.config.hidden_size)
        self.classification = nn.Linear(self.pre_trained_model.config.hidden_size, self.config['n_labels'])

        # Initialization module -- xavier or some others?
        torch.nn.init.xavier_uniform_(self.hidden.weight)
        torch.nn.init.xavier_uniform_(self.classification.weight)
    
        # others 
        self.relu = nn.ReLU()
        self.loss_func = nn.BCEWithLogitsLoss(reduction="mean")
        self.dropout = nn.Dropout()

    def forward(self, input_ids, attention_mask, labels=None):
        # pre_trained model output
        output = self.pre_trained_model(input_ids=input_ids, attention_mask=attention_mask)
        output = torch.mean(output.last_hidden_state, 1)   # mean pooling as in the paper 
        # FF classifier 
        output = self.hidden(output)
        # output = self.dropout(output)
        output = self.relu(output)
        output = self.classification(output)
        # loss (change accordingly with the type of loss function used)
        loss = 0
        if labels is not None:
            loss = self.loss_func(output.view(-1, self.config['n_labels']), labels.view(-1, self.config['n_labels']))
        return loss, output
    
    def training_step(self, batch, batch_index):
        loss, output = self.forward(**batch)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return {"loss": loss, "predictions": output, "labels": batch["labels"]}

    
    def validation_step(self, batch, batch_index):
        loss, output = self.forward(**batch)
        self.log("val_loss", loss, prog_bar=True, logger=True)
        return {"val_loss": loss, "predictions": output, "labels": batch["labels"]}

    def predict_step(self, batch, batch_index):
        _, output = self.forward(**batch)
        return output
    
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.config['lr'], weight_decay=self.config['weight_decay'], no_deprecation_warning=True, correct_bias=False)
        total_steps = self.config['train_size'] / self.config['batch_size']
        warmup_steps = math.floor(total_steps * self.config['warmup'])
        scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
        return [optimizer], [scheduler]

In [None]:
config = {
    # Some randomly typed-in initial configs
    'model_name': 'roberta-base',
    'batch_size': 256,
    'lr': 1e-4,
    'warmup': 0.06,
    'weight_decay': 0.01,
    'n_epochs': 100,
    'train_size': len(acl_datamodule.train_dataloader()),
    'n_labels': len(labels)

}


In [None]:
'''Single output sanity check'''
model = ACL_Classifier(config=config)
idx = 0
input_ids = train.__getitem__(idx)['input_ids']
attention_mask = train.__getitem__(idx)['attention_mask']
labels = train.__getitem__(idx)['labels']
loss, output = model(input_ids.unsqueeze(0), attention_mask.unsqueeze(0), labels.unsqueeze(0))
print("loss" + str(loss))
print("raw prediction: " + str(output))

print("label: " + str(labels))
print(train.__getitem__(idx)['text'])

In [None]:
# Copy datamodule here - for convenience
acl_datamodule = ACL_DataLoader(train_path=train_path, val_path=val_path, batch_size=config['batch_size'])
acl_datamodule.setup()
model = ACL_Classifier(config=config)

trainer = pl.Trainer(max_epochs=config['n_epochs'], gpus=1, num_sanity_val_steps=10)
trainer.fit(model, acl_datamodule)

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/