In [None]:
from datasets import load_dataset

from transformers import (
    AutoTokenizer,
    BertModel, BertConfig,
   )

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm

from peft import LoraConfig
import evaluate
import torch
import numpy as np

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [None]:

# import accuracy evaluation metric
accuracy = evaluate.load("accuracy")
rmse = evaluate.load('mse')

# define an evaluation function to pass into trainer later
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=1)

    return {"accuracy": accuracy.compute(predictions=predictions, references=labels)}

In [None]:
class BertForSEQCLF(nn.Module):
    def __init__(self, hidden_size, num_labels):
        super(BertForSEQCLF, self).__init__()
        self.classifier = nn.Linear(hidden_size, num_labels)

    def forward(self, sequence_output):
        logits = self.classifier(sequence_output[:, 0])  # Take the [CLS] token's hidden state
        return logits

class BertForTextSummarization(nn.Module):
    def __init__(self, hidden_size):
        super(BertForTextSummarization, self).__init__()
        self.decoder = nn.Linear(hidden_size, hidden_size)  # You may want to use a more sophisticated decoder

    def forward(self, sequence_output):
        return self.decoder(sequence_output)

import torch.nn.functional as F

class BertForSTS(nn.Module):
    def __init__(self, hidden_size):
        super(BertForSTS, self).__init__()
        self.dense = nn.Linear(hidden_size, 1)
        # self.sigmoid = nn.Sigmoid()

    def forward(self, pooled_output):
        # pooled_output = sequence_output[:, 0]  # Using [CLS] token output
        logits = self.dense(pooled_output)
        # scaled_logit = 5 * self.sigmoid(logits)
         # Approximate sigmoid using two ReLUs
        approx_sigmoid = F.relu(logits) - F.relu(logits - 5)
        return approx_sigmoid
        # return scaled_logit

class BertForQuestionAnswering(nn.Module):
    def __init__(self, hidden_size):
        super(BertForQuestionAnswering, self).__init__()
        self.qa_outputs = nn.Linear(hidden_size, 2)

    def forward(self, sequence_output):
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        return {'start_logits' :start_logits, "end_logits" : end_logits}
