In [1]:
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
import os
import warnings
from pytorch_lightning.callbacks import ModelCheckpoint

from transformers import BertModel, BertTokenizer
from abc import ABC
import pytorch_lightning as pl
import torch
from transformers import BertForTokenClassification, BertConfig, Adafactor, AdamW
from transformers import pipeline
from sklearn.metrics import balanced_accuracy_score, accuracy_score
import torch.nn.functional as F
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import label_ranking_average_precision_score, average_precision_score, f1_score
from datasets import load_dataset
from datasets import Dataset
from datasets import load_from_disk
import pandas as pd
from sklearn.model_selection import train_test_split
import json

In [2]:
class TokenDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in
                self.encodings.items()}  # keys are input_ids, token_type_ids, attention_mask, labels, values are stored as a list of lists
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return (len(self.labels))


class newDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe):
        self.encodings = {"input_ids": dataframe["input_ids"], "token_type_ids": dataframe["token_type_ids"],
                          "attention_mask": dataframe["attention_mask"]}
        self.labels = {"labels": dataframe["labels"]}

    def __getitem__(self, idx):
        item = {"input_ids": self.encodings["input_ids"][idx], "token_type_ids": self.encodings["token_type_ids"][idx],
                "attention_mask": self.encodings["attention_mask"][idx], "labels": self.labels["labels"][idx]}
        return item

    def __len__(self):
        return (len(self.labels["labels"]))

class BertTokClassification(pl.LightningModule, ABC):
    def __init__(
            self,
            config: BertConfig = None,
            pretrained_dir: str = None,
            use_adafactor: bool = False,
            learning_rate=3e-5,
            **kwargs
    ):
        super().__init__()
        self.learning_rate = learning_rate
        self.use_adafactor = use_adafactor
        if pretrained_dir is None:
            self.bert = BertForTokenClassification(config, **kwargs)
        else:
            self.bert = BertForTokenClassification.from_pretrained(pretrained_dir, **kwargs)

    def forward(self, input_ids, attention_mask, labels):
        return self.bert(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        outputs = self(input_ids=input_ids.to(self.device), attention_mask=attention_mask.to(self.device), labels=labels.to(self.device, dtype=torch.int64))
        loss = outputs.loss


        def get_acc(labels, logits):
            sumList = []
            for i in range(len(labels)):
                y_pred = torch.max(logits[i], 1).indices
                score = accuracy_score(labels[i], y_pred)
                sumList.append(score)
            avg = sum(sumList) / len(labels)
            return avg


        accuracy1 = get_acc(labels.cpu(), outputs.logits.cpu())

        # accuracy = balanced_accuracy_score(master[0], master[1])
        self.log(
            "train_batch_accuracy",
            accuracy1,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        outputs = self(input_ids=input_ids.to(self.device), attention_mask=attention_mask.to(self.device), labels=labels.to(self.device, dtype=torch.int64))
        loss = outputs.loss
        self.log(
            "val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )

        # def get_balanced_accuracy(labels, logits):
        #     y_pred = torch.max(logits, 1).indices
        #     score = balanced_accuracy_score(labels, y_pred)
        #     return score
        #
        # def label_average_precision(labels, logits):
        #     y_pred = torch.max(prob, 1).indices
        #     score = label_ranking_average_precision_score(labels, y_pred)
        #     return score
        # def f1_calc(labels, logits):
        #     sumList = []
        #     for i in range(len(labels)):
        #         y_pred = torch.max(logits[i], 1).indices
        #         score = f1_score(labels[i], y_pred, average='macro')
        #         sumList.append(score)
        #     avg = sum(sumList) / len(labels)
        #     return avg




        # """
        # 1. Iterate over the batch:
        #     for each_label in labels.cpu():
        #         shorten the length of the list to its true length using attention list and the function [true_length]
        #
        #     for each_logit in outputs.logits.cpu():
        #         use torch.max(outputs.logits.cpu()[0], 1) to get the indices for each logit (best label prediction)
        #         shorten the indices to its proper label length
        #         compare the indices to the labels
        # """
        def true_length(y_attention_mask):  # finds the start and stop of the actual sequence
            switch = False
            start = 0
            stop = 0
            counter = 0
            attention_mask = list(y_attention_mask)
            for i in attention_mask:
                if int(i) == 1 and switch == False:
                    switch = True
                    start = counter
                elif int(i) == 0 and switch == True:
                    stop = counter
                    break
                elif counter == 511:
                    stop = 512
                counter += 1
            return (start, stop)


        def short_clean(attention_mask, labels, logits): #attention_mask, labels.cpu(), outputs.logits.cpu()

            def true_length(y_attention_mask):  # finds the start and stop of the actual sequence
                switch = False
                start = 0
                stop = 0
                counter = 0
                attention_mask = list(y_attention_mask)
                for i in attention_mask:
                    if int(i) == 1 and switch == False:
                        switch = True
                        start = counter
                    elif int(i) == 0 and switch == True:
                        stop = counter
                        break
                    counter += 1
                return (start, stop)

            masterPred = []
            masterTrue = []
            for batch_index in range(len(labels)):
                real_len = true_length(attention_mask[batch_index])
                predIndecies = torch.max(outputs.logits.cpu()[batch_index], 1).indices
                start = real_len[0]
                stop = real_len[1]
                currentTrue = torch.LongTensor(labels[batch_index][start:stop])
                currentPred = torch.LongTensor(predIndecies[start:stop])
                if len(currentTrue) == 0:
                    masterTrue.append(currentTrue.tolist())
                    masterPred.append(currentPred.tolist())
                    print(f"CURRENT-PRED LEN: {len(currentPred)}")
                    print(f"CURRENT-TRUE LEN:{len(currentTrue)}")

            return (masterTrue, masterPred)

        master = short_clean(attention_mask, labels.cpu(), outputs.logits.cpu())
        print("###################################")
        print(f"MASTER-TRUE: {master[0]}")
        print("###################################")
        print(f"MASTER-PRED: {master[1]}")
        print("###################################")
        print("=======")
        print(f"LABEL LEN: {len(labels.cpu())}")
        for i in range(len(labels.cpu())):
            print(f"SINGLE LABEL LEN: {len(labels.cpu()[i])}")
            print(f"ATTENTION LEN: {len(attention_mask[i])}")
            #print(attention_mask[i])
            print(f"TRUE LEN: {true_length(attention_mask[i])}")
        print("=======")
        print(f"LABELS: {labels.cpu()}")
        print("||||||||||||||||||||||||||||")
        print("=======")
        print(f"LOGITS LEN: {len(outputs.logits.cpu())}")
        for i in outputs.logits.cpu():
            print(f"SINGLE LOGIT LEN: {len(i)}")
        print(f"LOGIT SINGLE LIST LEN: {len(outputs.logits.cpu()[0][0])}")
        b_logit = torch.max(outputs.logits.cpu()[0], 1)
        b_logit_indices = torch.max(outputs.logits.cpu()[0], 1).indices
        print(f"LOGIT BEST: {b_logit}")
        print(f"LOGIT BEST INDICES: {b_logit_indices}")
        print("=======")
        print(f"LOGITS: {outputs.logits.cpu()}")

        # accuracy = label_average_precision(labels.cpu(), logits=outputs.logits.cpu()) #replaced get_balanced_accuracy(labels.cpu(), logits=outputs.logits.cpu()) with label ranking average precision

        # def balanced_accuracy_score(labels, logits):
        #     sumList = []
        #     for i in range(len(labels)):
        #         y_predList = []
        #         trueList = []
        #         y_pred = logits[i]
        #         previous = 0
        #         for lab in labels[i]:
        #             if lab == -100:
        #                 y_predList.append(y_pred[previous])
        #                 trueList.append(previous)
        #             else:
        #                 previous = lab
        #                 y_predList.append(y_pred[previous])
        #                 trueList.append(previous)
        #
        #         num = average_precision_score(trueList, y_predList)
        #         sumList.append(num)
        #     big = sum(sumList) / len(labels)
        #     return big

        def get_bal_acc(labels, logits):
            sumList = []
            for i in range(len(labels)):
                y_pred = torch.max(logits[i], 1).indices
                score = balanced_accuracy_score(labels[i], y_pred)
                sumList.append(score)
            avg = sum(sumList) / len(labels)
            return avg

        accuracy = get_bal_acc(labels.cpu(), outputs.logits.cpu())

        self.log(
            "val_accuracy",
            accuracy,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return {"val_loss": loss}

    def configure_optimizers(self):
        if self.use_adafactor:
            return Adafactor(
                self.parameters(),
                lr=self.learning_rate,
                eps=(1e-30, 1e-3),
                clip_threshold=1.0,
                decay_rate=-0.8,
                beta1=None,
                weight_decay=0.0,
                relative_step=False,
                scale_parameter=False,
                warmup_init=False)
        else:
            return AdamW(self.parameters(), lr=self.learning_rate)

    def save_pretrained(self, pretrained_dir):
        self.bert.save_pretrained(self, prtrained_dir)

    def predict_classes(self, input_ids, attention_mask, return_logits=False):
        output = self.bert(input_ids=input_ids.to(self.device), attention_mask=attention_mask)
        if return_logits:
            return output.logits
        else:
            probabilities = F.sigmoid(output.logits)
            predictions = torch.argmax(probabilities)
            return {"probabilities": probabilities, "predictions": predictions}

    def get_attention(self, input_ids, attention_mask, specific_attention_head: int = None):
        output = self.bert(inputs_ids=input_ids.to(self.device), attention_mask=attention_mask)
        if specific_attention_head is not None:
            last_layer = output.attentions[-1]  # grabs the last layer
            all_last_attention_heads = [torch.max(this_input[specific_attention_head], axis=0)[0].indices for this_input in last_layer]
            return all_last_attention_heads
        return output.attentions

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "true"

gpu_idx = 0
num_labels = 60
model_path = "/mnt/storage/grid/home/eric/hmm2bert/models/pullin/pullin>1000_whiteSpace_best_loss-epch9{1GPU}.pt"
data_folder = "pullin_parsed_data"
strat_val_name = "embedding_pullin_noDupes_whiteSpace_val>1000_stratified_domainPiece.pt"
strat_val_path = f"/mnt/storage/grid/home/eric/hmm2bert/{data_folder}/{strat_val_name}"
encoded_label_filename = "encoded_parsed_pullin_noDupes_whiteSpace>1000_withAA_not_domain.csv"
encoded_csv = f"/mnt/storage/grid/home/eric/hmm2bert/{data_folder}/{encoded_label_filename}"

tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)

encoded_test = torch.load(strat_val_path)

model = torch.load(model_path)
model.eval()

meatpipe = pipeline(task='ner', model=model.bert, tokenizer=tokenizer, device=gpu_idx)

In [None]:
# load csv to huggingface dataset AND pandas dataframe
dataset = load_dataset('csv', data_files=encoded_csv)
df = pd.read_csv(encoded_csv)
dataset = dataset['train']
print(dataset.column_names)

print(df.columns)
df = df.drop(["Unnamed: 0"], axis=1)
# df = df.drop(["start"], axis=1)
# df = df.drop(["stop"], axis=1)
print(df.columns)

#split the dataset into train and test, this produces a list with the row positions
num_rows_list = list(range(len(df)))
strat_train, strat_test = train_test_split(df, test_size=.2, stratify=df['labels'], random_state=420)

In [52]:
#start count @ 0
def labelLabel(label, sequence, start, stop):
    sequence = sequence.split()
    sequence = "".join(sequence)
    domainList = list(range(start, stop + 1))
    sequenceIndexCounter = 0
    trueList = []

    for currentNum in range(len(sequence)):  # iterate over each amino acid in a single sequence
        # ====
        shiftedNum = currentNum  #CHANGED FROM currentNum+1 TO currentNum # index that we compare to as sequence index starts at zero but start/stop starts at 1
        # ====
        if shiftedNum == domainList[sequenceIndexCounter]:
            trueList.append(label)
            if domainList[sequenceIndexCounter] == domainList[-1]:
                    sequenceIndexCounter = 0

            else:
                sequenceIndexCounter += 1
        else:
            trueList.append(0)
    return trueList

def domainAcc(true_label, predict):  # finds the start and stop of the actual sequence (range style where stop is actually stop - 1)
    switch = False
    start = None
    stop = len(true_label)
    counter = 0 #CHANGED FROM 1 TO 0
    domainLabel = None
    numCorrectLabels = 0
    numNotLabels = 0
    fullDomainCounter = 0
    
    #gets accuracy of prediction of whole sequence
    for i in range(len(true_label)):
        if true_label[i] == predict[i]:
            fullDomainCounter += 1
    fullDomainScore = fullDomainCounter/len(predict)

    #sets the start and stop of domain based on true label list
    for i in true_label:
        if int(i) != 0 and switch == False:
            switch = True
            start = counter
        elif int(i) == 0 and switch == True:
            stop = counter - 1
            break
        counter += 1
    
    #if domain is out of range and whole sequence is notDomain set start equal to stop + 1
    if start == None:
        start = stop + 1
    
    #calculates the number of amino acids in domain
    numInDomain = stop - (start)

    for i in true_label:
        if i != 0:
            domainLabel = i

    for i in predict[start - 1:stop]:
        if i == domainLabel:
            numCorrectLabels += 1
    try:
        domainLabelAcc = numCorrectLabels / numInDomain
    except:
        domainLabelAcc = 0
        
    for i in predict[:start - 1]:
        if i == 0:
            numNotLabels += 1

    for i in predict[stop - 1:]:
        if i == 0:
            numNotLabels += 1

    numNotInDomain = len(true_label) - numInDomain
    try:
        notDomainLabelAcc = numNotLabels / numNotInDomain
    except:
        notDomainLabelAcc = 0 #numNotLabels / len(true_label) #maybe change to 0?

    return [start, stop, domainLabelAcc, notDomainLabelAcc, fullDomainScore]

def matching_labels(true, predict, domainStart, domainStop):
    trueList = []
    predictListDomain = []
    predictListNotdomain = []
    for label in true:
        trueList.append(label)

    for label in predict[domainStart:domainStop]:
        predictListDomain.append(label)

    for label in predict[:domainStart + 1]:
        predictListNotdomain.append(label)
    for label in predict[domainStop:]:
        predictListNotdomain.append(label)

    trueList = set(trueList)
    trueList = list(trueList)
    predictListDomain = set(predictListDomain)
    predictListDomain = list(predictListDomain)
    predictListNotdomain = set(predictListNotdomain)
    predictListNotdomain = list(predictListNotdomain)

    return (trueList, predictListDomain, predictListNotdomain)

def evaluate_positions(true, predict):
    switch1 = True
    switch2 = True
    predStart = None
    trueStart = None
    predStop = None
    trueStop = None
    counter = 0



    for i in range(len(true)):

        if true[i] != 0 and switch2:
            switch2 = False
            trueStart = i
#                 print(f"REAL START POSITION {trueStart}")

        if predict[i] != 0 and switch1:
            switch1 = False
            predStart = i
#                 print(f"PRED START POSITION {predStart}")

        if true[i] == 0 and switch2 == False:
            switch2 = None
            trueStop = i - 1
#                 print(f"REAL STOP POSITION {trueStop}")

        if predict[i] == 0 and switch1 == False:
            switch1 = None
            predStop = i - 1
#                 print(f"PRED STOP POSITION {predStop}")

#             if predict[i] != true[i]:
#                 print(f"{counter}) {true[i]}|=|{predict[i]} >> NOT")
#             else:
#                 print(f"{counter}) {true[i]}|=|{predict[i]}")
        counter += 1

    if switch2 == False and trueStop == None:
        trueStop = len(true)

    if switch1 == False and predStop == None:
        predStop = len(predict)

    if trueStart == None and switch2 == True:
        print("HAAAAAAAAAAAAAAAAAAA NO DOMAIN")
        print(trueStart, trueStop, predStart, predStop)

    return (trueStart, trueStop, predStart, predStop)
#==================================


masterList = []
sampleCounter = 0

for i in range(len(strat_test)):
    labelDict = {}
    sampleDict = {}
    pred_labels = []
    tempList = meatpipe(strat_test.iloc[i]["sequence"], max_length=512)
    label = strat_test.iloc[i]["labels"]
    sequence = strat_test.iloc[i]["sequence"]
    start = strat_test.iloc[i]["start"]
    stop = strat_test.iloc[i]["stop"]
    true_labels = labelLabel(label, sequence, start, stop)
                                                                                                                                                                                             
    
    for aminoDict in tempList:
        labelNum = aminoDict["entity"]
        labelNum = int(labelNum[6:])
        pred_labels.append(labelNum)
        
    positions = evaluate_positions(true_labels, pred_labels)
    trueStart = positions[0]
    trueStop = positions[1]
    predStart = positions[2]
    predStop = positions[3]
    
    eval_metrics = domainAcc(true_labels, pred_labels)
    eval_start = eval_metrics[0]
    eval_stop = eval_metrics[1]
    domainLabelAcc = eval_metrics[2]
    notDomainLabelAcc = eval_metrics[3]
    fullDomainAcc = eval_metrics[4]
    
    domainStart = eval_start - 1
    domainStop = eval_stop
    
    items = matching_labels(true_labels, pred_labels, domainStart, domainStop)
    trueItems = items[0]
    predDomItems = items[1]
    predNotDomItems = items[2]
    
    if notDomainLabelAcc < 0:
        print("NEGATIVE VALUE")
#         print(f"{notDomainLabelAcc} = {eval_metrics[5]} / {eval_metrics[6]}")
#         print(start, stop)
#         print(trueStart, trueStop)
#         print(label)
    
    sampleDict["fullDomain_accuracy"] = float(fullDomainAcc)
    sampleDict["domain_accuracy"] = float(domainLabelAcc)
    sampleDict["notDomain_accuracy"] = float(notDomainLabelAcc)
    sampleDict["true_labels"] = list(trueItems)
    sampleDict["predDomain_labels"] = list(predDomItems)
    sampleDict["predNotDomain_labels"] = list(predNotDomItems)
    sampleDict["trueStart"] = trueStart
    sampleDict["trueStop"] = trueStop
    sampleDict["predStart"] = float(predStart)
    sampleDict["predStop"] = float(predStop)
    
    
    for i in trueItems:
        if i != 0:
            attentionLabel = str(i)
    
    
    labelDict[attentionLabel] = sampleDict

    masterList.append(labelDict)

    sampleCounter += 1
    if sampleCounter % 100 == 0:
        print(f"{sampleCounter} / {len(strat_test)}")
#     print(sampleDict)
#     print("=============================")



HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 18
100 / 96420
200 / 96420
300 / 96420
400 / 96420
HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 56
500 / 96420
600 / 96420
HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 55
HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 32
700 / 96420
800 / 96420
HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 10
900 / 96420
1000 / 96420
1100 / 96420
1200 / 96420
HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 29
1300 / 96420
1400 / 96420
HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 39
HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 64
1500 / 96420
1600 / 96420
1700 / 96420
1800 / 96420
1900 / 96420
HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 64
2000 / 96420
HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 55
2100 / 96420
HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 57
2200 / 96420
HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 52
2300 / 96420
2400 / 96420
HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 44
2500 / 96420
HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 87
HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None

In [62]:
#parses through masterList to group the metrics of each sample by label for easier evaluation
masterEvalDict = {}
sampleCounter = 0
for labelDictLevel in masterList:
    lab = list(labelDictLevel.keys())
    lab = lab[0]
    labelDict = labelDictLevel[lab]
    if lab not in masterEvalDict:
        newDict = {
            "fullDomain_accuracy": [labelDict["fullDomain_accuracy"]],
            "domain_accuracy": [labelDict["domain_accuracy"]],
            "notDomain_accuracy": [labelDict["notDomain_accuracy"]],
            "true_labels": list(labelDict["true_labels"]),
            "predDomain_labels": list(labelDict["predDomain_labels"]),
            "predNotDomain_labels": list(labelDict["predNotDomain_labels"])
        }
        if labelDict["trueStart"] == labelDict["predStart"]:
            newDict["start"] = [1]
        else:
            newDict["start"] = [0]
            
        if labelDict["trueStop"] == labelDict["predStop"]:
            newDict["stop"] = [1]
        else:
            newDict["stop"] = [0]

        masterEvalDict[lab] = newDict
    else:
        innerLabelDict = masterEvalDict[lab]
        innerLabelDict["fullDomain_accuracy"].append(labelDict["fullDomain_accuracy"])
        innerLabelDict["domain_accuracy"].append(labelDict["domain_accuracy"])
        innerLabelDict["notDomain_accuracy"].append(labelDict["notDomain_accuracy"])

        innerLabelDict["true_labels"] = list(innerLabelDict["true_labels"])
        innerLabelDict["true_labels"] += labelDict["true_labels"]
        innerLabelDict["true_labels"] = set(innerLabelDict["true_labels"])
        

        innerLabelDict["predDomain_labels"] = list(innerLabelDict["predDomain_labels"])
        innerLabelDict["predDomain_labels"] += labelDict["predDomain_labels"]
        innerLabelDict["predDomain_labels"] = set(innerLabelDict["predDomain_labels"])
        

        innerLabelDict["predNotDomain_labels"] = list(innerLabelDict["predNotDomain_labels"])
        innerLabelDict["predNotDomain_labels"] += labelDict["predNotDomain_labels"]
        innerLabelDict["predNotDomain_labels"] = set(innerLabelDict["predNotDomain_labels"])
        
        if labelDict["trueStart"] == labelDict["predStart"]:
            innerLabelDict["start"].append(1)
        else:
            innerLabelDict["start"].append(0)
            
        if labelDict["trueStop"] == labelDict["predStop"]:
            innerLabelDict["stop"].append(1)
        else:
            innerLabelDict["stop"].append(0)
    
    sampleCounter += 1
    if sampleCounter % 100 == 0:
        print(f"{sampleCounter} / {len(strat_test)}")

100 / 96420
200 / 96420
300 / 96420
400 / 96420
500 / 96420
600 / 96420
700 / 96420
800 / 96420
900 / 96420
1000 / 96420
1100 / 96420
1200 / 96420
1300 / 96420
1400 / 96420
1500 / 96420
1600 / 96420
1700 / 96420
1800 / 96420
1900 / 96420
2000 / 96420
2100 / 96420
2200 / 96420
2300 / 96420
2400 / 96420
2500 / 96420
2600 / 96420
2700 / 96420
2800 / 96420
2900 / 96420
3000 / 96420
3100 / 96420
3200 / 96420
3300 / 96420
3400 / 96420
3500 / 96420
3600 / 96420
3700 / 96420
3800 / 96420
3900 / 96420
4000 / 96420
4100 / 96420
4200 / 96420
4300 / 96420
4400 / 96420
4500 / 96420
4600 / 96420
4700 / 96420
4800 / 96420
4900 / 96420
5000 / 96420
5100 / 96420
5200 / 96420
5300 / 96420
5400 / 96420
5500 / 96420
5600 / 96420
5700 / 96420
5800 / 96420
5900 / 96420
6000 / 96420
6100 / 96420
6200 / 96420
6300 / 96420
6400 / 96420
6500 / 96420
6600 / 96420
6700 / 96420
6800 / 96420
6900 / 96420
7000 / 96420
7100 / 96420
7200 / 96420
7300 / 96420
7400 / 96420
7500 / 96420
7600 / 96420
7700 / 96420
7800 / 9

In [64]:
#compiles the metrics for each label into a label average
resultDict = {}
sampleCounter = 0

for eachLabel in masterEvalDict:
    innerLabDict = {}

    metricData = masterEvalDict[eachLabel]
    fullDom_acc = sum(metricData["fullDomain_accuracy"]) / len(metricData["fullDomain_accuracy"])
    dom_acc = sum(metricData["domain_accuracy"]) / len(metricData["domain_accuracy"])
    notDom_acc = sum(metricData["notDomain_accuracy"]) / len(metricData["notDomain_accuracy"])

    true_lab = list(metricData["true_labels"])
    predDomLab = list(metricData["predDomain_labels"])
    predNotdomLab = list(metricData["predNotDomain_labels"])
    
    start_acc = sum(metricData["start"]) / len(metricData["start"])
    stop_acc = sum(metricData["stop"]) / len(metricData["stop"])
    
    innerLabDict["fullDomain_accuracy"] = fullDom_acc
    innerLabDict["domain_accuracy"] = dom_acc
    innerLabDict["notDomain_accuracy"] = notDom_acc

    innerLabDict["true_labels"] = true_lab
    innerLabDict["pred_domain_labels"] = predDomLab
    innerLabDict["pred_notDomain_labels"] = predNotdomLab
    
    innerLabDict["start_acc"] = start_acc
    innerLabDict["stop_acc"] = stop_acc

    resultDict[eachLabel] = innerLabDict
    
    sampleCounter += 1
    print(f"{sampleCounter} / {len(masterEvalDict)}")

1 / 59
2 / 59
3 / 59
4 / 59
5 / 59
6 / 59
7 / 59
8 / 59
9 / 59
10 / 59
11 / 59
12 / 59
13 / 59
14 / 59
15 / 59
16 / 59
17 / 59
18 / 59
19 / 59
20 / 59
21 / 59
22 / 59
23 / 59
24 / 59
25 / 59
26 / 59
27 / 59
28 / 59
29 / 59
30 / 59
31 / 59
32 / 59
33 / 59
34 / 59
35 / 59
36 / 59
37 / 59
38 / 59
39 / 59
40 / 59
41 / 59
42 / 59
43 / 59
44 / 59
45 / 59
46 / 59
47 / 59
48 / 59
49 / 59
50 / 59
51 / 59
52 / 59
53 / 59
54 / 59
55 / 59
56 / 59
57 / 59
58 / 59
59 / 59


In [79]:
resultDict

{'2': {'fullDomain_accuracy': 0.4576951344848438,
  'domain_accuracy': 0.40806765185650623,
  'notDomain_accuracy': 0.6160044400988957,
  'true_labels': [0, 2],
  'pred_domain_labels': [0, 33, 2, 32, 3, 5, 7, 47, 23, 24, 57, 58, 28, 25],
  'pred_notDomain_labels': [0,
   2,
   3,
   4,
   5,
   7,
   14,
   17,
   22,
   23,
   24,
   25,
   26,
   28,
   29,
   30,
   32,
   33,
   34,
   35,
   37,
   38,
   40,
   44,
   45,
   46,
   47,
   48,
   50,
   53,
   57,
   58],
  'start_acc': 0.0,
  'stop_acc': 0.0},
 '58': {'fullDomain_accuracy': 0.5968739396224102,
  'domain_accuracy': 0.3625884689647297,
  'notDomain_accuracy': 0.7005999797022135,
  'true_labels': [0, 58],
  'pred_domain_labels': [0,
   32,
   2,
   3,
   33,
   5,
   38,
   35,
   7,
   47,
   48,
   17,
   53,
   23,
   24,
   25,
   58,
   28,
   57],
  'pred_notDomain_labels': [0,
   2,
   3,
   5,
   6,
   7,
   8,
   9,
   11,
   12,
   13,
   15,
   17,
   20,
   21,
   23,
   24,
   25,
   26,
   28,
   29,
 

In [98]:
num = 10
for i in resultDict[num]:
    print(f"{i}: {resultDict[num][i]}")

fullDomain_accuracy: 0.6433898992023841
domain_accuracy: 0.641329544782269
notDomain_accuracy: -2.592142937261512
true_labels: [0, 10]
pred_domain_labels: [0, 10, 53, 25, 58]
pred_notDomain_labels: [0, 10]


In [77]:
#fullDomain Accuracy

i80 = []
i60 = []
i40 = []
i20 = []

for i in resultDict:
    tempDict = {}
    dic = resultDict[i]
    tempDict[i] = resultDict[i]
    if (dic["fullDomain_accuracy"]) >= .8:
        i80.append(tempDict)
    elif (dic["fullDomain_accuracy"]) >= .6:
        i60.append(tempDict)
    elif (dic["fullDomain_accuracy"]) >= .4:
        i40.append(tempDict)
    elif (dic["fullDomain_accuracy"]) >= .2:
        i20.append(tempDict)

csvDict = {"label": [],
          "fullDomain_accuracy": []
          }

domainList = ['AA_not_domain', 'ABM', 'ACYLTRANSFERASE', 'ACYL_ADENYLATING', 'ACYL_COA_DEHYDROGENASE', 'ADENYLATION', 'AMIDOTRANSFERASE', 'AMINOTRANSFERASE', 'AMR_116', 'AMR_201', 'AMR_211', 'AMR_246', 'AMR_261', 'AMR_262', 'APE_AT', 'AfsA', 'AgrB', 'Asparagine_synthase', 'BACTERIOCIN_16', 'CAPREOMYCIDINE_SYNTHASE_1', 'CDPS', 'CHLORINATION', 'CLF', 'CONDENSATION', 'C_METHYLTRANSFERASE', 'DEHYDRATASE', 'DhpF', 'ECTOINE_SYNTHASE', 'ENOYLREDUCTASE', 'FORMYLTRANSFERASE', 'GLYCOSYLTRANSFERASE', 'IPM', 'KETOREDUCTASE', 'KETOSYNTHASE', 'KSA', 'LanB', 'LanC', 'LanKC', 'LanM', 'LasI', 'Lasso_precursors', 'LazA', 'LazC', 'MelC2', 'NISa', 'NISb', 'NISc', 'NITROREDUCTASE', 'PPTASE', 'PabC', 'PepM', 'PhzD', 'Putative_lantipeptide', 'Putative_lasso_peptide', 'Putative_phosphoramidate', 'ST_1', 'SbnA', 'THIOESTERASE', 'THIOLATION', 'Transglutaminase']

print("===i80===")
for i in i80:
#     print(i)
    for p in i:
        print(p, i[p]["fullDomain_accuracy"])
        csvDict["label"].append(domainList[int(p)])
        csvDict["fullDomain_accuracy"].append(i[p]["fullDomain_accuracy"])
print("===i60===")
for i in i60:
#     print(i)
    for p in i:
        print(p, i[p]["fullDomain_accuracy"])
        csvDict["label"].append(domainList[int(p)])
        csvDict["fullDomain_accuracy"].append(i[p]["fullDomain_accuracy"])
print("===i40===")
for i in i40:
#     print(i)
    for p in i:
        print(p, i[p]["fullDomain_accuracy"])
        csvDict["label"].append(domainList[int(p)])
        csvDict["fullDomain_accuracy"].append(i[p]["fullDomain_accuracy"])
print("===i20===")
for i in i20:
#     print(i)
    for p in i:
        print(p, i[p]["fullDomain_accuracy"])
        csvDict["label"].append(domainList[int(p)])
        csvDict["fullDomain_accuracy"].append(i[p]["fullDomain_accuracy"])


===i80===
35 0.8884683133358449
===i60===
32 0.6313884517736658
34 0.6039916371272273
===i40===
2 0.4576951344848438
58 0.5968739396224102
5 0.4773354089551971
30 0.5192282397198564
28 0.5966694246420666
9 0.40319114756383395
23 0.5501324816667955
44 0.41641125543727986
55 0.4977396441736124
3 0.4800080092874754
57 0.4514117369217191
45 0.4082879456572668
46 0.4250270205072633
36 0.429661393337344
59 0.5282408768573664
27 0.49927999384489524
25 0.45728624116604893
50 0.4040064326254271
29 0.5582947223167593
24 0.44306045570427804
47 0.410407723429879
10 0.48113360523554677
1 0.4418995674928282
===i20===
17 0.35827140456549805
33 0.3504383748408206
48 0.30328655587923214
26 0.2849499927099932
40 0.32515271702049187
39 0.28658804239919483
7 0.3411208390654438
4 0.267945174736306
49 0.2659552923232213
56 0.2535223476956692
37 0.2523713346416921
11 0.25054732883075176
12 0.21500476153366024
14 0.24443160441052186
41 0.39619034915532786
21 0.36386975785551934
6 0.2770003611586277
31 0.38906

In [78]:
import pandas as pd

df = pd.DataFrame(csvDict)
df.to_csv("/mnt/storage/grid/home/eric/hmm2bert/pullin_parsed_data/pullin>1000_whiteSpace_fullDomain_acc(epch9).csv")

In [75]:
#domain Accuracy

i80 = []
i60 = []
i40 = []
i20 = []

for i in resultDict:
    tempDict = {}
    dic = resultDict[i]
    tempDict[i] = resultDict[i]

    if dic["domain_accuracy"] >= .8:
        i80.append(tempDict)
    elif dic["domain_accuracy"] >= .6:
        i60.append(tempDict)
    elif dic["domain_accuracy"] >= .4:
        i40.append(tempDict)
    elif dic["domain_accuracy"] >= .2:
        i20.append(tempDict)
        

csvDict2 = {"label": [],
          "domain_accuracy": []
          }

domainList = ['AA_not_domain', 'ABM', 'ACYLTRANSFERASE', 'ACYL_ADENYLATING', 'ACYL_COA_DEHYDROGENASE', 'ADENYLATION', 'AMIDOTRANSFERASE', 'AMINOTRANSFERASE', 'AMR_116', 'AMR_201', 'AMR_211', 'AMR_246', 'AMR_261', 'AMR_262', 'APE_AT', 'AfsA', 'AgrB', 'Asparagine_synthase', 'BACTERIOCIN_16', 'CAPREOMYCIDINE_SYNTHASE_1', 'CDPS', 'CHLORINATION', 'CLF', 'CONDENSATION', 'C_METHYLTRANSFERASE', 'DEHYDRATASE', 'DhpF', 'ECTOINE_SYNTHASE', 'ENOYLREDUCTASE', 'FORMYLTRANSFERASE', 'GLYCOSYLTRANSFERASE', 'IPM', 'KETOREDUCTASE', 'KETOSYNTHASE', 'KSA', 'LanB', 'LanC', 'LanKC', 'LanM', 'LasI', 'Lasso_precursors', 'LazA', 'LazC', 'MelC2', 'NISa', 'NISb', 'NISc', 'NITROREDUCTASE', 'PPTASE', 'PabC', 'PepM', 'PhzD', 'Putative_lantipeptide', 'Putative_lasso_peptide', 'Putative_phosphoramidate', 'ST_1', 'SbnA', 'THIOESTERASE', 'THIOLATION', 'Transglutaminase']

print("===i80===")
for i in i80:
#     print(i)
    for p in i:
        print(p, i[p]["domain_accuracy"])
        csvDict2["label"].append(domainList[int(p)])
        csvDict2["domain_accuracy"].append(i[p]["domain_accuracy"])
print("===i60===")
for i in i60:
#     print(i)
    for p in i:
        print(p, i[p]["domain_accuracy"])
        csvDict2["label"].append(domainList[int(p)])
        csvDict2["domain_accuracy"].append(i[p]["domain_accuracy"])
print("===i40===")
for i in i40:
#     print(i)
    for p in i:
        print(p, i[p]["domain_accuracy"])
        csvDict2["label"].append(domainList[int(p)])
        csvDict2["domain_accuracy"].append(i[p]["domain_accuracy"])
print("===i20===")
for i in i20:
#     print(i)
    for p in i:
        print(p, i[p]["domain_accuracy"])
        csvDict2["label"].append(domainList[int(p)])
        csvDict2["domain_accuracy"].append(i[p]["domain_accuracy"])



===i80===
35 0.922905706321317
===i60===
59 0.6093828855681871
34 0.6046613642726261
===i40===
2 0.40806765185650623
48 0.42867607993093604
30 0.521760264896504
28 0.4394982614677542
32 0.44196503496455924
9 0.433713525029486
44 0.4298261216366602
55 0.5018924912042765
45 0.41819681086643773
46 0.44434342725241804
36 0.4355434892050209
27 0.5055173441105751
47 0.443159512977647
10 0.4837003274823915
1 0.5023859885430279
===i20===
58 0.3625884689647297
5 0.3938023220442825
17 0.3601502372689123
33 0.29004507777528127
26 0.27982567835113803
40 0.2884972497331793
39 0.24828579056246422
7 0.23096483013961963
4 0.26553179622482836
23 0.38460417832690014
49 0.2615427019248544
56 0.21860325835639793
37 0.24868724676855278
3 0.39197747649000975
57 0.2984728260096482
14 0.2170364838932869
41 0.3838856405972542
21 0.3118725212298015
25 0.39903318289262785
50 0.3857071932708305
29 0.3962661370140107
24 0.3180261305707753
31 0.3606145520739758
43 0.32018547314568546
18 0.30424743334727894
38 0.334

In [76]:
import pandas as pd

df1 = pd.DataFrame(csvDict2)
df1.to_csv("/mnt/storage/grid/home/eric/hmm2bert/pullin_parsed_data/pullin>1000_whiteSpace_domain_acc(epch9).csv")

In [72]:
#notdomain Accuracy

i80 = []
i60 = []
i40 = []
i20 = []

for i in resultDict:
    tempDict = {}
    dic = resultDict[i]
    tempDict[i] = resultDict[i]

    if dic["notDomain_accuracy"] >= .8:
        i80.append(tempDict)
    elif dic["notDomain_accuracy"] >= .6:
        i60.append(tempDict)
    elif dic["notDomain_accuracy"] >= .4:
        i40.append(tempDict)
    elif dic["notDomain_accuracy"] >= .2:
        i20.append(tempDict)
        

csvDict3 = {"label": [],
          "notDomain_accuracy": []
          }

domainList = ['AA_not_domain', 'ABM', 'ACYLTRANSFERASE', 'ACYL_ADENYLATING', 'ACYL_COA_DEHYDROGENASE', 'ADENYLATION', 'AMIDOTRANSFERASE', 'AMINOTRANSFERASE', 'AMR_116', 'AMR_201', 'AMR_211', 'AMR_246', 'AMR_261', 'AMR_262', 'APE_AT', 'AfsA', 'AgrB', 'Asparagine_synthase', 'BACTERIOCIN_16', 'CAPREOMYCIDINE_SYNTHASE_1', 'CDPS', 'CHLORINATION', 'CLF', 'CONDENSATION', 'C_METHYLTRANSFERASE', 'DEHYDRATASE', 'DhpF', 'ECTOINE_SYNTHASE', 'ENOYLREDUCTASE', 'FORMYLTRANSFERASE', 'GLYCOSYLTRANSFERASE', 'IPM', 'KETOREDUCTASE', 'KETOSYNTHASE', 'KSA', 'LanB', 'LanC', 'LanKC', 'LanM', 'LasI', 'Lasso_precursors', 'LazA', 'LazC', 'MelC2', 'NISa', 'NISb', 'NISc', 'NITROREDUCTASE', 'PPTASE', 'PabC', 'PepM', 'PhzD', 'Putative_lantipeptide', 'Putative_lasso_peptide', 'Putative_phosphoramidate', 'ST_1', 'SbnA', 'THIOESTERASE', 'THIOLATION', 'Transglutaminase']

print("===i80===")
for i in i80:
#     print(i)
    for p in i:
        print(p, i[p]["notDomain_accuracy"])
        csvDict3["label"].append(domainList[int(p)])
        csvDict3["notDomain_accuracy"].append(i[p]["notDomain_accuracy"])
print("===i60===")
for i in i60:
#     print(i)
    for p in i:
        print(p, i[p]["notDomain_accuracy"])
        csvDict3["label"].append(domainList[int(p)])
        csvDict3["notDomain_accuracy"].append(i[p]["notDomain_accuracy"])
print("===i40===")
for i in i40:
#     print(i)
    for p in i:
        print(p, i[p]["notDomain_accuracy"])
        csvDict3["label"].append(domainList[int(p)])
        csvDict3["notDomain_accuracy"].append(i[p]["notDomain_accuracy"])
print("===i20===")
for i in i20:
#     print(i)
    for p in i:
        print(p, i[p]["notDomain_accuracy"])
        csvDict3["label"].append(domainList[int(p)])
        csvDict3["notDomain_accuracy"].append(i[p]["notDomain_accuracy"])


===i80===
26 0.8319195949143526
30 0.9113267836691589
39 0.8729722410993923
49 0.8799108965945573
57 0.8143083198025104
14 0.8117058235047103
27 0.9710232625571626
29 0.8300773923321882
6 0.9451262618291333
51 0.9095113612963526
43 0.8162990695103397
18 0.9333700602542765
===i60===
2 0.6160044400988957
58 0.7005999797022135
5 0.7303550415915867
28 0.7674708617002134
32 0.741833303866938
7 0.6049045795159611
23 0.7809435043792882
56 0.6685265289387909
3 0.7507299197204561
11 0.7479891284639724
12 0.6216342408923545
21 0.7713591658142271
34 0.7404236104571775
24 0.6423377725573116
8 0.7944855023904508
31 0.6997339010968244
22 0.6464601981768804
13 0.6008726641460956
16 0.7462207300491273
===i40===
17 0.5285718426341856
33 0.5233082860116292
40 0.5990331074136818
4 0.4991174873094625
55 0.4875919060881039
45 0.5303765582505622
46 0.49903874046720675
36 0.5003447814439411
59 0.48657964840607576
41 0.5319284163422167
25 0.5876771069691155
50 0.4910393004545077
47 0.5095869535504906
42 0.526

In [73]:
import pandas as pd

df1 = pd.DataFrame(csvDict3)
df1.to_csv("/mnt/storage/grid/home/eric/hmm2bert/pullin_parsed_data/pullin>1000_whiteSpace_notDomain_acc(epch9).csv")

In [74]:
finalZ = {}
finalZ["0"] = resultDict
with open("/mnt/storage/grid/home/eric/hmm2bert/pullin_parsed_data/pullin>1000_whiteSpace_results.json", "w") as file:
    json.dump(finalZ, file)

TypeError: Object of type 'int64' is not JSON serializable