<a href="https://colab.research.google.com/github/rohitdutta2510/Claim-Span-identification-using-LLMs/blob/main/Claims_QnA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import json
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AdamW, get_linear_schedule_with_warmup, pipeline
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np
import ast
from tqdm import tqdm

In [None]:
# TRAIN_PATH = '/content/drive/MyDrive/MTP/Dataset/kgp_train.csv'
TRAIN_PATH = '/content/drive/MyDrive/MTP/Dataset/gen_combined.csv'
TEST_PATH = '/content/drive/MyDrive/MTP/Dataset/kgp_test.csv'
VAL_PATH = '/content/drive/MyDrive/MTP/Dataset/kgp_dev.csv'

In [None]:
# return start, end, total length
def transformTokenIndexToStringIndex(row):
    span_start = json.loads(row['span_start_index'])[0]
    span_end = json.loads(row['span_end_index'])[0]
    cur_index = 0
    start_str_index = -1
    end_str_index = -1
    tokens = ast.literal_eval(row['tokens'])
    for i, token in enumerate(tokens):
        # print(cur_index, token, len(token))
        if( i== span_start):
            start_str_index = cur_index
        if ( i == span_end):
            end_str_index = cur_index + len(token) - 1
        cur_index += len(token)  +1 # white space
    return start_str_index, end_str_index

In [None]:
def getDataset(filepath):
    data = pd.read_csv(filepath)
    data['context'] = data['tokens'].apply(lambda token: ' '.join(ast.literal_eval(token)))
    data['start_index'] = data.apply(lambda row: transformTokenIndexToStringIndex(row)[0], axis = 1)
    data['end_index'] = data.apply(lambda row: transformTokenIndexToStringIndex(row)[1], axis = 1)

    data = data[data['end_index'] != -1] # removing datapoints having end_index = -1

    return data

In [None]:
# getDataset(TRAIN_PATH).iloc[1]

In [None]:
# Hyperparameters
# Define training parameters
batch_size = 32
num_epochs = 10
learning_rate = 2e-6
# Check if a GPU is available
if torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

device = DEVICE

model_name = "deepset/roberta-base-squad2"

In [None]:
def custom_metrics(pred_starts, pred_ends, true_starts, true_ends):
    TOTAL_PRE = 0
    TOTAL_RECALL = 0
    TOTAL_F1= 0
    n = len(pred_starts)
    for i in range(n):
        pred_start = pred_starts[i]
        pred_end = pred_ends[i]
        true_start = true_starts[i]
        true_end = true_ends[i]

        span_overlap = max( min(true_end, pred_end) - max(true_start, pred_start) + 1, 0)

        pred_span_length = pred_end - pred_start + 1
        if pred_span_length <= 0:
            cur_pre = 0
        else:
            cur_pre = span_overlap / pred_span_length
        TOTAL_PRE += cur_pre

        span_recall = max( min(true_end, pred_end) - max(true_start, pred_start) + 1, 0)

        true_span_length = true_end - true_start + 1

        cur_recall = span_recall / true_span_length
        TOTAL_RECALL += cur_recall
        if (cur_recall + cur_pre) <=0:
            cur_f1 = 0
        else :
            cur_f1 = ((2*cur_recall*cur_pre)/ (cur_recall + cur_pre))
        TOTAL_F1 += cur_f1
    return ((TOTAL_PRE / n), (TOTAL_RECALL)/n, (TOTAL_F1)/n)

In [None]:
def getDataLoader(filepath):
    data = getDataset(filepath)
    questions = ["What is being claimed about covid-19 vaccine in this tweet?"]* len(data) # Why?
    contexts = data['context'].tolist()
    inputs = tokenizer(contexts, questions, padding=True, truncation=True, return_tensors="pt",return_token_type_ids=True).to(device)
    start_positions = []
    end_positions = []
    start_index = data['start_index'].tolist()
    end_index = data['end_index'].tolist()
    for i in range(len(start_index)):
        pos = None
        char_index = start_index[i]
        while pos is None:
            pos = inputs.char_to_token(i, char_index)
            char_index += 1
        start_positions.append(pos)

        pos = None
        char_index = end_index[i]
        while pos is None:
            pos = inputs.char_to_token(i, char_index)
            char_index -= 1
        end_positions.append(pos)

    start_positions = torch.tensor(start_positions)
    end_positions = torch.tensor(end_positions)

    dataset = TensorDataset(inputs.input_ids, inputs.attention_mask, start_positions, end_positions)
    return DataLoader(dataset, batch_size=batch_size)


In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name).to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/79.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/496M [00:00<?, ?B/s]

In [None]:
train_dataloader = getDataLoader(TRAIN_PATH)
val_data_loader = getDataLoader(VAL_PATH)

best_f1_score = 0

#Fine-tuning Loop
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):  # Replace with your desired number of epochs
    model.train()

    total_loss= 0
    total_precision = 0
    total_f1 = 0
    total_recall = 0
    num_batch = len(train_dataloader)
    for batch in tqdm(train_dataloader, desc="Training"):
        input_ids, attention_mask, start_positions, end_positions = [t.to(device) for t in batch]

        outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        predicted_start = outputs.start_logits.argmax(dim=1)
        predicted_end = outputs.end_logits.argmax(dim=1)
        prec, rec, f1 = custom_metrics(predicted_start, predicted_end, start_positions, end_positions)
        total_precision += prec
        total_recall += rec
        total_f1 += f1
        loss = outputs.loss
        loss.backward()
        total_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch+1}: Loss: {total_loss}, Precision: {total_precision/num_batch}, Recall: {total_recall/num_batch}, F1 Score: {total_f1/num_batch}")

    model.eval()

    true_span_start = []
    true_span_end = []
    pred_span_start = []
    pred_span_end = []

    # Wrap the test_data_loader with tqdm for the progress bar
    for batch in tqdm(val_data_loader, desc="Validation"):
        with torch.no_grad():
            input_ids, attention_mask, start_positions, end_positions = [t.to(device) for t in batch]

            outputs = model(input_ids, attention_mask=attention_mask)
            predicted_start = outputs.start_logits.argmax(dim=1)
            predicted_end = outputs.end_logits.argmax(dim=1)

            true_span_start.extend(start_positions)
            true_span_end.extend(end_positions)
            pred_span_start.extend(predicted_start)
            pred_span_end.extend(predicted_end)

    # Calculate and print precision, recall, and F1 score for the test set
    precision_val, recall_val, f1_val = custom_metrics(predicted_start, predicted_end, start_positions, end_positions)

    print(f"Val Precision: {precision_val}, Val Recall: {recall_val}, Val F1 Score: {f1_val}")

    if f1_val >= best_f1_score:
        model.save_pretrained("/content/drive/MyDrive/MTP/fine_tuned_combined_model")
        best_f1_score = f1_val


Training: 100%|██████████| 195/195 [03:50<00:00,  1.18s/it]


Epoch 1: Loss: 492.78871488571167, Precision: 0.46893808245658875, Recall: 0.5117174983024597, F1 Score: 0.44410011172294617


Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]


Val Precision: 0.5333333611488342, Val Recall: 0.952000081539154, Val F1 Score: 0.6611687541007996


Training: 100%|██████████| 195/195 [03:52<00:00,  1.19s/it]


Epoch 2: Loss: 363.16892647743225, Precision: 0.5581360459327698, Recall: 0.6677841544151306, F1 Score: 0.5647165775299072


Validation: 100%|██████████| 33/33 [00:09<00:00,  3.67it/s]


Val Precision: 0.5817949175834656, Val Recall: 0.952000081539154, Val F1 Score: 0.7006672024726868


Training: 100%|██████████| 195/195 [03:52<00:00,  1.19s/it]


Epoch 3: Loss: 334.0686229467392, Precision: 0.583550214767456, Recall: 0.7022126317024231, F1 Score: 0.5940567851066589


Validation: 100%|██████████| 33/33 [00:08<00:00,  3.69it/s]


Val Precision: 0.5817949175834656, Val Recall: 0.952000081539154, Val F1 Score: 0.7006672024726868


Training: 100%|██████████| 195/195 [03:52<00:00,  1.19s/it]


Epoch 4: Loss: 316.88877165317535, Precision: 0.5982397198677063, Recall: 0.7175949811935425, F1 Score: 0.6089369058609009


Validation: 100%|██████████| 33/33 [00:08<00:00,  3.67it/s]


Val Precision: 0.5817949175834656, Val Recall: 0.952000081539154, Val F1 Score: 0.7006672024726868


Training: 100%|██████████| 195/195 [03:52<00:00,  1.19s/it]


Epoch 5: Loss: 305.93818616867065, Precision: 0.6074138879776001, Recall: 0.7311949133872986, F1 Score: 0.620825469493866


Validation: 100%|██████████| 33/33 [00:08<00:00,  3.69it/s]


Val Precision: 0.5817949175834656, Val Recall: 0.952000081539154, Val F1 Score: 0.7006672024726868


Training: 100%|██████████| 195/195 [03:51<00:00,  1.19s/it]


Epoch 6: Loss: 297.9094614982605, Precision: 0.6100097298622131, Recall: 0.7349936962127686, F1 Score: 0.6242656707763672


Validation: 100%|██████████| 33/33 [00:08<00:00,  3.69it/s]


Val Precision: 0.5817949175834656, Val Recall: 0.952000081539154, Val F1 Score: 0.7006672024726868


Training: 100%|██████████| 195/195 [03:52<00:00,  1.19s/it]


Epoch 7: Loss: 291.67316538095474, Precision: 0.6170781850814819, Recall: 0.7390387058258057, F1 Score: 0.6307056546211243


Validation: 100%|██████████| 33/33 [00:08<00:00,  3.67it/s]


Val Precision: 0.5817949175834656, Val Recall: 0.952000081539154, Val F1 Score: 0.7006672024726868


Training: 100%|██████████| 195/195 [03:51<00:00,  1.19s/it]


Epoch 8: Loss: 283.8448808193207, Precision: 0.622242271900177, Recall: 0.7434734106063843, F1 Score: 0.635040283203125


Validation: 100%|██████████| 33/33 [00:08<00:00,  3.69it/s]


Val Precision: 0.5817949175834656, Val Recall: 0.952000081539154, Val F1 Score: 0.7006672024726868


Training: 100%|██████████| 195/195 [03:51<00:00,  1.19s/it]


Epoch 9: Loss: 278.89671808481216, Precision: 0.6267263889312744, Recall: 0.7503165602684021, F1 Score: 0.6401208639144897


Validation: 100%|██████████| 33/33 [00:08<00:00,  3.67it/s]


Val Precision: 0.5817949175834656, Val Recall: 0.952000081539154, Val F1 Score: 0.7006672024726868


Training: 100%|██████████| 195/195 [03:52<00:00,  1.19s/it]


Epoch 10: Loss: 273.03790628910065, Precision: 0.6307471990585327, Recall: 0.7565194368362427, F1 Score: 0.645780622959137


Validation: 100%|██████████| 33/33 [00:08<00:00,  3.68it/s]


Val Precision: 0.5817949175834656, Val Recall: 0.952000081539154, Val F1 Score: 0.7006672024726868


In [None]:
# Save the fine-tuned model if needed
# model.save_pretrained("fine_tuned_model")

In [None]:
# Evaluation on the test set
model.eval()
true_span_start = []
true_span_end = []
pred_span_start = []
pred_span_end = []

test_data_loader = getDataLoader(TEST_PATH)
# Wrap the test_data_loader with tqdm for the progress bar
for batch in tqdm(test_data_loader, desc="Testing"):
    with torch.no_grad():
        input_ids, attention_mask, start_positions, end_positions = [t.to(device) for t in batch]

        outputs = model(input_ids, attention_mask=attention_mask)
        predicted_start = outputs.start_logits.argmax(dim=1)
        predicted_end = outputs.end_logits.argmax(dim=1)

        true_span_start.extend(start_positions)
        true_span_end.extend(end_positions)
        pred_span_start.extend(predicted_start)
        pred_span_end.extend(predicted_end)

# Calculate and print precision, recall, and F1 score for the test set
precision_test, recall_test, f1_test = custom_metrics(predicted_start, predicted_end, start_positions, end_positions)

print(f"Test Precision: {precision_test}, Test Recall: {recall_test}, Test F1 Score: {f1_test}")

Testing: 100%|██████████| 33/33 [00:08<00:00,  3.75it/s]

Test Precision: 0.0, Test Recall: 0.0, Test F1 Score: 0.0





In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load the fine-tuned model
model_checkpoint = "/content/drive/MyDrive/MTP/fine_tuned_model"
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint).to("cpu")

# Create a Q&A pipeline
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)

# Example evaluation
test_question = "What is being claimed about covid-19 vaccine in this tweet?"
test_context = '''Hey guy, hope you all are living well. I just got to know that
Vaccine is made using human blood.
vaccine correlated to heart attack.
I think thats gross. By the way did you know about this?
'''

test_context_1 = '''Vaccine Myth Buster: Contraindication: Allergic reaction on previous dose of vaccine +ve after
 vaccination? That's because vaccine will fully protect you only after 2 wks of 2nd dose, Wear mask.
 Covid +ve = then take vaccine after 6 wks of recovery. Both vaccines are equally good.'''


test_context_2 = '''I watched a tutorial from a RN in Atlanta who was black. They said that they mRNA
protein that \u00e2 \u20ac \u2122 s in Pfizer is bad for the melanin in people who are
black. so, could imagine being black and getting the covid vaccine And 1 of the side effects 6 months is you get
vitiligo \u00f0\u0178 \u02dc \u201d'''

QA_input = {
    'question': test_question,
    'context': test_context_1
}

answer = qa_pipeline(QA_input,topk=2)

print("\nQuestion:", test_question)
print("\nAnswer:", answer)




Question: What is being claimed about covid-19 vaccine in this tweet?

Answer: [{'score': 0.17221011221408844, 'start': 39, 'end': 108, 'answer': 'Allergic reaction on previous dose of vaccine +ve after\n vaccination?'}, {'score': 0.05471419915556908, 'start': 124, 'end': 184, 'answer': 'vaccine will fully protect you only after 2 wks of 2nd dose,'}]
