In [1]:
import time
import sys
from pyspark import SparkConf, SparkContext
import json
from pyspark.sql import SparkSession
import math 
import numpy as np
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from pyspark.sql.functions import col, pandas_udf, PandasUDFType, udf
from pyspark.sql.types import ArrayType, FloatType, DoubleType, IntegerType
from transformers import BertTokenizer
from pytorch_pretrained_bert import BertModel, BertForMaskedLM
from pytorch_pretrained_bert import BertConfig
from pyspark.ml.functions import predict_batch_udf

  from .autonotebook import tqdm as notebook_tqdm


## Model

In [2]:
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)

In [3]:
class BertForSequenceClassification(nn.Module):
    """BERT model for classification.
    This module is composed of the BERT model with a linear layer on top of
    the pooled output.
    Params:
        `config`: a BertConfig class instance with the configuration to build a new model.
        `num_labels`: the number of classes for the classifier. Default = 2.
    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
            with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts
            `extract_features.py`, `run_classifier.py` and `run_squad.py`)
        `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
            types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
            a `sentence B` token (see BERT paper for more details).
        `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
            selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
            input sequence length in the current batch. It's the mask that we typically use for attention when
            a batch has varying length sentences.
        `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
            with indices selected in [0, ..., num_labels].
    Outputs:
        if `labels` is not `None`:
            Outputs the CrossEntropy classification loss of the output with the labels.
        if `labels` is `None`:
            Outputs the classification logits of shape [batch_size, num_labels].
    Example usage:
    ```python
    # Already been converted into WordPiece token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
    config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
    num_labels = 2
    model = BertForSequenceClassification(config, num_labels)
    logits = model(input_ids, token_type_ids, input_mask)
    ```
    """
    def __init__(self, num_labels=[2,3]): # Change number of labels here.
        super(BertForSequenceClassification, self).__init__()
        self.num_labels = num_labels
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.bert_gate = nn.Sequential(
                    nn.Linear(1, config.hidden_size),
                    nn.ReLU(),
                    nn.Linear(config.hidden_size, config.hidden_size),
                    nn.Sigmoid(),
                )

        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier0 = nn.Linear(config.hidden_size*2, num_labels[0])
        self.classifier1 = nn.Linear(config.hidden_size, num_labels[1])
        #self.fc1 = nn.Linear(config.hidden_size*2, 512)
        nn.init.xavier_normal_(self.bert_gate[0].weight)
        nn.init.xavier_normal_(self.bert_gate[2].weight)
        nn.init.xavier_normal_(self.classifier0.weight)
        nn.init.xavier_normal_(self.classifier1.weight)

    '''def forward_once(self, x):
        # Forward pass
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output'''

    def forward_once(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)
        #logits = self.classifier(pooled_output

        return pooled_output

    def forward(self, task, task_features, input_ids1, input_ids2):
        if task == 'fakenews':
            # forward pass of input 1
            print(task_features.shape)
            output1 = 2*self.bert_gate(task_features) * self.forward_once(input_ids1, token_type_ids=None, attention_mask=None, labels=None)
            # forward pass of input 2
            output2 = 2*self.bert_gate(task_features) * self.forward_once(input_ids2, token_type_ids=None, attention_mask=None, labels=None)

            out = torch.cat((output1, output2), 1)
            #print(out.shape)

            logits = self.classifier0(out)
        elif task == 'sentimental':
            # forward pass of input 1
            output1 = 2*self.bert_gate(task_features) * self.forward_once(input_ids1, token_type_ids=None, attention_mask=None, labels=None)
            
            #print(out.shape)
            logits = self.classifier1(output1)

        return logits

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True

## Test with Pyspark

##### RDD

In [4]:
spark = SparkSession.builder \
    .master('local[*]') \
    .config("spark.driver.memory", "24g") \
    .appName('my-cool-app') \
    .getOrCreate()

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
best_model_wts = 'bert_model_test_noFC1_triBERT_binary_focalloss.pth'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
net = BertForSequenceClassification()
# net = BERT_CLASSIFIER(5, 100)

state_dict = torch.load(best_model_wts)
bc_model_state = spark.sparkContext.broadcast(state_dict)

csv_file_path = "sentimental/sentimental_data.csv"
# Read the CSV file into a DataFrame
# df = spark.read.parquet(csv_file_path, header=True, inferSchema=True)
df = spark.read.csv(csv_file_path, header=True, inferSchema=True)

rdd = df.limit(10000).rdd
# print(rdd.take(5))
def get_model_for_eval():
    # Broadcast the model state_dict
    # Load the state dictionary into the model
    # net.load_state_dict(bc_model_state.value)
    net.to(device)
    net.eval()
    return net

def compute_prediction(data):
    # data = sc.parallelize(candidate_data)
    def preprocess_text(row):
        sentiment_model = get_model_for_eval()
        tokenized_text = tokenizer.encode_plus(
            row['text of the tweet'],
            add_special_tokens=True,
            max_length=128,
            return_token_type_ids=True,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
        prediction = sentiment_model(task = 'sentimental', task_features = torch.tensor([[1.]]).to(device) , input_ids1 = tokenized_text['input_ids'].to(device), input_ids2 = None)
        prediction = torch.argmax(prediction, dim = -1).cpu().detach().item()
        id_ = row['id of the tweet']
        return (id_, (prediction))
     
    data_predict = data.map(preprocess_text)

    return data_predict 
start_time = time.time()
data_predict = compute_prediction(rdd)
output = data_predict.collect()
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time} seconds")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/11/15 18:14:44 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/11/15 18:14:45 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


cuda:0


[Stage 3:>                                                          (0 + 1) / 1]

Elapsed time: 196.22374486923218 seconds


                                                                                

##### UDF

In [None]:
spark = SparkSession.builder \
    .master('local[*]') \
    .config("spark.driver.memory", "24g") \
    .appName('my-cool-app') \
    .getOrCreate()

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
best_model_wts = 'bert_model_test_noFC1_triBERT_binary_focalloss.pth'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
net = BertForSequenceClassification()
# net = BERT_CLASSIFIER(5, 100)

state_dict = torch.load(best_model_wts)
bc_model_state = spark.sparkContext.broadcast(state_dict)


# kk = ['I would like', 'I love u', 'Hi you are nice']
# df = spark.sparkContext.parallelize([[kk[j]] for j in range(3)]).toDF()

csv_file_path = "sentimental/sentimental_data.csv"
# Read the CSV file into a DataFrame
# df = spark.read.parquet(csv_file_path, header=True, inferSchema=True)
df = spark.read.csv(csv_file_path, header=True, inferSchema=True)
print(df.count())
def get_model_for_eval():
  # Broadcast the model state_dict
  # Load the state dictionary into the model
  net.load_state_dict(bc_model_state.value)
  net.to(device)
  net.eval()
  return net

def one_row_predict(x):
    model = get_model_for_eval()
    tokenized_text = tokenizer.encode_plus(
            x,
            add_special_tokens=True,
            max_length=128,
            return_token_type_ids=True,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
    prediction = model(task = 'sentimental', task_features = torch.tensor([[1.]]).to(device) , input_ids1 = tokenized_text['input_ids'].to(device), input_ids2 = None)
    prediction = torch.argmax(prediction, dim = -1).cpu().detach().item()
    return prediction
start_time = time.time()
one_row_udf = udf(one_row_predict, IntegerType())
df = df.withColumn('pred_one_row', one_row_udf(col('text of the tweet')))

# df.write.csv("predictions.csv", header=True, mode='overwrite')
df.show(10)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time} seconds")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/11/15 17:37:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/11/15 17:37:54 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


cuda:0


KeyboardInterrupt: 