# Setup and Import

In [None]:
!pip install ankh --quiet
!pip install seqeval --quiet

In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['WANDB_DISABLED'] = 'true'

import torch
import numpy as np
import random

seed = 7

torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

import ankh

from torch import nn
from torch.utils.data import Dataset, DataLoader

from transformers import Trainer, TrainingArguments, EvalPrediction
from datasets import load_dataset

from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
from scipy import stats
from functools import partial
import pandas as pd
from tqdm.auto import tqdm

In [None]:
def get_num_params(model):
    return sum(p.numel() for p in model.parameters())

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Available device:', device)

In [None]:
model, tokenizer = ankh.load_large_model()
model.eval()
model.to(device=device)

In [None]:
print(f"Number of parameters:", get_num_params(model))

# Data

In [None]:
data_part = "test"

df = pd.read_csv(f"/kaggle/input/data-skripsi/{data_part}.csv")

input_column_name = 'input'
labels_column_name = 'labels'
disorder_column_name = 'disorder'

sequences, labels, disorder = (
    df[input_column_name], 
    df[labels_column_name],
    df[disorder_column_name]
)

In [None]:
def preprocess_dataset(sequences, labels, disorder, max_length=None):
    
    sequences = ["".join(seq.split()) for seq in sequences]
    
    if max_length is None:
        max_length = len(max(sequences, key=lambda x: len(x)))

    seqs = [list(seq)[:max_length] for seq in sequences]
    
    labels = ["".join(label.split()) for label in labels]
    labels = [list(label)[:max_length] for label in labels]
    
    disorder = [" ".join(disorder.split()) for disorder in disorder]
    disorder = [disorder.split()[:max_length] for disorder in disorder]
    
    assert len(seqs) == len(labels) == len(disorder)
    return seqs, labels, disorder

In [None]:
def embed_dataset(model, sequences, shift_left = 0, shift_right = -1):
    inputs_embedding = []
    with torch.no_grad():
        for sample in tqdm(sequences):
            ids = tokenizer.batch_encode_plus([sample], add_special_tokens=True, 
                                              padding=True, is_split_into_words=True, 
                                              return_tensors="pt")
            embedding = model(input_ids=ids['input_ids'].to(device))[0]
            embedding = embedding[0].detach().cpu().numpy()[shift_left:shift_right]
            inputs_embedding.append(embedding)
    return inputs_embedding

In [None]:
sequences, labels, disorder = preprocess_dataset(sequences, labels, disorder)
after_preprocess = pd.DataFrame({
    "sequence": ["".join(seq) for seq in sequences],
    "label": ["".join(lbl) for lbl in labels],
    "disorder": [" ".join(dis) for dis in disorder]  # kalau disorder per-residue word
})

after_preprocess.to_csv("dataset_output.csv", index=False)


In [None]:
embeddings = embed_dataset(model, sequences)

In [None]:
unique_tags = {'B', 'C', 'E', 'G', 'H', 'I', 'S', 'T'}
tag2id = {'B': 0, 'C': 1, 'I': 2, 'T': 3, 'S': 4, 'E': 5, 'G': 6, 'H': 7}
id2tag = {0: 'B', 1: 'C', 2: 'I', 3: 'T', 4: 'S', 5: 'E', 6: 'G', 7: 'H'}

In [None]:
def encode_tags(labels):
    labels = [[tag2id[tag] for tag in doc] for doc in labels]
    return labels

In [None]:
labels_encodings = encode_tags(labels)

In [None]:
def mask_disorder(labels, masks):
    for label, mask in zip(labels,masks):
        for i, disorder in enumerate(mask):
            if disorder == "0.0":
                label[i] = -100
    return labels

In [None]:
labels_encodings = mask_disorder(labels_encodings, disorder)

In [None]:
print(len(embeddings))

In [None]:
print(len(labels_encodings))

In [None]:
torch.save(embeddings, f'{data_part}_ssp8_embeddings.pt')
torch.save(labels_encodings, f'{data_part}_ssp8_labels.pt')

In [None]:
print(len(embeddings[3]))