In [1]:
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import requests
import os
import numpy as np

import torch
from torchvision import models
from torch.nn import CrossEntropyLoss
from torch.nn.functional import softmax
from torch.optim import Adam,lr_scheduler
from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import Dataset

In [2]:
Deeplabv3plus_path = r'E:\codes\python\area51m\seg_halfprecision\plant-segmentation'
Data_path = r'E:\codes\python\area51m\seg_halfprecision\plant-segmentation'
CroppedTrainingLabelNames = 'CroppedTrainingLabelNames.csv'
CroppedTestLabelNames = 'CroppedTestLabelNames.csv'
CroppedTraining_input_folder = 'cropped_input'
CroppedTraining_label_folder = 'cropped_label'
CroppedTest_input_folder = 'cropped_input_test'
CroppedTest_label_folder = 'cropped_label_test'
ModelPath = r'E:\data\MODELS' 
TwoStageModelName = 'twostage_model'

In [3]:
from utils import *
import sys
import os

device = torch.device("cuda:0")

sys.path.append(Deeplabv3plus_path)
sys.path.append(Deeplabv3plus_path+r'\pytorch_deeplab_xception')
from pytorch_deeplab_xception.modeling import deeplab

os.environ['TORCH_HOME'] = ModelPath

In [4]:
num_c = 2

dl = deeplab.DeepLab(num_classes=num_c,backbone = 'resnet').to(device)

for name, param in dl.named_parameters():                
    if name.startswith('backbone'):
        param.requires_grad = False

In [5]:
image_size = (200,200)

transform = transforms.Compose([transforms.Resize(image_size),
                                transforms.ColorJitter(brightness = 0.15,
                                                       saturation = 0.1,
                                                       hue = 0.01,
                                                       contrast = 0.15),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

In [6]:
data = pd.read_csv(r'{0}\{1}'.format(Data_path,CroppedTrainingLabelNames))

dataset_train = Segdata(transform,image_size,
                        CroppedTrainingLabelNames,Data_path,[CroppedTraining_input_folder,CroppedTraining_label_folder])

dataset_test = Segdata(transform,image_size,
                       CroppedTestLabelNames,Data_path,[CroppedTest_input_folder,CroppedTest_label_folder])

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=6, 
                                         shuffle=True, num_workers=0,drop_last  = True)

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, num_workers=0,drop_last  = False)

In [7]:
plist = [
         {'params': dl.decoder.parameters(),'lr':1e-2}
         ]

freq = 1/get_frequency(data,image_size,num_c,Data_path,CroppedTraining_label_folder)

criterion = CrossEntropyLoss(weight=freq)

optimizer = Adam(params=plist)

scheduler = lr_scheduler.StepLR(optimizer, step_size=33,gamma=0.1)

train_loss_list = []
test_loss_list = []

best_val_loss = float('inf')

n_epoch = 50

In [None]:
for epoch in range(n_epoch):
    training_loss,training_iou = train(dl,dataloader_train,optimizer,criterion,train_loss_list,num_c)
    val_loss,val_iou = evaluate(dl,dataloader_test,criterion,test_loss_list,num_c)
    scheduler.step()
    print('finished {0}th training'.format(epoch+1))
    print('training loss: {0}, training iou: {1}'.format(training_loss.item(),training_iou))
    print('validation loss: {0}, validation iou: {1}'.format(val_loss.item(),val_iou))
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(dl.state_dict(), r'{0}\{1}_best.pt'.format(ModelPath,TwoStageModelName))
torch.save(dl.state_dict(), r'{0}\{1}_last.pt'.format(ModelPath,TwoStageModelName))

finished 1th training
training loss: 0.5442003027118486, training iou: 0.46611311668062305
validation loss: 0.4940067201231917, validation iou: 0.43870349676925646
