# Train GPT-2 with scientific literatures from specific domain

In this notebook, I will show how to teach GPT-2 learn how to write scientific paper in cancer immunotherapy. I will use pubmed central full text (plain text) to train the GPT-2 model using   

## Preparation

Import necessary libraries, clean GPU mem cache

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline

torch.cuda.empty_cache()

#quantization_config = BitsAndBytesConfig(load_in_4bit=True)

Next, define the device, and model/tokenizer ids.

In [2]:
device = "cuda"
model_id = "gpt2"
tokenizer_id = "gpt2"

Load the LLM into memory (large LLM may cause OOM)

In [3]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=device,
    #quantization_config=quantization_config, 
    low_cpu_mem_usage = True,
    trust_remote_code=True )

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Prompt

Now let's prepare the prompt...

In [4]:
messages = [
    {"role": "user",
     "content": """
Title: Single cell RNA sequencing of primary breast cancer.
Summary: We performed single cell RNA sequencing (RNA-seq) for 549 primary breast cancer cells and lymph node metastases from 11 patients with distinct molecular subtypes (BC01-BC02, estrogen receptor positive (ER+); BC03, double positive (ER+ and HER2+); BC03LN, lymph node metastasis of BC03; BC04-BC06, human epidermal growth factor receptor 2 positive (HER2+); BC07-BC11, triple-negative breast cancer (TNBC); BC07LN, lymph node metastasis of BC07) and matched bulk tumors. We separated these single cells into epithelial tumor and tumor-infiltrating immune cells using inferred CNVs from RNA-seq. The refined single cell profiles for the tumor and immune cells provide key expression signatures of breast cancer and the surrounding microenvironment.
Overall design: All single-cell mRNA expression profiles were acquired from eleven patients (BC01-BC11) including two lymph node metastases (BC03LN, BC07LN) (549 samples). We applied four filtering criteria so as to remove samples with low sequencing quality and finally obtained 515 single cell sequencing data. Matched bulk tumor tissues and/or pooled cells were also sequenced and analyzed by the single cell RNA-seq pipeline (14 samples). Bulk tumor transcriptomes showed significant correlations with the average of single cell transcriptomes.
Question: What is the data type?
"""}
]

In [5]:
prompt = tokenizer.apply_chat_template(messages, tokenize=False)

In [6]:
pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
    )


In [7]:
generation_args = {
        "max_new_tokens": 512,
        "return_full_text": False,
        "temperature": 0.1,
        "do_sample": True,
    }

In [9]:
prompt

'<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nTitle: Single cell RNA sequencing of primary breast cancer.\nSummary: We performed single cell RNA sequencing (RNA-seq) for 549 primary breast cancer cells and lymph node metastases from 11 patients with distinct molecular subtypes (BC01-BC02, estrogen receptor positive (ER+); BC03, double positive (ER+ and HER2+); BC03LN, lymph node metastasis of BC03; BC04-BC06, human epidermal growth factor receptor 2 positive (HER2+); BC07-BC11, triple-negative breast cancer (TNBC); BC07LN, lymph node metastasis of BC07) and matched bulk tumors. We separated these single cells into epithelial tumor and tumor-infiltrating immune cells using inferred CNVs from RNA-seq. The refined single cell profiles for the tumor and immune cells provide key expression signatures of breast cancer and the surrounding microenvironment.\nOverall design: All single-cell mRNA expression profiles were acquired from eleven patients (BC01-BC11) including two lym

In [10]:
output = pipe(prompt, **generation_args)


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [10]:
print(output[0]['generated_text'])



The data type is the number of cells in a tumor cell and the number of cells in the tumor cell. The number of cells in a tumor cell is the number of cells in the tumor cell. The number of cells in the tumor cell is the number of cells in the tumor cell.
The number of cells in the tumor cell is the number of cells in the tumor cell. The number of cells in the tumor cell is the number of cells in the tumor cell.
The number of cells in the tumor cell is the number of cells in the tumor cell.
The number of cells in the tumor cell is the number of cells in the tumor cell.
The number of cells in the tumor cell is the number of cells in the tumor cell.
The number of cells in the tumor cell is the number of cells in the tumor cell.
The number of cells in the tumor cell is the number of cells in the tumor cell.
The number of cells in the tumor cell is the number of cells in the tumor cell.
The number of cells in the tumor cell is the number of cells in the tumor cell.
The number of cells in th