In [None]:
from transformers import TrainingArguments, Trainer, AutoTokenizer
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from transformers import RobertaTokenizer, RobertaForSequenceClassification, RobertaModel,RobertaConfig
from sklearn.model_selection import StratifiedKFold
from lime.lime_text import LimeTextExplainer
from sklearn.metrics import classification_report
import warnings
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer
from sklearn.metrics import precision_recall_fscore_support, matthews_corrcoef, balanced_accuracy_score


from SupCL_Seq import SupCsTrainer
import torch.nn as nn

warnings.filterwarnings('ignore')

file_path = 'Hall_2012_cleaned.csv'
# file_path = 'Jeyaraman_2020_cleaned.csv'
# file_path = 'Radjenovic_2013_cleaned.csv'
# file_path = 'Smid_2020_cleaned.csv'

df = pd.read_csv(file_path, delimiter=',')
df = df.dropna(axis=0)
num_df = df. shape[0]

print(f"Number of data: {num_df}")
class_counts_df = df['label_included'].value_counts()
print(class_counts_df)

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary') # or 'micro', 'macro', 'weighted' based on your needs
    mcc = matthews_corrcoef(labels, predictions)
    balanced_acc = balanced_accuracy_score(labels, predictions)
    
    return {
        'mcc': mcc,
        'balanced_accuracy': balanced_acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [None]:
from nltk.corpus import wordnet, words
import random
import pandas as pd

def get_synonyms(word):
    synonyms = set()
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            synonym = lemma.name().replace('_', ' ')
            synonyms.add(synonym)
    if word in synonyms:
        synonyms.remove(word)
    return list(synonyms)

def synonym_replacement(sentence, n):
    words = sentence.split()
    new_words = words.copy()
    random_word_list = list(set([word for word in words if word.isalpha()]))
    random.shuffle(random_word_list)
    num_replaced = 0
    for random_word in random_word_list:
        synonyms = get_synonyms(random_word)
        if len(synonyms) >= 1:
            synonym = random.choice(synonyms)
            new_words = [synonym if word == random_word else word for word in new_words]
            num_replaced += 1
        if num_replaced >= n:
            break
    return ' '.join(new_words)

def random_insertion(sentence, n):
    words = sentence.split()
    for _ in range(n):
        new_synonyms = []
        random_word = random.choice(words)
        synonyms = get_synonyms(random_word)
        if synonyms:
            new_synonym = random.choice(synonyms)
            insert_position = random.randint(0, len(words))
            words.insert(insert_position, new_synonym)
    return ' '.join(words)

def random_deletion(sentence, p):
    words = sentence.split()
    if len(words) == 1:
        return sentence
    new_words = []
    for word in words:
        r = random.uniform(0, 1)
        if r > p:
            new_words.append(word)
    if len(new_words) == 0:
        return random.choice(words)
    return ' '.join(new_words)

def augment_text(df, minority_class, augment_by):
    minority_df = df[df['label_included'] == minority_class]
    n_minority = len(minority_df)
    n_augmentations = int(n_minority * augment_by)
    augmented_texts = []
    for _ in range(n_augmentations):
        original_text = random.choice(minority_df['Corpus'].tolist())
        augmented_text = original_text
        # Choose a random augmentation technique
        augmentation_type = random.choice(['synonym_replacement', 'random_insertion', 'random_deletion'])
        if augmentation_type == 'synonym_replacement':
            augmented_text = synonym_replacement(augmented_text, n=1)
        elif augmentation_type == 'random_insertion':
            augmented_text = random_insertion(augmented_text, n=1)
        elif augmentation_type == 'random_deletion':
            augmented_text = random_deletion(augmented_text, p=0.25)
        augmented_texts.append(augmented_text)
    augmented_df = pd.DataFrame(augmented_texts, columns=['Corpus'])
    augmented_df['label_included'] = minority_class
    return augmented_df

# Example usage
df_augmented = augment_text(df, minority_class=1, augment_by=0.6)
df_sample = pd.concat([df, df_augmented], ignore_index=True)
df_sample = df_sample.sample(frac=1).reset_index(drop=True)

texts = df_sample['Corpus'].tolist()
labels = df_sample['label_included'].tolist()
class_counts = df_sample['label_included'].value_counts()
print(class_counts)


In [None]:
train_x, test_texts, train_y, test_labels = train_test_split(df_sample['Corpus'], df_sample['label_included'], test_size=0.2, stratify=df_sample['label_included'], random_state=42)


In [None]:
model_name= "roberta-base"

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)


train_encodings = tokenizer(list(train_texts), truncation=True, padding=True)
val_encodings = tokenizer(list(test_texts), truncation=True, padding=True)

In [None]:
def extract_embeddings(model, dataloader):
    model.eval()
    model.to('cuda')

    embeddings = []
    for batch in dataloader:
        inputs = {key: val.to('cuda') for key, val in batch.items() if key != 'labels'}
        with torch.no_grad():
            outputs = model(**inputs)
            # Take the embeddings from the last hidden state for the [CLS] token
            cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            embeddings.append(cls_embeddings)

    # Convert list of embeddings into a single numpy array
    return np.vstack(embeddings)

In [None]:
from transformers import TrainerCallback



class MetricsLogger(TrainerCallback):
    def __init__(self):
        super().__init__()
        self.training_loss = []
        self.validation_loss = []
        self.accuracy = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        # Logs might contain training loss, validation loss, and validation metrics
        if 'loss' in logs:  # Training loss
            self.training_loss.append(logs['loss'])
        if 'eval_loss' in logs:  # Validation loss
            self.validation_loss.append(logs['eval_loss'])
        if 'eval_accuracy' in logs:  # Accuracy
            self.accuracy.append(logs['eval_accuracy'])

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)
model = RobertaModel.from_pretrained(model_name)
pre_train_embeddings = extract_embeddings(model, train_dataloader)

num_labels = len(set(labels)) 

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = torch.tensor(labels)  # Convert labels to a tensor for consistent indexing

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]  # Direct tensor indexing
        return item

    def __len__(self):
        return len(self.labels)


# Create datasets
train_dataset = CustomDataset(train_encodings, train_labels.tolist())
val_dataset = CustomDataset(val_encodings, val_labels.tolist())


In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)
model = RobertaModel.from_pretrained(model_name)
pre_train_embeddings = extract_embeddings(model, train_dataloader)

num_labels = len(set(labels)) 

In [None]:
CL_args = TrainingArguments(
        output_dir = 'iterations_1/results',
        save_total_limit = 1,
        num_train_epochs=3,
        per_device_train_batch_size=10,  
        evaluation_strategy = 'no',  # Changed to 'epoch' to ensure evaluation happens
        logging_strategy='epoch', 
        eval_steps = 500,# Log metrics at the end of each epoch
        save_strategy='epoch', 
        warmup_steps=50, 
        learning_rate = 1e-05,
        report_to ='tensorboard',
        weight_decay=0.01,               
        logging_dir='./logs',
    )
metrics_logger = MetricsLogger()

In [None]:
SupCL_trainer = SupCsTrainer.SupCsTrainer(
            w_drop_out=[0.0,0.05],
            temperature= 0.05,
            def_drop_out=0.1,
            pooling_strategy='mean',
            model = model,
            args = CL_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            tokenizer=tokenizer,
#             callbacks=[metrics_logger],
            compute_metrics=compute_metrics,
        )

torch.cuda.reset_peak_memory_stats(device='cuda')

In [None]:
SupCL_trainer.train()
SupCL_trainer.save_model('iterations_1/cs_baseline')

In [None]:
model = RobertaForSequenceClassification.from_pretrained('iterations_1/cs_baseline', num_labels=num_labels)
fine_tuned_base_model = model.roberta
# Freeze the base model
for param in model.base_model.parameters():
    param.requires_grad = False


In [None]:
post_train_embeddings = extract_embeddings(fine_tuned_base_model, train_dataloader)

In [None]:

args = TrainingArguments(
        output_dir = './results_1',
        save_total_limit = 1,
        num_train_epochs=10,
        per_device_train_batch_size=10,  
        per_device_eval_batch_size=10,
        evaluation_strategy = 'epoch',
        eval_steps = 500,
        learning_rate = 5e-03,
        logging_strategy='epoch',  # Log metrics at the end of each epoch
        save_strategy='epoch',
        report_to ='tensorboard',
        weight_decay=0.01,               
        logging_dir='./logs',
    )

In [None]:

trainer = Trainer(
            model,
            args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            tokenizer=tokenizer,
            callbacks=[metrics_logger],
            compute_metrics=compute_metrics
        )

trainer.train()

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

# Use t-SNE to reduce dimensionality for visualization
tsne = TSNE(n_components=2, random_state=42)
tsne_initial_embeddings = tsne.fit_transform(pre_train_embeddings)
tsne_final_embeddings = tsne.fit_transform(post_train_embeddings)

# Visualize the embeddings
plt.figure(figsize=(16, 8))

plt.subplot(1, 2, 1)
plt.scatter(tsne_initial_embeddings[:, 0], tsne_initial_embeddings[:, 1], c=train_labels, cmap='coolwarm')
plt.colorbar()
plt.title('Initial Embeddings')

plt.subplot(1, 2, 2)
plt.scatter(tsne_final_embeddings[:, 0], tsne_final_embeddings[:, 1], c=train_labels, cmap='coolwarm')
plt.colorbar()
plt.title('Final Embeddings After Training')

plt.show()


In [None]:
from sklearn.metrics import silhouette_score

pre_silhouette = silhouette_score(pre_train_embeddings, train_labels)
post_silhouette = silhouette_score(post_train_embeddings, train_labels)

print(f"Silhouette score before training: {pre_silhouette}")
print(f"Silhouette score after training: {post_silhouette}")

In [None]:
std_dev_metrics = {
    'precision': np.std(fold_metrics['precision']),
    'recall': np.std(fold_metrics['recall']),
    'f1': np.std(fold_metrics['f1']),
    'mcc': np.std(fold_metrics['mcc']),
    'balanced_accuracy': np.std(fold_metrics['balanced_accuracy'])
}

# Optionally, you can print these values to see them
for metric, std_dev in std_dev_metrics.items():
    print(f"The standard deviation for {metric} is {std_dev:.4f}")

In [None]:
aggregate_metrics = {metric: np.mean(values) for metric, values in fold_metrics.items()}

# Print aggregate metrics
print("Aggregate Metrics Across All Folds:")
for metric, value in aggregate_metrics.items():
    print(f"{metric}: {value:.4f}")

In [None]:
import matplotlib.pyplot as plt

import seaborn as sns


# Plotting confusion matrix with seaborn
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=['Negative', 'Positive'], yticklabels=['Negative', 'Positive'])
plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()

In [None]:
import torch
import matplotlib.pyplot as plt
from lime.lime_text import LimeTextExplainer
from lime.lime_text import IndexedString,TextDomainMapper


In [None]:
def predict_proba(texts):
    # Convert texts to the format the model expects
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
    
    # Move inputs to the correct device
    inputs = {key: tensor.to(model.device) for key, tensor in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Convert logits to probabilities
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()
    
    return probabilities

In [None]:
from lime.lime_text import LimeTextExplainer

import numpy as np
torch.cuda.empty_cache()

# Create a LIME Text Explainer
lime_explainer = LimeTextExplainer(class_names=["Irrelevant", "Relevant"])

In [None]:
import scipy as sp
def f(x):
    tv = torch.tensor(
        [
            tokenizer.encode(v, padding="max_length", max_length=500, truncation=True)
            for v in x
        ]
    ).cuda()
    outputs = model(tv)[0].detach().cpu().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores[:, 1])  # use one vs rest logit units
    return val

In [None]:
# this is for hall
# texts = ["towards logistic regression model fault prone code software project paper challenge logistic regression model able fault prone object class software project several study successful result design complexity metric purpose data exploration distribution metric varies project task project difficult first endeavor problem simple log transformation design complexity comparable project transformation project data spread data prediction model",
#        "uml design metric fault prone class java system identifying software problem implementation cheaper implementation hence fault proneness software module early software artifact software design beneficial software engineer early prediction fault motivation consideration composition usefulness uml design metric fault proneness java historical data significant industrial java system uml prediction model case study level detail message import sequence diagram significant predictor class fault proneness prediction model uml design metric better accuracy built code metric",
       
#        "nonstationary motor fault detection using recent quadratic time frequency representation use electric motor aerospace transportation industry condition time fault detection electric motor importance motor diagnostics nonstationary environment difficult sophisticated signal processing technique recent time plethora new time frequency distribution analysis nonstationary signal superior frequency resolution zhao atlas mark distribution distribution paper use new time frequency distribution nonstationary fault diagnostics electric motor common myth quadratic time frequency distribution suitable commercial implementation paper issue detail optimal discrete time implementation quadratic time frequency distribution time frequency representation digital signal processing platform method",
#        "determination bga structural defect joint defect ray laminography equipment software ray laminography recent year latest system feature small spl mu diameter location spl mu dimension ic package detail complex spaced structure microlaminographs resolution method microlaminography paper technology methodology result microlaminography system failure analysis ic packaging copper trace layer bga substrate bond wire short plane solder resist lot bga assembly subsequent verification destructive physical analysis dpa reconstruction individual solder ball assembly defect analysis normal ray dpa information"]

# # this is for Jeyaraman_2020_cleaned
# texts = ["clinical radiographical year long term outcome microfracture autologous chondrocyte implantation pair analysis purpose clinical radiographical long term outcome microfracture mfx first generation periosteum autologous chondrocyte implantation aci method subject knee joint aci microfracture post operative least year clinical pre post operative outcome numeric analog scale na pain lysholm tegner ikdc koos radiographical evaluation magnetic resonance mri assessment regenerate quality magnetic resonance observation cartilage repair tissue mocart knee osteoarthritis system mko relaxation time rt map microstructural cartilage analysis result mfx aci patient female male age year good long term outcome low pain score significant improved clinical score final lysholm functional na score higher mfx group lysholm mfx v aci na function mfx v aci mocart score qualitative difference kos analysis cartilage repair small defect significant better outcome relaxation time difference group region regenerate tissue conclusion study coherent statistical difference cartilage repair procedure mfx superior treatment small cartilage defect",
#         "novel minimally invasive technique cartilage repair human knee using arthroscopic microfracture injection mesenchymal stem cell hyaluronic acid prospective comparative study safety short term efficacy introduction current cell cartilage repair technique form scaffold separate surgical procedure novel scaffold le technique cartilage repair human knee arthroscopic microfracture outpatient intra articular injection autologous bone marrow mesenchymal stem cell msc hyaluronic acid ha material method seventy age sex lesion size knee symptomatic cartilage defect underwent cartilage repair technique open technique msc beneath periosteal patch defect prospective evaluation group international cartilage repair society icrs cartilage injury evaluation package question short form sf health survey international knee documentation committee ikdc subjective knee evaluation form lysholm knee scale tegner activity level scale postoperative magnetic resonance mri evaluation year patient significant adverse event course study final follow mean month significant improvement mean ikdc lysholm sf physical component score visual analogue pain score treatment group conclusion short term result novel technique comparable open procedure added advantage invasive single operation general anaesthesia safety efficacy ongoing trial key word chondral novel osteoarthritis regeneration",
#         "biological effect bone marrow concentrate knee pathology abstract population active incidence cost knee pain acute injury symptomatic knee osteoarthritis current treatment method short respect ability intra articular environment normal joint homeostasis basic science clinical evidence efficacy cell therapy bone marrow concentrate bmc promise nonsurgical joint treatment approach bmc inherent advantage treatment various knee pathology point care orthobiologic product delivers growth factor inflammatory protein mesenchymal stem cell evidence use bmc repair focal cartilage defect treatment generalized knee pain high quality study necessary clinical utility bmc particular attention patient selection aspiration processing reporting functional outcome",
#         "arthroscopic technique fixation three dimensional scaffold autologous chondrocyte transplantation structural property vitro model aim present study structural property matrix autologous chondrocyte implantation multiple fixation technique fresh porcine knee undergone load failure ultimate failure load yield load stiffness different technique fixation mm thick polymer fleece fixation biodegradable polylevolactide transosseous technique conventional suture fixation technique pin transosseous anchoring maximum load yield load higher group fixation group transosseous group conventional suture stiffness higher group group biomechanical data fixation technique fixation transosseous higher ultimate load yield load stiffness conventional suture technique time point data bioceed biotissue technology gmbh freiburg germany biomechanical data outstanding fixation strength arthroscopic technique bioceed matrix scaffold autologous chondrocyte transplantation thus arthroscopic fixation biomaterial patient turn research arthroscopic technique biomaterial"]
# # #this is for rajv
# texts = ["prediction fault proneness early phase object development complexity object software several metric chidamber kemerer metric object metric effectiveness viewpoint fault proneness object software evaluation metric design specification source code inner complexity class information algorithm class structure end design phase estimation fault proneness early phase effort fault newspaper publisher new method fault proneness object class early phase several complexity metric object software method checkpoint analysis design implementation phase fault prone applicable metric checkpoint",
#         "non invasive label free quantitative characterisation live cell monolayer culture paper development evanescent wave microscope predictive software quantitative study cellular process live cell quality high resolution use label technique current measurement capability live cell light microscopy critical data cell adhesion particular mouse neural stem cell tirm technique result tirm image high information content containing detail cell morphology cell adhesion combination time lapse provide information cell dynamic process motility",
#         "knowledge system faulty component detection production testing electronic device knowledge system faulty component detection identification production testing analog electronic board main part guided measuring probe diagnostic expert system result inductive machine learning technique diagnostic rule acquisition"]
# 10,01,01
texts = ["empirical approach software fault prediction measuring software quality term fault proneness data tomorrow programmer fault prone area project development faulty area previous developed project experienced professional development fault prone module person faulty area solution minimum time budget turn increase software quality client satisfaction fuzzy mean clustering technique prediction faulty non faulty module project datasets training module available nasa project cm pc jm requirement code metric combination metric model model result combination metric model best prediction model approach others literature accurate approach matlab",
         "conceptual coupling metric object oriented system coupling software maintainability metric predictor external software quality fault proneness impact analysis ripple effect change many measure object oo software specific dimension paper new set measure oo system conceptual coupling semantic information source code identifier comment case study open source software system new measure structural coupling case study conceptual coupling new dimension measure metric",
         "fault prediction capability program file logical coupling metric frequent change source file bug software metric source file paper approach set metric logical coupling source file metric historical data software change fixing post release bug propose set metric number bug capable bug prediction model experiment experimental result propose set metric number bug hence bug prediction model experiment accuracy bug predictor model"
    ]
# 01,10,10
# texts =["static analysis tool early indicator pre release defect density software development helpful early estimate defect density software component estimate fault prone area code empirical approach early prediction pre release density defect static analysis defect different static analysis tool actual pre release defect density window server strong positive correlation static analysis defect density pre release defect density pre release defect density actual pre release defect density high degree statistical significance discriminant analysis result static analysis tool high low quality component overall classification rate",
#        "fault cached history version history software system fault prone entity basic assumption fault isolation burst several related fault location likely fault location fault location location fault location location cache moment fault developer likely fault prone location useful verification validation resource fault prone file entity evaluation open source project revision cache selects source code file fault significant advance state art",
#        "fault content class challenge fault prediction today prediction possible low cost possible needing little data possible language average developer paper fault method summary available metric result sampled class fault content entire system method large software system java class line code evaluation fault generalization method good fault prone cluster possible value representative class"]
# #this is for smid
# texts =["data dependent prior mitigate small sample bias latent growth model mixed effect model mem latent growth model lgms interchangeable discipline specific nomenclature software implementation model interchangeable small sample size maximum likelihood estimation small sample bias mem lgms bayesian method dependent asymptotics issue choice factor covariance matrix prior distribution substantial influence small sample tutorial difference lgms mem data dependent prior established class method frequentist bayesian paradigm small sample bias prevalent lgm software additional programming bare minimum",
#        "country comparative survey analysis bayesian linear perspective meuleman billiet simulation study question many country accurate multilevel sem estimation comparative study author sample country accurate estimation bayesian estimation method structural equation much lower sample current study simulation meuleman billiet bayesian estimation lowest number country multilevel sem main result simulation sample country sufficient accurate bayesian estimation multilevel sem practicable number country available large scale comparative survey",
#        "structural equation interchangeable dyad structural equation sem straightforward fashion data interchangeable dyad dyad member author general strategy sem model estimation comparison fit assessment dyad level pairwise dyadic data application approach actor partner interdependence model confirmatory factor analysis latent growth curve analysis",
#        "performance method test upper level mediation presence nonnormal data monte carlo study statistical performance standard robust multilevel mediation analysis method indirect effect cluster experimental design various departure normality performance method upper level mediation process indirect effect effect group treatment person level outcome person level mediator method bias parametric percentile bootstrap empirical test best overall performance method nonnormal score distribution elevated type error rate poorer confidence interval coverage condition preliminary finding new mediation analysis method robust test indirect effect"]

for i, text in enumerate(texts):
    print(f"Explanation for Text {i+1}:")
    exp = lime_explainer.explain_instance(text, predict_proba, num_features=20, labels=[1],num_samples=50)
    fig = exp.as_pyplot_figure()
    fig.savefig('lime_explanation.png', bbox_inches='tight')  # Save as PNG
    plt.show()
    feature_names = exp.as_list()
    df = pd.DataFrame(feature_names, columns=['Feature', 'Weight'])
    print(df)
    
    exp.show_in_notebook(text=True)

In [None]:
shap_explainer = shap.Explainer(f, tokenizer)

shap_values = shap_explainer(texts, fixed_context=1)

In [None]:

for i, text_shap_values in enumerate(shap_values):
    print(f"SHAP values for Text {i + 1}:")
    shap.plots.text(text_shap_values)

In [None]:
for i, text_shap_values in enumerate(shap_values):
    print(f"SHAP values for Text {i + 1}:")
    # Display the bar plot for the current text's SHAP values
    shap.plots.bar(text_shap_values, max_display=40)
    plt.show()