## Intro
This notebook will:
#### extrac the embeddings for each model (baseline vs fine tuned) 
#### train a shallow model (xgboost ) 
#### test the validation set accuracy against the shallow model results 


In [1]:
import pathlib
import torch

from esm import FastaBatchedDataset, pretrained

In [2]:
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")

def extract_embeddings(model, fasta_file,caslabel, output_dir, tokens_per_batch=4096, seq_length=1022,repr_layers=[33]):
    from transformers import EsmTokenizer, EsmModel
    import torch
    # read the fasta file and extract sequences

    
    seqs =[seq]
    inputs = tokenizer(seqs, return_tensors="pt", padding=True, truncation=True)
    outputs = model(**inputs)
    last_hidden_states = outputs.last_hidden_state
    x = last_hidden_states.detach()
    x =x.mean(axis=1)

NameError: name 'EsmTokenizer' is not defined

In [None]:
def extract_embeddingsx(model_name, fasta_file,caslabel, output_dir, tokens_per_batch=4096, seq_length=1022,repr_layers=[33]):
    
    model, alphabet = pretrained.load_model_and_alphabet(model_name)
    model.eval()
    

    if torch.cuda.is_available():
        model = model.cuda('cuda:1')
    
#     caslist = ['cas1','cas2','cas3','cas4','cas10','cas12','cas14','cas9','cas13a', 'cas13b', 'cas13c','cas13d']
    dataset = FastaBatchedDataset.from_file(fasta_file)
    batches = dataset.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)

    data_loader = torch.utils.data.DataLoader(
        dataset, 
        collate_fn=alphabet.get_batch_converter(seq_length), 
        batch_sampler=batches
    )

    output_dir.mkdir(parents=True, exist_ok=True)
    
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):

            print(f'Processing batch {batch_idx + 1} of {len(batches)}')

            if torch.cuda.is_available():
                toks = toks.to(device="cuda:1", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=False)
            

            logits = out["logits"].to(device="cpu")
            representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()}
            
            for i, label in enumerate(labels):
                entry_id = label.split()[0]
                
                filename = output_dir / f"{entry_id}.pt"
                truncate_len = min(seq_length, len(strs[i]))

                result = {"entry_id": entry_id}
                caslabel = ""
                for word in caslist:
                    if word.lower() in label.lower():
                        caslabel = word.lower()
                    
                result['label'] = caslabel
                result["mean_representations"] = {
                        layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }
                
                
                
                torch.save(result, filename)
#                 if i >1:
#                     return(result)
                


In [None]:
def extract_embeddings0(model_name, fasta_file, caslabel, output_dir, tokens_per_batch=4096, seq_length=1022, repr_layers=[33]):
    from Bio import SeqIO
    import torch
    from torch.utils.data import DataLoader
    
    # Load your model and alphabet (assuming a custom function or similar exists)
    model, alphabet = pretrained.load_model_and_alphabet(model_name)
    model.eval()

    # Setup model for multiple GPUs
    if torch.cuda.is_available():
        model = model.cuda()
        model = torch.nn.DataParallel(model)

    # Load dataset
    dataset = FastaBatchedDataset.from_file(fasta_file)
    batches = dataset.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)

    # Prepare DataLoader
    data_loader = DataLoader(
        dataset, 
        collate_fn=alphabet.get_batch_converter(seq_length), 
        batch_sampler=batches
    )

    output_dir.mkdir(parents=True, exist_ok=True)
    
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(f'Processing batch {batch_idx + 1} of {len(batches)}')

            toks = toks.cuda()  # Ensure tokens are on GPU
            out = model(toks, repr_layers=repr_layers, return_contacts=False)
            
            # Process outputs
            logits = out["logits"].cpu()
            representations = {layer: t.cpu() for layer, t in out["representations"].items()}

            for i, label in enumerate(labels):
                entry_id = label.split()[0]
                filename = output_dir / f"{entry_id}.pt"
                truncate_len = min(seq_length, len(strs[i]))

                result = {"entry_id": entry_id, "label": caslabel}
                result["mean_representations"] = {
                        layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }
                
                
                torch.save(result, filename)
                
                


In [None]:
caslist = ['cas1','cas2','cas3','cas4','cas5','cas6','cas7','cas8','cas9','cas10','cas11','cas12','cas13']


In [None]:
# model_name = 'esm2_t33_650M_UR50D'
model_name = 'esm2_t6_8M_UR50D'


In [None]:
torch.cuda.empty_cache()


In [None]:
rep_layer_number = 6 # for the 8 m model, for 650 use 33 

for cas in caslist: 
    
    casfolder = f"/home/salaris/protein_model/data/{cas}/"
    
    training_fasta_file = pathlib.Path(casfolder + cas + '_training.fasta')
    validation_fasta_file = pathlib.Path(casfolder + cas + '_validation.fasta')
    
    training_embedding_folder = pathlib.Path(casfolder  + "_" +model_name + "_" + 'embeddings/' +  'training/')
    validation_embedding_folder = pathlib.Path(casfolder  + "_" +model_name + "_" + 'embeddings/' +   'validation/')
    print(training_embedding_folder, validation_embedding_folder)
    print(training_fasta_file, validation_fasta_file)
    extract_embeddings(model_name, 
                       fasta_file= training_fasta_file,caslabel= cas, 
                       output_dir= training_embedding_folder, tokens_per_batch=2048 * 2, seq_length=1022,repr_layers=[rep_layer_number])

    extract_embeddings(model_name, 
                       fasta_file= validation_fasta_file,caslabel= cas, 
                       output_dir= validation_embedding_folder, tokens_per_batch=2048 * 2 , seq_length=1022,repr_layers=[rep_layer_number])
    