In [None]:
!pip install torchmetrics
!pip install adapters peft

In [10]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertConfig, BertModel, AdamW, get_constant_schedule_with_warmup
import pandas as pd
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from adapters import AdapterCompositionBlock, AdapterConfig, UniPELTConfig, AutoAdapterModel
from torchmetrics.functional import f1_score, accuracy
from tqdm import tqdm
import pickle
from peft import LoraConfig, get_peft_model, TaskType

In [None]:
!pip install gdown

In [None]:
import gdown

gdown.download("https://drive.google.com/file/d/1k5LMwmYF7PF-BzYQNE2ULBae79nbM268/view?usp=drive_link", "subtaskB_train.jsonl", quiet=False, fuzzy=True)
gdown.download("https://drive.google.com/file/d/1oh9c-d0fo3NtETNySmCNLUc6H1j4dSWE/view?usp=drive_link", "subtaskB_dev.jsonl", quiet=False, fuzzy=True)

Downloading...
From (original): https://drive.google.com/uc?id=1k5LMwmYF7PF-BzYQNE2ULBae79nbM268
From (redirected): https://drive.google.com/uc?id=1k5LMwmYF7PF-BzYQNE2ULBae79nbM268&confirm=t&uuid=fa96b154-3ffa-47d4-8e7c-e4b6c4a864a6
To: /content/subtaskB_train.jsonl
100%|██████████| 155M/155M [00:05<00:00, 30.0MB/s]
Downloading...
From: https://drive.google.com/uc?id=1oh9c-d0fo3NtETNySmCNLUc6H1j4dSWE
To: /content/subtaskB_dev.jsonl
100%|██████████| 4.93M/4.93M [00:00<00:00, 94.2MB/s]


'subtaskB_dev.jsonl'

### Parameters

In [16]:
batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
max_length = 128
epoch_nums = 5
lr = 1e-4
epsilon = 1e-8
splits = [0.01, 0.05, 0.1, 0.5]

train_path = 'subtaskB_train.jsonl'
val_path = 'subtaskB_dev.jsonl'

discriminator_save_path = 'discriminator.pth'
bert_save_path = 'bert.pth'
report_path = 'report_Bert_adapter.csv'

### Data Preprocessing

In [None]:
train_data = pd.read_json(train_path,lines=True)
val_data = pd.read_json(val_path, lines=True)

label_dict = {'chatGPT':0, 'human':1, 'cohere':2, 'davinci':3, 'bloomz':4, 'dolly':5}
label2int = lambda label: label_dict[label]

train_text = list(train_data['text'])
label_train = list(train_data['model'].apply(label2int))
text_val= list(val_data['text'])
label_val = list(val_data['model'].apply(label2int))

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
splits = [0.01, 0.05, 0.1, 0.5]
train_datasets = []
for split in splits:
    labeled_text, _, label, _  = train_test_split(train_text,label_train,test_size=1-split)
    label = torch.LongTensor(label)
    tokenized_labeled_text = tokenizer(labeled_text, max_length=max_length, truncation=True, padding='max_length',return_tensors='pt')

    tokenized_text = {'input_ids':tokenized_labeled_text['input_ids'],
                      'attention_mask': tokenized_labeled_text['attention_mask'],
                      'label': label}

    train_dataset = TensorDataset(tokenized_text['input_ids'],tokenized_text['attention_mask'], tokenized_text['label'])
    train_datasets.append(train_dataset)
    print(f"train dataset for split {split} added.")

with open('train_datasets.pkl','wb') as f:
     pickle.dump(train_datasets,f)

tokenized_text = tokenizer(text_val, max_length=max_length, truncation=True, padding='max_length',return_tensors='pt')
val_dataset = TensorDataset(tokenized_text['input_ids'], tokenized_text['attention_mask'], torch.LongTensor(label_val))
with open('val_dataset.pkl','wb') as f:
     pickle.dump(val_dataset,f)

train dataset for split 0.01 added.
train dataset for split 0.05 added.
train dataset for split 0.1 added.
train dataset for split 0.5 added.


In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
with open('/content/drive/MyDrive/train_datasets.pkl','rb') as f:
    train_datasets = pickle.load(f)

with open('/content/drive/MyDrive/val_dataset.pkl','rb') as f:
    val_dataset = pickle.load(f)

In [7]:
trainLoaders = []
for train_dataset in train_datasets:
    trainLoaders.append(DataLoader(train_dataset,batch_size=batch_size,shuffle=True))

valLoader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

### Model

In [8]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.feat = nn.Sequential(nn.Dropout(p=0.2), nn.Linear(768,768), nn.LeakyReLU(), nn.Dropout(p=0.2))
        self.logit = nn.Linear(768,6)

    def forward(self, x):
        feat = self.feat(x)
        logit = self.logit(feat)
        return feat, logit

class Bert(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = BertModel.from_pretrained('bert-base-uncased')
        lora_config = LoraConfig(
        task_type=TaskType.FEATURE_EXTRACTION, # this is necessary
        )

        # add LoRA adaptor
        self.model = get_peft_model(self.model, lora_config)

    def forward(self, input_ids, att_mask):
        return self.model(input_ids, att_mask)[0][:,0,:]

### Training and Validation

In [9]:
def validation(bert, discriminator, valLoader):
    with torch.no_grad():
        bert.eval()
        discriminator.eval()
        all_prediction = []
        all_targets = []
        for i, batch in tqdm(enumerate(valLoader), total=len(valLoader), desc=f'Validation'):

            input_ids = batch[0].cuda()
            att_mask = batch[1].cuda()
            targets = batch[2].cuda()

            y_bert = bert(input_ids, att_mask)
            logits = discriminator(y_bert)[1]

            preds = logits.max(dim=-1)[1]
            all_prediction.append(preds.cpu())
            all_targets.append(targets.cpu())


    return f1_score(preds, targets, 'multiclass', num_classes=6), accuracy(preds, targets, 'multiclass', num_classes=6)

In [17]:


#bert = torch.nn.parallel.DataParallel(bert, device_ids=list(range(2)), dim=0)
#generator = torch.nn.parallel.DataParallel(generator, device_ids=list(range(2)), dim=0)
#discriminator = torch.nn.parallel.DataParallel(discriminator, device_ids=list(range(2)), dim=0)



f1s = []
accs = []

for split, trainLoader in zip(splits,trainLoaders):
    discriminator = Discriminator().cuda()
    bert = Bert().cuda()

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = AdamW(list(discriminator.parameters()) + list(bert.parameters()), lr=lr)

    num_train_steps = int(len(trainLoader) * epoch_nums)
    num_warmup_steps = int(num_train_steps * 0.1)
    scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps = num_warmup_steps)

    for epoch in range(epoch_nums):
        discriminator.train()
        bert.train()

        current_loss = 0.0
        for i, batch in tqdm(enumerate(trainLoader), total=len(trainLoader), desc=f'({split}) epoch {epoch}'):

            input_ids = batch[0].cuda()
            att_mask = batch[1].cuda()
            targets = batch[2].cuda()

            y_bert = bert(input_ids, att_mask)
            logits = discriminator(y_bert)[1]

            loss = criterion(logits, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            current_loss += loss.item()

            scheduler.step()


        print(f'Loss: {current_loss / len(trainLoader)}')
        f1, acc = validation(bert, discriminator, valLoader)
        print(f'f1 score: {f1.item()}, accuracy: {acc.item()}')

        torch.save(discriminator.state_dict(), f'split_{split}_'+discriminator_save_path)
        torch.save(bert.state_dict(), f'split_{split}_'+bert_save_path)

    f1s.append(f1.item())
    accs.append(acc.item())

report = pd.DataFrame({"splits": splits, "accuracies": accs, "f1 score": f1s})
report.to_csv(report_path)

(0.01) epoch 0: 100%|██████████| 23/23 [00:15<00:00,  1.46it/s]


Loss: 1.789930779000987


Validation: 100%|██████████| 94/94 [00:40<00:00,  2.34it/s]


f1 score: 0.2083333283662796, accuracy: 0.2083333283662796


(0.01) epoch 1: 100%|██████████| 23/23 [00:14<00:00,  1.55it/s]


Loss: 1.7446655086849048


Validation: 100%|██████████| 94/94 [00:35<00:00,  2.67it/s]


f1 score: 0.375, accuracy: 0.375


(0.01) epoch 2: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Loss: 1.6917161630547566


Validation: 100%|██████████| 94/94 [00:33<00:00,  2.85it/s]


f1 score: 0.2916666567325592, accuracy: 0.2916666567325592


(0.01) epoch 3: 100%|██████████| 23/23 [00:16<00:00,  1.42it/s]


Loss: 1.650471127551535


Validation: 100%|██████████| 94/94 [00:37<00:00,  2.52it/s]


f1 score: 0.1666666716337204, accuracy: 0.1666666716337204


(0.01) epoch 4: 100%|██████████| 23/23 [00:14<00:00,  1.57it/s]


Loss: 1.5760386352953704


Validation: 100%|██████████| 94/94 [00:37<00:00,  2.54it/s]


f1 score: 0.3333333432674408, accuracy: 0.3333333432674408


(0.05) epoch 0: 100%|██████████| 111/111 [01:21<00:00,  1.35it/s]


Loss: 1.7523448939795967


Validation: 100%|██████████| 94/94 [00:29<00:00,  3.15it/s]


f1 score: 0.375, accuracy: 0.375


(0.05) epoch 1: 100%|██████████| 111/111 [01:21<00:00,  1.36it/s]


Loss: 1.4928501479260556


Validation: 100%|██████████| 94/94 [00:30<00:00,  3.04it/s]


f1 score: 0.4583333432674408, accuracy: 0.4583333432674408


(0.05) epoch 2: 100%|██████████| 111/111 [01:22<00:00,  1.35it/s]


Loss: 1.2009331788028683


Validation: 100%|██████████| 94/94 [00:29<00:00,  3.23it/s]


f1 score: 0.375, accuracy: 0.375


(0.05) epoch 3: 100%|██████████| 111/111 [01:20<00:00,  1.37it/s]


Loss: 1.0277780342746425


Validation: 100%|██████████| 94/94 [00:30<00:00,  3.11it/s]


f1 score: 0.3333333432674408, accuracy: 0.3333333432674408


(0.05) epoch 4: 100%|██████████| 111/111 [01:21<00:00,  1.36it/s]


Loss: 0.9049429641113625


Validation: 100%|██████████| 94/94 [00:30<00:00,  3.10it/s]


f1 score: 0.375, accuracy: 0.375


(0.1) epoch 0: 100%|██████████| 222/222 [02:44<00:00,  1.35it/s]


Loss: 1.6856546745643959


Validation: 100%|██████████| 94/94 [00:31<00:00,  2.96it/s]


f1 score: 0.3333333432674408, accuracy: 0.3333333432674408


(0.1) epoch 1: 100%|██████████| 222/222 [02:41<00:00,  1.38it/s]


Loss: 1.1684740636799786


Validation: 100%|██████████| 94/94 [00:28<00:00,  3.29it/s]


f1 score: 0.375, accuracy: 0.375


(0.1) epoch 2: 100%|██████████| 222/222 [02:44<00:00,  1.35it/s]


Loss: 0.9256818710683702


Validation: 100%|██████████| 94/94 [00:30<00:00,  3.13it/s]


f1 score: 0.3333333432674408, accuracy: 0.3333333432674408


(0.1) epoch 3: 100%|██████████| 222/222 [02:38<00:00,  1.40it/s]


Loss: 0.8022968336805567


Validation: 100%|██████████| 94/94 [00:29<00:00,  3.20it/s]


f1 score: 0.4166666567325592, accuracy: 0.4166666567325592


(0.1) epoch 4: 100%|██████████| 222/222 [02:40<00:00,  1.39it/s]


Loss: 0.7387058537554096


Validation: 100%|██████████| 94/94 [00:26<00:00,  3.55it/s]


f1 score: 0.4583333432674408, accuracy: 0.4583333432674408


(0.5) epoch 0: 100%|██████████| 1110/1110 [13:30<00:00,  1.37it/s]


Loss: 1.2501451225162625


Validation: 100%|██████████| 94/94 [00:31<00:00,  2.98it/s]


f1 score: 0.5, accuracy: 0.5


(0.5) epoch 1: 100%|██████████| 1110/1110 [13:33<00:00,  1.36it/s]


Loss: 0.711013347635398


Validation: 100%|██████████| 94/94 [00:28<00:00,  3.26it/s]


f1 score: 0.5833333134651184, accuracy: 0.5833333134651184


(0.5) epoch 2: 100%|██████████| 1110/1110 [13:24<00:00,  1.38it/s]


Loss: 0.6163750424309894


Validation: 100%|██████████| 94/94 [00:32<00:00,  2.94it/s]


f1 score: 0.5416666865348816, accuracy: 0.5416666865348816


(0.5) epoch 3: 100%|██████████| 1110/1110 [13:33<00:00,  1.36it/s]


Loss: 0.5702215953184677


Validation: 100%|██████████| 94/94 [00:24<00:00,  3.88it/s]


f1 score: 0.5833333134651184, accuracy: 0.5833333134651184


(0.5) epoch 4: 100%|██████████| 1110/1110 [13:34<00:00,  1.36it/s]


Loss: 0.5282231621645592


Validation: 100%|██████████| 94/94 [00:28<00:00,  3.33it/s]


f1 score: 0.4166666567325592, accuracy: 0.4166666567325592
