In [1]:
import tensorflow as tf
import pandas as pd
from transformers import BertTokenizer, TFBertForSequenceClassification, AutoConfig
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, precision_score, recall_score,f1_score, accuracy_score
import numpy as np
import math
from tqdm import tqdm

In [2]:
model_name = "bert-base-uncased"
max_length_sequecnce = 256

In [3]:
bert_tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=True)
bert_model = TFBertForSequenceClassification.from_pretrained(model_name, num_labels=2)

All model checkpoint layers were used when initializing TFBertForSequenceClassification.

Some layers of TFBertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
data_preprocess = pd.read_csv("multi_balanced_data_preprocess.csv")
data_bin = data_preprocess[["tweet", "Disorder"]]
data_bin = data_bin[data_bin["Disorder"] != "CONTROL"]
# diagnosed_group = data_bin[data_bin["Disorder"] != "CONTROL"]
# diagnosed_group["Disorder"] = "DIAGNOSED"
# control_group = data_bin[data_bin["Disorder"] == "CONTROL"]
# data_bin = pd.concat([diagnosed_group, control_group], axis=0)
# encode_target = {"DIAGNOSED": 1, "CONTROL": 0}
# data_bin["Disorder"] = data_bin["Disorder"].map(encode_target)
print(data_bin["Disorder"].value_counts())

Disorder
EATING DISORDER    5283
BIPOLAR            5186
SCHIZOPHRENIA      4959
PTSD               4936
AUTISM             4860
OCD                4843
ADHD               4806
DEPRESSION         4805
ANXIETY            4707
Name: count, dtype: int64


### GAN-BERT

In [8]:
import tensorflow as tf
from transformers import TFBertModel, BertTokenizer
from tensorflow.keras.layers import Dense, Input, Dropout
from tensorflow.keras.models import Model
import numpy as np

# Load the pre-trained BERT model and tokenizer
bert_model_name = 'bert-base-uncased'
bert_model = TFBertModel.from_pretrained(bert_model_name)
tokenizer = BertTokenizer.from_pretrained(bert_model_name)

# Set the maximum sequence length
max_seq_length = 128

def create_generator(latent_dim, embedding_dim):
    # Define the generator model
    generator_input = Input(shape=(latent_dim,), dtype='float32')
    x = Dense(256, activation='relu')(generator_input)
    x = Dense(512, activation='relu')(x)
    x = Dense(embedding_dim, activation='linear')(x)
    generator = Model(generator_input, x, name='generator')
    return generator

def create_discriminator(input_shape):
    # Define the discriminator model
    discriminator_input = Input(shape=input_shape, dtype='float32')
    x = Dense(512, activation='relu')(discriminator_input)
    x = Dropout(0.5)(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.5)(x)
    x = Dense(1, activation='sigmoid')(x)
    discriminator = Model(discriminator_input, x, name='discriminator')
    return discriminator

# Parameters
latent_dim = 100
embedding_dim = 768
# learning_rate 0.001
# Create generator and discriminator
generator = create_generator(latent_dim, embedding_dim)
discriminator = create_discriminator((embedding_dim,))

# Compile the discriminator
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Create the GAN model
discriminator.trainable = False
gan_input = Input(shape=(latent_dim,))
generated_embedding = generator(gan_input)
gan_output = discriminator(generated_embedding)
gan = Model(gan_input, gan_output)
gan.compile(optimizer='adam', loss='binary_crossentropy')

# Function to encode text using BERT
def encode_texts(texts):
    input_ids = []
    attention_masks = []

    for text in texts:
        encoded = tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=max_seq_length,
            pad_to_max_length=True,
            return_attention_mask=True,
            return_tensors='tf',
        )
        input_ids.append(encoded['input_ids'])
        attention_masks.append(encoded['attention_mask'])

    input_ids = tf.concat(input_ids, axis=0)
    attention_masks = tf.concat(attention_masks, axis=0)
    return input_ids, attention_masks

# Function to get BERT embeddings
def get_bert_embeddings(texts):
    input_ids, attention_masks = encode_texts(texts)
    outputs = bert_model(input_ids, attention_mask=attention_masks)
    return outputs.last_hidden_state[:, 0, :]

# Training the GAN-BERT model
def train_gan_bert(texts, labels, epochs=10, batch_size=32):
    real_labels = np.ones((batch_size, 1))
    
    for epoch in range(epochs):
        idx = np.random.randint(0, len(texts), batch_size)
        real_texts = [texts[i] for i in idx]
        real_embeddings = get_bert_embeddings(real_texts)
        
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        generated_embeddings = generator.predict(noise)
        
        # Train the discriminator
        d_loss_real = discriminator.train_on_batch(real_embeddings, real_labels)
        d_loss_fake = discriminator.train_on_batch(generated_embeddings, real_labels)
        
        # Train the generator
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = gan.train_on_batch(noise, real_labels)  # Use real_labels here
        
        print(f"{epoch+1}/{epochs} [D loss: {d_loss_real[0] + d_loss_fake[0]}] [G loss: {g_loss}]")




Some layers from the model checkpoint at bert-base-uncased were not used when initializing TFBertModel: ['nsp___cls', 'mlm___cls']
- This IS expected if you are initializing TFBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertModel were initialized from the model checkpoint at bert-base-uncased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.


In [9]:
def classification_metrics(actuals: np.array, preds: np.array) -> pd.DataFrame:
    metrics_performance = pd.DataFrame()
    precision_0 = precision_score(actuals,preds, pos_label=0)
    precision_1 = precision_score(actuals,preds, pos_label=1)
    recall_0 = recall_score(actuals,preds,pos_label=0)
    recall_1 = recall_score(actuals,preds,pos_label=1)
    f1_0 = f1_score(actuals,preds,pos_label=0)
    f1_1 = f1_score(actuals,preds,pos_label=1)
    accuracy = accuracy_score(actuals,preds)
    metrics_performance["precision_0"] = [precision_0]
    metrics_performance["precision_1"] = [precision_1]
    metrics_performance["recall_0"] = [recall_0]
    metrics_performance["recall_1"] = [recall_1]
    metrics_performance["f1_0"] = [f1_0]
    metrics_performance["f1_1"] = [f1_1]
    metrics_performance["accuracy"] = [accuracy]
    return metrics_performance

In [21]:
def classification_metrics(actuals: np.array, preds: np.array, average:"binary") -> pd.DataFrame:
    metrics_performance = pd.DataFrame()
    precision = precision_score(actuals,preds,average=average)
    recall = recall_score(actuals,preds,average=average)
    f1 = f1_score(actuals,preds,average=average)
    accuracy = accuracy_score(actuals,preds)
    metrics_performance["precision"] = [precision]
    metrics_performance["recall"] = [recall]
    metrics_performance["f1"] = [f1]
    metrics_performance["accuracy"] = [accuracy]
    return metrics_performance

In [10]:
X = data_bin[["tweet"]]
y = data_bin[["Disorder"]]

In [11]:
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.1,random_state=42)

In [13]:
# Example usage
texts = X_train["tweet"].values.tolist()
labels = y_train["Disorder"].values.tolist()

train_gan_bert(texts, labels, epochs=10, batch_size=64)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


1/10 [D loss: 1.1940953731536865] [G loss: 0.220872700214386]
2/10 [D loss: 0.039488205686211586] [G loss: 0.006460283882915974]
3/10 [D loss: 0.001762085739756003] [G loss: 0.00010567634308245033]
4/10 [D loss: 0.0003586734101190814] [G loss: 4.347660706116585e-06]
5/10 [D loss: 3.56482897245769e-05] [G loss: 1.7233823257356562e-08]
6/10 [D loss: 7.173366027735106e-06] [G loss: 2.089856065978779e-09]
7/10 [D loss: 4.457669702591371e-06] [G loss: 4.467061387458671e-09]
8/10 [D loss: 7.710558898045946e-07] [G loss: 4.4053275567376704e-13]
9/10 [D loss: 3.7223866686690845e-07] [G loss: 6.021012410439147e-14]
10/10 [D loss: 4.90008976208974e-07] [G loss: 3.809942285349022e-17]


In [14]:
len(X_train["tweet"].values)/100

399.46

In [15]:
len(X_test["tweet"].values)

4439

In [16]:
train_preds = []
# start_point = 0
# end_point = 100
# for batch in tqdm(range(len(X_train["tweet"].values)//100)):
for train_text in X_train["tweet"].values.tolist()[:100]:
    train_embeddings = get_bert_embeddings(train_text)
    train_preds.append(np.ravel(discriminator.predict(train_embeddings[0:1]))[0])

    # print(np.squeeze(discriminator.predict(test_embeddings[0:1])))



In [17]:
test_preds = []
for test_text in X_test["tweet"].values[:100]:
    test_embeddings = get_bert_embeddings(test_text)
    test_preds.append(np.ravel(discriminator.predict(test_embeddings[0:1])))



In [18]:
train_preds = np.array(train_preds)
test_preds = np.array(test_preds)

In [None]:
train_pred = np.where(train_pred<0.6,0,1)
train_model_performance_lstm = classification_metrics(y_train,train_pred)

In [1]:
train_performance_table = classification_metrics(y_train[:100],train_preds)
test_performance_table = classification_metrics(y_test[:100],test_preds)