In [None]:
class OutputFits(FitsFile):

    def __init__(self, filepath, csv, names = []):
        
        FitsFile.__init__(self, filepath = filepath, wait = True)
        
        # by initializing FitsFile we already have observation
        if not names:
            names = ["model", "residual"]
            
        # Don't need these but for posterity
        # self.hdu_num  = 4
        # self.num_imgs = self.hdu_num - 1
        
        # Exclude observation and primary HDU (0)
        for num, name in zip(range(2, 4), names):
            hdu = HDU(name, self.file[num].header, self.file[num].data)
            self.all_hdu[name] = hdu
            
        # For convenience, we usually use the model here
            
        self.update_params(**self.all_hdu) 
        
        # Dict is very redundant here but just for funsies
        self.header = dict(self.model.header)
        _header = GalfitHeader()
        # Can call the helper directly since we're just using the header dict
        _header.from_file_helper(self.header)
        self.feedme = FeedmeContainer(path_to_feedme = filepath, header = _header)
        self.feedme.from_file(self.header)
        
        self.data = self.model.data
        self._residual_helper(csv)
        
        self.close()
        
        # self.observation = self.all_hdu.get("observation", None)
        # self.model       = self.all_hdu.get("model", None)
        # self.residual    = self.all_hdu.get("residual", None)
        
        
    def generate_masked_residual(self, mask):
        
        plt.imshow(self.model.data)
        plt.show()
        
        #get darkest pixel to no wash out the rest of the plot
        dark = self.model.data[0][0]
        
        for values in self.model.data:
            for pixel in values:
                if pixel < dark:
                    dark = pixel
                                    
        #Idea 1: the SpArcFiRe numbers are pixel locations
        #find bulge by scaling down the center coordinates
        row = self.bulge_center_row * len(self.model.data[0]) // self.input_size
        col = self.bulge_center_col * len(self.model.data[0]) // self.input_size
        rad = self.bulge_rad * len(self.model.data[0]) // self.input_size
        
        #print(row, col)
        
        start_row = int(row - rad)
        end_row = int(row + rad + 1)
        
        start_col = int(col - rad)
        end_col = int(col + rad + 1)
        
        #box out according to the radius
        plt.imshow(self.model.data[start_row:end_row, start_col:end_col])
        plt.show()
        
        #contrast in plot
        self.model.data[start_row:end_row, start_col:end_col] = dark        
        plt.imshow(self.model.data)
        plt.show()

        crop_box = self.feedme.header.region_to_fit

        # To adjust for python indexing
        box_min, box_max = crop_box[0] - 1, crop_box[1]

        # To invert the matrix since galfit keeps 0 valued areas 
        crop_mask = 1 - mask.data[box_min:box_max, box_min:box_max]
        
        try:
            self.masked_residual = (self.observation.data - self.model.data)*crop_mask

            # TODO: norm of observation, norm of model, take the min of the two and divide by that
            self.norm_observation = slg.norm(crop_mask*self.observation.data)
            self.norm_model = slg.norm(crop_mask*self.model.data)
            self.norm_residual = slg.norm(crop_mask*self.residual.data)
            self.masked_residual_normalized = self.masked_residual/min(self.norm_observation, self.norm_model)
            # Masked residual normalized
            # I seem to use this acronym a lot
            self.nmr = slg.norm(self.masked_residual_normalized)

        except ValueError:
            print(f"There is probably an observation error with galaxy {self.gname}, continuing...")
            # print(np.shape(mask_fits_file.data))
            # print(np.shape(fits_file.data))
            # print(crop_box)
            return None
        
        return self.masked_residual_normalized
    
    
    def _residual_helper(self, csv):
        
        try:
            info = pd.read_csv(csv)
            self.bulge_center_row = float(info[' inputCenterR'][0])
            self.bulge_center_col = float(info[' inputCenterC'][0])
            self.input_size = float(info[' iptSz'][0].split()[0][1:])
            self.bulge_rad = float(info[' bulgeMajAxsLen'][0])
            
            print(self.input_size)
        
        except (AttributeError, FileNotFoundError, Exception):
            print('enter catch')
            

            
            
if __name__ == "__main__":
    from RegTest.RegTest import *
    
    
    
    
# Testing from_file
if __name__ == "__main__":
    
    #TEST_DATA_DIR = '/home/azraz/run2_1000_galfit'
    gname = "1237671124296532233"
    obs   = pj(TEST_DATA_DIR, "sparcfire-in", f"{gname}.fits")
    model = pj(TEST_DATA_DIR, "sparcfire-out", gname, f"{gname}_out.fits")
    csv = pj(TEST_DATA_DIR, "sparcfire-out", gname, f"{gname}.csv")
    mask  = pj(TEST_DATA_DIR, "sparcfire-out", gname, f"{gname}_star-rm.fits")
    
    test_obs   = FitsFile(obs)
    test_model = OutputFits(model, csv)
    test_mask  = FitsFile(mask)
    
    print(test_obs.observation)
    print()
    print(test_model.feedme)
    print()
    print(test_model.model)
    
    # Purposefully do not fill in some of the header parameters
    # since those do not exist in the output FITS header
    # This is done to remind the user/programmer that the 
    # OutputFits object only serves to represent the header
    # nothing more, nothing less and so also reminds them to
    # use a different method to fill in the header.
    #print(test_model.feedme.header)
    
#     _header = GalfitHeader()
#     _header.from_file_helper(test_out.header)
    
#     crop_box = _header.region_to_fit
#     # To adjust for python indexing
#     box_min, box_max = crop_box[0] - 1, crop_box[1]
        
#     print(np.shape(test_in.data[box_min:box_max, box_min:box_max]))
    print("\nThese should all be the same .")
    print(np.shape(test_model.observation.data))
    print(np.shape(test_model.data))
    print(np.shape(test_model.residual.data))
    crop_box = test_model.feedme.header.region_to_fit
    # + 1 to account for python indexing
    crop_rad = crop_box[1] - crop_box[0] + 1
    print(f"({crop_rad}, {crop_rad})")
    print("Andddd pre crop")
    print(np.shape(test_obs.observation.data))