In [1]:
import pandas as pd
import numpy as np
import random
import nltk
from nltk.corpus import wordnet
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import BertTokenizer, BertForMaskedLM
from transformers import RobertaTokenizer,RobertaForMaskedLM,RobertaForSequenceClassification
import argostranslate.package
import argostranslate.translate
import torch
from tqdm import tqdm

torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')
nltk.download('omw-1.4')

bert_model_name = 'bert-base-uncased'
mlm_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
mlm_bert_model = BertForMaskedLM.from_pretrained(bert_model_name)

argostranslate.package.update_package_index()

available_packages = argostranslate.package.get_available_packages()

en_to_zh = next(filter(lambda x: x.from_code == "en" and x.to_code == "zh", available_packages))
argostranslate.package.install_from_path(en_to_zh.download())

zh_to_en = next(filter(lambda x: x.from_code == "zh" and x.to_code == "en", available_packages))
argostranslate.package.install_from_path(zh_to_en.download())

def get_synonyms(word, pos):
    synonyms = set()
    for syn in wordnet.synsets(word):
        if get_wordnet_pos(pos) == syn.pos():
            for lemma in syn.lemmas():
                synonyms.add(lemma.name())
    if word in synonyms:
        synonyms.remove(word)
    return synonyms

def get_wordnet_pos(treebank_tag):
    if treebank_tag.startswith('J'):
        return wordnet.ADJ
    elif treebank_tag.startswith('V'):
        return wordnet.VERB
    elif treebank_tag.startswith('N'):
        return wordnet.NOUN
    elif treebank_tag.startswith('R'):
        return wordnet.ADV
    else:
        return None

def synonym_replacement(sentence, n=1):
    words = sentence.split()
    new_words = words.copy()
    pos_tags = nltk.pos_tag(words)
    
    random_word_list = list(set(words))
    random.shuffle(random_word_list)
    num_replaced = 0
    replacement_types = ['N', 'J']  

    for random_word, pos in pos_tags:
        if pos.startswith(tuple(replacement_types)):
            synonyms = get_synonyms(random_word, pos)
            if len(synonyms) >= 1:
                synonym = random.choice(list(synonyms))
                new_words = [synonym if word == random_word else word for word in new_words]
                num_replaced += 1
            if num_replaced >= n:
                break

    if num_replaced < n:
        for random_word, pos in pos_tags:
            if not pos.startswith(tuple(replacement_types)):  
                synonyms = get_synonyms(random_word, pos)
                if len(synonyms) >= 1:
                    synonym = random.choice(list(synonyms))
                    new_words = [synonym if word == random_word else word for word in new_words]
                    num_replaced += 1
                if num_replaced >= n:
                    break

    return ' '.join(new_words)

def random_insertion(sentence, n=1):
    words = sentence.split()
    new_words = words.copy()

    max_inserts = max(1, int(len(words) * 0.1))  
    n = min(n, max_inserts)
    
    for _ in range(n):
        add_word(new_words)
    
    return ' '.join(new_words)

def add_word(new_words):
    random_word = new_words[random.randint(0, len(new_words)-1)]
    synonyms = get_synonyms(random_word, 'n')
    if len(synonyms) < 1:
        synonyms = [random_word]
    random_synonym = random.choice(list(synonyms))
    random_idx = random.randint(0, len(new_words)-1)
    new_words.insert(random_idx, random_synonym)

def bert_embedding_replacement(sentence, top_k=5, replace_rate=0.15):
    words = sentence.split()
    num_words_to_replace = max(1, int(len(words) * replace_rate))
    words_to_replace = random.sample(range(len(words)), num_words_to_replace)

    new_words = words.copy()

    for idx in words_to_replace:
        masked_sentence = words.copy()
        masked_sentence[idx] = '[MASK]'
        masked_sentence = ' '.join(masked_sentence)

        inputs = mlm_tokenizer(masked_sentence, return_tensors='pt')
        mask_token_index = torch.where(inputs["input_ids"] == mlm_tokenizer.mask_token_id)[1]
        token_logits = mlm_bert_model(**inputs).logits
        mask_token_logits = token_logits[0, mask_token_index, :]
        top_k_tokens = torch.topk(mask_token_logits, top_k, dim=1).indices[0].tolist()

        new_word = random.choice([mlm_tokenizer.decode([token]).strip() for token in top_k_tokens])
        new_words[idx] = new_word

    return ' '.join(new_words)

def tfidf_augmentation(sentences, augmentation_rate=0.05):
    vectorizer = TfidfVectorizer()
    X = vectorizer.fit_transform(sentences)
    features = vectorizer.get_feature_names_out()

    def augment_sentence(sentence):
        try:
            words = sentence.split()
            num_words_to_augment = max(1, int(len(words) * augmentation_rate))
            tfidf_scores = vectorizer.transform([sentence]).toarray().flatten()
            word_scores = [(word, tfidf_scores[features.tolist().index(word)]) 
                           for word in words if word in features]
            word_scores.sort(key=lambda x: x[1], reverse=True)
            words_to_augment = [word for word, score in word_scores[:num_words_to_augment]]

            new_sentence = []
            for word in words:
                if word in words_to_augment:
                    synonyms = wordnet.synsets(word)
                    if synonyms:
                        synonym = random.choice(synonyms).lemmas()[0].name()
                        new_sentence.append(synonym.replace('_', ' '))
                    else:
                        new_sentence.append(word)
                else:
                    new_sentence.append(word)

            return ' '.join(new_sentence)
        except Exception as e:
            return sentence 
    return [augment_sentence(sentence) for sentence in sentences]

def argos_back_translation(sentence):
   
    chinese_text = argostranslate.translate.translate(sentence, 'en', 'zh')
    back_to_english = argostranslate.translate.translate(chinese_text, 'zh', 'en')
    return back_to_english

tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=5)
model.load_state_dict(torch.load('Labeling_Roberta_model.bin'))
model.to('cuda')

def get_predictions(sentence, label):
    inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True).to('cuda')  
    with torch.no_grad():
        outputs = model(**inputs)  
    probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
    return probs[0][label].item()
    
def process_data(file_name, aug_select_file_name):
    
    df = pd.read_csv(file_name)
    
    sentences = df['sentence'].tolist()
    tfidf_augmented_sentences = tfidf_augmentation(sentences)

    augmented_data = []
    for index, row in tqdm(df.iterrows(), total=df.shape[0], desc=f"Enhancing sentences for {file_name}"):
        sentence = row['sentence']
        label = row['label']
        
        augmentations = [
            synonym_replacement,
            argos_back_translation,
            random_insertion,
            bert_embedding_replacement,
            lambda x: tfidf_augmented_sentences[index]  
        ]

        for aug_id, augment in enumerate(augmentations):
            augmented_sentence = augment(sentence)
            probability = get_predictions(augmented_sentence, label)
            augmented_data.append([sentence, augmented_sentence, label, aug_id, probability])

    aug_df = pd.DataFrame(augmented_data, columns=['original_sentence','augmented_sentence', 'label', 'aug_select', 'probability'])
 
    best_augmentations = []
    aug_df['group_id'] = aug_df.index // 5
    
    for group_id, group in tqdm(aug_df.groupby('group_id'), total=aug_df['group_id'].nunique(), desc=f"Selecting best augmentations for {file_name}"):
        sorted_group = group.sort_values(by='probability', ascending=False)
        best_augmentations.append([
            sorted_group.iloc[0]['original_sentence'],  
            sorted_group.iloc[0]['label'],
            sorted_group.iloc[0]['aug_select'],
            sorted_group.iloc[1]['aug_select']
        ])

    final_df = pd.DataFrame(best_augmentations, columns=['sentence', 'label', 'aug_select_01', 'aug_select_02'])
    final_df.to_csv(aug_select_file_name, index=False, encoding='utf-8')

print('process test.csv...')
process_data('test.csv', 'test_aug_select.csv')




[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\Administrator\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\Administrator\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\Administrator\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Be

process train.csv...


Enhancing sentences for train.csv: 100%|█████████████████████████████████████████| 8544/8544 [1:03:56<00:00,  2.23it/s]
Selecting best augmentations for train.csv: 100%|████████████████████████████████| 8544/8544 [00:01<00:00, 4590.87it/s]


process validation.csv...


Enhancing sentences for validation.csv: 100%|██████████████████████████████████████| 1101/1101 [08:14<00:00,  2.23it/s]
Selecting best augmentations for validation.csv: 100%|███████████████████████████| 1101/1101 [00:00<00:00, 4528.02it/s]


process test.csv...


Enhancing sentences for test.csv: 100%|████████████████████████████████████████████| 2210/2210 [16:24<00:00,  2.24it/s]
Selecting best augmentations for test.csv: 100%|█████████████████████████████████| 2210/2210 [00:00<00:00, 4565.68it/s]
