In [2]:
# from functools import partial
import multiprocessing
import os
import random
import string
import sys
from pathlib import Path

from Bio import SeqIO
from functools import partial
from PIL import Image
import numpy as np
from imgaug import augmenters as iaa
import pandas as pd
import cv2
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.engine import DefaultTrainer
from detectron2.evaluation import COCOEvaluator
from detectron2.data import build_detection_test_loader
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.structures import BoxMode
from detectron2.utils.visualizer import Visualizer
from google.colab.patches import cv2_imshow
import tqdm
from collections import Counter

# ProjectRoot = Path(__file__).resolve().parent.parent.parent
ProjectRoot = Path('/home/satish27may/ProteinDomainDetection')
print(f"ProjectRoot: {str(ProjectRoot)}")

sys.path.append(str(ProjectRoot))

def protein_seq2image(item=None, color_map=None):
    """
    Create an image from a sequence
    """
    sequence,img_name, img_h, img_w = item
    assert len(sequence)<=img_w, f"!! Len of sequence({len(sequence)}) should be less than img_w({img_w}), "
    image = np.full((img_h, img_w,3), (500,500,500))
    for index in range(len(sequence)):
        image[:, index, :] = color_map[sequence[index]]
    pil_image = Image.fromarray(image.astype(np.uint8))
    assert pil_image.size == (img_w, img_h), f"{pil_image.size}!=({img_w},{img_h})"
    pil_image.save(img_name)
    
def protein_seq2image_seqlen(item=None, color_map=None):
    """
    Create an image from a sequence
    """
    sequence,img_name, img_h, img_w = item
    img_w = len(sequence)
    assert len(sequence)<=img_w, f"!! Len of sequence({len(sequence)}) should be less than img_w({img_w}), "
    image = np.full((img_h, img_w,3), (500,500,500))
    for index in range(len(sequence)):
        image[:, index, :] = color_map[sequence[index]]
    pil_image = Image.fromarray(image.astype(np.uint8))
    assert pil_image.size == (img_w, img_h), f"{pil_image.size}!=({img_w},{img_h})"
    pil_image.save(img_name)



class Data:
    
    def __init__(self, classes) -> None:
        super().__init__()
        self.classes = classes
        self.data_dir = ProjectRoot/'data'
        self.images_dir = ProjectRoot/'data/PfamData/protein_seq_images'
        if not self.images_dir.exists():
            self.images_dir.mkdir(parents=True, exist_ok=True)
            
        self.color_map = {}
        index = 0
        for amino_acid in string.ascii_uppercase:
            self.color_map[amino_acid] = (index+10, index+10, index+10)
            index = index+10
        
    def create_protein_domain_data(self):
        records = []
        domain_data_records=[]
        seq_data_records = []
        for class_name in self.classes:
            cls_img_dir = self.images_dir/f"{class_name}"
            cls_img_dir.mkdir(exist_ok=True, parents=True)
            super_class, class_id = class_name.split('-')
            full_seq_data = self.data_dir/f'PfamData/{super_class}___full_sequence_data/{class_name}___full_sequence_data.fasta'
            dom_data = self.data_dir/f'PfamData/{super_class}___full_sequence_data/{class_name}___domain_data.fasta'
            # parse sequences of all classes
            for record in SeqIO.parse(full_seq_data, 'fasta'):
                seq_data_records.append({'Sequence': record.seq._data,
                                        'name': record.name,
                                        'id': record.id,
                                        'Class':class_id,
                                        'SeqLen':len(record.seq._data),
                                        'SuperClass':super_class}
                                      )
            
            # parse domains of all classes
            for record in SeqIO.parse(dom_data, 'fasta'):
                domain_data_records.append({'id':record.id.split('/')[0],
                                            'dom':record.seq._data,
                                            'dom_pos':tuple([int(pos)-1 for  pos in record.id.split('/')[-1].split('-')]),
                                            'dom_len':len(record.seq._data)
                                            })
        seq_data_df = pd.DataFrame(data=seq_data_records)
        domain_data_df = pd.DataFrame(data=domain_data_records)
        all_data = pd.merge(seq_data_df, domain_data_df,how='inner',on='id')
        all_data.drop_duplicates(inplace=True)
        
        for index, sequence in enumerate(all_data['Sequence'].unique()):
            sequence_df = all_data[all_data['Sequence']==sequence]
            
            records.append({'Sequence':sequence,
                            'Class':'||'.join(sequence_df['Class']),
                            'SuperClass':'||'.join(sequence_df['SuperClass']),
                            'name': '||'.join(sequence_df['name']),
                            'SeqLen':sequence_df['SeqLen'].tolist()[0],
                            'dom':sequence_df['dom'].tolist(),
                            'dom_pos':sequence_df['dom_pos'].tolist(),
                            'dom_len':sequence_df['dom_len'].tolist(),
                            'img_pth':self.images_dir/f"img_{index}.png",
                            })
        return pd.DataFrame(data=records)
    
    def create_protein_seq_images(self, data_df, img_h, img_w):
        unique_classes = []
        for cls_name in data_df['Class'].unique():
            cls_list = cls_name.split('||')
            if len(cls_list)>1:
                unique_classes.extend(cls_list)
            else:
                if type(cls_list)==list:
                    unique_classes.extend(cls_list)
                else:
                    unique_classes.append(cls_list)
        unique_classes = list(set(unique_classes))
        print(f"Generating images of dim {img_h}x{img_w} for classes: {unique_classes}")
        for class_name in self.classes:
            if not (self.images_dir/f'{class_name}').exists():
                (self.images_dir/f'{class_name}').mkdir(parents=True, exist_ok=True)
        
        partial_protein_seq2image = partial(protein_seq2image,color_map=self.color_map)
        items = [(sequence, img_name, img_h, img_w ) for sequence, img_name in zip(data_df['Sequence'],data_df['img_pth'])]
        with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as p:
            p.map(partial_protein_seq2image, items)
            
    def create_protein_seq_len_images(self, data_df, img_h, img_w):
        unique_classes = []
        for cls_name in data_df['Class'].unique():
            cls_list = cls_name.split('||')
            if len(cls_list)>1:
                unique_classes.extend(cls_list)
            else:
                unique_classes.append(cls_list)
        #unique_classes = list(set(unique_classes))
        #print(f"Generating images of dim {img_h}x{img_w} for classes: {unique_classes}")
        for class_name in self.classes:
            if not (self.images_dir/f'{class_name}').exists():
                (self.images_dir/f'{class_name}').mkdir(parents=True, exist_ok=True)
        
        partial_protein_seq2image_seqlen = partial(protein_seq2image_seqlen,color_map=self.color_map)
        items = [(sequence, img_name, img_h, img_w ) for sequence, img_name in zip(data_df['Sequence'],data_df['img_pth'])]
        with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as p:
            p.map(partial_protein_seq2image_seqlen, items)
            
class Detectron:
    
    def __init__(self, classes=None, model_dir=None, img_h=None, img_w=None) -> None:
        super().__init__()
        self.classes = classes
        self.img_h = img_h
        self.img_w = img_w
        self.model_dir = model_dir
        os.system(f"! trash-put {str(self.model_dir)}")
        self.model_dir.mkdir(parents=True, exist_ok=True)
        self.cfg = get_cfg()
        self.cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/retinanet_R_50_FPN_3x.yaml"))
        self.cfg.DATASETS.TRAIN = ("train",)
        self.cfg.DATASETS.TEST = ("valid",)
        self.cfg.INPUT.RANDOM_FLIP = "vertical"
        self.cfg.TEST.DETECTIONS_PER_IMAGE = 100

        self.cfg.INPUT.MIN_SIZE_TRAIN = 800#64
        self.cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING = "choice"
        self.cfg.INPUT.MAX_SIZE_TRAIN = 1330#300
        self.cfg.INPUT.MIN_SIZE_TEST = 800#64
        self.cfg.INPUT.MAX_SIZE_TEST = 1330#300

        self.cfg.TEST.AUG.FLIP = False
        self.cfg.DATALOADER.NUM_WORKERS = 8
        self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/retinanet_R_50_FPN_3x.yaml")  
        self.cfg.SOLVER.IMS_PER_BATCH = 8
        self.cfg.SOLVER.BASE_LR = 3e-3  
        self.cfg.SOLVER.LR_SCHEDULER_NAME = "WarmupCosineLR"
        # self.cfg.SOLVER.MAX_ITER = 3000
        print(f"Number of classes model is seeing: {len(classes)}")
        self.cfg.MODEL.RETINANET.NUM_CLASSES = len(classes)
        self.cfg.MODEL.BACKBONE.FREEZE_AT=3

        # exp
#         self.cfg.MODEL.RESNETS.NORM = "BN"
        # self.cfg.MODEL.RETINANET.FOCAL_LOSS_GAMMA =0.5 #didn't work
#         self.cfg.MODEL.RETINANET.IOU_THRESHOLDS = [0.4, 0.99]
        
        self.cfg.OUTPUT_DIR =str(self.model_dir)
        os.makedirs(self.cfg.OUTPUT_DIR, exist_ok=True)
        
    def get_value_counts(self, df, col, sep):
        values_container = []
        for value in df[col]:
            value_list = value.split(sep)
            values_container.extend(value_list)
        return Counter(values_container)
            
        
    @staticmethod
    def add_gaussian_noise(img_pth,dim, index)->str:
        seq = iaa.Sequential([iaa.SaltAndPepper(0.05)])
        img_arrs = np.zeros((2,dim[0], dim[1], 3))
        # print(f"img_arrsshape: {img_arrs.shape} ")
        # print(f"img_pth:{img_pth}")
        # print(f"img shape:{np.array(Image.open(img_pth)).shape}")
        img_arrs[0,:,:,:] = np.array(Image.open(img_pth))
        images_aug = seq(images=img_arrs)[0]
        aug_img_pth = img_pth.parent/f'{img_pth.stem}_gaussian_noise_{index}.png'
        Image.fromarray(images_aug.astype(np.uint8)).save(aug_img_pth)
        return str(aug_img_pth)
    
    
    @staticmethod
    def add_cutout(img_pth, dim, index)->str:
        seq = iaa.Sequential([iaa.Cutout(nb_iterations=(10, 20), size=0.05, squared=False)])
        img_arrs = np.zeros((2,dim[0], dim[1], 3))
        img_arrs[0,:,:,:] = np.array(Image.open(img_pth))
        images_aug = seq(images=img_arrs)[0]
        aug_img_pth = img_pth.parent/f'{img_pth.stem}_cutout_{index}.png'
        Image.fromarray(images_aug.astype(np.uint8)).save(aug_img_pth)
        return str(aug_img_pth)
    
    def augment_data(self,data_df):
        class_freq = dict(self.get_value_counts(data_df, 'Class', '||'))
        max_value = max(class_freq.values())
        dim = (img_h, img_w)
        print('Augmenting train data.................')
        print(f'Max samples: {max_value}')
        classes = [cls_name.split('-')[1] for cls_name in self.classes]
        new_rows = []
        for cls_name in classes:
            class_df = data_df[data_df['Class'].str.contains(cls_name)]
            all_rows = [row for _, row in class_df.iterrows()]
            num_cls_samples = class_freq[cls_name]
            num_augs = max_value - num_cls_samples
            print(f"Creating {num_augs} augs for {cls_name}")
            for index in range(int(round(num_augs/2))):
                row = random.choice(all_rows)
                img_pth = Path(row['img_pth'])
                gaussian_noise = row.copy()
                gaussian_noise['img_pth'] = self.add_gaussian_noise(img_pth,dim, index)
                cutout = row.copy()
                cutout['img_pth'] = self.add_cutout(img_pth,dim, index)
                new_rows.extend([cutout, gaussian_noise])
#         print(data_df.columns)
#         print(new_rows[:10])
        aug_data = pd.DataFrame(data=new_rows)
#         print(data_df.head())
#         print(aug_data.head())
        data_df = pd.concat([data_df, ],axis='rows')
        data_df.reset_index(drop=True, inplace=True)
        return data_df.sample(frac=1)
    def create_train_valid_data(self, data):
        train_dfs,valid_dfs = [],[],
        for class_name in self.classes:
            class_id = class_name.split('-')[-1]
            class_df = data[data['Class'].str.contains(class_id)].sample(frac=1)
            num_samples = class_df.shape[0]
            num_train_samples = int(round(num_samples*0.7))
            train_dfs.append(class_df.iloc[:num_train_samples,:])
            valid_dfs.append(class_df.iloc[num_train_samples:,:])
        train_data = pd.concat(train_dfs,axis='rows').sample(frac=1)
#         train_data = self.augment_data(train_data)### augment train data!!!!!!!!!!!!!!!!!!
        valid_data = pd.concat(valid_dfs,axis='rows').sample(frac=1)
        print(f'Train data dist:\n')
        print(self.get_value_counts(train_data,'Class','||'))
        print(f'Valid data dist:\n')
        print(self.get_value_counts(valid_data,'Class','||'))
        #assert len(set(list(train_data['id'])).intersection(set(list(valid_data['id']))))==0, f"There is a data leak between train and valid"
        return train_data, valid_data
        
    def register_custom_data(self, data, mode, img_h, img_w):
        print(f"Registring {mode} data.......")
        print(f"Classes selected: {self.classes}")
        # print(f"SuperClasses selected: {data['SuperClass'].unique()}")
        data = data.reset_index(drop=True)
        self.C2I = {class_name:index for index, class_name in enumerate(self.classes)}
        dicts_list = []
        data = data.reset_index(drop=True)
        for index in tqdm.tqdm(range(data.shape[0])):
            class_list = data['Class'][index].split('||')
            super_class_list = data['SuperClass'][index].split('||')
            pil_img = Image.open(data['img_pth'][index])
            img_w, img_h = pil_img.size
            dom_pos_list = data['dom_pos'][index]
            if len(dom_pos_list)>1:
                annts = []
                for dom_index in range(len(dom_pos_list)):
                    x1,x2 = dom_pos_list[dom_index]
                    annts.append({'bbox':[x1, 0, x2, img_h],
                                            'bbox_mode':BoxMode.XYXY_ABS,
                                            'category_id':  self.C2I[f"{super_class_list[dom_index]}-{class_list[dom_index]}"],
                                            })
            elif len(dom_pos_list)==1:
                x1,x2 = dom_pos_list[0]
                annts = [{'bbox':[x1, 0, x2, img_h],
                                            'bbox_mode':BoxMode.XYXY_ABS,
                                            'category_id':  self.C2I[f"{super_class_list[0]}-{class_list[0]}"],
                                            }]
                
                    
            dicts_list.append({'file_name':data['img_pth'][index],
                            'height':img_h,
                            'width': img_w,
                            'image_id': index,
                            'annotations':annts,
                            })
        def get_data():
            return dicts_list
        DatasetCatalog.register(mode, get_data)
        MetadataCatalog.get(mode).set(thing_classes = self.classes)
        
    def train(self):
        trainer = DefaultTrainer(self.cfg) 
        trainer.resume_or_load(resume=False)
        trainer.train() 
        return trainer

    def evaluate(self, trainer):
        # evaluate model
        evaluator = COCOEvaluator("valid", ("bbox",), False, output_dir=self.cfg.OUTPUT_DIR )
        val_loader = build_detection_test_loader(self.cfg, "valid")
        print(trainer.test(self.cfg, trainer.model, evaluator))
        
        
    def inference(self):
        self.cfg.MODEL.WEIGHTS = os.path.join(self.cfg.OUTPUT_DIR, "model_final.pth")  # path to the model we just trained
        self.cfg.MODEL.RETINANET.SCORE_THRESH_TEST =  0.5  # set a custom testing threshold
        predictor = DefaultPredictor(self.cfg)
        self.I2C = {value:key for key, value in self.C2I.items()}
        theta=0.9
        def calculate_recall(theta):
            print(f"Calculating recall @ {theta}")
            class_recall_map = {class_name:{'num_bboxes':0, 'pred_bboxes':0} for class_name in self.classes}
            num_bboxes = 0
            pred_bboxes = 0
            for d in tqdm.tqdm(DatasetCatalog.get("valid")):
                im = cv2.imread(str(d["file_name"]))
                for antn_index in range(len(d['annotations'])):
                    class_recall_map[self.I2C[d['annotations'][antn_index]['category_id']]]['num_bboxes']+=1
                
                outputs = predictor(im)

                for index,score in enumerate(outputs['instances'].get_fields()['scores']):
                    if score.cpu().numpy().item() >= theta :
                        for anntn in d['annotations']:
                            if anntn['category_id']==outputs['instances'].get_fields()['pred_classes'][index].cpu().numpy().item():
                                class_recall_map[self.I2C[anntn['category_id']]]['pred_bboxes']+=1

            print(f"Recall: @{theta}: \n")
            for cls_name, value in class_recall_map.items():
                print(f" {cls_name}: {value['pred_bboxes']/value['num_bboxes']}, {value['pred_bboxes'], value['num_bboxes']}")
        try:
            calculate_recall(0.5)
            calculate_recall(0.8)
            calculate_recall(0.9)
            calculate_recall(0.99)
        except:
            pass
        # for d in random.sample(dataset_dicts, 3):    
        #     im = cv2.imread(d["file_name"])
        #     outputs = predictor(im)  
        #     v = Visualizer(im[:, :, ::-1],
        #                 metadata=MetadataCatalog.get("train").set(thing_classes = self.classes), 
        #                 scale=1.5, 
        #                 instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels. This option is only available for segmentation models
        #     )
        #     out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
        #     cv2_imshow(out.get_image()[:, :, ::-1])
        return predictor
            
    def test_for_open_set( self, class_df, classes, img_h, img_w, predictor):                                              
        data_handler = Data(classes)                                                                                       
        protein_domain_data = data_handler.create_protein_domain_data()                                                    
        num_rows = protein_domain_data.shape[0]                                                                            
        protein_domain_data = protein_domain_data[~protein_domain_data['Sequence'].isin(class_df['Sequence'])]             
        print(f"Dropping {abs(num_rows - protein_domain_data.shape[0])} sequences which are common with {self.classes}'s sequences")
        protein_domain_data = protein_domain_data[protein_domain_data['SeqLen']<img_w]                        
#         data_handler.create_protein_seq_images(protein_domain_data, img_h, img_w) 
        data_handler.create_protein_seq_len_images(protein_domain_data, img_h, img_w)                           
        all_imgs = protein_domain_data['img_pth'].tolist()                                                    
        print(f"Using {len(all_imgs)} images from classes {classes} for open set recognition test...........")
        count_0_99 = 0                                                                             
        count_0_9 = 0                                                                              
        count_0_8 = 0                                                                                                                                
        count_0_7 = 0                                                                                                                                
        for img_pth in tqdm.tqdm(all_imgs):                                                                                                          
            im = cv2.imread(str(img_pth))                                                                                                            
            outputs = predictor(im)                                                                                                                  
            # v = Visualizer(im[:, :, ::-1],                                                                                                         
            #             metadata=MetadataCatalog.get("valid").set(thing_classes = self.classes),                                                   
            #             scale=0.5,                                                                                                                 
            #             instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels. This option is only available for segmentation models                                                                                                                                     
            # )                                                                                                                              
            # out = v.draw_instance_predictions(outputs["instances"].to("cpu"))                                                              
            # cv2_imshow(out.get_image()[:, :, ::-1])                                                                                        
            try:                                                                                                                             
                instance = outputs['instances']                                                                                              
                if max(instance.get_fields()['scores']).cpu().numpy() >= 0.99:                                                               
                    count_0_99+=1                                                                                                            
                elif max(instance.get_fields()['scores']).cpu().numpy() >= 0.9 and max(instance.get_fields()['scores']).cpu().numpy() < 0.99:
                    count_0_9+=1                                                                                                            
                elif max(instance.get_fields()['scores']).cpu().numpy() >= 0.8 and max(instance.get_fields()['scores']).cpu().numpy() < 0.9:
                    count_0_8+=1                                                                                                            
                elif max(instance.get_fields()['scores']).cpu().numpy() >= 0.7 and max(instance.get_fields()['scores'].cpu().numpy()) < 0.8:
                    count_0_7+=1                                                                                        
            except Exception as e:                                                                                      
                #print(max(instance.get_fields()['scores']).cpu().numpy())                                              
                #print(e)                                                                                               
                pass                                                                                                    
                                                                                                                        
        print(f"Number of predictions with score >0.99 {count_0_99} out of {len(all_imgs)}: {count_0_99/len(all_imgs)}")
        print(f"Number of predictions with score >0.9 {count_0_9} out of {len(all_imgs)}: {count_0_9/len(all_imgs)}")
        print(f"Number of predictions with score >0.8 {count_0_8} out of {len(all_imgs)}: {count_0_8/len(all_imgs)}")
        print(f"Number of predictions with score >0.7 {count_0_7} out of {len(all_imgs)}: {count_0_7/len(all_imgs)}")
      
      
# if __name__ == "__main__":
all_classes = ['Amidase_2-PF01510',
'Amidase_3-PF01520',]
classes=['CHAP-PF05257']#, 'SH3_3-PF08239', 'SH3_4-PF06347']# 'SH3_4-PF06347','SH3_5-PF08460']
        # 'SH3_3-PF08239',
        # 'SH3_5-PF08460',
        # 'LysM-PF01476']
img_h, img_w = 64, 300 # padding is also controlled by this img_w
seq_len =300   
seq_buckets = (0, img_w) 

data_block = Data(classes)
protein_data = data_block.create_protein_domain_data()
protein_data = protein_data[protein_data['SeqLen']<seq_len]

# Select majority domains from each class based on dom len dist
# CHAP>>60-120 SH3_4>>50-63 SH3_3>>45-75
chap_indexs = []
chap_data = protein_data[protein_data['Class'].str.contains('PF05257')]
for index, row in chap_data.iterrows():
    flag=False
    for dom_len in row['dom_len']:
        if dom_len >= 70 and dom_len <=102:
            flag=True
        else:
            flag=False
    if flag:
        chap_indexs.append(index)
        
# sh3_4_indexs = []
# sh3_4_data = protein_data[protein_data['Class'].str.contains('PF06347')]
# for index, row in sh3_4_data.iterrows():
#     flag=False
#     for dom_len in row['dom_len']:
#         if dom_len >=52 and dom_len <=58:
#             flag=True
#         else:
#             flag=False
#     if flag:
#         sh3_4_indexs.append(index)
            
# sh3_3_indexs = []
# sh3_3_data = protein_data[protein_data['Class'].str.contains('PF08239')]
# for index, row in sh3_3_data.iterrows():
#     flag=False
#     for dom_len in row['dom_len']:
#         if dom_len >=50 and dom_len <=60:
#             flag=True
#         else:
#             flag=False
#     if flag:
#         sh3_3_indexs.append(index)

num_rows = protein_data.shape[0]
protein_data = protein_data[protein_data.index.isin(chap_indexs)]#+sh3_3_indexs+sh3_4_indexs)]
print(f"Number of records dropeed after selecting major domains based on len: {num_rows - protein_data.shape[0]}")
            
# data_block.create_protein_seq_images(protein_data, img_h, img_w)
data_block.create_protein_seq_len_images(protein_data, img_h, img_w)

model_block = Detectron(classes=classes, model_dir=ProjectRoot/f"models/{'_'.join(classes)}_{seq_buckets[0]}_{seq_buckets[1]}",img_h=img_h,img_w=img_w)
model_block.cfg.SOLVER.MAX_ITER = 30
train_data, valid_data = model_block.create_train_valid_data(protein_data)
model_block.register_custom_data(train_data, 'train', img_h, img_w)
model_block.register_custom_data(valid_data, 'valid', img_h, img_w)
trainer = model_block.train()
model_block.evaluate(trainer)
predictor = model_block.inference()
#class_index = all_classes.index(classes[0])
#del all_classes[class_index]
model_block.test_for_open_set(protein_data, all_classes, img_h, img_w, predictor)
    


  warn("IPython.utils.traitlets has moved to a top-level traitlets package.")


ProjectRoot: /home/satish27may/ProteinDomainDetection
Number of records dropeed after selecting major domains based on len: 106


Loading config /home/satish27may/anaconda3/envs/detectron2/lib/python3.7/site-packages/detectron2/model_zoo/configs/COCO-Detection/../Base-RetinaNet.yaml with yaml.unsafe_load. Your machine may be at risk if the file contains malicious content.
 40%|████      | 468/1163 [00:00<00:00, 4676.05it/s]

Number of classes model is seeing: 1
Train data dist:

Counter({'PF05257': 1190})
Valid data dist:

Counter({'PF05257': 507})
Registring train data.......
Classes selected: ['CHAP-PF05257']


100%|██████████| 1163/1163 [00:00<00:00, 4732.71it/s]
100%|██████████| 499/499 [00:00<00:00, 5084.63it/s]


Registring valid data.......
Classes selected: ['CHAP-PF05257']
[32m[12/11 14:15:32 d2.engine.defaults]: [0mModel:
RetinaNet(
  (backbone): FPN(
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelP6P7(
      (p6): Conv2d(2048, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (p7): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): Fro

Skip loading parameter 'head.cls_score.weight' to the model due to incompatible shapes: (720, 256, 3, 3) in the checkpoint but (9, 256, 3, 3) in the model! You might want to double check if this is expected.
Skip loading parameter 'head.cls_score.bias' to the model due to incompatible shapes: (720,) in the checkpoint but (9,) in the model! You might want to double check if this is expected.


[32m[12/11 14:15:34 d2.engine.train_loop]: [0mStarting training from iteration 0


  tensor = torch.from_numpy(np.ascontiguousarray(img))
  tensor = torch.from_numpy(np.ascontiguousarray(img))
  tensor = torch.from_numpy(np.ascontiguousarray(img))
  tensor = torch.from_numpy(np.ascontiguousarray(img))
  tensor = torch.from_numpy(np.ascontiguousarray(img))
  tensor = torch.from_numpy(np.ascontiguousarray(img))
  tensor = torch.from_numpy(np.ascontiguousarray(img))
  tensor = torch.from_numpy(np.ascontiguousarray(img))


[32m[12/11 14:15:44 d2.utils.events]: [0m eta: 0:00:04  iter: 19  total_loss: 1.71  loss_cls: 1.278  loss_box_reg: 0.4319  time: 0.4518  data_time: 0.0432  lr: 1.7781e-05  max_mem: 5636M
[32m[12/11 14:15:49 d2.utils.events]: [0m eta: 0:00:00  iter: 29  total_loss: 1.305  loss_cls: 0.9851  loss_box_reg: 0.3299  time: 0.4425  data_time: 0.0164  lr: 2.4628e-07  max_mem: 5636M
[32m[12/11 14:15:49 d2.engine.hooks]: [0mOverall training speed: 28 iterations in 0:00:12 (0.4425 s / it)
[32m[12/11 14:15:49 d2.engine.hooks]: [0mTotal training time: 0:00:13 (0:00:00 on hooks)
[32m[12/11 14:15:49 d2.data.build]: [0mDistribution of instances among all 1 categories:
[36m|   category   | #instances   |
|:------------:|:-------------|
| CHAP-PF05257 | 507          |
|              |              |[0m
[32m[12/11 14:15:49 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in inference: [ResizeShortestEdge(short_edge_length=(800, 800), max_size=1330, sample_style='choice')]
[32m

  1%|          | 4/499 [00:00<00:15, 32.41it/s]

Calculating recall @ 0.5


100%|██████████| 499/499 [00:16<00:00, 30.10it/s]
  1%|          | 4/499 [00:00<00:14, 33.08it/s]

Recall: @0.5: 

 CHAP-PF05257: 0.0, (0, 507)
Calculating recall @ 0.8


100%|██████████| 499/499 [00:16<00:00, 30.05it/s]
  1%|          | 4/499 [00:00<00:15, 32.64it/s]

Recall: @0.8: 

 CHAP-PF05257: 0.0, (0, 507)
Calculating recall @ 0.9


100%|██████████| 499/499 [00:16<00:00, 30.29it/s]
  1%|          | 4/499 [00:00<00:14, 33.14it/s]

Recall: @0.9: 

 CHAP-PF05257: 0.0, (0, 507)
Calculating recall @ 0.99


100%|██████████| 499/499 [00:16<00:00, 30.30it/s]


Recall: @0.99: 

 CHAP-PF05257: 0.0, (0, 507)
Dropping 0 sequences which are common with ['CHAP-PF05257']'s sequences


  0%|          | 3/11364 [00:00<07:00, 27.03it/s]

Using 11364 images from classes ['Amidase_2-PF01510', 'Amidase_3-PF01520'] for open set recognition test...........


100%|██████████| 11364/11364 [06:07<00:00, 30.91it/s]

Number of predictions with score >0.99 0 out of 11364: 0.0
Number of predictions with score >0.9 0 out of 11364: 0.0
Number of predictions with score >0.8 0 out of 11364: 0.0
Number of predictions with score >0.7 0 out of 11364: 0.0





# Less than 10k training

In [None]:
# from functools import partial
import multiprocessing
import os
import random
import string
import sys
from pathlib import Path

from Bio import SeqIO
from functools import partial
from PIL import Image
import numpy as np
from imgaug import augmenters as iaa
import pandas as pd
import cv2
from ast import literal_eval
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.engine import DefaultTrainer
from detectron2.evaluation import COCOEvaluator
from detectron2.data import build_detection_test_loader
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.structures import BoxMode
from detectron2.utils.visualizer import Visualizer
from google.colab.patches import cv2_imshow
import tqdm
from collections import Counter

# ProjectRoot = Path(__file__).resolve().parent.parent.parent
ProjectRoot = Path('/home/satish27may/ProteinDomainDetection')
print(f"ProjectRoot: {str(ProjectRoot)}")

sys.path.append(str(ProjectRoot))

def protein_seq2image(item=None, color_map=None):
    """
    Create an image from a sequence
    """
    sequence,img_name, img_h, img_w = item
    assert len(sequence)<=img_w, f"!! Len of sequence({len(sequence)}) should be less than img_w({img_w}), "
    image = np.full((img_h, img_w,3), (500,500,500))
    for index in range(len(sequence)):
        image[:, index, :] = color_map[sequence[index]]
    pil_image = Image.fromarray(image.astype(np.uint8))
    assert pil_image.size == (img_w, img_h), f"{pil_image.size}!=({img_w},{img_h})"
    pil_image.save(img_name)
    
def protein_seq2image_seqlen(item=None, color_map=None):
    """
    Create an image from a sequence
    """
    sequence,img_name, img_h, img_w = item
    img_w = len(sequence)
    assert len(sequence)<=img_w, f"!! Len of sequence({len(sequence)}) should be less than img_w({img_w}), "
    image = np.full((img_h, img_w,3), (500,500,500))
    for index in range(len(sequence)):
        image[:, index, :] = color_map[sequence[index]]
    pil_image = Image.fromarray(image.astype(np.uint8))
    assert pil_image.size == (img_w, img_h), f"{pil_image.size}!=({img_w},{img_h})"
    pil_image.save(img_name)



class Data:
    
    def __init__(self, classes) -> None:
        super().__init__()
        self.classes = classes
        self.data_dir = ProjectRoot/'data'
        self.images_dir = ProjectRoot/'data/PfamData/protein_seq_images'
        if not self.images_dir.exists():
            self.images_dir.mkdir(parents=True, exist_ok=True)
            
        self.color_map = {}
        index = 0
        for amino_acid in string.ascii_uppercase:
            self.color_map[amino_acid] = (index+10, index+10, index+10)
            index = index+10
        
    def create_protein_domain_data(self):
        records = []
        domain_data_records=[]
        seq_data_records = []
        for class_name in tqdm.tqdm(self.classes):
            cls_img_dir = self.images_dir/f"{class_name}"
            cls_img_dir.mkdir(exist_ok=True, parents=True)
            super_class, class_id = class_name.split('-')
            full_seq_data = self.data_dir/f'PfamData/{super_class}___full_sequence_data/{class_name}___full_sequence_data.fasta'
            dom_data = self.data_dir/f'PfamData/{super_class}___full_sequence_data/{class_name}___domain_data.fasta'
            # parse sequences of all classes
            for record in SeqIO.parse(full_seq_data, 'fasta'):
                seq_data_records.append({'Sequence': record.seq._data,
                                        'name': record.name,
                                        'id': record.id,
                                        'Class':class_id,
                                        'SeqLen':len(record.seq._data),
                                        'SuperClass':super_class}
                                      )
            
            # parse domains of all classes
            for record in SeqIO.parse(dom_data, 'fasta'):
                domain_data_records.append({'id':record.id.split('/')[0],
                                            'dom':record.seq._data,
                                            'dom_pos':tuple([int(pos)-1 for  pos in record.id.split('/')[-1].split('-')]),
                                            'dom_len':len(record.seq._data)
                                            })
        seq_data_df = pd.DataFrame(data=seq_data_records)
        domain_data_df = pd.DataFrame(data=domain_data_records)
        all_data = pd.merge(seq_data_df, domain_data_df,how='inner',on='id')
        all_data.drop_duplicates(inplace=True)
        
        for index, sequence in tqdm.tqdm(enumerate(all_data['Sequence'].unique())):
            sequence_df = all_data[all_data['Sequence']==sequence]
            
            classes_in_record = '-'.join(list(set(sequence_df['Class'])))
            records.append({'Sequence':sequence,
                            'Class':'||'.join(sequence_df['Class']),
                            'SuperClass':'||'.join(sequence_df['SuperClass']),
                            'name': '||'.join(sequence_df['name']),
                            'SeqLen':sequence_df['SeqLen'].tolist()[0],
                            'dom':sequence_df['dom'].tolist(),
                            'dom_pos':sequence_df['dom_pos'].tolist(),
                            'dom_len':sequence_df['dom_len'].tolist(),
                            'img_pth':self.images_dir/f"img_{index}.png",
                            })
        return pd.DataFrame(data=records)
    
    def create_protein_seq_images(self, data_df, img_h, img_w):
        unique_classes = []
        for cls_name in data_df['Class'].unique():
            cls_list = cls_name.split('||')
            if len(cls_list)>1:
                unique_classes.extend(cls_list)
            else:
                if type(cls_list)==list:
                    unique_classes.extend(cls_list)
                else:
                    unique_classes.append(cls_list)
        unique_classes = list(set(unique_classes))
        print(f"Generating images of dim {img_h}x{img_w} for classes: {unique_classes}")
        for class_name in self.classes:
            if not (self.images_dir/f'{class_name}').exists():
                (self.images_dir/f'{class_name}').mkdir(parents=True, exist_ok=True)
        
        partial_protein_seq2image = partial(protein_seq2image,color_map=self.color_map)
        print('Creating items')
        items = [(sequence, img_name, img_h, img_w ) for sequence, img_name in zip(data_df['Sequence'],data_df['img_pth'])]
        print('Creating images with multi processing')
        with multiprocessing.Pool(processes=multiprocessing.cpu_count()-1) as p:
            p.map(partial_protein_seq2image, items)
#         for item in tqdm.tqdm(items):
#             partial_protein_seq2image(item)
            
    def create_protein_seq_len_images(self, data_df, img_h, img_w):
        unique_classes = []
        for cls_name in data_df['Class'].unique():
            cls_list = cls_name.split('||')
            if len(cls_list)>1:
                unique_classes.extend(cls_list)
            else:
                unique_classes.append(cls_list)
        #unique_classes = list(set(unique_classes))
        #print(f"Generating images of dim {img_h}x{img_w} for classes: {unique_classes}")
        for class_name in self.classes:
            if not (self.images_dir/f'{class_name}').exists():
                (self.images_dir/f'{class_name}').mkdir(parents=True, exist_ok=True)
        
        partial_protein_seq2image_seqlen = partial(protein_seq2image_seqlen,color_map=self.color_map)
        items = [(sequence, img_name, img_h, img_w ) for sequence, img_name in zip(data_df['Sequence'],data_df['img_pth'])]
        with multiprocessing.Pool(processes=12) as p:
            p.map(partial_protein_seq2image_seqlen, items)
            
    def filter_data(self, data_df, class_id, min_dom_len, max_dom_len):
        """
        Function to filter class data based on domain len filters
        """
        filtered_indexes = []
        class_df = data_df[data_df['Class'].str.contains(class_id)]
        for index, row in class_df.iterrows():
            flag=False
            for dom_len in row['dom_len']:
                if dom_len >= min_dom_len and  dom_len <=max_dom_len:
                    flag=True
                else:
                    flag=False
                    break
            if flag:
                filtered_indexes.append(index)
        return filtered_indexes
            
class Detectron:
    
    def __init__(self, classes=None, model_dir=None, img_h=None, img_w=None) -> None:
        super().__init__()
        self.classes = classes
        self.img_h = img_h
        self.img_w = img_w
        self.model_dir = model_dir
        os.system(f"! rm -r {str(self.model_dir)}")
        self.model_dir.mkdir(parents=True, exist_ok=True)
        self.cfg = get_cfg()
        self.cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/retinanet_R_50_FPN_3x.yaml"))
        self.cfg.DATASETS.TRAIN = ("train",)
        self.cfg.DATASETS.TEST = ("valid",)
        self.cfg.INPUT.RANDOM_FLIP = "vertical"
        self.cfg.TEST.DETECTIONS_PER_IMAGE = 100

        self.cfg.INPUT.MIN_SIZE_TRAIN = 800#64
        self.cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING = "choice"
        self.cfg.INPUT.MAX_SIZE_TRAIN = 1330#300
        self.cfg.INPUT.MIN_SIZE_TEST = 800#64
        self.cfg.INPUT.MAX_SIZE_TEST = 1330#300

        self.cfg.TEST.AUG.FLIP = False
        self.cfg.DATALOADER.NUM_WORKERS = 8
        self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/retinanet_R_50_FPN_3x.yaml")  
        self.cfg.SOLVER.IMS_PER_BATCH = 32
        self.cfg.SOLVER.BASE_LR = 3e-3  
        self.cfg.SOLVER.LR_SCHEDULER_NAME = "WarmupCosineLR"
        # self.cfg.SOLVER.MAX_ITER = 3000
        print(f"Number of classes model is seeing: {len(classes)}")
        self.cfg.MODEL.RETINANET.NUM_CLASSES = len(classes)
        self.cfg.MODEL.BACKBONE.FREEZE_AT=2

        # exp
#         self.cfg.MODEL.RESNETS.NORM = "BN"
        # self.cfg.MODEL.RETINANET.FOCAL_LOSS_GAMMA =0.5 #didn't work
#         self.cfg.MODEL.RETINANET.IOU_THRESHOLDS = [0.4, 0.99]
        
        self.cfg.OUTPUT_DIR =str(self.model_dir)
        os.makedirs(self.cfg.OUTPUT_DIR, exist_ok=True)
        
    def get_value_counts(self, df, col, sep):
        values_container = []
        for value in df[col]:
            value_list = value.split(sep)
            values_container.extend(value_list)
        return Counter(values_container)
            
        
    @staticmethod
    def add_gaussian_noise(img_pth,dim, index)->str:
        seq = iaa.Sequential([iaa.SaltAndPepper(0.05)])
        img_arrs = np.zeros((2,dim[0], dim[1], 3))
        # print(f"img_arrsshape: {img_arrs.shape} ")
        # print(f"img_pth:{img_pth}")
        # print(f"img shape:{np.array(Image.open(img_pth)).shape}")
        img_arrs[0,:,:,:] = np.array(Image.open(img_pth))
        images_aug = seq(images=img_arrs)[0]
        aug_img_pth = img_pth.parent/f'{img_pth.stem}_gaussian_noise_{index}.png'
        Image.fromarray(images_aug.astype(np.uint8)).save(aug_img_pth)
        return str(aug_img_pth)
    
    
    @staticmethod
    def add_cutout(img_pth, dim, index)->str:
        seq = iaa.Sequential([iaa.Cutout(nb_iterations=(10, 20), size=0.05, squared=False)])
        img_arrs = np.zeros((2,dim[0], dim[1], 3))
        img_arrs[0,:,:,:] = np.array(Image.open(img_pth))
        images_aug = seq(images=img_arrs)[0]
        aug_img_pth = img_pth.parent/f'{img_pth.stem}_cutout_{index}.png'
        Image.fromarray(images_aug.astype(np.uint8)).save(aug_img_pth)
        return str(aug_img_pth)
    
    def augment_data(self,data_df):
        class_freq = dict(self.get_value_counts(data_df, 'Class', '||'))
        max_value = max(class_freq.values())
        dim = (img_h, img_w)
        print('Augmenting train data.................')
        print(f'Max samples: {max_value}')
        classes = [cls_name.split('-')[1] for cls_name in self.classes]
        new_rows = []
        for cls_name in classes:
            class_df = data_df[data_df['Class'].str.contains(cls_name)]
            all_rows = [row for _, row in class_df.iterrows()]
            num_cls_samples = class_freq[cls_name]
            num_augs = max_value - num_cls_samples
            print(f"Creating {num_augs} augs for {cls_name}")
            for index in range(int(round(num_augs/2))):
                row = random.choice(all_rows)
                img_pth = Path(row['img_pth'])
                gaussian_noise = row.copy()
                gaussian_noise['img_pth'] = self.add_gaussian_noise(img_pth,dim, index)
                cutout = row.copy()
                cutout['img_pth'] = self.add_cutout(img_pth,dim, index)
                new_rows.extend([cutout, gaussian_noise])
#         print(data_df.columns)
#         print(new_rows[:10])
        aug_data = pd.DataFrame(data=new_rows)
#         print(data_df.head())
#         print(aug_data.head())
        data_df = pd.concat([data_df, aug_data],axis='rows')
        data_df.reset_index(drop=True, inplace=True)
        return data_df.sample(frac=1)
    def create_train_valid_data(self, data):
        print('Creaing train and valid data')
        train_dfs,valid_dfs = [],[],
        for class_name in tqdm.tqdm(self.classes):
            class_id = class_name.split('-')[-1]
            class_df = data[data['Class'].str.contains(class_id)].sample(frac=1)
            num_samples = class_df.shape[0]
            num_train_samples = int(round(num_samples*0.7))
            train_dfs.append(class_df.iloc[:num_train_samples,:])
            valid_dfs.append(class_df.iloc[num_train_samples:,:])
        train_data = pd.concat(train_dfs,axis='rows').sample(frac=1)
        train_data = self.augment_data(train_data)### augment train data!!!!!!!!!!!!!!!!!!
        valid_data = pd.concat(valid_dfs,axis='rows').sample(frac=1)
        print(f'Train data dist:\n')
        print(self.get_value_counts(train_data,'Class','||'))
        print(f'Valid data dist:\n')
        print(self.get_value_counts(valid_data,'Class','||'))
        #assert len(set(list(train_data['id'])).intersection(set(list(valid_data['id']))))==0, f"There is a data leak between train and valid"
        return train_data, valid_data
        
    def register_custom_data(self, data, mode, img_h, img_w):
        print(f"Registring {mode} data.......")
        print(f"Classes selected: {self.classes}")
        # print(f"SuperClasses selected: {data['SuperClass'].unique()}")
        data = data.reset_index(drop=True)
        self.C2I = {class_name:index for index, class_name in enumerate(self.classes)}
        dicts_list = []
        data = data.reset_index(drop=True)
        for index in tqdm.tqdm(range(data.shape[0])):
            class_list = data['Class'][index].split('||')
            super_class_list = data['SuperClass'][index].split('||')
            pil_img = Image.open(data['img_pth'][index])
            img_w, img_h = pil_img.size
            dom_pos_list = data['dom_pos'][index]
            if len(dom_pos_list)>1:
                annts = []
                for dom_index in range(len(dom_pos_list)):
                    x1,x2 = dom_pos_list[dom_index]
                    annts.append({'bbox':[x1, 0, x2, img_h],
                                            'bbox_mode':BoxMode.XYXY_ABS,
                                            'category_id':  self.C2I[f"{super_class_list[dom_index]}-{class_list[dom_index]}"],
                                            })
            elif len(dom_pos_list)==1:
                x1,x2 = dom_pos_list[0]
                annts = [{'bbox':[x1, 0, x2, img_h],
                                            'bbox_mode':BoxMode.XYXY_ABS,
                                            'category_id':  self.C2I[f"{super_class_list[0]}-{class_list[0]}"],
                                            }]
                
                    
            dicts_list.append({'file_name':data['img_pth'][index],
                            'height':img_h,
                            'width': img_w,
                            'image_id': index,
                            'annotations':annts,
                            })
        def get_data():
            return dicts_list
        DatasetCatalog.register(mode, get_data)
        MetadataCatalog.get(mode).set(thing_classes = self.classes)
        
    def train(self):
        trainer = DefaultTrainer(self.cfg) 
#         trainer.resume_or_load(resume=False)
        trainer.train() 
        return trainer

    def evaluate(self, trainer):
        # evaluate model
        evaluator = COCOEvaluator("valid", ("bbox",), False, output_dir=self.cfg.OUTPUT_DIR )
        val_loader = build_detection_test_loader(self.cfg, "valid")
        print(trainer.test(self.cfg, trainer.model, evaluator))
        
        
    def inference(self):
        self.cfg.MODEL.WEIGHTS = os.path.join(self.cfg.OUTPUT_DIR, "model_final.pth")  # path to the model we just trained
        self.cfg.MODEL.RETINANET.SCORE_THRESH_TEST =  0.5  # set a custom testing threshold
        predictor = DefaultPredictor(self.cfg)
        self.I2C = {value:key for key, value in self.C2I.items()}
        theta=0.9
        def calculate_recall(theta):
            print(f"Calculating recall @ {theta}")
            class_recall_map = {class_name:{'num_bboxes':0, 'pred_bboxes':0} for class_name in self.classes}
            num_bboxes = 0
            pred_bboxes = 0
            for d in tqdm.tqdm(DatasetCatalog.get("valid")):
                im = cv2.imread(str(d["file_name"]))
                for antn_index in range(len(d['annotations'])):
                    class_recall_map[self.I2C[d['annotations'][antn_index]['category_id']]]['num_bboxes']+=1
                
                outputs = predictor(im)

                for index,score in enumerate(outputs['instances'].get_fields()['scores']):
                    if score.cpu().numpy().item() >= theta :
                        for anntn in d['annotations']:
                            if anntn['category_id']==outputs['instances'].get_fields()['pred_classes'][index].cpu().numpy().item():
                                class_recall_map[self.I2C[anntn['category_id']]]['pred_bboxes']+=1

            print(f"Recall: @{theta}: \n")
            for cls_name, value in class_recall_map.items():
                print(f" {cls_name}: {value['pred_bboxes']/value['num_bboxes']}, {value['pred_bboxes'], value['num_bboxes']}")
        try:
            calculate_recall(0.5)
            calculate_recall(0.8)
            calculate_recall(0.9)
            calculate_recall(0.99)
        except:
            pass
        # for d in random.sample(dataset_dicts, 3):    
        #     im = cv2.imread(d["file_name"])
        #     outputs = predictor(im)  
        #     v = Visualizer(im[:, :, ::-1],
        #                 metadata=MetadataCatalog.get("train").set(thing_classes = self.classes), 
        #                 scale=1.5, 
        #                 instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels. This option is only available for segmentation models
        #     )
        #     out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
        #     cv2_imshow(out.get_image()[:, :, ::-1])
        return predictor
            
    def test_for_open_set( self, class_df, classes, img_h, img_w, predictor):                                              
        data_handler = Data(classes)                                                                                       
        protein_domain_data = data_handler.create_protein_domain_data()                                                    
        num_rows = protein_domain_data.shape[0]                                                                            
        protein_domain_data = protein_domain_data[~protein_domain_data['Sequence'].isin(class_df['Sequence'])]             
        print(f"Dropping {abs(num_rows - protein_domain_data.shape[0])} sequences which are common with {self.classes}'s sequences")
        protein_domain_data = protein_domain_data[protein_domain_data['SeqLen']<img_w]                        
        data_handler.create_protein_seq_images(protein_domain_data, img_h, img_w) 
#         data_handler.create_protein_seq_len_images(protein_domain_data, img_h, img_w)                           
        all_imgs = protein_domain_data['img_pth'].tolist()                                                    
        print(f"Using {len(all_imgs)} images from classes {classes} for open set recognition test...........")
        count_0_99 = 0                                                                             
        count_0_9 = 0                                                                              
        count_0_8 = 0                                                                                                                                
        count_0_7 = 0                                                                                                                                
        for img_pth in tqdm.tqdm(all_imgs):                                                                                                          
            im = cv2.imread(str(img_pth))                                                                                                            
            outputs = predictor(im)                                                                                                                  
            # v = Visualizer(im[:, :, ::-1],                                                                                                         
            #             metadata=MetadataCatalog.get("valid").set(thing_classes = self.classes),                                                   
            #             scale=0.5,                                                                                                                 
            #             instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels. This option is only available for segmentation models                                                                                                                                     
            # )                                                                                                                              
            # out = v.draw_instance_predictions(outputs["instances"].to("cpu"))                                                              
            # cv2_imshow(out.get_image()[:, :, ::-1])                                                                                        
            try:                                                                                                                             
                instance = outputs['instances']                                                                                              
                if max(instance.get_fields()['scores']).cpu().numpy() >= 0.99:                                                               
                    count_0_99+=1                                                                                                            
                elif max(instance.get_fields()['scores']).cpu().numpy() >= 0.9 and max(instance.get_fields()['scores']).cpu().numpy() < 0.99:
                    count_0_9+=1                                                                                                            
                elif max(instance.get_fields()['scores']).cpu().numpy() >= 0.8 and max(instance.get_fields()['scores']).cpu().numpy() < 0.9:
                    count_0_8+=1                                                                                                            
                elif max(instance.get_fields()['scores']).cpu().numpy() >= 0.7 and max(instance.get_fields()['scores'].cpu().numpy()) < 0.8:
                    count_0_7+=1                                                                                        
            except Exception as e:                                                                                      
                #print(max(instance.get_fields()['scores']).cpu().numpy())                                              
                #print(e)                                                                                               
                pass                                                                                                    
                                                                                                                        
        print(f"Number of predictions with score >0.99 {count_0_99} out of {len(all_imgs)}: {count_0_99/len(all_imgs)}")
        print(f"Number of predictions with score >0.9 {count_0_9} out of {len(all_imgs)}: {count_0_9/len(all_imgs)}")
        print(f"Number of predictions with score >0.8 {count_0_8} out of {len(all_imgs)}: {count_0_8/len(all_imgs)}")
        print(f"Number of predictions with score >0.7 {count_0_7} out of {len(all_imgs)}: {count_0_7/len(all_imgs)}")
      
      

        
        
# if __name__ == "__main__":
all_classes = ['Amidase_2-PF01510',
'Amidase_3-PF01520',]
classes=['SH3_3-PF08239','peptidase-PF01433','Lysozyme-PF01183','Lysozyme-PF05838','Lysozyme-PF01374','Lysozyme-PF11860',
    'Lysozyme-PF00182','CHAP-PF05257','Lysozyme-PF04965','peptidase-PF05193','Lysozyme-PF00959','SH3_4-PF06347',
    'Lysozyme-PF13702','Lysozyme-PF03245','Lysozyme-PF18013']
 
    
img_h, img_w = 64, 300 # padding is also controlled by this img_w
seq_len =300   
seq_buckets = (0, img_w) 

data_block = Data(classes)
protein_data = data_block.create_protein_domain_data()
protein_data.to_csv(ProjectRoot/'data/PfamData/less_than_10k_samples_all_seq_len_data.csv', index=False)
protein_data = pd.read_csv(ProjectRoot/'data/PfamData/less_than_10k_samples_all_seq_len_data.csv')
for col in ['dom', 'dom_len', 'dom_pos']:
    protein_data[col] = protein_data[col].apply(lambda x: literal_eval(x))
protein_data = protein_data[(protein_data['SeqLen']>=seq_buckets[0]) & (protein_data['SeqLen']<seq_buckets[1])]

# filter dom lens which are outliers in each class
filtered_indexes = []
cls_dom_len_fltr_mp = {'PF08239':(50,60),
                       'PF01433':(150,240),
                       'PF01183':(160,200),
                       'PF05838':(80,100),
                       'PF01374':(190,220),
                       'PF11860':(160,190),
                       'PF00182':(50,240),
                       'PF05257':(70,102),
                       'PF04965':(85,105),
                       'PF05193':(160,200),
                       'PF00959':(110,130),
                       'PF06347':(52,58),
                       'PF13702':(145,175),
                       'PF03245':(120,140),
                       'PF18013':(120,165)
                      }
for cls_nm in classes:
    cls_id = cls_nm.split('-')[-1]
    filtered_indexes.extend(data_block.filter_data(protein_data, cls_id, cls_dom_len_fltr_mp[cls_id][0], cls_dom_len_fltr_mp[cls_id][1]))

num_rows = protein_data.shape[0]
protein_data = protein_data[protein_data.index.isin(filtered_indexes)]
print(f"Dropped {num_rows-protein_data.shape[0]} sequences based on domain len filtering") 
    
# # Select majority domains from each class based on dom len dist
# # CHAP>>60-120 SH3_4>>50-63 SH3_3>>45-75
# # chap_indexs = []
# # chap_data = protein_data[protein_data['Class'].str.contains('PF05257')]
# # for index, row in chap_data.iterrows():
# #     flag=False
# #     for dom_len in row['dom_len']:
# #         if dom_len >= 70 and dom_len <=102:
# #             flag=True
# #         else:
# #             flag=False
# #     if flag:
# #         chap_indexs.append(index)
        
# # sh3_4_indexs = []
# # sh3_4_data = protein_data[protein_data['Class'].str.contains('PF06347')]
# # for index, row in sh3_4_data.iterrows():
# #     flag=False
# #     for dom_len in row['dom_len']:
# #         if dom_len >=52 and dom_len <=58:
# #             flag=True
# #         else:
# #             flag=False
# #     if flag:
# #         sh3_4_indexs.append(index)
            
# # sh3_3_indexs = []
# # sh3_3_data = protein_data[protein_data['Class'].str.contains('PF08239')]
# # for index, row in sh3_3_data.iterrows():
# #     flag=False
# #     for dom_len in row['dom_len']:
# #         if dom_len >=50 and dom_len <=60:
# #             flag=True
# #         else:
# #             flag=False
# #     if flag:
# #         sh3_3_indexs.append(index)

# # num_rows = protein_data.shape[0]
# # protein_data = protein_data[protein_data.index.isin(chap_indexs)]#+sh3_3_indexs+sh3_4_indexs)]
# # print(f"Number of records dropeed after selecting major domains based on len: {num_rows - protein_data.shape[0]}")
            
# data_block.create_protein_seq_images(protein_data, img_h, img_w)
# # data_block.create_protein_seq_len_images(protein_data, img_h, img_w)

# model_block = Detectron(classes=classes, model_dir=ProjectRoot/f"models/{'_'.join(classes)}_{seq_buckets[0]}_{seq_buckets[1]}",img_h=img_h,img_w=img_w)
# model_block.cfg.SOLVER.MAX_ITER = 3000
# train_data, valid_data = model_block.create_train_valid_data(protein_data)
# model_block.register_custom_data(train_data, 'train', img_h, img_w)
# model_block.register_custom_data(valid_data, 'valid', img_h, img_w)
# trainer = model_block.train()
# model_block.evaluate(trainer)
# predictor = model_block.inference()
# #class_index = all_classes.index(classes[0])
# #del all_classes[class_index]
# model_block.test_for_open_set(protein_data, all_classes, img_h, img_w, predictor)
    


  0%|          | 0/15 [00:00<?, ?it/s]

ProjectRoot: /home/satish27may/ProteinDomainDetection


100%|██████████| 15/15 [00:01<00:00,  7.88it/s]
31947it [03:04, 169.89it/s]

# Check class data

In [None]:
class

# multi GPU training


In [None]:
import logging
import os
from collections import OrderedDict
import torch

import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
from detectron2.evaluation import (
    CityscapesInstanceEvaluator,
    CityscapesSemSegEvaluator,
    COCOEvaluator,
    COCOPanopticEvaluator,
    DatasetEvaluators,
    LVISEvaluator,
    PascalVOCDetectionEvaluator,
    SemSegEvaluator,
    verify_results,
)
from detectron2.modeling import GeneralizedRCNNWithTTA


class Trainer(DefaultTrainer):
    """
    We use the "DefaultTrainer" which contains pre-defined default logic for
    standard training workflow. They may not work for you, especially if you
    are working on a new research project. In that case you can write your
    own training loop. You can use "tools/plain_train_net.py" as an example.
    """

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        """
        Create evaluator(s) for a given dataset.
        This uses the special metadata "evaluator_type" associated with each builtin dataset.
        For your own dataset, you can simply create an evaluator manually in your
        script and do not have to worry about the hacky if-else logic here.
        """
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        
        return COCOEvaluator(dataset_name, output_dir=output_folder)

    @classmethod
    def test_with_TTA(cls, cfg, model):
        logger = logging.getLogger("detectron2.trainer")
        # In the end of training, run an evaluation with TTA
        # Only support some R-CNN models.
        logger.info("Running inference with test-time augmentation ...")
        model = GeneralizedRCNNWithTTA(cfg, model)
        evaluators = [
            cls.build_evaluator(
                cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
            )
            for name in cfg.DATASETS.TEST
        ]
        res = cls.test(cfg, model, evaluators)
        res = OrderedDict({k + "_TTA": v for k, v in res.items()})
        return res


def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg


def main(args):
    cfg = setup(args)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        res = Trainer.test(cfg, model)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        if comm.is_main_process():
            verify_results(cfg, res)
        return res

    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop (see plain_train_net.py) or
    subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks(
            [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
        )
    trainer.train()
    evaluator = COCOEvaluator("valid", ("bbox",), False, output_dir=cfg.OUTPUT_DIR )
    val_loader = build_detection_test_loader(cfg, "valid")
    print(trainer.test(cfg, trainer.model, evaluator))


if __name__ == "__main__":
    args = default_argument_parser().parse_args()
    print("Command Line Args:", args)
    launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,),
    )