In [1]:
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
import os
import warnings
from pytorch_lightning.callbacks import ModelCheckpoint

from transformers import BertModel, BertTokenizer
from abc import ABC
import pytorch_lightning as pl
import torch
from transformers import BertForTokenClassification, BertConfig, Adafactor, AdamW
from transformers import pipeline
from sklearn.metrics import balanced_accuracy_score, accuracy_score
import torch.nn.functional as F
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import label_ranking_average_precision_score, average_precision_score, f1_score
from datasets import load_dataset
from datasets import Dataset
from datasets import load_from_disk
import pandas as pd
from sklearn.model_selection import train_test_split
import json

In [2]:
class TokenDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in
                self.encodings.items()}  # keys are input_ids, token_type_ids, attention_mask, labels, values are stored as a list of lists
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return (len(self.labels))


class newDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe):
        self.encodings = {"input_ids": dataframe["input_ids"], "token_type_ids": dataframe["token_type_ids"],
                          "attention_mask": dataframe["attention_mask"]}
        self.labels = {"labels": dataframe["labels"]}

    def __getitem__(self, idx):
        item = {"input_ids": self.encodings["input_ids"][idx], "token_type_ids": self.encodings["token_type_ids"][idx],
                "attention_mask": self.encodings["attention_mask"][idx], "labels": self.labels["labels"][idx]}
        return item

    def __len__(self):
        return (len(self.labels["labels"]))

class BertTokClassification(pl.LightningModule, ABC):
    def __init__(
            self,
            config: BertConfig = None,
            pretrained_dir: str = None,
            use_adafactor: bool = False,
            learning_rate=3e-5,
            **kwargs
    ):
        super().__init__()
        self.learning_rate = learning_rate
        self.use_adafactor = use_adafactor
        if pretrained_dir is None:
            self.bert = BertForTokenClassification(config, **kwargs)
        else:
            self.bert = BertForTokenClassification.from_pretrained(pretrained_dir, **kwargs)

    def forward(self, input_ids, attention_mask, labels):
        return self.bert(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        outputs = self(input_ids=input_ids.to(self.device), attention_mask=attention_mask.to(self.device), labels=labels.to(self.device, dtype=torch.int64))
        loss = outputs.loss


        def get_acc(labels, logits):
            sumList = []
            for i in range(len(labels)):
                y_pred = torch.max(logits[i], 1).indices
                score = accuracy_score(labels[i], y_pred)
                sumList.append(score)
            avg = sum(sumList) / len(labels)
            return avg


        accuracy1 = get_acc(labels.cpu(), outputs.logits.cpu())

        # accuracy = balanced_accuracy_score(master[0], master[1])
        self.log(
            "train_batch_accuracy",
            accuracy1,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        outputs = self(input_ids=input_ids.to(self.device), attention_mask=attention_mask.to(self.device), labels=labels.to(self.device, dtype=torch.int64))
        loss = outputs.loss
        self.log(
            "val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )

        # def get_balanced_accuracy(labels, logits):
        #     y_pred = torch.max(logits, 1).indices
        #     score = balanced_accuracy_score(labels, y_pred)
        #     return score
        #
        # def label_average_precision(labels, logits):
        #     y_pred = torch.max(prob, 1).indices
        #     score = label_ranking_average_precision_score(labels, y_pred)
        #     return score
        # def f1_calc(labels, logits):
        #     sumList = []
        #     for i in range(len(labels)):
        #         y_pred = torch.max(logits[i], 1).indices
        #         score = f1_score(labels[i], y_pred, average='macro')
        #         sumList.append(score)
        #     avg = sum(sumList) / len(labels)
        #     return avg




        # """
        # 1. Iterate over the batch:
        #     for each_label in labels.cpu():
        #         shorten the length of the list to its true length using attention list and the function [true_length]
        #
        #     for each_logit in outputs.logits.cpu():
        #         use torch.max(outputs.logits.cpu()[0], 1) to get the indices for each logit (best label prediction)
        #         shorten the indices to its proper label length
        #         compare the indices to the labels
        # """
        def true_length(y_attention_mask):  # finds the start and stop of the actual sequence
            switch = False
            start = 0
            stop = 0
            counter = 0
            attention_mask = list(y_attention_mask)
            for i in attention_mask:
                if int(i) == 1 and switch == False:
                    switch = True
                    start = counter
                elif int(i) == 0 and switch == True:
                    stop = counter
                    break
                elif counter == 511:
                    stop = 512
                counter += 1
            return (start, stop)


        def short_clean(attention_mask, labels, logits): #attention_mask, labels.cpu(), outputs.logits.cpu()

            def true_length(y_attention_mask):  # finds the start and stop of the actual sequence
                switch = False
                start = 0
                stop = 0
                counter = 0
                attention_mask = list(y_attention_mask)
                for i in attention_mask:
                    if int(i) == 1 and switch == False:
                        switch = True
                        start = counter
                    elif int(i) == 0 and switch == True:
                        stop = counter
                        break
                    counter += 1
                return (start, stop)

            masterPred = []
            masterTrue = []
            for batch_index in range(len(labels)):
                real_len = true_length(attention_mask[batch_index])
                predIndecies = torch.max(outputs.logits.cpu()[batch_index], 1).indices
                start = real_len[0]
                stop = real_len[1]
                currentTrue = torch.LongTensor(labels[batch_index][start:stop])
                currentPred = torch.LongTensor(predIndecies[start:stop])
                if len(currentTrue) == 0:
                    masterTrue.append(currentTrue.tolist())
                    masterPred.append(currentPred.tolist())
                    print(f"CURRENT-PRED LEN: {len(currentPred)}")
                    print(f"CURRENT-TRUE LEN:{len(currentTrue)}")

            return (masterTrue, masterPred)

        master = short_clean(attention_mask, labels.cpu(), outputs.logits.cpu())
        print("###################################")
        print(f"MASTER-TRUE: {master[0]}")
        print("###################################")
        print(f"MASTER-PRED: {master[1]}")
        print("###################################")
        print("=======")
        print(f"LABEL LEN: {len(labels.cpu())}")
        for i in range(len(labels.cpu())):
            print(f"SINGLE LABEL LEN: {len(labels.cpu()[i])}")
            print(f"ATTENTION LEN: {len(attention_mask[i])}")
            #print(attention_mask[i])
            print(f"TRUE LEN: {true_length(attention_mask[i])}")
        print("=======")
        print(f"LABELS: {labels.cpu()}")
        print("||||||||||||||||||||||||||||")
        print("=======")
        print(f"LOGITS LEN: {len(outputs.logits.cpu())}")
        for i in outputs.logits.cpu():
            print(f"SINGLE LOGIT LEN: {len(i)}")
        print(f"LOGIT SINGLE LIST LEN: {len(outputs.logits.cpu()[0][0])}")
        b_logit = torch.max(outputs.logits.cpu()[0], 1)
        b_logit_indices = torch.max(outputs.logits.cpu()[0], 1).indices
        print(f"LOGIT BEST: {b_logit}")
        print(f"LOGIT BEST INDICES: {b_logit_indices}")
        print("=======")
        print(f"LOGITS: {outputs.logits.cpu()}")

        # accuracy = label_average_precision(labels.cpu(), logits=outputs.logits.cpu()) #replaced get_balanced_accuracy(labels.cpu(), logits=outputs.logits.cpu()) with label ranking average precision

        # def balanced_accuracy_score(labels, logits):
        #     sumList = []
        #     for i in range(len(labels)):
        #         y_predList = []
        #         trueList = []
        #         y_pred = logits[i]
        #         previous = 0
        #         for lab in labels[i]:
        #             if lab == -100:
        #                 y_predList.append(y_pred[previous])
        #                 trueList.append(previous)
        #             else:
        #                 previous = lab
        #                 y_predList.append(y_pred[previous])
        #                 trueList.append(previous)
        #
        #         num = average_precision_score(trueList, y_predList)
        #         sumList.append(num)
        #     big = sum(sumList) / len(labels)
        #     return big

        def get_bal_acc(labels, logits):
            sumList = []
            for i in range(len(labels)):
                y_pred = torch.max(logits[i], 1).indices
                score = balanced_accuracy_score(labels[i], y_pred)
                sumList.append(score)
            avg = sum(sumList) / len(labels)
            return avg

        accuracy = get_bal_acc(labels.cpu(), outputs.logits.cpu())

        self.log(
            "val_accuracy",
            accuracy,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return {"val_loss": loss}

    def configure_optimizers(self):
        if self.use_adafactor:
            return Adafactor(
                self.parameters(),
                lr=self.learning_rate,
                eps=(1e-30, 1e-3),
                clip_threshold=1.0,
                decay_rate=-0.8,
                beta1=None,
                weight_decay=0.0,
                relative_step=False,
                scale_parameter=False,
                warmup_init=False)
        else:
            return AdamW(self.parameters(), lr=self.learning_rate)

    def save_pretrained(self, pretrained_dir):
        self.bert.save_pretrained(self, prtrained_dir)

    def predict_classes(self, input_ids, attention_mask, return_logits=False):
        output = self.bert(input_ids=input_ids.to(self.device), attention_mask=attention_mask)
        if return_logits:
            return output.logits
        else:
            probabilities = F.sigmoid(output.logits)
            predictions = torch.argmax(probabilities)
            return {"probabilities": probabilities, "predictions": predictions}

    def get_attention(self, input_ids, attention_mask, specific_attention_head: int = None):
        output = self.bert(inputs_ids=input_ids.to(self.device), attention_mask=attention_mask)
        if specific_attention_head is not None:
            last_layer = output.attentions[-1]  # grabs the last layer
            all_last_attention_heads = [torch.max(this_input[specific_attention_head], axis=0)[0].indices for this_input in last_layer]
            return all_last_attention_heads
        return output.attentions

In [3]:
####################################################
gpu_idx = 0
num_labels = 60

#model to evaluate with
model_path = "/mnt/storage/grid/home/eric/hmm2bert/models/pullin/pullin>1000_whiteSpace_best_loss-epch9{1GPU}.pt"
#encoded csv to use
data_folder = "pullin_parsed_data"
encoded_label_filename = "encoded_parsed_pullin_noDupes_whiteSpace>1000_withAA_not_domain.csv"
encoded_csv = f"/mnt/storage/grid/home/eric/hmm2bert/{data_folder}/{encoded_label_filename}"
####################################################



os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "true"

#load in tokenizer, model (eval mode) and pipeline
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
model = torch.load(model_path)
model.eval()
meatpipe = pipeline(task='ner', model=model.bert, tokenizer=tokenizer, device=gpu_idx)

# load csv to pandas dataframe
df = pd.read_csv(encoded_csv)
print(df.columns)
df = df.drop(["Unnamed: 0"], axis=1)
print(df.columns)

#split the dataset into train and test, this produces a list with the row positions
strat_train, strat_test = train_test_split(df, test_size=.2, stratify=df['labels'], random_state=420)

Index(['Unnamed: 0', 'sequence', 'labels', 'start', 'stop'], dtype='object')
Index(['sequence', 'labels', 'start', 'stop'], dtype='object')


In [60]:
#start count @ 0
def labelLabel(label, sequence, start, stop):
    sequence = sequence.split()
    sequence = "".join(sequence)
    domainList = list(range(start, stop + 1))
    sequenceIndexCounter = 0
    trueList = []

    for currentNum in range(len(sequence)):  # iterate over each amino acid in a single sequence
        # ====
        shiftedNum = currentNum  #CHANGED FROM currentNum+1 TO currentNum # index that we compare to as sequence index starts at zero but start/stop starts at 1
        # ====
        if shiftedNum == domainList[sequenceIndexCounter]:
            trueList.append(label)
            if domainList[sequenceIndexCounter] == domainList[-1]:
                    sequenceIndexCounter = 0

            else:
                sequenceIndexCounter += 1
        else:
            trueList.append(0)
    return trueList

def domainAcc(true_label, predict):  # finds the start and stop of the actual sequence (range style where stop is actually stop - 1)
    switch = False
    start = None
    stop = len(true_label)
    counter = 0 #CHANGED FROM 1 TO 0
    domainLabel = None
    numCorrectLabels = 0
    numNotLabels = 0
    fullDomainCounter = 0
    
    #gets accuracy of prediction of whole sequence
    for i in range(len(true_label)):
        if true_label[i] == predict[i]:
            fullDomainCounter += 1
    fullDomainScore = fullDomainCounter/len(predict)

    #sets the start and stop of domain based on true label list
    for i in true_label:
        if int(i) != 0 and switch == False:
            switch = True
            start = counter
        elif int(i) == 0 and switch == True:
            stop = counter - 1
            break
        counter += 1
    
    #if domain is out of range and whole sequence is notDomain set start equal to stop + 1
    if start == None:
        start = stop + 1
    
    #calculates the number of amino acids in domain
    numInDomain = stop - (start)

    for i in true_label:
        if i != 0:
            domainLabel = i

    for i in predict[start - 1:stop]:
        if i == domainLabel:
            numCorrectLabels += 1
    try:
        domainLabelAcc = numCorrectLabels / numInDomain
    except:
        domainLabelAcc = 0
        
    for i in predict[:start - 1]:
        if i == 0:
            numNotLabels += 1

    for i in predict[stop - 1:]:
        if i == 0:
            numNotLabels += 1

    numNotInDomain = len(true_label) - numInDomain
    try:
        notDomainLabelAcc = numNotLabels / numNotInDomain
    except:
        notDomainLabelAcc = 0 #numNotLabels / len(true_label) #maybe change to 0?

    return [start, stop, domainLabelAcc, notDomainLabelAcc, fullDomainScore]

def matching_labels(true, predict, domainStart, domainStop):
    trueList = []
    predictListDomain = []
    predictListNotdomain = []
    for label in true:
        trueList.append(label)

    for label in predict[domainStart:domainStop]:
        predictListDomain.append(label)

    for label in predict[:domainStart + 1]:
        predictListNotdomain.append(label)
    for label in predict[domainStop:]:
        predictListNotdomain.append(label)

    trueList = set(trueList)
    trueList = list(trueList)
    predictListDomain = set(predictListDomain)
    predictListDomain = list(predictListDomain)
    predictListNotdomain = set(predictListNotdomain)
    predictListNotdomain = list(predictListNotdomain)

    return (trueList, predictListDomain, predictListNotdomain)

def evaluate_positions(true, predict):
    switch1 = True
    switch2 = True
    predStart = None
    trueStart = None
    predStop = None
    trueStop = None
    counter = 0



    for i in range(len(true)):

        if true[i] != 0 and switch2:
            switch2 = False
            trueStart = i
#                 print(f"REAL START POSITION {trueStart}")

        if predict[i] != 0 and switch1:
            switch1 = False
            predStart = i
#                 print(f"PRED START POSITION {predStart}")

        if true[i] == 0 and switch2 == False:
            switch2 = None
            trueStop = i - 1
#                 print(f"REAL STOP POSITION {trueStop}")

        if predict[i] == 0 and switch1 == False:
            switch1 = None
            predStop = i - 1
#                 print(f"PRED STOP POSITION {predStop}")

#             if predict[i] != true[i]:
#                 print(f"{counter}) {true[i]}|=|{predict[i]} >> NOT")
#             else:
#                 print(f"{counter}) {true[i]}|=|{predict[i]}")
        counter += 1

    if switch2 == False and trueStop == None:
        trueStop = len(true)

    if switch1 == False and predStop == None:
        predStop = len(predict)

    if trueStart == None and switch2 == True:
        print("HAAAAAAAAAAAAAAAAAAA NO DOMAIN")
        print(trueStart, trueStop, predStart, predStop)

    return (trueStart, trueStop, predStart, predStop)
#==================================
def metric_extractor(strat_test):
    
    #parses through each sample and adds metrics to dict which is then appended to a list
    masterList = []
    sampleCounter = 0

    for i in range(len(strat_test)):
        labelDict = {}
        sampleDict = {}
        pred_labels = []
        tempList = meatpipe(strat_test.iloc[i]["sequence"], max_length=512)
        label = strat_test.iloc[i]["labels"]
        sequence = strat_test.iloc[i]["sequence"]
        start = strat_test.iloc[i]["start"]
        stop = strat_test.iloc[i]["stop"]
        true_labels = labelLabel(label, sequence, start, stop)


        for aminoDict in tempList:
            labelNum = aminoDict["entity"]
            labelNum = int(labelNum[6:])
            pred_labels.append(labelNum)

        positions = evaluate_positions(true_labels, pred_labels)
        trueStart = positions[0]
        trueStop = positions[1]
        predStart = positions[2]
        predStop = positions[3]

        eval_metrics = domainAcc(true_labels, pred_labels)
        eval_start = eval_metrics[0]
        eval_stop = eval_metrics[1]
        domainLabelAcc = eval_metrics[2]
        notDomainLabelAcc = eval_metrics[3]
        fullDomainAcc = eval_metrics[4]

        domainStart = eval_start - 1
        domainStop = eval_stop

        items = matching_labels(true_labels, pred_labels, domainStart, domainStop)
        trueItems = items[0]
        predDomItems = items[1]
        predNotDomItems = items[2]

        if notDomainLabelAcc < 0:
            print("NEGATIVE VALUE")
    #         print(f"{notDomainLabelAcc} = {eval_metrics[5]} / {eval_metrics[6]}")
    #         print(start, stop)
    #         print(trueStart, trueStop)
    #         print(label)

        sampleDict["fullDomain_accuracy"] = float(fullDomainAcc)
        sampleDict["domain_accuracy"] = float(domainLabelAcc)
        sampleDict["notDomain_accuracy"] = float(notDomainLabelAcc)
        sampleDict["true_labels"] = list(trueItems)
        sampleDict["predDomain_labels"] = list(predDomItems)
        sampleDict["predNotDomain_labels"] = list(predNotDomItems)
        sampleDict["trueStart"] = trueStart
        sampleDict["trueStop"] = trueStop
        sampleDict["predStart"] = float(predStart)
        sampleDict["predStop"] = float(predStop)


        for i in trueItems:
            if i != 0:
                attentionLabel = str(i)


        labelDict[attentionLabel] = sampleDict

        masterList.append(labelDict)

        sampleCounter += 1
        if sampleCounter % 100 == 0:
            print(f"{sampleCounter} / {len(strat_test)}")
    #     print(sampleDict)
    #     print("=============================")
    return masterList


In [61]:
masterList = metric_extractor(strat_test)

HAAAAAAAAAAAAAAAAAAA NO DOMAIN
None None 0 18
100 / 96420


KeyboardInterrupt: 

In [52]:
finalZ = {"0": masterList}
with open("/mnt/storage/grid/home/eric/hmm2bert/pullin_parsed_data/pullin>1000_TESTING", "w") as file:
    json.dump(finalZ, file)

TypeError: Object of type 'int64' is not JSON serializable

In [23]:
masterList

[{'2': {'fullDomain_accuracy': 0.41479099678456594,
   'domain_accuracy': 0.4155405405405405,
   'notDomain_accuracy': 0.6,
   'true_labels': [0, 2],
   'predDomain_labels': [0, 2],
   'predNotDomain_labels': [0, 2],
   'trueStart': 7.0,
   'trueStop': 303.0,
   'predStart': 0.0,
   'predStop': 112.0}},
 {'58': {'fullDomain_accuracy': 0.6911764705882353,
   'domain_accuracy': 0.21875,
   'notDomain_accuracy': 0.8115942028985508,
   'true_labels': [0, 58],
   'predDomain_labels': [0, 58],
   'predNotDomain_labels': [0, 25, 58, 5],
   'trueStart': 259.0,
   'trueStop': 323.0,
   'predStart': 0.0,
   'predStop': 0.0}},
 {'5': {'fullDomain_accuracy': 0.474609375,
   'domain_accuracy': 0.375,
   'notDomain_accuracy': 0.825,
   'true_labels': [0, 5],
   'predDomain_labels': [0, 5],
   'predNotDomain_labels': [0, 58, 5],
   'trueStart': 19.0,
   'trueStop': 411.0,
   'predStart': 0.0,
   'predStop': 104.0}},
 {'17': {'fullDomain_accuracy': 0.25,
   'domain_accuracy': 0.28227571115973743,
   '

In [58]:
s = json.dumps(masterList)
open("/mnt/storage/grid/home/eric/hmm2bert/pullin_parsed_data/pullin>1000_TESTING", "w").write(s)

TypeError: Object of type 'int64' is not JSON serializable

In [57]:
for g in masterList:
    for i in g:
        print(type(g))
        print(type(i))
        di = g[i]
        for title in di:
            print(type(title), type(di[title]))

<class 'dict'>
<class 'str'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'str'> <class 'list'>
<class 'str'> <class 'list'>
<class 'str'> <class 'list'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'dict'>
<class 'str'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'str'> <class 'list'>
<class 'str'> <class 'list'>
<class 'str'> <class 'list'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'dict'>
<class 'str'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'str'> <class 'list'>
<class 'str'> <class 'list'>
<class 'str'> <class 'list'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'str'> <class 'float'>
<class 'dict'>
<class 