In [1]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [2]:
!pip install pytorch-lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.7.1-py3-none-any.whl (701 kB)
[K     |████████████████████████████████| 701 kB 12.9 MB/s 
[?25hCollecting PyYAML>=5.4
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 68.9 MB/s 
[?25hCollecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2022.7.1-py3-none-any.whl (141 kB)
[K     |████████████████████████████████| 141 kB 57.6 MB/s 
[?25hCollecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.9.3-py3-none-any.whl (419 kB)
[K     |████████████████████████████████| 419 kB 54.8 MB/s 
Collecting tensorboard>=2.9.1
  Downloading tensorboard-2.10.0-py3-none-any.whl (5.9 MB)
[K     |████████████████████████████████| 5.9 MB 57.1 MB/s 
Collecting pyDeprecate>=0.3.1
  Downloading pyDep

In [3]:
import pytorch_lightning as pl
import pandas as pd
import cv2
import os 
import torch
from PIL import Image
from torch import nn
from torch.utils.data import Dataset ,DataLoader, random_split
import numpy as np
import torch
from sklearn.model_selection import train_test_split 
import torchvision
from torchvision import transforms, datasets, models
import matplotlib.pyplot as plt
import torchmetrics
from torchmetrics.functional import accuracy
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torchvision.utils import make_grid
import math
import torch.nn.functional as F
import warnings
warnings.filterwarnings('ignore')

In [4]:
class WheatDataset(Dataset):
  def __init__(self, root, folder='train', transforms=None):
    self.transforms = []
    if transforms != None:
      self.transforms.append(transforms)
    self.root = root
    self.folder = folder
    box_data = pd.read_csv(os.path.join(root, "train.csv"))
    self.box_data = pd.concat([box_data, box_data.bbox.str.split('[').str.get(1).str.split(']').str.get(0).str.split(',', expand=True)], axis=1)
    self.imgs = list(os.listdir(os.path.join(root, self.folder)))

  def __len__(self):
    return len(self.imgs)

  def __getitem__(self, idx):
    img_path = os.path.join(os.path.join(self.root, self.folder), self.imgs[idx])
    img = Image.open(img_path).convert("RGB")
    df = self.box_data[self.box_data['image_id'] == self.imgs[idx].split('.')[0]]
    if df.shape[0] != 0:
      df[2] = df[0].astype(float) + df[2].astype(float) # bbox = [x, y, w, h] with this line we find xmax cause pytorch looks for [xmin, ymin, xmax, ymax]
      df[3] = df[1].astype(float) + df[3].astype(float) # we find ymax
      boxes = df[[0,1,2,3]].astype(float).values
      labels = np.ones(len(boxes))
    else:
      boxes = np.asarray([[0,0,0,0]])
      labels = np.ones(len(boxes))
    
    for i in self.transforms:
      img=i(img)

    targets = {}
    targets['boxes'] = torch.from_numpy(boxes).double()
    targets['labels'] = torch.from_numpy(labels).type(torch.int64)

    return img.double(), targets
    

In [None]:
root = "/content/drive/MyDrive/Datasets/global_wheat_dataset"
dataset = WheatDataset(root=root, folder="train", transforms=transforms.ToTensor())

In [None]:
dataset[0]

In [None]:
pl.seed_everything(1)

In [None]:
len(dataset)

In [None]:
arr = np.arange(15)
arr[:-13]

In [None]:
indices = torch.randperm(len(dataset)).tolist()
dataset_train = torch.utils.data.Subset(dataset, indices[:2500])
dataset_test = torch.utils.data.Subset(dataset, indices[2500:])
dataloader_train = DataLoader(dataset_train, batch_size=4, shuffle=True, collate_fn=lambda x:list(zip(*x)))
dataloader_test = DataLoader(dataset_test, batch_size=4, shuffle=False, collate_fn=lambda x:list(zip(*x)))

In [None]:
images,labels=next(iter(dataloader_train))
from matplotlib import patches
def view(images,labels,k,std=1,mean=0):
    figure = plt.figure(figsize=(30,30))
    images=list(images)
    labels=list(labels)
    for i in range(k):
        out=torchvision.utils.make_grid(images[i])
        inp=out.cpu().numpy().transpose((1,2,0))
        inp=np.array(std)*inp+np.array(mean)
        inp=np.clip(inp,0,1)  
        ax = figure.add_subplot(2,2, i + 1)
        ax.imshow(images[i].cpu().numpy().transpose((1,2,0)))
        l=labels[i]['boxes'].cpu().numpy()
        l[:,2]=l[:,2]-l[:,0]
        l[:,3]=l[:,3]-l[:,1]
        for j in range(len(l)):
            ax.add_patch(patches.Rectangle((l[j][0],l[j][1]),l[j][2],l[j][3],linewidth=2,edgecolor='w',facecolor='none')) 

view(images,labels,4)

In [None]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

class LitModel(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.weights = torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT
    self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=self.weights)
    self.num_classes = 2 # 1 class (wheat) + background
    self.in_features = self.model.roi_heads.box_predictor.cls_score.in_features

    self.model.roi_heads.box_predictor = FastRCNNPredictor(self.in_features, self.num_classes)

  def forward(self, x):
    self.model.eval()
    return self.model(x)

  def training_step(self, batch, batch_idx):
    images, targets = batch
    model = self.model.double()
    loss_dict = model(images, targets)
    losses = sum(loss for loss in loss_dict.values())
    self.log_dict(loss_dict)
    return losses

  def configure_optimizers(self):
    return torch.optim.SGD(self.parameters(), lr=0.01)


In [None]:
model = LitModel()
trainer = pl.Trainer(max_epochs=10,
                  accelerator="auto",
                  devices=1 if torch.cuda.is_available() else None,
                  callbacks=[LearningRateMonitor(logging_interval="step"),
                                TQDMProgressBar(refresh_rate=20)],
)
trainer.fit(model, dataloader_train)