In [None]:
import pandas as pd
import numpy as np
import os
import cv2

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torch
import torchvision
import torch.nn.functional as F

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler

from matplotlib import pyplot as plt

In [None]:
# load the model and restore parameters
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained_backbone=False)

num_classes = 2  # 1 class (wheat) + background

# get number of input channels for the final linear classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# load the trained weights
trained_weights=torch.load('/kaggle/input/my-own-starter-fft-fasterrcnn-train/'+\
                           'fft_faster_rcnn_weights_w_metric_6_epochs.pth.tar',\
                            map_location=torch.device('cpu'))
model.load_state_dict(trained_weights['model_state_dict'])

In [None]:
device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model=model.to(device)
print(device)

def format_prediction_string(boxes, scores):
    pred_strings = []
    for s, b in zip(scores, boxes.astype(int)):
        pred_strings.append(f'{s:.4f} {b[0]} {b[1]} {b[2] - b[0]} {b[3] - b[1]}')

    return " ".join(pred_strings)

In [None]:
# boxes=np.array([])
# scores=np.array([])
# print(type(boxes))
# print(boxes.astype(int))

# pred_strings = []
# for s, b in zip(scores, boxes.astype(int)):
#     pred_strings.append(f'{s:.4f} {b[0]} {b[1]} {b[2] - b[0]} {b[3] - b[1]}')
# " ".join(pred_strings)

In [None]:
# Load and get fft masks
all_ffts=np.load('../input/wheat-mean-ffts-200-images/wheat_detection_mean_ffts_200_images.npz')
wheat_freq=all_ffts['wheat_freq']
im_freq=all_ffts['im_freq']
pad2=1024

thr_list=[0.2, 0.4, 0.6, 0.8]
n_thr=len(thr_list)
mask=np.zeros((pad2,pad2,3,n_thr))
for i in range(3):
    plot_wheat=np.log(wheat_freq[:,:,i])
    plot_im=np.log(im_freq[:,:,i])

#     print(np.min(plot_wheat))
#     print(np.min(plot_im))

    # if printed values all positive
    if np.min(plot_wheat)>0 and np.min(plot_im)>0:
        plot_wheat[0,:]=0
        plot_wheat[:,0]=0
        plot_im[0,:]=0
        plot_im[:,0]=0

    plot_wheat=plot_wheat/np.sum(np.abs(plot_wheat))
    plot_im=plot_im/np.sum(np.abs(plot_im))

    fft_diff=plot_wheat-plot_im
    fft_diff=np.fft.fftshift(fft_diff)

    for kt in range(n_thr):
        f_thr=thr_list[kt]
        mask[:,:,i,kt]=fft_diff>f_thr*1e-7
        exclude=120
        mask[:exclude,:,:,kt],mask[-exclude:,:,:,kt]=0,0
        mask[:,:exclude,:,kt],mask[:,-exclude:,:,kt]=0,0
        
        plt.figure()
        plt.imshow(mask[:,:,i,kt],vmin=-0.2e-7,vmax=2e-7)
        plt.colorbar()
        plt.pause(0.1)

        mask[:,:,i,kt]=np.fft.fftshift(mask[:,:,i,kt])

In [None]:
def mask_input(images,mask,thr_list,pad2):
    new_images=[]
    n_thr=len(thr_list)
    im_masked=np.zeros((3,pad2,pad2))
    for image in images:
        # get fft of full image in 3 color channels and mask
        for j in range(3):
            for kt in range(n_thr):
                im_masked[j,:,:]+=np.real(np.fft.ifft2(np.fft.fft2(image[j,:,:])*mask[:,:,j,kt]))
        # normalize to 0-1
        im_masked=im_masked-np.min(im_masked)
        im_masked=im_masked/np.max(im_masked)
        
        new_images.append(torch.from_numpy(im_masked).float())
    return new_images    

In [None]:
# Resizing input 'correctly' for avoiding syntax errors, but this will make the bboxes WRONG
# For FasterRCNN different threshold than MaskRCNN
detection_threshold = 0.80

model.eval()
results = []
for images in os.listdir("../input/global-wheat-detection/test/"):
    image_path = os.path.join("../input/global-wheat-detection/test/", images)

    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
    image = np.transpose(image, (2, 0, 1))
    image /= 255.0

    image = torch.tensor(image, dtype=torch.float)
    
    # Interpolate to 1024x1024 for masking
    reinterpolate=False
    if image.shape[1]!=pad2 or image.shape[2]!=pad2:
        re_h=image.shape[1]
        re_w=image.shape[2]
        image=F.interpolate(image[None,:,:,:], size=(pad2, pad2), mode='bilinear')
        image=torch.squeeze(image)
        reinterpolate=True
    
    # mask the input images
    image = mask_input([image],mask,thr_list,pad2)
    image = image[0].to(device)
    
    # Reinterpolate to original size for forward pass
    if reinterpolate:
        image=F.interpolate(image[None,:,:,:], size=(re_h, re_w), mode='bilinear')
    else:
        image=torch.unsqueeze(image, 0)

        
    with torch.no_grad():
        outputs = model(image)
    
    boxes = outputs[0]['boxes'].data.cpu().numpy()
    scores = outputs[0]['scores'].data.cpu().numpy()

    boxes = boxes[scores >= detection_threshold].astype(np.int32)
    scores = scores[scores >= detection_threshold]
    image_id = images[:-4]
    
    print(image_id)

    result = {
        'image_id': image_id,
        'PredictionString': format_prediction_string(boxes, scores)
    }

    results.append(result)

In [None]:
test_df = pd.DataFrame(results, columns=['image_id', 'PredictionString'])
test_df.head()

test_df.to_csv('submission.csv', index=False)