# CasDiscovery Prediction v1.0

`Input`: suspicious proteins in .fasta/.faa format.

`Output`: predicted tags of the proteins, including `cas9`, `cas12`, `cas13` and `other`.

*Note: before executing the following commands, please ensure that you have to save/add the following folders into your Google Drive ("MyDrive" folder):

- [model](https://drive.google.com/drive/folders/1y4WKwsoBsqBb_R2Cdj0cwYiLIPnBXj01?usp=sharing)

In [None]:
#@title Step.01 setup **Environment** (~6m)
%%time
import os, time, signal
import sys, random, string, re
## mount google drive
from google.colab import drive
drive.mount('/content/drive')

## install conda
!pip install -q condacolab
import condacolab
condacolab.install()
!conda update -n base -c conda-forge conda -y

## packages install
!pip install pyfaidx

## protein to be predicted
!mkdir protein

In [None]:
################### prediction codes ###################
import ast
import logging
import pandas as pd
import numpy as np
import pyfaidx
from sklearn.utils import shuffle
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import argparse
from transformers import AutoTokenizer, EsmForSequenceClassification

## Class
class TqdmLoggingHandler(logging.Handler):
    def __init__(self, level=logging.NOTSET):
        super().__init__(level)

    def emit(self, record):
        try:
            msg = self.format(record)
            tqdm.write(msg)
            self.flush()
        except Exception:
            self.handleError(record)

class SequenceDataset(Dataset):
    def __init__(self, inputs, labels, names):
        self.input_ids = inputs['input_ids']
        self.attention_mask = inputs['attention_mask']
        self.labels = torch.tensor(labels)
        self.names = names
    def __len__(self):
        return len(self.input_ids)
    def __getitem__(self, idx):
        return {'labels': self.labels[idx], 'input_ids': self.input_ids[idx], 'attention_mask': self.attention_mask[idx], 'ids': idx}
    def get_num_samples_per_class(self):
        return torch.bincount(self.labels).tolist()

## Functions
def parse_head_mask(value):
    try:
        print(torch.tensor(ast.literal_eval(value)))
        return torch.tensor(ast.literal_eval(value))
    except ValueError:
        raise argparse.ArgumentTypeError(f"Invalid head_mask value: {value}")

def create_dataset(tokenizer, fasta_dir, max_seq_len, label_to_id_fn, random_seed):
    labels, sequences, names = read_fasta(fasta_dir)
    if random_seed is not None:
        labels, sequences, names = shuffle(labels, sequences, names, random_state=random_seed) # type: ignore
    inputs = tokenizer(sequences, padding=True, truncation=True, max_length=max_seq_len, return_tensors='pt', add_special_tokens=True)
    label_ids = label_to_id_fn(labels)
    return SequenceDataset(inputs, label_ids, names)

def read_fasta(fasta_dir):
    labels = []
    names = []
    sequences = []
    for fasta_file in os.listdir(fasta_dir):
        if not fasta_file.endswith(('.faa', '.fasta')):
            continue
        label = fasta_file.split('.')[0]
        fasta = pyfaidx.Fasta(os.path.join(fasta_dir, fasta_file), rebuild=False)
        for record in fasta:
            labels.append(label)
            seq = str(record)
            seq = re.sub(r"[\n\*]", '', seq)
            seq = re.sub(r"[UZOB]", "X", seq)
            sequences.append(seq)
            names.append(record.name)
    print(f"Read {len(labels)} sequences from {fasta_dir}, sequences: {len(sequences)}, names: {len(names)} from fasta_dir: {fasta_dir}")
    time.sleep(1) # avoid multi process issues
    return labels, sequences, names

def get_dataloader(tokenizer, label_to_id_fn, random_seed, args):
    eval_dataset = create_dataset(tokenizer, args.eval_dataset_dir, args.max_seq_len, label_to_id_fn, random_seed)
    eval_dataloader = DataLoader(eval_dataset, batch_size=args.eval_batch_size)
    return eval_dataloader

def get_label_to_id_fn(all_labels):
    def label_to_id_fn(labels):
        return [all_labels.index(label) if label in all_labels else 0 for label in labels]
    return label_to_id_fn

In [None]:
#@title Step.02 run **CasDiscovery Prediction** (~1m)
%%time
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(TqdmLoggingHandler())

DATASET_TRAINING_KEYS = ['labels', 'input_ids', 'attention_mask']

#@markdown **Parameters** settings
#@markdown ---

from google.colab import files
protein_dir = 'protein'
input_protein = 'suspicious.faa' #@param {type:"string"}
#@markdown - input protein `.faa` file. default `suspicious.faa`.
#@markdown - or you can upload your own protein `.faa` file.
if input_protein == 'suspicious.faa':
  !cp drive/MyDrive/inputs/suspicious.faa protein/
else:
  input_protein = files.upload()
  input_protein = list(input_protein.keys())[0]
  !mv {input_protein} protein/

max_seq_len = 1560 #@param ["1560"] {type:"raw"}
#@markdown - maximum length of your protein sequence. default `1560`.
batch_size = 1 #@param ["1"] {type:"raw"}
random_seed = 42 #@param ["42"] {type:"raw"}

out_table = 'pred_result.csv' #@param {type:"string"}
#@markdown - output table.

args = argparse.ArgumentParser(description='CasDiscovery')
args.eval_dataset_dir = protein_dir
args.max_seq_len = max_seq_len
args.eval_batch_size = batch_size
model_name = 'drive/MyDrive/models/CasDiscovery'
all_labels = ['cas9','cas12','cas13','noncas']
dataset_random_seed = None

label_to_id_fn = get_label_to_id_fn(all_labels)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmForSequenceClassification.from_pretrained(model_name, num_labels=len(all_labels)).cuda().eval()
eval_dataloader = get_dataloader(tokenizer, label_to_id_fn, dataset_random_seed, args)

merged_ids = torch.tensor([], dtype=torch.int)
merged_predicted_label_ids = torch.tensor([], dtype=torch.int)
merged_logits = torch.tensor([], dtype=torch.float)

with torch.no_grad():
  for batch in eval_dataloader:
    inputs = {k: v for k, v in batch.items() if k in DATASET_TRAINING_KEYS}
    label_ids = batch.get("labels")
    ids = batch.get("ids")
    inputs['labels'] = inputs['labels'].cuda()
    inputs['input_ids'] = inputs['input_ids'].cuda()
    inputs['attention_mask'] = inputs['attention_mask'].cuda()
    outputs = model(**inputs)
    logits = outputs.get("logits")
    predicted_probs = torch.softmax(logits, dim=1)
    predicted_label_ids = torch.argmax(logits, dim=1)

    merged_ids = torch.cat((merged_ids, ids.cpu()), dim=0)
    merged_predicted_label_ids = torch.cat((merged_predicted_label_ids, predicted_label_ids.cpu()), dim=0)
    merged_logits = torch.cat((merged_logits, predicted_probs.cpu()), dim=0)

    merged_predicted_labels = [all_labels[label_id] for label_id in merged_predicted_label_ids]
    merged_names = [eval_dataloader.dataset.names[id] for id in merged_ids]
    df = pd.DataFrame({
      "name": merged_names,
      "predicted_label": merged_predicted_labels,
    })
    for i, label in enumerate(all_labels):
      df[f"prob: {label}"] = [f"{round(prob[i] * 100, 2)}%" for prob in merged_logits.numpy()]
    df.to_csv(out_table, sep='\t',index=False)

print('Prediction finished!\nResults output to: %s' % out_table)
print('------------------------------')
print('Total protein: %d' % df.shape[0])
print('Cas9: %d' % df[df.predicted_label=='cas9'].shape[0])
print('Cas12: %d' % df[df.predicted_label=='cas12'].shape[0])
print('Cas13: %d' % df[df.predicted_label=='cas13'].shape[0])

In [None]:
df[['name','predicted_label']]

In [None]:
#@title Step.03 Package and download results
from google.colab import files
!zip -r CasPrediction.zip {out_table}
files.download(f"CasPrediction.zip")

# *Note*

This pipeline use the fine-tuned ESM-2 `facebook/esm2_t33_650M_UR50D` for `EsmForSequenceClassification`. When you got suspicious proteins abjacent to the CRISPR arrays, this pipeline would help you to identify the potential single effector Cas emzymes (`cas9`, `cas12` or `cas13`).

The output table `pred_result.csv` presents the following columns:

- `name`. The ids of each protein.

- `predicted_label`. Model predicted labels.

- `prob: cas9`. The normalized probability of this protein to be identified as a Cas9 enzyme.

- `prob: cas12`. The normalized probability of this protein to be identified as a Cas12 enzyme.

- `prob: cas13`. The normalized probability of this protein to be identified as a Cas13 enzyme.

- `prob: noncas`. The normalized probability of this protein to be identified as a non-single effector.