In [None]:
# Single Image has 4 channels -> red, green, blue and near infra red
################################################################################################
BATCH_SIZE = 8
LEARNING_RATE = 0.02
EPOCHS = 5
CROP = 256
TEST_SIZE = 40668
RESIZE = 256
RMEAN = 0.6287
GMEAN = 0.5494
BMEAN = 0.3999
NIRMEAN = 0.0808
################################################################################################

import os 
import pandas as pd  # for lookup in annotation file
import torch
from torch.nn.utils.rnn import pad_sequence  # pad batch
from torch.utils.data import DataLoader, Dataset
from PIL import Image  # Load img
import torchvision.transforms as transforms
import torch.nn as nn
import statistics
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import random

class PlanetDataset(Dataset):
    def __init__(self, csvfile, imgdir, transform=None, extension='jpg'):
        self.csvfile = csvfile
        self.imgdir = imgdir
        self.extension = extension
        
        self.df = pd.read_csv(csvfile)[:4000]
            
        self.transform = transform
        self.images = self.df['image_name']
        self.tags = self.df['tags']
        self.tags_to_index = {'agriculture' : 0, 'artisinal_mine': 1, 'bare_ground' : 2, 'blooming' : 3,
                         'blow_down':4,'clear':5,'cloudy':6,'conventional_mine':7,'cultivation':8,
                         'habitation':9,'haze':10,'partly_cloudy':11,'primary':12,'road':13,'selective_logging':14,
                         'slash_burn':15,'water':16}
        
        self.index_to_tags = {}
        for e in self.tags_to_index:
            self.index_to_tags[self.tags_to_index[e]] = e
        
    def __len__(self):
        return len(self.df)
    
    def mytransforms(self, img):
        #img = transforms.Resize((RESIZE, RESIZE))(img)
        #if random.random() > 0.5:
        #    img = TF.adjust_gamma(img, 1, gain=1)
        #if random.random() > 0.5:
        #    img = TF.adjust_saturation(img, 2)
        img = transforms.RandomRotation(90)(img)
        img = transforms.RandomRotation(180)(img)
        img = transforms.RandomHorizontalFlip()(img)
        img = transforms.RandomVerticalFlip()(img)
        #img = transforms.RandomCrop(CROP)(img)
        img = transforms.ColorJitter()(img)
        #img = transforms.GaussianBlur(5)(img)
        img = transforms.ToTensor()(img)
        img = transforms.Normalize((RMEAN, GMEAN, BMEAN, NIRMEAN), (0.5, 0.5, 0.5, 0.5))(img)
        
        return img
            
        
    def __getitem__(self, index):
        image_name = self.images[index] + '.' + self.extension
        alltags = self.tags[index]
        img  = Image.open(os.path.join(self.imgdir, image_name))
        
        img = self.mytransforms(img)
        img = self.addAuxiliaryLayers(img)
        
        one_hot_labels = [0 for _ in range(17)]
        for tag in alltags.split(' '):
            one_hot_labels[self.tags_to_index[tag]] = 1.0
        
        return img, torch.tensor(one_hot_labels)
    
    #ignore this function not used
    def addAuxiliaryLayers(self, img):
        red = img[0]
        green = img[1]
        blue = img[2]
        nir = img[3]
        #img = torch.vstack((img, ((nir - red) / (nir+red)).unsqueeze(0) ))
        #img = torch.vstack(( img, ( 2.5 * ((nir-red)/(nir+6*red-7.5*blue+1)) ).unsqueeze(0) ))
        #img = torch.vstack(( img, ( (nir-red)/(nir+red+0.5) * 1.5 ).unsqueeze(0) ))
        #img = torch.vstack(( img, ( (2*nir+1-torch.sqrt((2*nir+1)**2 - 8*(nir-red)))/2 ).unsqueeze(0) ))
        #img = torch.vstack(( img, ( (green-nir)/(green+nir) ).unsqueeze(0) ))
        #img = torch.vstack(( img, ( nir/red ).unsqueeze(0) ))
        #img = torch.vstack(( img, (  ).unsqueeze(0) ))
        return img
    
    
'''class GISModel(nn.Module):
    def __init__(self):
        super(GISModel, self).__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.drop = nn.Dropout(0.5)
        
        # 3 is input color channels for image
        self.conv1 = nn.Conv2d(9, 100, 5)
        self.conv2 = nn.Conv2d(100, 200, 5)
        self.fc1 = nn.Linear(61*61*200, 100)
        self.fc2 = nn.Linear(100, 200)
        self.fc3 = nn.Linear(200, 17)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 61*61*200)
        x = self.drop(F.relu(self.fc1(x)))
        x = F.relu(self.fc2(x))
        x = self.drop(F.relu(self.fc3(x)))
        
        return x'''
        

def showSamples(dataset, loader):
    count = 0
    for idx, (imgs, labels) in enumerate(loader):
        for i in range(imgs.shape[0]):
            count+=1
            plt.imshow(transforms.ToPILImage()(imgs[i]).convert('RGB'))
            all_tags =[]
            for j in range(17):
                if labels[i][j] == 1.0:
                    all_tags.append(dataset.index_to_tags[j])
            print(all_tags)
            plt.show()
            
            if count > 4:
                return

THRESHOLD = 0.3
def getPredForAccuracy(originals, predicted, threshold=THRESHOLD):
    predicted = torch.sigmoid(predicted)
    predicted[predicted >= threshold] = 1
    predicted[predicted < threshold] = 0
    return (predicted == originals).sum(), originals.numel() 

def getPredForFBeta(predicted, originals, threshold=THRESHOLD):
    predicted = torch.sigmoid(predicted)
    predicted[predicted >= threshold] = 1
    predicted[predicted < threshold] = 0
    
    confusion_vector = predicted / originals
    tp = torch.sum(confusion_vector == 1).item()
    fp = torch.sum(confusion_vector == float('inf')).item()
    #tn = torch.sum(torch.isnan(confusion_vector)).item()
    fn = torch.sum(confusion_vector == 0).item()
    p = tp/(tp+fp)
    r = tp/(tp+fn)
    return ( (5*p*r)/(4*p+r) )

def calculateMean(loader):
    redSum = torch.zeros((CROP, CROP))
    greenSum = torch.zeros((CROP, CROP))
    blueSum = torch.zeros((CROP, CROP))
    nirSum = torch.zeros((CROP, CROP))
    count = 0
    for idx, (imgs, tags) in tqdm(
                enumerate(train_loader), total=len(train_loader), leave=False
            ):
        for img in imgs:
            redSum += img[0]
            greenSum += img[1]
            blueSum += img[2]
            nirSum += img[3]
        count += imgs.shape[0]
    
    totalPixels = CROP * CROP * count
    return redSum.sum()/(totalPixels), greenSum.sum()/(totalPixels), blueSum.sum()/(totalPixels), nirSum.sum()/(totalPixels)


In [None]:
transform = None
def train():

    dataset = PlanetDataset('../input/planets-dataset/planet/planet/train_classes.csv',
                           '../input/understanding-amazon-from-space-tar/train-tif-v2/train-tif-v2', transform,
                           extension='tif')
    
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=8, shuffle=True, pin_memory=False)
    val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, num_workers=8, shuffle=True, pin_memory=False)
    
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    model = models.resnet50(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    
    # tif images have 4 channels
    model.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model.fc = nn.Linear(model.fc.in_features, 17)
    #model = GISModel()
    model.to(device)

    #criterion = nn.CrossEntropyLoss()
    criterion = nn.BCEWithLogitsLoss(reduction='mean')
    #criterion = nn.BCELoss(reduction='mean')
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)

    total_steps = len(train_loader)

    for epoch in range(EPOCHS):
        for idx, (imgs, tags) in tqdm(
                enumerate(train_loader), total=len(train_loader), leave=False
            ):
            imgs = imgs.to(device)
            tags = tags.to(device)

            outputs = model(imgs)
            loss = criterion(outputs, tags)

            optimizer.zero_grad()
            loss.backward(loss)
            optimizer.step()

        correct_pred = 1
        total_pred = 1
        all_outputs = None
        all_tags = None
        with torch.no_grad():
            for idx, (imgs, tags) in tqdm(
                    enumerate(val_loader), total=len(val_loader), leave=False
                ):
                imgs = imgs.to(device)
                tags = tags.to(device)
                outputs = model(imgs)
                cpred, totpred = getPredForAccuracy(tags, outputs)
                correct_pred += cpred
                total_pred += totpred
                
                if all_outputs == None:
                    all_outputs = outputs
                    all_tags = tags
                else:
                    all_outputs = torch.vstack((all_outputs, outputs))
                    all_tags = torch.vstack((all_tags, tags))
                
            print('Single batch shape: ', imgs.shape)
            print('Predictions by model for single sample image: ', model(imgs[0].unsqueeze(0)))
            print('Predictions after applying Sigmoid: (We will threshold the probabilities and assign 1/0)', torch.sigmoid(model(imgs[0].unsqueeze(0))))
            print('Actual tags: ', tags[0])
            
        print(f'Epoch: [{epoch+1} / {EPOCHS}] \t Loss: {loss.item():.4f} \t Accuracy (Val): {correct_pred/total_pred:.4f} \t F2 score (Val): {getPredForFBeta(all_outputs, all_tags):.4f}')  
    
    return model, dataset

model, dataset = train()
