In [3]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
import json

In [4]:
import json
import numpy as np
import torch

def read_question_answer_file(file_path):
    """Reads a JSONL file with question-answer data and returns a list of dictionaries."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            data.append(json.loads(line))  # Parse each line as JSON
    return data

# Load your dataset
dataset_path = r'med_data/phrases_no_exclude_train.jsonl' 
# dataset_path = r'C:\Users\ranad\OneDrive - University of Glasgow\Attachments\Msc Final Year project\Data\MedQA-USMLE-4-options\phrases_no_exclude_train.jsonl'  # Replace with the path to your JSON file
questions_data = read_question_answer_file(dataset_path)

In [5]:
print(questions_data[0]['question'])
print(questions_data[0]['answer_idx'])
print(questions_data[0]['options'])

A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?
D
{'A': 'Ampicillin', 'B': 'Ceftriaxone', 'C': 'Doxycycline', 'D': 'Nitrofurantoin'}


In [4]:
# !pip install wikipedia
# !pip install Wikipedia-API

In [6]:
import re

def clean_text(text):
  """Removes special characters and extra whitespace from text.

  Args:
    text: The input text to be cleaned.

  Returns:
    The cleaned text.
  """

  # Remove special characters using regular expression
  text = re.sub(r"[^\w\s]", "", text)

  # Remove extra whitespace
  text = " ".join(text.split())

  return text

In [7]:
def save_string_to_file(text, filename):
  """Saves a string to a text file.

  Args:
    text: The string to be saved.
    filename: The name of the file to create.
  """

  with open(filename, 'w', encoding='utf-8') as f:
    f.write(text)

# # Example usage:
# my_string = "This is the text I want to save."
# save_string_to_file(my_string, "output.txt")

In [8]:
def sanitize_filename(filename):
  """Sanitizes a filename by replacing special characters with underscores.

  Args:
    filename: The original filename.

  Returns:
    The sanitized filename.
  """

  # Replace non-alphanumeric characters with underscores
  filename = re.sub(r'[^\w]', '_', filename)

  # Remove leading and trailing underscores
  filename = filename.strip('_')

  return filename

In [9]:
def read_txt_file(file_path):
  """Reads data from a JSON file.

  Args:
    file_path: The path to the text file.

  Returns:
    The parsed txt data as a Python object.
  """

  # Open the file and read its contents
  with open(file_path, 'r') as file:
    data = file.read()
  return data

# # Example usage:
# file_path = r'C:\Users\ranad\Documents\WikiData\Perform_a_diagnostic_loop_electrosurgical_excision.txt'
# data = read_txt_file(file_path)
# print(data)

# Generate bert embeddings

In [10]:
import torch
import numpy as np
from transformers import BertTokenizer, BertModel

# Initialize the tokenizer and model (bert-base-uncased)
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Ensure the model runs on GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bert_model.to(device)


# Function to compute embeddings
def get_embeddings(text,tokenizer,model):
    
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
    inputs = {key: value.to(device) for key, value in inputs.items()}

    with torch.no_grad():  # No need to compute gradients
        outputs = model(**inputs)

    last_hidden_state = outputs.last_hidden_state
    embeddings = last_hidden_state.mean(dim=1)

    # Compute the mean of the token embeddings to get a fixed-size representation
    embeddings = embeddings.squeeze().cpu().numpy()  # (batch size x hidden size)

    return embeddings



# Generate medical domain based bert embeddings

<h1>Dictionary format data</h1>

In [11]:
# pubmed_dir = r'C:\Users\ranad\Documents\WikiData'
# pubmed_dir =r'WikiData'
pubmed_dir =r'Wiki_for_embeddings'
# Dictionary to store the original data (not embeddings) by file key


def extract_dataDict(directory):
    data_dict = {}
    # Iterate through each JSON file in the directory
    for filename in os.listdir(directory):
        filepath = os.path.join(directory, filename)
        if os.path.isfile(filepath):
            with open(filepath, 'r') as file:
                data = json.load(file)
                file_key = sanitize_filename(filename[:-4])  
                data_dict[file_key.lower()] = data
    return data_dict

data_dict = extract_dataDict(pubmed_dir)

In [10]:
print(len(data_dict))
data_dict['ampicillin']

3627


[['Ampicillin is an antibiotic belonging to the aminopenicillin class of the penicillin family. The drug is used to prevent and treat a number of bacterial infections, such as respiratory tract infections, urinary tract infections, meningitis, salmonellosis, and endocarditis. It may also be used to prevent group B streptococcal infection in newborns. It is used by mouth, by injection into a muscle, or intravenously.',
  'Common side effects include rash, nausea, and diarrhea. It should not be used in people who are allergic to penicillin. Serious side effects may include Clostridium difficile colitis or anaphylaxis. While usable in those with kidney problems, the dose may need to be decreased. Its use during pregnancy and breastfeeding appears to be generally safe.',
  "Ampicillin was discovered in 1958 and came into commercial use in 1961. It is on the World Health Organization's List of Essential Medicines. The World Health Organization classifies ampicillin as critically important f

In [12]:
from tqdm import tqdm
medQA_Filtered_data = []

for key, data in tqdm(data_dict.items()):
        for texts in data:
            for text in texts:
                if len(text.split())>15:
                    medQA_Filtered_data.append(text)

100%|██████████| 3627/3627 [00:01<00:00, 2086.25it/s]


In [13]:
len(medQA_Filtered_data)

598850

In [14]:
def getEmbeddingsListwise(data_list):
    embeddings = []
    for text in tqdm(data_list):
        embedding = get_embeddings(text, bert_tokenizer, bert_model)
        embeddings.append(embedding)
    return embeddings

fiass_embeddings = getEmbeddingsListwise(medQA_Filtered_data)

 65%|██████▍   | 388691/598850 [38:04<21:55, 159.72it/s]  IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|██████████| 598850/598850 [58:35<00:00, 170.37it/s]


# FIASS Installation and Storage

In [15]:
!pip install faiss-cpu 

Collecting faiss-cpu
  Downloading faiss_cpu-1.8.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.7 kB)
Downloading faiss_cpu-1.8.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (27.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.0/27.0 MB[0m [31m51.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.8.0.post1
[0m

In [16]:
import faiss

embeddings = np.vstack(fiass_embeddings) 

# Get the dimensionality of the embeddings
embedding_dim = embeddings.shape[1]

# Create a FAISS index
# IndexFlatL2 is a simple, exact nearest neighbor search index with L2 distance
index = faiss.IndexFlatL2(embedding_dim)

# Add embeddings to the index
index.add(embeddings)

# Check the number of vectors in the index
print(f"Number of vectors in the index: {index.ntotal}")

# Save the index to a file for later use
faiss.write_index(index, 'FAISS/QA_Wiki_index.index')

Number of vectors in the index: 598850


In [17]:
import faiss
# Load the FAISS index
faiss_index = faiss.read_index('FAISS/QA_Wiki_index.index')

# Function to search the FAISS index
def search_faiss_index(query, index, tokenizer, model, top_k=5):
    # Compute the query embedding
    query_embedding = get_embeddings(query,tokenizer, model)

    # Reshape the query embedding to match FAISS expected input format
    query_embedding = np.expand_dims(query_embedding, axis=0).astype('float32')

    # Search the FAISS index
    distances, indices = index.search(query_embedding, top_k)

    return distances, indices

In [18]:
# Example usage
query = "Nitrofurantoin"
top_k = 10

distances, indices = search_faiss_index(query, faiss_index, bert_tokenizer, bert_model, top_k=10)

# Print results
print(f"Top-{top_k} results:")
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
    print(f"Result {i + 1}:")
    print(f"Index: {idx}, Distance: {distance}")
    # If you have the original text stored, you can retrieve it like this:
    # print(f"Text: {embeddings_dict['some_key'][idx]}")

Top-10 results:
Result 1:
Index: 254068, Distance: 33.83258056640625
Result 2:
Index: 472844, Distance: 33.83258056640625
Result 3:
Index: 232256, Distance: 33.9600944519043
Result 4:
Index: 268898, Distance: 33.9600944519043
Result 5:
Index: 404617, Distance: 33.9600944519043
Result 6:
Index: 456275, Distance: 33.9600944519043
Result 7:
Index: 562584, Distance: 33.9600944519043
Result 8:
Index: 61860, Distance: 36.544273376464844
Result 9:
Index: 290549, Distance: 37.16175842285156
Result 10:
Index: 558961, Distance: 37.173057556152344


# Library installations

In [1]:
!pip install langchain
!pip install transformers
!pip install accelerate
!pip install bitsandbytes
!pip install --upgrade pip
!pip install --upgrade langchain
!pip install langchain_community
!pip list | grep langchain
!pip list | grep langchain_community

!pip install -U langchain-huggingface

Collecting langchain
  Downloading langchain-0.2.14-py3-none-any.whl.metadata (7.1 kB)
Collecting SQLAlchemy<3,>=1.4 (from langchain)
  Downloading SQLAlchemy-2.0.32-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.6 kB)
Collecting aiohttp<4.0.0,>=3.8.3 (from langchain)
  Downloading aiohttp-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.5 kB)
Collecting async-timeout<5.0.0,>=4.0.0 (from langchain)
  Downloading async_timeout-4.0.3-py3-none-any.whl.metadata (4.2 kB)
Collecting langchain-core<0.3.0,>=0.2.32 (from langchain)
  Downloading langchain_core-0.2.33-py3-none-any.whl.metadata (6.2 kB)
Collecting langchain-text-splitters<0.3.0,>=0.2.0 (from langchain)
  Downloading langchain_text_splitters-0.2.2-py3-none-any.whl.metadata (2.1 kB)
Collecting langsmith<0.2.0,>=0.1.17 (from langchain)
  Downloading langsmith-0.1.99-py3-none-any.whl.metadata (13 kB)
Collecting pydantic<3,>=1 (from langchain)
  Downloading pydantic-2.8.2-py3-none-a

# Building the pipeline with the langchain

In [2]:
import os
import langchain

### prompts
from langchain import PromptTemplate, LLMChain

### models
# from langchain.llms import HuggingFacePipeline
# from langchain.embeddings import HuggingFaceInstructEmbeddings


import torch
import transformers
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    BitsAndBytesConfig,
    pipeline
)


#model = "/kaggle/input/llama-2/pytorch/7b-chat-hf/1"
#model = "/kaggle/input/llama-3/transformers/8b-chat-hf/1"
model = "meta-llama/Meta-Llama-3-8B"
#model = "Undi95/Meta-Llama-3-8B-hf"

tokenizer = AutoTokenizer.from_pretrained(model)

        
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_compute_dtype = torch.float16,
    bnb_4bit_use_double_quant = True,
)

model_llama = AutoModelForCausalLM.from_pretrained(
    model,
    quantization_config = bnb_config,
    device_map = 'auto',
    do_sample=False,
    token='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'
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

# Designing a context summary prompt

In [21]:
# question = questions_data[0]['question']
# options = "\nA. Ampicillin\nB. Ceftriaxone\nC. Doxycycline\nD. Nitrofurantoin\n"

# option_dict = questions_data[0]['options']

# # Search FAISS index for each option and retrieve the corresponding text
# retrieved_contexts = []
# for option_key, option_text in option_dict.items():
#    text_data = data_dict.get(sanitize_filename(option_text).lower())
#    if (text_data and len(text_data)>0) :
#      retrieved_context = truncate_to_words(text_data,1000)
#      retrieved_contexts.append(retrieved_context)

# context = '\n'.join([s.strip() + '.' for s in retrieved_contexts])
# context

### FOR EMBEDDINGS ##########
question = questions_data[0]['question']
options = "\nA. Ampicillin\nB. Ceftriaxone\nC. Doxycycline\nD. Nitrofurantoin\n"

option_dict = questions_data[0]['options']

# Search FAISS index for each option and retrieve the corresponding text
retrieved_contexts = []
for option_key, option_text in option_dict.items():
    distances, indices = search_faiss_index(option_text, faiss_index, bert_tokenizer, bert_model, top_k=10)
    retrieved_context = " ".join([medQA_Filtered_data[idx] for i, (distance, idx) in enumerate(zip(distances[0], indices[0]))])
    retrieved_context = truncate_to_words(retrieved_context,1000)
    retrieved_contexts.append(retrieved_context)

context = '\n'.join([s.strip() + '.' for s in retrieved_contexts])
context

"Amantadine and rimantadine have been introduced to combat influenza. These agents act on penetration and uncoating. Amantadine and rimantadine have been introduced to combat influenza. These agents act on penetration and uncoating. Amantadine and rimantadine have been introduced to combat influenza. These agents act on penetration and uncoating. Amantadine and rimantadine have been introduced to combat influenza. These agents act on penetration and uncoating. There are also interferon-inducing drugs, notably tilorone that is shown to be effective against Ebola virus. There are also interferon-inducing drugs, notably tilorone that is shown to be effective against Ebola virus. There are also interferon-inducing drugs, notably tilorone that is shown to be effective against Ebola virus. There are also interferon-inducing drugs, notably tilorone that is shown to be effective against Ebola virus. There are also interferon-inducing drugs, notably tilorone that is shown to be effective agains

In [22]:
# Use the summarization prompt to generate a summary
summarization_prompt = f"""[INST] Summarize the following text concisely:[/INST]
{context}
"""

with torch.no_grad():
    summary_output = model_llama.generate(
        **tokenizer(summarization_prompt, return_tensors="pt", truncation=True).to(model_llama.device),
        max_new_tokens=256,
    )

# Decode and print the generated summary
summary_text = tokenizer.decode(summary_output[0], skip_special_tokens=True)
summary_text = summary_text.replace("[INST] Summarize the following text concisely:[/INST]", "").strip()
print(summary_text)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Amantadine and rimantadine have been introduced to combat influenza. These agents act on penetration and uncoating. Amantadine and rimantadine have been introduced to combat influenza. These agents act on penetration and uncoating. Amantadine and rimantadine have been introduced to combat influenza. These agents act on penetration and uncoating. Amantadine and rimantadine have been introduced to combat influenza. These agents act on penetration and uncoating. There are also interferon-inducing drugs, notably tilorone that is shown to be effective against Ebola virus. There are also interferon-inducing drugs, notably tilorone that is shown to be effective against Ebola virus. There are also interferon-inducing drugs, notably tilorone that is shown to be effective against Ebola virus. There are also interferon-inducing drugs, notably tilorone that is shown to be effective against Ebola virus. There are also interferon-inducing drugs, notably tilorone that is shown to be effective against

In [23]:
# prompt generation
template_context = """
###Context###:\n {context} \n [INST]Answer the following below question using the Context.[/INST]\n
###Question###:\n {question} \n[INST]Select the correct option only. No explanation required[/INST]\n
Options: {options}

#Answer:""" # Force a single-line response

prompt_template_context = PromptTemplate(template=template_context, input_variables=["question", "options", "context"])
prompt_context = prompt_template_context.format(question=question, options=options, context=summary_text) 

In [25]:

# pipeline(prompt_context, max_new_tokens=1)

tokenizer.padding_side = 'left'
tokenizer.pad_token = tokenizer.eos_token 
inputs = tokenizer(prompt_context, return_tensors='pt', truncation=True, padding="max_length", max_length=8000).to(model_llama.device)
outputs = model_llama.generate(**inputs, max_new_tokens=1)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



###Context###:
 Amantadine and rimantadine have been introduced to combat influenza. These agents act on penetration and uncoating. Amantadine and rimantadine have been introduced to combat influenza. These agents act on penetration and uncoating. Amantadine and rimantadine have been introduced to combat influenza. These agents act on penetration and uncoating. Amantadine and rimantadine have been introduced to combat influenza. These agents act on penetration and uncoating. There are also interferon-inducing drugs, notably tilorone that is shown to be effective against Ebola virus. There are also interferon-inducing drugs, notably tilorone that is shown to be effective against Ebola virus. There are also interferon-inducing drugs, notably tilorone that is shown to be effective against Ebola virus. There are also interferon-inducing drugs, notably tilorone that is shown to be effective against Ebola virus. There are also interferon-inducing drugs, notably tilorone that is shown to be 

In [63]:
position = response.find('#Answer:')
prediction = response[position+8 :position+10].strip()
prediction

'B'

# Data Loader with RAG using FAISS index

In [20]:
def truncate_to_words(text, max_words=4000):
  """Truncates a text string to a maximum number of words.

  Args:
    text: The input text string.
    max_words: The maximum number of words to keep.

  Returns:
    The truncated text string.
  """

  words = text.split()
  if len(words) <= max_words:
    return text
  else:
    truncated_text = ' '.join(words[:max_words])
    return truncated_text

In [27]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from tqdm import tqdm  # For progress bar

class QuestionAnswerDataset_RAG(Dataset):
    def __init__(self, questions_data, faiss_texts):
        self.questionData = questions_data
        self.faiss_texts = faiss_texts
    

    def __len__(self):
        return len(self.questionData)

    def __getitem__(self, idx):
        question_data = self.questionData[idx]
        question = question_data['question']
        options = question_data['options']
        options_str = "\n".join([f"{key}. {value}" for key, value in options.items()])
        answer = question_data['answer_idx']
        
        # Search FAISS index for each option and retrieve the corresponding text
        retrieved_contexts = []
        for option_key, option_text in options.items():
            # code for full text extarction.
            # ctx = self.faiss_texts.get(sanitize_filename(option_text).lower())
            # if (ctx and len(ctx)>0) :
            #     retrieved_context = truncate_to_words(ctx,1000)
            #     retrieved_contexts.append(retrieved_context)
            distances, indices = search_faiss_index(option_text, faiss_index, bert_tokenizer, bert_model, top_k=10)
            retrieved_context = " ".join([medQA_Filtered_data[idx] for i, (distance, idx) in enumerate(zip(distances[0], indices[0]))])
            retrieved_context = truncate_to_words(retrieved_context,1000)
            retrieved_contexts.append(retrieved_context)

        # Combine the contexts with the question and options for the final return
        combined_context = truncate_to_words("\n".join(retrieved_contexts))

        # Use the summarization prompt to generate a summary
        summarization_prompt = f"""[INST] Summarize the following text concisely:[/INST]
        {combined_context}
        """
        
        with torch.no_grad():
            summary_output = model_llama.generate(
                **tokenizer(summarization_prompt, return_tensors="pt").to(model_llama.device),
                max_new_tokens=256,
            )
        
        # Decode and print the generated summary
        summary_text = tokenizer.decode(summary_output[0], skip_special_tokens=True)
        summary_text = summary_text.replace("[INST] Summarize the following text concisely:[/INST]", "").strip()

        return question, options_str, answer, summary_text

    def get_text_embedding(self, text):
        # Placeholder for the actual embedding generation logic
        # Replace this with the method to generate embeddings from text
        return np.random.rand(768)  # Example: Replace with actual embedding



# Load your dataset
dataset_RAG = QuestionAnswerDataset_RAG(questions_data[:1000], data_dict)
dataloader_RAG = DataLoader(dataset_RAG, batch_size=4, shuffle=False)  # Adjust batch_size as needed

# Running Batches with context

In [28]:
correct_predictions = 0
total_predictions = 0
responses = []
answers = []

for batch in tqdm(dataloader_RAG):
    questions, options_strs, answer_idxs, combined_contexts = batch
    tokenizer.padding_side = 'left'
    tokenizer.pad_token = tokenizer.eos_token
    prompts = [prompt_template_context.format(question=question, options=options_str, context=combined_context) for question, options_str,combined_context in zip(questions, options_strs,combined_contexts)]
    
    inputs = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True, max_length=8000).to(model_llama.device)
    
    with torch.no_grad():
        outputs = model_llama.generate(**inputs,max_new_tokens=1)
    
    decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    #print(decoded_outputs)
    for decoded_output, answer in zip(decoded_outputs, answer_idxs):
        position = decoded_output.find('#Answer:')
        answer_pred = decoded_output[position+8 :position+10].strip()
        #print(answer_pred)
        if answer == answer_pred.strip():
            correct_predictions += 1
        
        responses.append(answer_pred)
        answers.append(answer)
        total_predictions += 1

  0%|          | 0/250 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  0%|          | 1/250 [00:47<3:18:22, 47.80s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  1%|          | 2/250 [01:36<3:19:21, 48.23s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end

OutOfMemoryError: CUDA out of memory. Tried to allocate 7.13 GiB. GPU 0 has a total capacty of 23.70 GiB of which 6.07 GiB is free. Process 130156 has 17.62 GiB memory in use. Of the allocated memory 11.23 GiB is allocated by PyTorch, and 5.31 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [66]:
print(f"Accuracy: {correct_predictions / len(responses):.2%}")
correct_predictions

Accuracy: 36.90%


369

In [44]:
responses.count('n:')
# responses.index('n:')
# question
responses
filtered_list = [element for element in responses if element not in ['A', 'B', 'C', 'D']]
filtered_list

['C',
 'A',
 'C',
 'A',
 'D',
 'B',
 'D',
 'B',
 'D',
 'B',
 'D',
 'A',
 'D',
 'C',
 'C',
 'D',
 'D',
 'D',
 'D',
 'D',
 'A',
 'D',
 'A',
 'C',
 'D',
 'D',
 'D',
 'C',
 'D',
 'D',
 'B',
 'D',
 'A',
 'A',
 'C',
 'A',
 'A',
 'A',
 'D',
 'C',
 'B',
 'D',
 'C',
 'A',
 'D',
 'A',
 'D',
 'C',
 'A',
 'C',
 'D',
 'D',
 'D',
 'A',
 'B',
 'D',
 'A',
 'D',
 'D',
 'D',
 'D',
 'D',
 'A',
 'B',
 'D',
 'A',
 'A',
 'D',
 'D',
 'A',
 'A',
 'A',
 'B',
 'A',
 'B',
 'A',
 'A',
 'D',
 'D',
 'A',
 'B',
 'A',
 'C',
 'D',
 'A',
 'B',
 'D',
 'A',
 'D',
 'A',
 'A',
 'D',
 'B',
 'A',
 'A',
 'B',
 'A',
 'A',
 'D',
 'C',
 'D',
 'D',
 'D',
 'C',
 'D',
 'A',
 'A',
 'D',
 'A',
 'D',
 'A',
 'A',
 'A',
 'B',
 'D',
 'A',
 'A',
 'D',
 'D',
 'D',
 'D',
 'A',
 'A',
 'C',
 'B',
 'D',
 'B',
 'D',
 'B',
 'A',
 'A',
 'D',
 'D',
 'C',
 'D',
 'D',
 'A',
 'C',
 'D',
 'D',
 'D',
 'A',
 'D',
 'D',
 'A',
 'C',
 'D',
 'C',
 'B',
 'C',
 'B',
 'A',
 'A',
 'D',
 'C',
 'D',
 'A',
 'D',
 'D',
 'A',
 'A',
 'D',
 'C',
 'C',
 'B',
 'A',
 'B'

In [None]:
def free_gpu_cache():
  """Frees the GPU cache and memory."""
  if torch.cuda.is_available():
      with torch.cuda.device(0):
          torch.cuda.empty_cache()