# Dataset tokenization

Some code from https://github.com/huggingface/olm-training

In [1]:
from transformers import (
    AutoTokenizer,
)
from datasets import load_dataset

In [3]:
dataset = load_dataset("text", 
                       data_files={"train": "data/train.txt", 
                                   "dev": "data/dev.txt"},
                      cache_dir="data/cache")

In [5]:
tokenizer = AutoTokenizer.from_pretrained("g5_tokenizer")

In [6]:
def tokenize(example):
    tokenized_example = tokenizer(
       example["text"], return_special_tokens_mask=True
    )
    return tokenized_example

In [None]:
tokenized_ds = dataset.map(tokenize, remove_columns=["text"], batched=True)

In [None]:
max_len = 1110 # This number is to have an actual input size of 1000 for the model

# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We add a little padding so these tokens can be evenly split into examples with max_len # of tokens.
    if total_length >= args.max_len:
        remainder  = total_length - (total_length // max_len) * max_len
        if remainder > 0:
            concatenated_examples["input_ids"] += [tokenizer.pad_token_id]*(max_len - remainder)
            concatenated_examples["special_tokens_mask"] += [1]*(max_len - remainder)
            concatenated_examples["attention_mask"] += [0]*(max_len - remainder)
            if "token_type_ids" in concatenated_examples:
                # token_type_ids is 0 - we don't support next-sentence-prediction.
                concatenated_examples["token_type_ids"] += [0]*(max_len - remainder)
            total_length = len(concatenated_examples[list(examples.keys())[0]])
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + max_len] for i in range(0, total_length, max_len)]
        for k, t in concatenated_examples.items()
    }
    return result

In [None]:
# Note that because the batch size is 1000, the fraction of examples with pad tokens will only be <= 1/1000.
# The rest of the examples will have a full max_len tokens without padding.
tokenized_ds = tokenized_ds.map(group_texts, batched=True, batch_size=1000, num_proc=4)

In [None]:
print(f"the dataset contains in total {len(tokenized_ds)*max_len} tokens")

tokenized_ds.save_to_disk("g5_dataset")

## ensembl ping for translation task

In [32]:
from biomart import BiomartServer
import time
from tqdm import tqdm

import pandas as pd
mart_df = pd.read_csv("data/mart_export.csv")
df = mart_df[["Gene stable ID", "Mouse gene stable ID"]].copy()

# connect
server = BiomartServer("http://www.ensembl.org/biomart")
mart = server.datasets['hsapiens_gene_ensembl']

def fetch_sequences_in_batches(gene_ids, batch_size):
    sequences = {}
    for i in tqdm(range(0, len(gene_ids), batch_size)):
        batch_ids = gene_ids[i:i + batch_size]
        response = mart.search({
            'filters': {
                'ensembl_gene_id': batch_ids
            },
            'attributes': [
                'ensembl_gene_id', 'coding'
            ]
        })
        for line in response.iter_lines():
            parts = line.decode('utf-8').split("\t")
            gene_id, sequence = parts[0], parts[1]
            sequences[gene_id] = sequence
        time.sleep(0.5)
    return sequences

# Fetch sequences in batches
gene_ids = df['Gene stable ID'].tolist()
sequences = fetch_sequences_in_batches(gene_ids, batch_size=250)

# Add the sequences to the DataFrame
df['gene_sequence'] = df['Gene stable ID'].map(sequences)
flipped = dict((v,k) for k,v in sequences.items())

100%|██████████| 718/718 [2:29:04<00:00, 12.46s/it]  


## preprocess translation data

In [1]:
import pandas as pd

seq_df = pd.read_csv("../data/prot_sequences_dedup.csv").dropna()

In [2]:
from datasets import Dataset

dataset = Dataset.from_pandas(seq_df)
dataset = dataset.remove_columns(["Gene stable ID", "Protein stable ID", "Mouse protein or transcript stable ID", "Mouse gene stable ID"])
dataset

Dataset({
    features: ['hum_seq', 'mouse_seq'],
    num_rows: 116025
})

In [6]:
from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50")

def preprocess_function(examples):
    model_inputs = tokenizer(examples["hum_seq"], text_target=examples["mouse_seq"], max_length=512, padding='max_length', truncation=True)
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/116025 [00:00<?, ? examples/s]

In [8]:
cleaned_dataset = tokenized_dataset.remove_columns(["hum_seq", "mouse_seq"])

In [9]:
cleaned_dataset.save_to_disk("../g5_prot_translation_data")
cleaned_dataset

Saving the dataset (0/2 shards):   0%|          | 0/116025 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 116025
})

## test translations

In [1]:
import pandas as pd

seq_df = pd.read_csv("../data/mart_sequences.csv").dropna()

In [2]:
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
)

In [3]:
tokenizer = AutoTokenizer.from_pretrained("../g5_tokenizer")
model = AutoModelForSeq2SeqLM.from_pretrained("../g5_human_mouse_finetune_v2/model")



In [54]:
human_seq = "ATACCCATGGCCAACCTCCTACTCCTCATTGTACCCATTCTAATCGCAATGGCATTCCTAATGCTTACCGAACGAAAAATTCTAGGCTATATACAACTACGCAAAGGCCCCAACGTTGTAGGCCCCTACGGGCTACTACAACCCTTCGCTGACGCCATAAAACTCTTCACCAAAGAGCCCCTAAAACCCGCCACATCTACCATCACCCTCTACATCACCGCCCCGACCTTAGCTCTCACCATCGCTCTTCTACTATGAACCCCCCTCCCCATACCCAACCCCCTGGTCAACCTCAACCTAGGCCTCCTATTTATTCTAGCCACCTCTAGCCTAGCCGTTTACTCAATCCTCTGATCAGGGTGAGCATCAAACTCAAACTACGCCCTGATCGGCGCACTGCGAGCAGTAGCCCAAACAATCTCATATGAAGTCACCCTAGCCATCATTCTACTATCAACATTACTAATAAGTGGCTCCTTTAACCTCTCCACCCTTATCACAACACAAGAACACCTCTGATTACTCCTGCCATCATGACCCTTGGCCATAATATGATTTATCTCCACACTAGCAGAGACCAACCGAACCCCCTTCGACCTTGCCGAAGGGGAGTCCGAACTAGTCTCAGGCTTCAACATCGAATACGCCGCAGGCCCCTTCGCCCTATTCTTCATAGCCGAATACACAAACATTATTATAATAAACACCCTCACCACTACAATCTTCCTAGGAACAACATATGACGCACTCTCCCCTGAACTCTACACAACATATTTTGTCACCAAGACCCTACTTCTAACCTCCCTGTTCTTATGAATTCGAACAGCATACCCCCGATTCCGCTACGACCAACTCATACACCTCCTATGAAAAAACTTCCTACCACTCACCCTAGCATTACTTATATGATATGTCTCCATACCCATTACAATCTCCAGCATTCCCCCTCAAACCTA"

inputs = tokenizer(human_seq, return_tensors="pt").input_ids

outputs = model.generate(inputs, max_new_tokens=1100, penalty_alpha=0.6, top_k=4, )

pred_mouse_seq = tokenizer.decode(outputs[0], skip_special_tokens=True)

In [55]:
pred_mouse_seq

'GGCTGGAGGAGGATGAACACTATGATTACCACCAGGAGATTGCCAGGTCATCCTATGCCGACATGCTACATGACAAAGACAGAAATATAAAATACTACCAGGGTATCCGGGCAGCTGTGAGCAGGGTGAAAGACAGAGGACAGAAGGCCTTGGTTCTTGACATTGGCACTGGCACAGGCCTCTTGTCAATGATGGCAGTTACTGCAGGGGCTGACTTCTGCTATGCTATCGAGGTTTTTAAGCCTATGGCTGAGGCTGCTGTGAAGATTGTGGAGAGGAATGGCTTCAGTGATAAGATTAAAGTCATTAACAAGCACTCCACTGAGGTGACAGTCGGACCAGATGGTGACTTGCCGTGTCGTGCTAACATTCTGATCACGGAGCTGTTTGACACAGAGCTGATTGGGGAGGGAGCGCTGCCCTCTTATGAGCATGCACACAAGCATCTTGTCCAGGAAGACTGCGAGGCAGTGCCACACAGGGCAACTGTCTATGCCCAGCTGGTGGAGTCCCGAAGGATGTGGTCCTGGAACAAGCTGTTTCCCGTCCGTGTCCGGACGAGTCTAGGCGAGCAGGTCATCGTCCCCCCCTCAGAATTGGAGAGGTGTCCTGGTGCGCCTTCAGTCTGTGACATTCAGCTGAACCAGGTGTCGCCTGCTGACTTCACTGTCCTCAGTGATGTGCTGCCAATGTTCAGCGTGGACTTCAGCAAGCAAGTCAGCAGCTCGGCAGCGTGCCATAGCAGGCAGTTTGTACCTTTGGCGTCTGGCCAAGCACAGGTGGTTCTGTCCTGGTGGGACATTGAAATGGACCCTGAGGGCAAGATCAAGTGCACCATGGCACCCTTTTGGGCACAGACAGATCCGCAGGAGCTTCAGGTAAGAGGCAGGAGCTGA'

In [49]:
seq_df.iloc[0]["gene_sequence"]

'ATACCCATGGCCAACCTCCTACTCCTCATTGTACCCATTCTAATCGCAATGGCATTCCTAATGCTTACCGAACGAAAAATTCTAGGCTATATACAACTACGCAAAGGCCCCAACGTTGTAGGCCCCTACGGGCTACTACAACCCTTCGCTGACGCCATAAAACTCTTCACCAAAGAGCCCCTAAAACCCGCCACATCTACCATCACCCTCTACATCACCGCCCCGACCTTAGCTCTCACCATCGCTCTTCTACTATGAACCCCCCTCCCCATACCCAACCCCCTGGTCAACCTCAACCTAGGCCTCCTATTTATTCTAGCCACCTCTAGCCTAGCCGTTTACTCAATCCTCTGATCAGGGTGAGCATCAAACTCAAACTACGCCCTGATCGGCGCACTGCGAGCAGTAGCCCAAACAATCTCATATGAAGTCACCCTAGCCATCATTCTACTATCAACATTACTAATAAGTGGCTCCTTTAACCTCTCCACCCTTATCACAACACAAGAACACCTCTGATTACTCCTGCCATCATGACCCTTGGCCATAATATGATTTATCTCCACACTAGCAGAGACCAACCGAACCCCCTTCGACCTTGCCGAAGGGGAGTCCGAACTAGTCTCAGGCTTCAACATCGAATACGCCGCAGGCCCCTTCGCCCTATTCTTCATAGCCGAATACACAAACATTATTATAATAAACACCCTCACCACTACAATCTTCCTAGGAACAACATATGACGCACTCTCCCCTGAACTCTACACAACATATTTTGTCACCAAGACCCTACTTCTAACCTCCCTGTTCTTATGAATTCGAACAGCATACCCCCGATTCCGCTACGACCAACTCATACACCTCCTATGAAAAAACTTCCTACCACTCACCCTAGCATTACTTATATGATATGTCTCCATACCCATTACAATCTCCAGCATTCCCCCTCAAACCTA'

In [48]:
seq_df.iloc[0]["mouse_gene_sequence"]

'GTGTTCTTTATTAATATCCTAACACTCCTCGTCCCCATTCTAATCGCCATAGCCTTCCTAACATTAGTAGAACGCAAAATCTTAGGGTACATACAACTACGAAAAGGCCCTAACATTGTTGGTCCATACGGCATTTTACAACCATTTGCAGACGCCATAAAATTATTTATAAAAGAACCAATACGCCCTTTAACAACCTCTATATCCTTATTTATTATTGCACCTACCCTATCACTCACACTAGCATTAAGTCTATGAGTTCCCCTACCAATACCACACCCATTAATTAATTTAAACCTAGGGATTTTATTTATTTTAGCAACATCTAGCCTATCAGTTTACTCCATTCTATGATCAGGATGAGCCTCAAACTCCAAATACTCACTATTCGGAGCTTTACGAGCCGTAGCCCAAACAATTTCATATGAAGTAACCATAGCTATTATCCTTTTATCAGTTCTATTAATAAATGGATCCTACTCTCTACAAACACTTATTACAACCCAAGAACACATATGATTACTTCTGCCAGCCTGACCCATAGCCATAATATGATTTATCTCAACCCTAGCAGAAACAAACCGGGCCCCCTTCGACCTGACAGAAGGAGAATCAGAATTAGTATCAGGGTTTAACGTAGAATACGCAGCCGGCCCATTCGCGTTATTCTTTATAGCAGAGTACACTAACATTATTCTAATAAACGCCCTAACAACTATTATCTTCCTAGGACCCCTATACTATATCAATTTACCAGAACTCTACTCAACTAACTTCATAATAGAAGCTCTACTACTATCATCAACATTCCTATGGATCCGAGCATCTTATCCACGCTTCCGTTACGATCAACTTATACATCTTCTATGAAAAAACTTTCTACCCCTAACACTAGCATTATGTATGTGACATATTTCTTTACCAATTTTTACAGCGGGAGTACCACCATACATATAG'

In [59]:
def percent_match(a, b):
    i = 0
    print(a,b)
    for x, y in zip(a, b):
        if x == y:
            i += 1
    
    return i / len(a)

def translate(x):
    return Seq(x).translate()

percent_match(human_seq, seq_df.iloc[0]["mouse_gene_sequence"])
percent_match(translate(human_seq), translate(seq_df.iloc[0]["mouse_gene_sequence"]))
print(translate(pred_mouse_seq))

ATACCCATGGCCAACCTCCTACTCCTCATTGTACCCATTCTAATCGCAATGGCATTCCTAATGCTTACCGAACGAAAAATTCTAGGCTATATACAACTACGCAAAGGCCCCAACGTTGTAGGCCCCTACGGGCTACTACAACCCTTCGCTGACGCCATAAAACTCTTCACCAAAGAGCCCCTAAAACCCGCCACATCTACCATCACCCTCTACATCACCGCCCCGACCTTAGCTCTCACCATCGCTCTTCTACTATGAACCCCCCTCCCCATACCCAACCCCCTGGTCAACCTCAACCTAGGCCTCCTATTTATTCTAGCCACCTCTAGCCTAGCCGTTTACTCAATCCTCTGATCAGGGTGAGCATCAAACTCAAACTACGCCCTGATCGGCGCACTGCGAGCAGTAGCCCAAACAATCTCATATGAAGTCACCCTAGCCATCATTCTACTATCAACATTACTAATAAGTGGCTCCTTTAACCTCTCCACCCTTATCACAACACAAGAACACCTCTGATTACTCCTGCCATCATGACCCTTGGCCATAATATGATTTATCTCCACACTAGCAGAGACCAACCGAACCCCCTTCGACCTTGCCGAAGGGGAGTCCGAACTAGTCTCAGGCTTCAACATCGAATACGCCGCAGGCCCCTTCGCCCTATTCTTCATAGCCGAATACACAAACATTATTATAATAAACACCCTCACCACTACAATCTTCCTAGGAACAACATATGACGCACTCTCCCCTGAACTCTACACAACATATTTTGTCACCAAGACCCTACTTCTAACCTCCCTGTTCTTATGAATTCGAACAGCATACCCCCGATTCCGCTACGACCAACTCATACACCTCCTATGAAAAAACTTCCTACCACTCACCCTAGCATTACTTATATGATATGTCTCCATACCCATTACAATCTCCAGCATTCCCCCTCAAACCTA GTGTTCTTTATTAATATCCTAACACTCCTCGTCCCCATTCTAA

In [56]:
percent_match(pred_mouse_seq, seq_df.iloc[0]["mouse_gene_sequence"])

0.25

In [23]:
percent_match(human_seq, pred_mouse_seq)

0.21652719665271966

In [27]:
from Bio.Seq import Seq

dna_seq = Seq("ATGATGCATCGTAC")
protein_seq = dna_seq.translate()

print(protein_seq) 

MMHR




In [32]:
translate(pred_mouse_seq)

Seq('GWRRMNTMITTRRLPGHPMPTCYMTKTEI*NTTRVSGQL*AG*KTEDRRPWFLT...AGA')

In [33]:
translate(seq_df.iloc[0]["mouse_gene_sequence"])

Seq('VFFINILTLLVPILIAIAFLTLVERKILGYIQLRKGPNIVGPYGILQPFADAIK...YI*')