<a href="https://colab.research.google.com/github/rulas99/thermal_anomaly_detection/blob/main/thermal_anomalies_DL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Semantic Segmentation [source](https://towardsdatascience.com/train-neural-net-for-semantic-segmentation-with-pytorch-in-50-lines-of-code-830c71a6544f)

In [1]:
import os
import numpy as np
import cv2
import torchvision.models.segmentation
import torch
import torchvision.transforms as tf

from pandas import read_csv

In [2]:
Learning_Rate = 1e-5
width = 500
height = 500 # image width and height
batchSize = 7

In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

Net = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True) # Load net

Net.classifier[4] = torch.nn.Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1)) # Change final layer to 3 classes

Net = Net.to(device)

optimizer = torch.optim.Adam(params=Net.parameters(),lr=Learning_Rate) # Create adam optimizer

criterion = torch.nn.CrossEntropyLoss() # Set loss function

In [4]:
train=read_csv("/content/drive/MyDrive/segment_sentinel2_hotspot/train.csv")

In [5]:
transformImg=tf.Compose([tf.ToPILImage(),tf.Resize((height,width)),tf.ToTensor(),])
                         #tf.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
                         
transformAnn=tf.Compose([tf.ToPILImage(),tf.Resize((height,width),tf.InterpolationMode.NEAREST),tf.ToTensor()])

In [6]:
def ReadRandomImage(): # First lets load random image and  the corresponding annotation
    idx=np.random.randint(0,len(train)) # Select random image
    thermal = cv2.imread(train.ori_path.iloc[idx])
    masked =  cv2.imread(train.label_path.iloc[idx],0)
    AnnMap = np.zeros(thermal.shape[0:2],np.float32)
    if masked is not None:  AnnMap[ masked == 255 ] = 1
    Img=transformImg(thermal)
    AnnMap=transformAnn(AnnMap)
    return Img,AnnMap


def LoadBatch(): # Load batch of images
    images = torch.zeros([batchSize,3,height,width])
    ann = torch.zeros([batchSize, height, width])
    for i in range(batchSize):
        images[i],ann[i]=ReadRandomImage()
    return images, ann

In [None]:
for itr in range(200): # Training loop
   images,ann=LoadBatch() # Load taining batch
   images=torch.autograd.Variable(images,requires_grad=False).to(device) # Load image
   ann = torch.autograd.Variable(ann, requires_grad=False).to(device) # Load annotation
   Pred=Net(images)['out'] # make prediction
   optimizer.zero_grad()
   Loss=criterion(Pred,ann.long()) # Calculate cross entropy loss
   Loss.backward() # Backpropogate loss
   optimizer.step() # Apply gradient descent change to weight

   if itr % 10 == 0: #Save model weight once every 100 steps permenant file
        accG = [(torch.argmax(Pred[i], 0).cpu().detach().numpy() == ann[i].cpu().detach().numpy()).sum()/(width*height) for i in range(batchSize)]
        accM = round(sum(accG)/batchSize,5)
        loss = round(float(Loss.data.cpu()),5)
        print(itr,f") Loss = {loss} -- Accuracy = {accM}")
        print(f"Saving model_{itr}.torch")
        torch.save(Net.state_dict(),   f'/content/drive/MyDrive/segment_sentinel2_hotspot/Models/model_{itr}.torch')

0 ) Loss = 0.81939 -- Accuracy = 0.09765
Saving 0.torch
10 ) Loss = 0.74898 -- Accuracy = 0.31476
Saving 10.torch
20 ) Loss = 0.71071 -- Accuracy = 0.46811
Saving 20.torch
30 ) Loss = 0.65425 -- Accuracy = 0.75519
Saving 30.torch
40 ) Loss = 0.61035 -- Accuracy = 0.83979
Saving 40.torch
50 ) Loss = 0.58145 -- Accuracy = 0.95638
Saving 50.torch
60 ) Loss = 0.55059 -- Accuracy = 0.96035
Saving 60.torch
70 ) Loss = 0.51686 -- Accuracy = 0.96542
Saving 70.torch
80 ) Loss = 0.49547 -- Accuracy = 0.99684
Saving 80.torch
90 ) Loss = 0.48958 -- Accuracy = 0.98169
Saving 90.torch
100 ) Loss = 0.4418 -- Accuracy = 0.99707
Saving 100.torch
110 ) Loss = 0.4357 -- Accuracy = 0.99865
Saving 110.torch
120 ) Loss = 0.43033 -- Accuracy = 0.99862
Saving 120.torch
130 ) Loss = 0.41902 -- Accuracy = 0.99905
Saving 130.torch
140 ) Loss = 0.39977 -- Accuracy = 0.99672
Saving 140.torch
150 ) Loss = 0.39665 -- Accuracy = 0.99874
Saving 150.torch
160 ) Loss = 0.36793 -- Accuracy = 0.99811
Saving 160.torch
170 