## Intro
This notebook will:
#### extrac the embeddings for each model (baseline vs fine tuned) 
#### train a shallow model (xgboost ) 
#### test the validation set accuracy against the shallow model results 


In [2]:
import pathlib
import torch

from esm import FastaBatchedDataset, pretrained, 

In [39]:
def extract_embeddings(model_name, fasta_file,caslabel, output_dir, tokens_per_batch=4096, seq_length=1022,repr_layers=[33]):
    
    model, alphabet = pretrained.load_model_and_alphabet(model_name)
    model.eval()

#     if torch.cuda.is_available():
#         model = model.cuda()
    
#     caslist = ['cas1','cas2','cas3','cas4','cas10','cas12','cas14','cas9','cas13a', 'cas13b', 'cas13c','cas13d']
    dataset = FastaBatchedDataset.from_file(fasta_file)
    batches = dataset.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)

    data_loader = torch.utils.data.DataLoader(
        dataset, 
        collate_fn=alphabet.get_batch_converter(seq_length), 
        batch_sampler=batches
    )

    output_dir.mkdir(parents=True, exist_ok=True)
    
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):

            print(f'Processing batch {batch_idx + 1} of {len(batches)}')

#             if torch.cuda.is_available():
#                 toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=False)

            logits = out["logits"].to(device="cpu")
            representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()}
            
            for i, label in enumerate(labels):
                entry_id = label.split()[0]
                
                filename = output_dir / f"{entry_id}.pt"
                truncate_len = min(seq_length, len(strs[i]))

                result = {"entry_id": entry_id}
                caslabel = ""
                for word in caslist:
                    if word.lower() in label.lower():
                        caslabel = word.lower()
                    
                result['label'] = caslabel
                result["mean_representations"] = {
                        layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }
                
                
                
                torch.save(result, filename)
#                 if i >1:
#                     return(label)
                


In [40]:
caslist = ['cas1','cas2','cas3','cas4','cas5','cas6','cas7','cas8','cas9','cas10','cas11','cas12','cas13']


In [41]:
model_name = 'esm2_t33_650M_UR50D'


In [None]:
for cas in caslist: 
    
    casfolder = f"/home/salaris/protein_model/data/{cas}/"
    
    training_fasta_file = pathlib.Path(casfolder + cas + '_training.fasta')
    validation_fasta_file = pathlib.Path(casfolder + cas + '_validation.fasta')
    
    training_embedding_folder = pathlib.Path(casfolder  + "_" +model_name + "_" + 'embeddings/' +  'training/')
    validation_embedding_folder = pathlib.Path(casfolder  + "_" +model_name + "_" + 'embeddings/' +   'validation/')
    print(training_embedding_folder, validation_embedding_folder)
    print(trainingcasfasta, validationcasfasta)
    extract_embeddings(model_name, 
                       fasta_file= training_fasta_file,caslabel= cas, 
                       output_dir= training_embedding_folder, tokens_per_batch=2048, seq_length=1022,repr_layers=[33])

    extract_embeddings(model_name, 
                       fasta_file= validation_fasta_file,caslabel= cas, 
                       output_dir= validation_embedding_folder, tokens_per_batch=2048, seq_length=1022,repr_layers=[33])


/home/salaris/protein_model/data/cas1/_esm2_t33_650M_UR50D_embeddings/training /home/salaris/protein_model/data/cas1/_esm2_t33_650M_UR50D_embeddings/validation
/home/salaris/protein_model/data/cas1/cas1_training.fasta /home/salaris/protein_model/data/cas1/cas1_validation.fasta
Processing batch 1 of 2341
Processing batch 2 of 2341
Processing batch 3 of 2341
Processing batch 4 of 2341
Processing batch 5 of 2341
Processing batch 6 of 2341
Processing batch 7 of 2341
Processing batch 8 of 2341
Processing batch 9 of 2341
Processing batch 10 of 2341
Processing batch 11 of 2341
Processing batch 12 of 2341
Processing batch 13 of 2341
Processing batch 14 of 2341
Processing batch 15 of 2341
Processing batch 16 of 2341
Processing batch 17 of 2341
Processing batch 18 of 2341
Processing batch 19 of 2341
Processing batch 20 of 2341
Processing batch 21 of 2341
Processing batch 22 of 2341
Processing batch 23 of 2341
Processing batch 24 of 2341
Processing batch 25 of 2341
Processing batch 26 of 2341
Pro