In [None]:
import torch
import glob
import json
import os
import cv2
import pandas as pd
from PIL import Image
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torchvision.transforms as tt
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

%matplotlib inline
matplotlib.rcParams['font.size'] = 14
matplotlib.rcParams['figure.figsize'] = (25, 25)

In [None]:
train_df = pd.read_csv('../input/global-wheat-detection/train.csv')
sub_df = pd.read_csv('../input/global-wheat-detection/sample_submission.csv')

bbox = []
for i in train_df['bbox']:
    bbox.append(json.loads(i))
train_df = train_df.drop(['bbox'], axis=1)
train_df['bbox'] = bbox
train_df

In [None]:
'''
bbox
image_id
labels
area
iscrowd
'''

class WheatDataset(Dataset):
    def __init__(self, root, dataframe, transforms=None):
        self.root = root
        self.dataframe = dataframe
        self.transforms = transforms
        self.images = (root + '/' + dataframe['image_id'] + '.jpg').unique().tolist()
        
    def __getitem__(self, index):
        image_path = self.images[index]
        img = Image.open(image_path).convert('RGB')
        
        image_name = image_path.split('.')[-2].split('/')[-1]
        related_df = self.dataframe[self.dataframe['image_id'] == image_name]
        
        boxes = torch.tensor(related_df['bbox'].values.tolist(), dtype=torch.float32, device=device)
        boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
        boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
        image_id = torch.tensor([index], device=device)
        labels = torch.ones((len(boxes), ), dtype=torch.int64, device=device)
        area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
        iscrowd = torch.zeros((len(boxes), ), dtype=torch.int64, device=device)
        
        target = {}
        target['boxes'] = boxes
        target['image_id'] = image_id
        target['labels'] = labels
        target['area'] = area
        target['iscrowd'] = iscrowd
        
        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target
        
    def __len__(self):
        return len(self.images)

In [None]:
root = '../input/global-wheat-detection/train'
dataframe = train_df
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = WheatDataset(root, dataframe)
dataset[90]

In [None]:
%cd ../input/myfile
!pip install pycocotools

In [None]:
import utils
import transforms as T
from engine import train_one_epoch, evaluate

def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [None]:
%cd ..
%cd ..
%cd working

In [None]:
num_classes = 2
def get_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

model = get_model(num_classes).to(device)
model.load_state_dict(torch.load('../input/fasterrcnn-resnet50-fpn/fasterrcnn_resnet50_fpn.pth'))

In [None]:
torch.manual_seed(42)

train_ds = WheatDataset(root, dataframe, get_transform(train=True))
valid_ds = WheatDataset(root, dataframe, get_transform(train=False))

indices = torch.randperm(len(train_ds)).tolist()
train_ds = Subset(train_ds, indices[:-50])
valid_ds = Subset(valid_ds, indices[-50:])

train_dl = DataLoader(train_ds, batch_size=2, shuffle=True, collate_fn=utils.collate_fn)
valid_dl = DataLoader(valid_ds, batch_size=1, shuffle=False, collate_fn=utils.collate_fn)

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [None]:
# num_epochs = 10
# for epoch in range(num_epochs):
#     train_one_epoch(model, optimizer, train_dl, device, epoch, print_freq=10)
#     lr_scheduler.step()
#     evaluate(model, valid_dl, device)

# Pseudo Labeling and Retraining

In [None]:
class WheatTestDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.images = glob.glob(os.path.join(root, '*'))
    
    def __getitem__(self, index):
        image_path = self.images[index]
        image_id = image_path.split('.')[-2].split('/')[-1]
        
        image = Image.open(image_path).convert('RGB')
        target = {}
        
        if self.transform is not None:
            image, _ = self.transform(image, target)
        return image, image_id
    
    def __len__(self):
        return len(self.images)

test_root = '../input/global-wheat-detection/test'
test_ds = WheatTestDataset(test_root, get_transform(train=False))
test_dl = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, collate_fn=utils.collate_fn)

In [None]:
''''
boxes
image_id
'''
from torchvision.transforms import functional as F
threshold = 0.5
image_ids = []
sources = ['test']
bbox = []
plabel_df = pd.DataFrame()

model.eval()
with torch.no_grad():
    for image, image_id in test_dl:
        prediction = model([image[0].to(device)])
        boxes = prediction[0]['boxes']
        boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
        boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
        
        scores = prediction[0]['scores']
        boxes = boxes[scores > threshold]
    
        w, h = F._get_image_size(image[0])
    
        for box in boxes:
            temp_df = pd.DataFrame({
                'image_id': image_id,
                'width': [w],
                'height': [h],
                'source': ['pseudo_label'],
                'bbox': [[round(b,0) for b in box.tolist()]]
            })
            plabel_df = pd.concat([plabel_df, temp_df], axis=0)

In [None]:
final_plabel_df = pd.concat([train_df, plabel_df], axis=0).reset_index(drop=True)
final_plabel_df

In [None]:
'''
boxes
image_id
labels
area
iscrowd
'''
class PLabelDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        self.image_names = self.df['image_id'].unique()
    
    def __getitem__(self, index):
        image_name = self.image_names[index]
        if self.df[self.df['image_id'] == image_name]['source'].values[0] == 'pseudo_label':
            root = '../input/global-wheat-detection/test'
        else:
            root = '../input/global-wheat-detection/train'
        
        image_path = os.path.join(root, image_name + '.jpg')
        image = Image.open(image_path).convert('RGB')
        
        boxes = torch.tensor(self.df[self.df['image_id'] == self.image_names[index]]['bbox'].tolist(), device=device, dtype=torch.float)
        boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
        boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
        
        image_id = torch.tensor([index], device=device, dtype=torch.int64)
        labels = torch.ones((len(boxes),), device=device, dtype=torch.int64)
        area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
        iscrowd = torch.zeros((len(boxes),), device=device, dtype=torch.int64)
        
        target = {}
        target['boxes'] = boxes
        target['image_id'] = image_id
        target['labels'] = labels
        target['area']= area
        target['iscrowd'] = iscrowd
        
        if self.transform is not None:
            image, target = self.transform(image, target)
        return image, target
    
    def __len__(self):
        return len(self.image_names)
    
ds = PLabelDataset(final_plabel_df)

In [None]:
indices = torch.randperm(len(ds))

plabel_ds = PLabelDataset(final_plabel_df, get_transform(train=True))
test_plabel_ds = PLabelDataset(final_plabel_df, get_transform(train=False))

train_ds = Subset(plabel_ds, indices=indices[:int(len(indices) * 0.9)])
valid_ds = Subset(test_plabel_ds, indices=indices[int(len(indices) * 0.9):])

train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=utils.collate_fn)
valid_dl = DataLoader(valid_ds, batch_size=8, shuffle=False, collate_fn=utils.collate_fn)

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9, weight_decay=0.0001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)

num_epochs = 4
for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, train_dl, device, epoch, print_freq=10)
    lr_scheduler.step()
    evaluate(model, valid_dl, device)

In [None]:
test_root = '../input/global-wheat-detection/test'
test_ds = WheatTestDataset(test_root, get_transform(train=False))
test_dl = DataLoader(test_ds, shuffle=False, batch_size=1, collate_fn=utils.collate_fn)

threshold = 0.5

def show_result(image: torch.Tensor, boxes):
    img = Image.fromarray(image.permute(1,2,0).mul(255).byte().numpy()).convert('RGB')
    img = np.array(img)
    for box in boxes:
        cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), (225, 0, 0), 3)
    plt.imshow(img)
    plt.axis(False)
    
model.eval()
with torch.no_grad():
    for i, (image, image_id) in enumerate(test_dl):
        prediction = model([image[0].to(device)])
        
        boxes = prediction[0]['boxes']
        scores = prediction[0]['scores']
        
        boxes = boxes[scores > threshold]
        plt.subplot(5, 2, i + 1)
        show_result(image[0], boxes)
plt.tight_layout()
plt.show()

In [None]:
torch.save(model.state_dict(), 'fasterrcnn_resnet50_fpn_plabel.pth')