In [238]:
# Importing libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

from PIL import Image

import os
import glob
import time

In [239]:
# Device : cuda:0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(device)

In [240]:
# Dataset processing
class CropDataloader(Dataset):

    def __init__(self, image_directory, mask_directory, transform = None ):

        super().__init__()

        # Take .jpg as original and .tif as label images

        self.images = sorted(glob.glob(os.path.join(image_directory, "*.jpg")))

        self.masks = sorted(glob.glob(os.path.join(mask_directory, "*.tif")))

        self.transform = transform


    def __len__(self):
        return len(self.images)
    

    def __getitem__(self, index):

        original_image = Image.open(self.images[index]).convert('L')

        mask_image = Image.open(self.masks[index])

        # Transform into the tensors

        tensor_image = self.transform(original_image)
        tensor_mask = self.transform(mask_image)

        return tensor_image, tensor_mask



# Image path
image_path = "/local/data/sdahal_p/Crop/data/train/original/"

mask_image_path = "/local/data/sdahal_p/Crop/data/train/mask/"

transform = transforms.Compose([
    # transforms.Resize((128, 128)),
    transforms.ToTensor()
])

# Dataloader for images

dataset = CropDataloader(image_path, mask_image_path, transform = transform)
dataloader = DataLoader(dataset, batch_size = 1, shuffle = True)


In [241]:
# Model

# Patch embedding

class PatchEmbedding(nn.Module):

    def __init__(self,model_dim, patch_size, number_of_patches, in_channels):
        super().__init__()

        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels = in_channels,
                out_channels = model_dim,
                kernel_size = patch_size,
                stride = patch_size,
            ),                  
            nn.Flatten(2))
        
        self.batch_size = 1
        
        self.cls_token = nn.Parameter(torch.randn(size=(self.batch_size, in_channels, model_dim)), requires_grad=True)

        # self.position_embeddings = nn.Parameter(torch.randn(size=(self.batch_size, number_of_patches + 1, model_dim)), requires_grad=True)

        self.position_embeddings = nn.Parameter(torch.randn(size=(self.batch_size, number_of_patches, model_dim)), requires_grad=True)

        self.dropout = nn.Dropout(p=0.05)

    def forward(self, x):

        # cls token

        cls_token = self.cls_token.expand(x.shape[0], -1, -1)

        x = self.patcher(x)

        x = x.permute(0, 2, 1)
        
        # x = torch.cat([cls_token, x], dim=1)
        
        x = self.position_embeddings + x 
        
        x = self.dropout(x)

        return x
    





In [242]:
class CropFormer(nn.Module):

    def __init__(self, model_dim, patch_size, number_of_patches, in_channels, encoders, num_heads, num_classes, original_img_size):
        super().__init__()

        self.embeddings_block = PatchEmbedding(model_dim, patch_size, number_of_patches, in_channels)

        encoder_layer= nn.TransformerEncoderLayer(d_model = model_dim, nhead = num_heads)

        self.encoder_blocks = torch.nn.TransformerEncoder(encoder_layer , num_layers = encoders)

        self.head = nn.Linear(model_dim, patch_size * patch_size)


        # self.linear_classifier = torch.nn.Linear(in_features=model_dim, out_features=num_classes)


    def forward(self, x):

        batch_size = x.shape[0]

        x =  self.embeddings_block(x)

        # print(x.size())

        x = self.encoder_blocks(x)

        # x = self.linear_classifier(x[:, 0, :])

        x = self.head(x)

        return x.view(batch_size, in_channels, original_img_size, original_img_size)


        return x


In [252]:

model_dim = 512
patch_size = 32
number_of_patches = 256
in_channels = 1

encoders = 6

num_heads = 8

num_classes = 4

original_img_size = 512

crop_model = CropFormer(model_dim, patch_size, number_of_patches, in_channels, encoders, num_heads, num_classes, original_img_size)

criterion = nn.BCEWithLogitsLoss()

optimizer = torch.optim.Adam(crop_model.parameters(), lr=0.0001)


In [253]:

# Model, loss, and optimizer

# Training loop

crop_model.train(True)


num_epochs = 10

for epoch in range(num_epochs):

    for inputs, targets in dataloader:

        optimizer.zero_grad()

        # print(inputs.size())

        outputs = crop_model(inputs)

        # print(outputs.size())

        # print(targets.size())

        loss = criterion(outputs, targets)
        loss.backward()

        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

print("Training complete!")


Epoch [1/10], Loss: 0.8123
Epoch [2/10], Loss: 0.4967
Epoch [3/10], Loss: 0.6087
Epoch [4/10], Loss: 0.4717
Epoch [5/10], Loss: 0.5204
Epoch [6/10], Loss: 0.5018
Epoch [7/10], Loss: 0.3711
Epoch [8/10], Loss: 0.5843
Epoch [9/10], Loss: 0.4838
Epoch [10/10], Loss: 0.2800
Training complete!
