In [1]:
!pip install peft

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from peft import LoraConfig, get_peft_model, TaskType
from transformers import BertTokenizer, BertConfig, BertModel, AdamW, get_constant_schedule_with_warmup
import pandas as pd
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torchmetrics.functional import f1_score, accuracy
from tqdm import tqdm
import pickle
import random
import re
import nltk
import subprocess
from nltk.corpus import stopwords
from nltk.tokenize import wordpunct_tokenize
from nltk.stem import WordNetLemmatizer
from nltk.tokenize.treebank import TreebankWordDetokenizer

In [3]:
!pip install gdown

In [4]:
import gdown

gdown.download("https://drive.google.com/file/d/1k5LMwmYF7PF-BzYQNE2ULBae79nbM268/view?usp=drive_link", "subtaskB_train.jsonl", quiet=False, fuzzy=True)
gdown.download("https://drive.google.com/file/d/1oh9c-d0fo3NtETNySmCNLUc6H1j4dSWE/view?usp=drive_link", "subtaskB_dev.jsonl", quiet=False, fuzzy=True)

### Parameters

In [2]:
batch_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
max_length = 128
epoch_nums = 3
lr = 1e-4
epsilon = 1e-8
splits = [0.01, 0.05, 0.1, 0.5]

train_path = '/kaggle/working/subtaskB_train.jsonl'
val_path = '/kaggle/working/subtaskB_dev.jsonl'

discriminator_save_path = 'discriminator_G2.pth'
generator_save_path = 'generator_G2.pth'
bert_save_path = 'bert_G2.pth'
report_path = 'report_GAN-BERT_G2.csv'

### Data Preprocessing

In [3]:
train_data = pd.read_json(train_path,lines=True)
val_data = pd.read_json(val_path, lines=True)

label_dict = {'chatGPT':0, 'human':1, 'cohere':2, 'davinci':3, 'bloomz':4, 'dolly':5}
label2int = lambda label: label_dict[label]

train_text = list(train_data['text'])
label_train = list(train_data['model'].apply(label2int))
text_val= list(val_data['text'])
label_val = list(val_data['model'].apply(label2int))

In [4]:
# Download and unzip wordnet
try:
    nltk.data.find('wordnet.zip')
except:
    nltk.download('wordnet', download_dir='/kaggle/working/')
    command = "unzip /kaggle/working/corpora/wordnet.zip -d /kaggle/working/corpora"
    subprocess.run(command.split())
    nltk.data.path.append('/kaggle/working/')

# Now you can import the NLTK resources as usual
from nltk.corpus import wordnet

# Download and unzip stopwords
try:
    nltk.data.find('stopwords.zip')
except:
    nltk.download('stopwords', download_dir='/kaggle/working/')
    command = "unzip /kaggle/working/corpora/stopwords.zip -d /kaggle/working/corpora"
    subprocess.run(command.split())
    nltk.data.path.append('/kaggle/working/')

# Now you can import the NLTK resources as usual
from nltk.corpus import stopwords

[nltk_data] Downloading package wordnet to /kaggle/working/...
Archive:  /kaggle/working/corpora/wordnet.zip
   creating: /kaggle/working/corpora/wordnet/
  inflating: /kaggle/working/corpora/wordnet/lexnames  
  inflating: /kaggle/working/corpora/wordnet/data.verb  
  inflating: /kaggle/working/corpora/wordnet/index.adv  
  inflating: /kaggle/working/corpora/wordnet/adv.exc  
  inflating: /kaggle/working/corpora/wordnet/index.verb  
  inflating: /kaggle/working/corpora/wordnet/cntlist.rev  
  inflating: /kaggle/working/corpora/wordnet/data.adj  
  inflating: /kaggle/working/corpora/wordnet/index.adj  
  inflating: /kaggle/working/corpora/wordnet/LICENSE  
  inflating: /kaggle/working/corpora/wordnet/citation.bib  
  inflating: /kaggle/working/corpora/wordnet/noun.exc  
  inflating: /kaggle/working/corpora/wordnet/verb.exc  
  inflating: /kaggle/working/corpora/wordnet/README  
  inflating: /kaggle/working/corpora/wordnet/index.sense  
  inflating: /kaggle/working/corpora/wordnet/data.

replace /kaggle/working/corpora/stopwords/dutch? [y]es, [n]o, [A]ll, [N]one, [r]ename:  NULL
(EOF or read error, treating as "[N]one" ...)


In [5]:
def preprocess_text(text, lemmatizer, stop_words):
    text = re.sub(r'[^\w\s]', '', text)
    text = text.lower()
    text = wordpunct_tokenize(text)
    text = [lemmatizer.lemmatize(token) for token in text]
    text = [lemmatizer.lemmatize(token, "v") for token in text]
    text = [re.sub(r'\b[0-9]+\b', '<NUM>', token) for token in text]
    text = [token for token in text if token != '<NUM>']
    text = [token for token in text if token not in stop_words]
    return TreebankWordDetokenizer().detokenize(text)

In [6]:
lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words('english'))
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

bow = torch.tensor([0], dtype = torch.long)
for text in tqdm(train_text):
    text = preprocess_text(text, lemmatizer, stop_words)
    tokens = tokenizer(text, max_length=max_length, add_special_tokens=False, truncation=False, padding=False, return_tensors='pt')
    bow = torch.cat((bow, tokens['input_ids'].view(-1)))
    
for text in tqdm(text_val):
    text = preprocess_text(text, lemmatizer, stop_words)
    tokens = tokenizer(text, max_length=max_length, add_special_tokens=False, truncation=False, padding=False, return_tensors='pt')
    bow = torch.cat((bow, tokens['input_ids'].view(-1)))

bow = bow[1:]    
with open('bow_list.pkl','wb') as f:
    pickle.dump(bow, f)


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

100%|██████████| 71027/71027 [18:11<00:00, 65.08it/s] 
100%|██████████| 3000/3000 [01:02<00:00, 47.85it/s]


In [None]:
with open('/kaggle/working/bow_list.pkl','rb') as f:
    bow = pickle.load(f)

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
splits = [0.01, 0.05, 0.1, 0.5]
train_datasets = []
for split in splits:
    labeled_text, unlabeled_text, label, _  = train_test_split(train_text,label_train,test_size=1-split)
    label = torch.tensor(label)
    tokenized_labeled_text = tokenizer(labeled_text, max_length=max_length, truncation=True, padding='max_length',return_tensors='pt')
    tokenized_unlabeled_text = tokenizer(unlabeled_text, max_length=max_length, truncation=True, padding='max_length',return_tensors='pt')
    rep_factor = int(np.log(len(unlabeled_text)/len(labeled_text)))
    if split == 0.5:
        rep_factor = 1
        
    tokenized_text = {'input_ids':torch.cat([tokenized_labeled_text['input_ids'].repeat((rep_factor,1)),tokenized_unlabeled_text['input_ids']],dim=0),
                      'attention_mask': torch.cat([tokenized_labeled_text['attention_mask'].repeat((rep_factor,1)),tokenized_unlabeled_text['attention_mask']],dim=0),
                      'label': torch.cat([label.repeat(rep_factor),-torch.ones(len(unlabeled_text))])}
    
    train_dataset = TensorDataset(tokenized_text['input_ids'],tokenized_text['attention_mask'], tokenized_text['label'].type(torch.int32))
    train_datasets.append(train_dataset)
    print(f"train dataset for split {split} added.")

with open('train_datasets.pkl','wb') as f:
     pickle.dump(train_datasets,f)

tokenized_text = tokenizer(text_val, max_length=max_length, truncation=True, padding='max_length',return_tensors='pt')
val_dataset = TensorDataset(tokenized_text['input_ids'], tokenized_text['attention_mask'], torch.tensor(label_val).type(torch.int32))
with open('val_dataset.pkl','wb') as f:
     pickle.dump(val_dataset,f)

In [4]:
with open('/kaggle/working/train_datasets.pkl','rb') as f:
    train_datasets = pickle.load(f)

with open('/kaggle/working/val_dataset.pkl','rb') as f:
    val_dataset = pickle.load(f)

In [5]:
trainLoaders = []
for train_dataset in train_datasets:
    trainLoaders.append(DataLoader(train_dataset,batch_size=batch_size,shuffle=True))

valLoader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

### Model

In [84]:
class Generator1(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(nn.Linear(100,768), nn.LeakyReLU(), nn.Dropout(p=0.1), nn.Linear(768,768))

    def forward(self):
        epsilon = torch.randn(batch_size,100).cuda()
        return self.model(epsilon)

class Generator2(nn.Module):
    def __init__(self, bow, input_size):
        super().__init__()
        self.model = Bert()
        self.bow = bow
        self.input_size = input_size

    def forward(self):
        indices = torch.randint(0, len(self.bow), (batch_size,self.input_size))
        samples = self.bow[indices]
        return self.model(samples.long(), torch.ones((batch_size, self.input_size), dtype=torch.long))


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.feat = nn.Sequential(nn.Dropout(p=0.1), nn.Linear(768,768), nn.LeakyReLU(), nn.Dropout(p=0.1))
        self.logit = nn.Linear(768,7)

    def forward(self, x):
        feat = self.feat(x)
        logit = self.logit(feat)
        return feat, logit

class Bert(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = BertModel.from_pretrained('bert-base-uncased')
        lora_config = LoraConfig(
        task_type=TaskType.FEATURE_EXTRACTION, # this is necessary
        )

        # add LoRA adaptor
        self.model = get_peft_model(self.model, lora_config)

    def forward(self, input_ids, att_mask):
        return self.model(input_ids, att_mask)[0][:,0,:]

### Training and Validation

In [11]:
class GAN_Bert_Loss(nn.Module):
    def __init__(self):
        super().__init__()
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, logits_bert, targets, logits_gen, feat_bert, feat_G):
        # logits have shape B x 7
        # targets has shape B
        probs_G = torch.nn.functional.softmax(logits_gen,dim=-1)
        lossG_unsup = -torch.mean(torch.log(1-probs_G[:,-1]+epsilon))
        lossG_feat = torch.mean(torch.pow(feat_G.mean(dim=0) - feat_bert.mean(dim=0),2))

        logits_bert_labeled = logits_bert[:,0:-1] # has shape B x 6
        label_mask = targets != -1
        lossD_sup = 0
        if label_mask.sum()>0:
            lossD_sup = self.criterion(logits_bert[:,0:-1][label_mask], targets[label_mask])

        probs_bert = torch.nn.functional.softmax(logits_bert,dim=-1)
        probs_G_detached = probs_G.detach()
        lossD_unsup = -torch.mean(torch.log(1-probs_bert[:,-1]+epsilon)) - torch.mean(torch.log(probs_G_detached[:,-1]+epsilon))


        lossD = lossD_sup + lossD_unsup
        lossG = lossG_feat + lossG_unsup
        
        return lossD, lossG

### Training G1

In [12]:
def validation(bert, discriminator, valLoader):
    with torch.no_grad():
        bert.eval()
        discriminator.eval()
        all_prediction = []
        all_targets = []
        for i, batch in tqdm(enumerate(valLoader), total=len(valLoader), desc=f'Validation'):

            input_ids = batch[0].cuda()
            att_mask = batch[1].cuda()
            targets = batch[2].type(torch.long).cuda()

            y_bert = bert(input_ids, att_mask)
            feat_bert, logit_bert = discriminator(y_bert)

            preds = logit_bert[:,0:-1].max(dim=-1)[1]
            all_prediction.append(preds.cpu())
            all_targets.append(targets.cpu())


    return f1_score(preds, targets, 'multiclass', num_classes=6), accuracy(preds, targets, 'multiclass', num_classes=6)

### Training G2

In [None]:
f1s = []
accs = []

for split, trainLoader  in zip(splits,trainLoaders):
    generator = Generator1().cuda()
    discriminator = Discriminator().cuda()
    bert = Bert().cuda()

    bert = torch.nn.parallel.DataParallel(bert, device_ids=list(range(2)), dim=0)
    generator = torch.nn.parallel.DataParallel(generator, device_ids=list(range(2)), dim=0)
    discriminator = torch.nn.parallel.DataParallel(discriminator, device_ids=list(range(2)), dim=0)

    criterion = GAN_Bert_Loss().cuda()
    gen_optimizer = AdamW(list(generator.parameters()), lr=lr)
    dis_optimizer = AdamW(list(bert.parameters())+list(discriminator.parameters()), lr=lr)
    
    num_train_steps = int(len(trainLoader) * epoch_nums)
    num_warmup_steps = int(num_train_steps * 0.1)
    scheduler_d = get_constant_schedule_with_warmup(dis_optimizer, num_warmup_steps = num_warmup_steps)
    scheduler_g = get_constant_schedule_with_warmup(gen_optimizer, num_warmup_steps = num_warmup_steps)
    
    for epoch in range(epoch_nums):
        generator.train()
        discriminator.train()
        bert.train()
        loss_gen = 0.0
        loss_dis = 0.0
        
        for i, batch in tqdm(enumerate(trainLoader), total=len(trainLoader), desc=f'({split}) epoch {epoch}'):

            input_ids = batch[0].cuda()
            att_mask = batch[1].cuda()
            targets = batch[2].type(torch.long).cuda()

            y_bert = bert(input_ids, att_mask)
            y_gen = generator()
            feat_G, logit_G = discriminator(y_gen)
            feat_bert, logit_bert = discriminator(y_bert)

            lossD, lossG = criterion(logit_bert, targets, logit_G, feat_bert, feat_G)

            gen_optimizer.zero_grad()
            lossG.backward(retain_graph=True)
            gen_optimizer.step()

            dis_optimizer.zero_grad()
            lossD.backward()
            dis_optimizer.step()

            loss_gen += lossG.item()
            loss_dis += lossD.item()

            scheduler_d.step()
            scheduler_g.step()

        print(f'loss: {(loss_gen+loss_dis)/len(trainLoader)}, Generator Loss: {loss_gen / len(trainLoader)}, Discriminator Loss: {loss_dis/len(trainLoader)}')
        f1, acc = validation(bert, discriminator, valLoader)
        print(f'f1 score: {f1.item()}, accuracy: {acc.item()}')

        torch.save(discriminator.state_dict(), f'split_{split}_'+discriminator_save_path)
        torch.save(generator.state_dict(), f'split_{split}_'+generator_save_path)
        torch.save(bert.state_dict(), f'split_{split}_'+bert_save_path)

    f1s.append(f1.item())
    accs.append(acc.item())

report = pd.DataFrame({"splits": splits, "accuracies": accs, "f1 score": f1s})
report.to_csv(report_path)

In [None]:
f1s = []
accs = []

for split, trainLoader  in zip(splits,trainLoaders):
    generator = Generator2(bow, 128).cuda()
    discriminator = Discriminator().cuda()
    bert = Bert().cuda()

    bert = torch.nn.parallel.DataParallel(bert, device_ids=list(range(2)), dim=0)
    generator = torch.nn.parallel.DataParallel(generator, device_ids=list(range(2)), dim=0)
    discriminator = torch.nn.parallel.DataParallel(discriminator, device_ids=list(range(2)), dim=0)

    criterion = GAN_Bert_Loss().cuda()
    gen_optimizer = AdamW(list(generator.parameters()), lr=lr)
    dis_optimizer = AdamW(list(bert.parameters())+list(discriminator.parameters()), lr=lr)
    
    num_train_steps = int(len(trainLoader) * epoch_nums)
    num_warmup_steps = int(num_train_steps * 0.1)
    scheduler_d = get_constant_schedule_with_warmup(dis_optimizer, num_warmup_steps = num_warmup_steps)
    scheduler_g = get_constant_schedule_with_warmup(gen_optimizer, num_warmup_steps = num_warmup_steps)
    
    for epoch in range(epoch_nums):
        generator.train()
        discriminator.train()
        bert.train()
        loss_gen = 0.0
        loss_dis = 0.0
        
        for i, batch in tqdm(enumerate(trainLoader), total=len(trainLoader), desc=f'({split}) epoch {epoch}'):

            input_ids = batch[0].cuda()
            att_mask = batch[1].cuda()
            targets = batch[2].type(torch.long).cuda()

            y_bert = bert(input_ids, att_mask)
            y_gen = generator()
            feat_G, logit_G = discriminator(y_gen)
            feat_bert, logit_bert = discriminator(y_bert)

            lossD, lossG = criterion(logit_bert, targets, logit_G, feat_bert, feat_G)

            gen_optimizer.zero_grad()
            lossG.backward(retain_graph=True)
            gen_optimizer.step()

            dis_optimizer.zero_grad()
            lossD.backward()
            dis_optimizer.step()

            loss_gen += lossG.item()
            loss_dis += lossD.item()

            scheduler_d.step()
            scheduler_g.step()

        print(f'loss: {(loss_gen+loss_dis)/len(trainLoader)}, Generator Loss: {loss_gen / len(trainLoader)}, Discriminator Loss: {loss_dis/len(trainLoader)}')
        f1, acc = validation(bert, discriminator, valLoader)
        print(f'f1 score: {f1.item()}, accuracy: {acc.item()}')

        torch.save(discriminator.state_dict(), f'split_{split}_'+discriminator_save_path)
        torch.save(generator.state_dict(), f'split_{split}_'+generator_save_path)
        torch.save(bert.state_dict(), f'split_{split}_'+bert_save_path)

    f1s.append(f1.item())
    accs.append(acc.item())

report = pd.DataFrame({"splits": splits, "accuracies": accs, "f1 score": f1s})
report.to_csv(report_path)