In [1]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
!pip install pytorch-lightning

In [3]:
import pytorch_lightning as pl
import pandas as pd
import cv2
import os 
import torch
from PIL import Image
from torch import nn
from torch.utils.data import Dataset ,DataLoader, random_split
import numpy as np
import torch
from sklearn.model_selection import train_test_split 
import torchvision
from torchvision import transforms, datasets, models
import matplotlib.pyplot as plt
import torchmetrics
from torchmetrics.functional import accuracy
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torchvision.utils import make_grid
import math
import torch.nn.functional as F
import warnings
warnings.filterwarnings('ignore')

In [4]:
class WheatDataset(Dataset):
  def __init__(self, root, folder='train', transforms=None):
    self.transforms = []
    if transforms != None:
      self.transforms.append(transforms)
    self.root = root
    self.folder = folder
    box_data = pd.read_csv(os.path.join(root, "train.csv"))
    self.box_data = pd.concat([box_data, box_data.bbox.str.split('[').str.get(1).str.split(']').str.get(0).str.split(',', expand=True)], axis=1)
    self.imgs = list(os.listdir(os.path.join(root, self.folder)))

  def __len__(self):
    return len(self.imgs)

  def __getitem__(self, idx):
    img_path = os.path.join(os.path.join(self.root, self.folder), self.imgs[idx])
    img = Image.open(img_path).convert("RGB")
    df = self.box_data[self.box_data['image_id'] == self.imgs[idx].split('.')[0]]
    if df.shape[0] != 0:
      df[2] = df[0].astype(float) + df[2].astype(float) # bbox = [x, y, w, h] with this line we find xmax cause pytorch looks for [xmin, ymin, xmax, ymax]
      df[3] = df[1].astype(float) + df[3].astype(float) # we find ymax
      boxes = df[[0,1,2,3]].astype(float).values
      labels = np.ones(len(boxes))
    else:
      boxes = np.asarray([[0,0,0,0]])
      labels = np.ones(len(boxes))
    
    for i in self.transforms:
      img=i(img)

    targets = {}
    targets['boxes'] = torch.from_numpy(boxes).double()
    targets['labels'] = torch.from_numpy(labels).type(torch.int64)

    return img.double(), targets

In [5]:
root = "/content/drive/MyDrive/Datasets/global_wheat_dataset"
dataset = WheatDataset(root=root, folder="train", transforms=transforms.ToTensor())

In [6]:
pl.seed_everything(1)

INFO:pytorch_lightning.utilities.seed:Global seed set to 1


1

In [7]:
indices = torch.randperm(len(dataset)).tolist()
dataset_train = torch.utils.data.Subset(dataset, indices[:2500])
dataset_test = torch.utils.data.Subset(dataset, indices[2500:])
dataloader_train = DataLoader(dataset_train, batch_size=4, shuffle=True, collate_fn=lambda x:list(zip(*x)))
dataloader_test = DataLoader(dataset_test, batch_size=4, shuffle=False, collate_fn=lambda x:list(zip(*x)))

In [9]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
num_classes = 2  # 1 class (person) + background
in_features = model.roi_heads.box_predictor.cls_score.in_features

model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model=model.to(device)


params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.01)


Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth


  0%|          | 0.00/160M [00:00<?, ?B/s]

In [11]:
model.train()
from tqdm.notebook import tqdm
for epoch in tqdm(range(1)):
    for images,targets in tqdm(dataloader_train):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        model=model.double()
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        losses.backward()

        optimizer.zero_grad()
        optimizer.step()
        
    print("Loss = {:.4f} ".format(losses.item()))

torch.save(model.state_dict(), './model.pth')

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/625 [00:00<?, ?it/s]

AssertionError: ignored