In [26]:
import wandb
from pathlib import Path
import os
from tqdm.auto import tqdm
from PIL import Image
import numpy as np
import torch
from torch import nn
from datasets import load_dataset
from transformers import MaskFormerImageProcessor, MaskFormerForInstanceSegmentation
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchinfo import summary
import torchvision.models as models
from tqdm.auto import tqdm
import torchmetrics
from torchmetrics import Accuracy

In [2]:
os.environ['WANDB_NOTEBOOK_NAME']='EDA.ipynb'
os.environ["KMP_DUPLICATE_LIB_OK"]='TRUE'

In [3]:
Logging_Place= 'LOCAL'

if Logging_Place == 'LOCAL':
    os.environ['WANDB_BASE_URL']="http://10.24.1.19:8080"
    os.environ['WANDB_API_KEY']='local-96dd72cbb60b0155149e2f9dc985a636ffa77b28'
    ! wandb login --host=http://10.24.1.19:8080

elif Logging_Place =="CLOUD":
    os.environ['WANDB_BASE_URL']="https://api.wandb.ai"
    os.environ['WANDB_API_KEY']='be69ec2b537cb972b0106fabd867f0dc68c4468a'
    ! wandb login --host=https://api.wandb.ai

^C


wandb: Network error (HTTPError), entering retry loop.
wandb: W&B API key is configured. Use `wandb login --relogin` to force relogin


In [3]:
data_path=Path('COD10K-v2')
train_path=data_path / 'Train'
test_path= data_path / 'Test'

train_images_path= train_path / 'Images/Image'
train_labels_path= train_path / 'GT_Objects/GT_Object'

test_images_path= test_path / 'Images/Image'
test_labels_path= test_path / 'GT_Objects/GT_Object'

In [4]:
train_images_list=list(train_images_path.glob('*.jpg'))
print(train_images_list[0:5])
train_labels_list=list(train_labels_path.glob('*.png'))
print(train_labels_list[0:5])

[WindowsPath('COD10K-v2/Train/Images/Image/COD10K-CAM-1-Aquatic-1-BatFish-1.jpg'), WindowsPath('COD10K-v2/Train/Images/Image/COD10K-CAM-1-Aquatic-1-BatFish-3.jpg'), WindowsPath('COD10K-v2/Train/Images/Image/COD10K-CAM-1-Aquatic-1-BatFish-7.jpg'), WindowsPath('COD10K-v2/Train/Images/Image/COD10K-CAM-1-Aquatic-1-BatFish-8.jpg'), WindowsPath('COD10K-v2/Train/Images/Image/COD10K-CAM-1-Aquatic-1-BatFish-9.jpg')]
[WindowsPath('COD10K-v2/Train/GT_Objects/GT_Object/COD10K-CAM-1-Aquatic-1-BatFish-1.png'), WindowsPath('COD10K-v2/Train/GT_Objects/GT_Object/COD10K-CAM-1-Aquatic-1-BatFish-3.png'), WindowsPath('COD10K-v2/Train/GT_Objects/GT_Object/COD10K-CAM-1-Aquatic-1-BatFish-7.png'), WindowsPath('COD10K-v2/Train/GT_Objects/GT_Object/COD10K-CAM-1-Aquatic-1-BatFish-8.png'), WindowsPath('COD10K-v2/Train/GT_Objects/GT_Object/COD10K-CAM-1-Aquatic-1-BatFish-9.png')]


In [5]:
super_classes_list=[ i.stem.split('-')[3] for i in train_images_list]
super_classes_list=[*set(super_classes_list)]
print(super_classes_list)


['Amphibian', 'Aquatic', 'Other', 'Terrestrial', 'Background', 'Flying', 'Terrestial']


In [6]:
index=5
sub_classes_list=[]
for i in train_images_list:
    split_path=i.stem.split('-')
    count=len(split_path)
    if (index<count):
        sub_classes_list.append(split_path[index])
sub_classes_list=[*set(sub_classes_list)]
print(sub_classes_list)

['Indoor', 'Crocodile', 'Sheep', 'StickInsect', 'Mantis', 'Reccoon', 'Owl', 'Sky', 'Tiger', 'Katydid', 'Human', 'GhostPipefish', 'Flounder', 'Deer', 'Dragonfly', 'Sciuridae', 'Spider', 'Frogmouth', 'Pagurian', 'Pipefish', 'ScorpionFish', 'Worm', 'Bee', 'Bittern', 'Fish', 'Ocean', 'Mockingbird', 'Giraffe', 'Vegetation', 'Toad', 'Other', 'Cat', 'FrogFish', 'Wolf', 'Monkey', 'Beetle', 'Grouse', 'Heron', 'CrocodileFish', 'Bug', 'Moth', 'StarFish', 'Turtle', 'Caterpillar', 'Frog', 'Bat', 'Butterfly', 'Crab', 'BatFish', 'Duck', 'ClownFish', 'Cicada', 'Lion', 'Lizard', 'Grasshopper', 'Slug', 'SeaHorse', 'Octopus', 'Cheetah', 'Chameleon', 'Stingaree', 'Centipede', 'Gecko', 'Leopard', 'Bird', 'Owlfly', 'Sand', 'Kangaroo', 'LeafySeaDragon', 'Ant', 'Snake', 'Dog', 'Shrimp', 'Rabbit']


#####    Setting up dataloaders

In [8]:
class CamouflageDataset(Dataset):
    def __init__(
        self,
        image_directory:str,
        gt_directory:str,
        transform=None
    ):
        self.image_directory=image_directory
        self.gt_directory=gt_directory
        self.transforms = transform
        self.image_list=list(self.image_directory.glob('*.jpg'))
        self.gt_list=list(self.gt_directory.glob('*.png'))

    def __getitem__(self, idx):
        img=Image.open(self.image_list[idx])
        mask=Image.open(self.gt_list[idx])
                                              
        self.transforms = transforms.Compose([  transforms.Resize((512,512)),
                                                transforms.ToTensor()

                                             ])
        img=self.transforms(img)
        mask=self.transforms(mask)     
        return img, mask

    def __len__(self):
        return len(self.image_list)

In [9]:
train_dataset=CamouflageDataset(image_directory=train_images_path,
                                gt_directory=train_labels_path
                                )
test_dataset=CamouflageDataset(image_directory=test_images_path,
                                gt_directory=test_labels_path
                               )

In [10]:
train_dataloader=DataLoader(train_dataset,batch_size=1, shuffle=True)
test_dataloader=DataLoader(test_dataset,batch_size=1, shuffle=False)

In [11]:
train_features, train_labels = next(iter(test_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

Feature batch shape: torch.Size([1, 3, 512, 512])
Labels batch shape: torch.Size([1, 1, 512, 512])


In [12]:
print(len(train_dataset))
print(len(test_dataset))

6000
4000


In [13]:
from transformers import SegformerFeatureExtractor
from transformers import SegformerForSemanticSegmentation

In [14]:
model = SegformerForSemanticSegmentation.from_pretrained("segformer-b3-finetuned-ade-512-512")
print(model)

SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(128, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(320, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)

In [15]:


model.decode_head.classifier=nn.Conv2d(768, 16, kernel_size=(1, 1), stride=(1, 1))
print(model.decode_head)

SegformerDecodeHead(
  (linear_c): ModuleList(
    (0): SegformerMLP(
      (proj): Linear(in_features=64, out_features=768, bias=True)
    )
    (1): SegformerMLP(
      (proj): Linear(in_features=128, out_features=768, bias=True)
    )
    (2): SegformerMLP(
      (proj): Linear(in_features=320, out_features=768, bias=True)
    )
    (3): SegformerMLP(
      (proj): Linear(in_features=512, out_features=768, bias=True)
    )
  )
  (linear_fuse): Conv2d(3072, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (batch_norm): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (activation): ReLU()
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Conv2d(768, 16, kernel_size=(1, 1), stride=(1, 1))
)


In [16]:
for params in model.segformer.parameters():
    params.requires_grad=False

In [17]:
summary(model=model, input_size=(1,3,512,512), col_names=["input_size" , "output_size", "trainable"], col_width=20, depth=5)

Layer (type:depth-idx)                                                      Input Shape          Output Shape         Trainable
SegformerForSemanticSegmentation                                            [1, 3, 512, 512]     [1, 16, 128, 128]    Partial
├─SegformerModel: 1-1                                                       [1, 3, 512, 512]     [1, 64, 128, 128]    False
│    └─SegformerEncoder: 2-1                                                [1, 3, 512, 512]     [1, 64, 128, 128]    False
│    │    └─ModuleList: 3-10                                                --                   --                   False
│    │    │    └─SegformerOverlapPatchEmbeddings: 4-1                       [1, 3, 512, 512]     [1, 16384, 64]       False
│    │    │    │    └─Conv2d: 5-1                                           [1, 3, 512, 512]     [1, 64, 128, 128]    False
│    │    │    │    └─LayerNorm: 5-2                                        [1, 16384, 64]       [1, 16384, 64]       False
│ 

In [18]:
output_logits=model(train_features)

In [19]:
output=output_logits.logits

#### Model train loop

In [20]:
optimizer=torch.optim.Adam(params=model.parameters(), lr=1e-5)
loss_function=nn.BCELoss()

In [31]:
def train_step(dataloader: DataLoader = train_dataloader, 
          optimizer: torch.optim.Adam = optimizer,
          loss_function: nn.BCELoss =loss_function
          
          ):
          train_loss=0
          train_accuracy=0
          model.train()

          for batch, (image, label_image)  in tqdm(enumerate(dataloader)):
            output=model(image)
            output_logits=output.logits
            output_logits=output_logits.view(1,1,512,512)
            output_logits=torch.sigmoid(output_logits) 
            #print(output_logits.shape)
            loss=loss_function(output_logits,label_image)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss=train_loss+loss
            print(batch)

            if (batch % 10 ==0):
              train_loss=train_loss / 10
              print(f" train loss : {train_loss}")

              train_loss=0


            

          

In [32]:
train_step()

0it [00:00, ?it/s]

0
 train loss : 0.05755419656634331
1
2
3
4
5
6
7
8
9
10
 train loss : 0.5789338946342468
11
12
13
14
15
16
17
18
19
20
 train loss : 0.5713996887207031


KeyboardInterrupt: 