# SETUP

In [None]:
from functools import reduce
import os
import re

import numpy as np
import pandas as pd

import torch

In [None]:
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
DATA_PATH = os.path.join(ROOT_PATH, 'data')
MIRNA_DATA_PATH = os.path.join(DATA_PATH, 'mirna.tsv')
MIRNA_MODEL_NAME = "mirna"
MIRNA_MODEL_VERSION = "0"
MIRNA_MODEL_CHECKPOINT = "checkpoint-1000"
MIRNA_MODEL_ROOT = os.path.join(os.getcwd(), f"model={MIRNA_MODEL_NAME}", f"version={MIRNA_MODEL_VERSION}")

In [None]:
os.makedirs(MIRNA_MODEL_ROOT, exist_ok=True)

In [None]:
MIRNA_DATA_PATH, MIRNA_MODEL_ROOT

# DATA

In [None]:
mirna_str = open(MIRNA_DATA_PATH).read().strip()

In [None]:
mirna_list = re.split(' |\n', mirna_str)

In [None]:
len(mirna_list)

In [None]:
mirna_list[:9]

In [None]:
mirna_list[-9:]

In [None]:
mirna_array_1 = np.array(mirna_list)

In [None]:
mirna_array_1.shape

In [None]:
mirna_array_2 = mirna_array_1.reshape(-1,3)

In [None]:
ma = pd.DataFrame.from_records(mirna_array_2, columns=['ID', 'Accession', 'sequence'])

In [None]:
ma['ID'] = ma['ID'].map(lambda _: _[1:] if _.startswith('>') else _)

In [None]:
ma

In [None]:
reduce(lambda a, b: a.union(set(b)), ma.sequence, set())

In [None]:
set(ma.iloc[0].sequence)

## Dataset

In [None]:
DATASET_TRAIN_FRACTION = 9/10

In [None]:
import pyarrow as pa
import pyarrow.parquet as pq

In [None]:
from datasets import Dataset, DatasetDict, load_dataset, splits

In [None]:
#Dataset.from_generator?
mirna_train_dataset = Dataset.from_pandas(ma.iloc[:int(DATASET_TRAIN_FRACTION*len(ma))], split='train')
mirna_test_dataset = Dataset.from_pandas(ma.iloc[int(DATASET_TRAIN_FRACTION*len(ma)):], split='test')

In [None]:
isinstance(mirna_train_dataset.data.table, pa.Table)

In [None]:
MIRNA_TRAIN_PATH = os.path.join(MIRNA_MODEL_ROOT, f"{MIRNA_MODEL_NAME}_train.parquet")
pq.write_table(mirna_train_dataset.data.table, MIRNA_TRAIN_PATH)

In [None]:
MIRNA_TEST_PATH = os.path.join(MIRNA_MODEL_ROOT, f"{MIRNA_MODEL_NAME}_test.parquet")
pq.write_table(mirna_test_dataset.data.table, MIRNA_TEST_PATH)

In [None]:
mirna_datasets = DatasetDict(
    {
        "train": mirna_train_dataset,  # .shuffle().select(range(50000)),
        "test": mirna_test_dataset,  # .shuffle().select(range(500))
    }
)

In [None]:
_sequence0 = ma.iloc[0].sequence

In [None]:
MIRNA_MAX_SEQUENCE_LENGTH = max(max(len(tt) for tt in ds['sequence']) for ds in mirna_datasets.values())
MIRNA_MAX_SEQUENCE_LENGTH

# TOKENIZER

Building following tutorial: https://huggingface.co/course/chapter6/8?fw=pt

Specifically, 'GPT-2' BPE-based Tokenizer.

In [None]:
from functools import reduce
def get_tokenizer_training_corpus(dataset_dict, *, chunk_size):
    dd = dataset_dict
    sequence_list = reduce(lambda sequence, ds: sequence+ds['sequence'], dd.values(), [])
    for i in range(0, len(sequence_list), chunk_size):
        yield sequence_list[i:i+chunk_size]

In [None]:
from tokenizers import (
    decoders,
    models,
    normalizers,
    pre_tokenizers,
    processors,
    trainers,
    Tokenizer,
)

In [None]:
_tokenizer = Tokenizer(models.BPE()) # all tokens are known

In [None]:
_tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)

In [None]:
_tokenizer.pre_tokenizer.pre_tokenize_str(_sequence0)

In [None]:
TOKENIZER_TRAINER_CHUNK_SIZE = 200
VOCAB_SIZE = 100

In [None]:
tokenizer_trainer = trainers.BpeTrainer(vocab_size=VOCAB_SIZE, special_tokens=["<|endoftext|>"])

In [None]:
_tokenizer.train_from_iterator(get_tokenizer_training_corpus(mirna_datasets, chunk_size=TOKENIZER_TRAINER_CHUNK_SIZE), 
                              trainer=tokenizer_trainer)

In [None]:
len(_tokenizer.get_vocab())

In [None]:
_encoding = _tokenizer.encode(_sequence0)
_tokens = _encoding.tokens

In [None]:
print(_tokens)

In [None]:
_tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)

In [None]:
encoding_ = _tokenizer.encode(_sequence0)
start, end = encoding_.offsets[4]
_sequence0[start:end]

In [None]:
_tokenizer.decoder = decoders.ByteLevel()

In [None]:
sequence0_ = _tokenizer.decode(_encoding.ids)

In [None]:
_sequence0 == sequence0_

In [None]:
TOKENIZER_PATH = os.path.join(MIRNA_MODEL_ROOT, f"{MIRNA_MODEL_NAME}_tokenizer.json")

In [None]:
_tokenizer.save(TOKENIZER_PATH)

In [None]:
_tokenizer = Tokenizer.from_file(TOKENIZER_PATH)

In [None]:
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=_tokenizer,
    bos_token="<|endoftext|>",
    eos_token="<|endoftext|>",
)

In [None]:
encoding = tokenizer.encode(_sequence0)

In [None]:
encoding

In [None]:
tokenizer.decode(encoding)

In [None]:
len(tokenizer)

In [None]:
len(tokenizer(mirna_datasets['train']['sequence'])['input_ids'])

## Tokenize datasets

In [None]:
MIRNA_MAX_SEQUENCE_LENGTH

In [None]:
MIRNA_CONTEXT_LENGTH = None

In [None]:
mirna_datasets

In [None]:
def tokenize(row, *, context_length=MIRNA_CONTEXT_LENGTH):
    outputs = tokenizer(
        row["sequence"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    return outputs
    """
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}
    """


mirna_tokenized_datasets = mirna_datasets.map(tokenize, batched=True, remove_columns=mirna_datasets["train"].column_names)
mirna_tokenized_datasets

# MODEL

In [None]:
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig

config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_ctx=MIRNA_CONTEXT_LENGTH,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

In [None]:
model = GPT2LMHeadModel(config)
model_size = sum(t.numel() for t in model.parameters())
print(f"GPT-2 size: {model_size/1000**2:.1f}M parameters")

In [None]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

Let’s have a look at an example:

In [None]:
out = data_collator([mirna_tokenized_datasets["train"][i] for i in range(5)])
for key in out:
    print(f"{key} shape: {out[key].shape}")

## Train

In [None]:
GPU = 3
# Set up CUDA environment BEFORE importing torch
import os
os.environ["WANDB_DISABLED"] = "TRUE"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = f"{GPU}"  # This shrinks the GPU universe and maps cuda:0 to {GPU}

In [None]:
import torch
torch.cuda.device_count()

In [None]:
torch.cuda.current_device() # This really is device {GPU}

In [None]:
from transformers import Trainer, TrainingArguments
import datetime
date = datetime.datetime.now().strftime('%Y-%m-%d')
time = datetime.datetime.now().strftime('%H.%M')


training_args = TrainingArguments(
    output_dir=MIRNA_MODEL_ROOT,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    evaluation_strategy="steps",
    eval_steps=5_000,
    logging_steps=5_000,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=1_000,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=5_000,
    #fp16=True,
    push_to_hub=False,
)

training_args = TrainingArguments(
    output_dir=f"{MIRNA_MODEL_ROOT}",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    #place_model_on_device=torch.device(f"cuda:{GPU}"),
    push_to_hub=False,
    num_train_epochs=12.0,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    data_collator=data_collator,
    train_dataset=mirna_tokenized_datasets["train"],
    eval_dataset=mirna_tokenized_datasets["test"],
)

In [None]:
try:
    model = model.from_pretrained(MIRNA_MODEL_ROOT)
except:
    model.to(f"cuda:0")
    trainer.train()
    model.to("cpu").save_pretrained(MIRNA_MODEL_ROOT, from_pt=True)
model.to("cuda:0")

### Quick check

In [None]:
%%time
from transformers import pipeline, set_seed
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
set_seed(42)
generator("", max_length=30, num_return_sequences=5)

# LATENT-SPACE

# #TODO:
* UMAP landscape of latent space on MIRNA
    - Extract lm_head LATENT INPUTS for each of the generated sequences:
        x Construct suitable args/kwargs for GPT2LMHeadModel.generate(input_ids, attention_mask, **generate_kwargs)
        - Compute latent codes for the training set sequences: 
            - Tokenize each sequence
            - Evaluate model on all of each of the sequence's initial segments, collecting the appropriate hidden states
            - Use training sequences long_latents and include them in UMAP
                   
    - UMAP the latent inputs
* PERPLEXITY/CROSS-ENTROPY measure of training set sequences
    - LITERATURE on how to evaluate the model performance by looking at perplexity
    - CODE for how to compute perplexity from model generation/transition scores
    - UNDERSTAND how CROSS-ENTROPY LOSS relates to PERPLEXITY
* PRETRAIN on other RNA datasets
* DECONFOUNDING/separation/whitening in latent space

From `transformers.generation.utils.GenerationMixin.compute_transition_scores()` docstring:


"""

...

Examples:

        ```python
        >>> from transformers import GPT2Tokenizer, AutoModelForCausalLM
        >>> import numpy as np

        >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
        >>> tokenizer.pad_token_id = tokenizer.eos_token_id
        >>> inputs = tokenizer(["Today is"], return_tensors="pt")

        >>> # Example 1: Print the scores for each token generated with Greedy Search
        >>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True)
        >>> transition_scores = model.compute_transition_scores(
        ...     outputs.sequences, outputs.scores, normalize_logits=True
        ... )
        >>> input_length = inputs.input_ids.shape[1]
        >>> generated_tokens = outputs.sequences[:, input_length:]
        >>> for tok, score in zip(generated_tokens[0], transition_scores[0]):
        ...     # | token | token string | logits | probability
        ...     print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
        |   262 |  the     | -1.414 | 24.33%
        |  1110 |  day     | -2.609 | 7.36%
        |   618 |  when    | -2.010 | 13.40%
        |   356 |  we      | -1.859 | 15.58%
        |   460 |  can     | -2.508 | 8.14%

        >>> # Example 2: Reconstruct the sequence scores from Beam Search
        >>> outputs = model.generate(
        ...     **inputs,
        ...     max_new_tokens=5,
        ...     num_beams=4,
        ...     num_return_sequences=4,
        ...     return_dict_in_generate=True,
        ...     output_scores=True,
        ... )
        >>> transition_scores = model.compute_transition_scores(
        ...     outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
        ... )
        >>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores.
        >>> # Tip: set `normalize_logits=True` to recompute the scores from the normalized logits.
        >>> output_length = inputs.input_ids.shape[1] + np.sum(transition_scores.numpy() < 0, axis=1)
        >>> length_penalty = model.generation_config.length_penalty
        >>> reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty)
        >>> print(np.allclose(outputs.sequences_scores, reconstructed_scores))
        True
        ```
"""
        

`generator("", max_length=30, num_return_sequences=5)` results in this call to `GenerationMixin.generate()`:

`model.generate(input_ids=None, attention_mask=None, generate_kwargs={'max_length': 30, 'num_return_sequences': 5})`

We augment it as follows:

In [None]:
NUM_RETURN_SEQUENCES = 100
MAX_LENGTH = MIRNA_MAX_SEQUENCE_LENGTH
torch.manual_seed(0);

In [None]:
%%time
outputs = model.generate(input_ids=None, 
                         attention_mask=None, 
                         return_dict_in_generate=True, 
                         output_scores=True,        # to compute perplexity/cross-entropy later
                         output_attentions=True,    # for viz
                         output_hidden_states=True, # for UMAP
                         max_length=MAX_LENGTH, 
                         num_return_sequences=NUM_RETURN_SEQUENCES)

## SEQUENCES (HEAD)

In [None]:
N_SAMPLES = 5
tokenizer.batch_decode(outputs.sequences[:N_SAMPLES,:], skip_special_tokens=True)

## LATENTS (UMAP)

In [None]:
#model.transformer

In [None]:
len(outputs.hidden_states) # one per sequence element, except the last one -- the model is not evaluated on it as input

In [None]:
len(outputs.hidden_states[0]) # one per layer: 12 GPT2Blocks followd by a LayerNorm for a total of 13

In [None]:
outputs.hidden_states[0][-1].shape # last layer is the logits -- the activations of the final LayerNorm following the 12 transformer blocks
# shape: [NUM_RETURN_SEQUENCES, 1, 768]

In [None]:
# long_latents: concat the activations from the last hidden layer for all sequence elements
# H is the last activation of shape [N, 1, D]
long_latents_list = [h.reshape(h.shape[0], h.shape[-1]) for h in [H[-1] for H in outputs.hidden_states]]
long_latents = torch.cat(long_latents_list, dim=-1)

In [None]:
long_latents.shape

In [None]:
# short_latents: take the last hidden layer activation for the last element of each sequence
_ = outputs.hidden_states[-1][-1]
short_latents = _.reshape((_.shape[0], _.shape[-1]))

In [None]:
short_latents.shape

## UMAP

In [None]:
import umap
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
sns.set(style='white', context='poster', rc={'figure.figsize':(1000,500)})
%matplotlib inline

In [None]:
fit = umap.UMAP()

In [None]:
%%time
lu = fit.fit_transform(long_latents.cpu().detach().numpy()) 

In [None]:
lu.shape

In [None]:
plt.scatter(x=lu[:,0], y=lu[:,1])

In [None]:
%%time
su = fit.fit_transform(short_latents.cpu().detach().numpy()) 

In [None]:
su.shape

In [None]:
plt.scatter(x=su[:,0], y=su[:,1])