In [3]:
def convert_to_positive(score):
    # Assuming this converts negative scores to positive range (e.g., 0-100)
    return max(0, min(100, (score + 1) * 50))  # Example conversion, adjust as needed

def to_markdown(text):
    # Simple markdown conversion (replace with your actual implementation if different)
    return text.replace("\n", "<br>").replace("**", "<strong>").replace("*", "<em>")


In [1]:
import torch
import pickle
import torch.nn as nn
from argparse import Namespace
from transformers import DebertaModel

# Define config globally
config = Namespace(hidden_dropout_prob=0.05, hidden_size=768)

class BertForSequenceRegression(nn.Module):
    def __init__(self, pretrained_model_name="microsoft/deberta-base", num_marks=2):
        super(BertForSequenceRegression, self).__init__()
        self.num_marks = num_marks
        self.bert = DebertaModel.from_pretrained(pretrained_model_name)  # Fresh DeBERTa model
        
        self.hidden_1 = nn.Linear(2 * config.hidden_size, 2 * config.hidden_size)
        self.notline_1 = nn.ReLU()
        self.dropout_1 = nn.Dropout(config.hidden_dropout_prob)
        self.hidden_2 = nn.Linear(2 * config.hidden_size, config.hidden_size)
        self.notline_2 = nn.ReLU()
        self.dropout_2 = nn.Dropout(config.hidden_dropout_prob)
        self.hidden = nn.Linear(config.hidden_size, 128)
        self.regres = nn.Linear(128, num_marks)

    def forward(self, input_ids, content_vector, token_type_ids=None, attention_mask=None, labels=None):
        output_bert = self.bert(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask
        ).last_hidden_state.mean(1)
        h_concat = torch.cat((content_vector, output_bert), dim=1)
        hidden_vec_1 = self.hidden_1(h_concat)
        hidden_drop_1 = self.dropout_1(hidden_vec_1)
        hidden_vecn_1 = self.notline_1(hidden_drop_1)
        hidden_vec_2 = self.hidden_2(hidden_vecn_1)
        hidden_drop_2 = self.dropout_2(hidden_vec_2)
        hidden_vecn_2 = self.notline_2(hidden_drop_2)
        hidden = self.hidden(hidden_vecn_2)
        marks = self.regres(hidden)
        return marks

# Device
device = torch.device("cpu")

# Load models with fresh DeBERTa base
model_notadhd_debarta = BertForSequenceRegression().cpu()  # Default to "microsoft/deberta-base"
model_notadhd_debarta.load_state_dict(torch.load("Models/Not ADHD/Debarta/deb_model_3.pt", map_location=device))

model_adhd_debarta = BertForSequenceRegression().cpu()
model_adhd_debarta.load_state_dict(torch.load("Models/ADHD/Debarta/adhd_deb_model_3.pt", map_location=device))

# Tokenizer and input prep
with open("Models/Not ADHD/Debarta/bert_tokenizer.pkl", "rb") as f:
    tokenizer = pickle.load(f)

summary = "Hello My name is taimour"
tokens = tokenizer.tokenize(summary)
cls_token_id = tokenizer.convert_tokens_to_ids("[CLS]")
sep_token_id = tokenizer.convert_tokens_to_ids("[SEP]")
token_index = [cls_token_id] + tokenizer.convert_tokens_to_ids(tokens) + [sep_token_id]

max_length = 260
pad_token_id = tokenizer.convert_tokens_to_ids("[PAD]")
if len(token_index) < max_length:
    pad = [pad_token_id] * (max_length - len(token_index))
    token_index = token_index + pad
else:
    token_index = token_index[:max_length]

input_ids = torch.tensor([token_index], dtype=torch.long).to(device)
attention_mask = torch.ones_like(input_ids).to(device)
attention_mask[0, len(tokens) + 2:] = 0  # Mask padding tokens
content_vector = torch.randn(1, config.hidden_size).to(device)

# Perform prediction
model_notadhd_debarta.eval()
with torch.no_grad():
    scores = model_notadhd_debarta(
        input_ids=input_ids,
        content_vector=content_vector,
        attention_mask=attention_mask
    )

content_score, wording_score = scores[0, 0].item(), scores[0, 1].item()


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
positive_content_score = convert_to_positive(content_score)
positive_wording_score = convert_to_positive(wording_score)


In [5]:
positive_content_score, positive_wording_score


(20.8264023065567, 15.861183404922485)