In [None]:
#Look at images
from pathlib import Path
from PIL import Image
import os
import random
import matplotlib.pyplot as plt
from mpmath.identification import transforms
from torchvision.ops.misc import interpolate

random.seed(42)
LOCALPATH = Path('data/1920x1080/1920x1080')
KAGGLEPATH = Path('/kaggle/input/fs2020-runway-dataset/1920x1080/1920x1080')

if os.path.exists(KAGGLEPATH):
    target_dir_feature = KAGGLEPATH
else:
    target_dir_feature = LOCALPATH

if target_dir_feature.exists():
    train_dir = target_dir_feature/'train'
    test_dir = target_dir_feature/'test'
    image_path_feature_list = list(train_dir.glob('*.png'))

    if len(image_path_feature_list)>0:

        random_image_path = random.choice(image_path_feature_list)
        image = Image.open(random_image_path)

        plt.figure(figsize=(15,15))
        plt.imshow(image)
        plt.title(random_image_path)
        plt.axis('off')
        plt.show()

    else:
        print(f'No images found in  {train_dir}')

else:
    print('Directory not found')

In [None]:
#look at the label masks
LOACLPATHMASKS = Path('data/labels/labels/areas/')
KAGGLEPATHMASKS = Path('/kaggle/input/fs2020-runway-dataset/labels/labels/areas/')
random.seed(42)

if os.path.exists(KAGGLEPATHMASKS):
    target_dir_masks = KAGGLEPATHMASKS
else:
    target_dir_masks = LOACLPATHMASKS

if target_dir_masks.exists():
    train_masks_dir = target_dir_masks/'train_labels_1920x1080/'
    test_masks_dir = target_dir_masks/'test_labels_1920x1080/'

    image_mask_list = list(train_masks_dir.glob('*.png'))
    if len(image_mask_list)>0:
        random_image_mask_path = random.choice(image_mask_list)
        image_mask = Image.open(random_image_mask_path)

        plt.figure(figsize=(15,15))
        plt.imshow(image_mask)
        plt.title(random_image_mask_path)
        plt.axis('off')
        plt.show()

    else:
        print(f'No images found in  {train_dir_labels}')
else:
    print('Directory not found')

In [None]:
import torch

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
#set device agnostics
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
import os
import cv2
import numpy as np
mask_dir = r"/kaggle/input/fs2020-runway-dataset/labels/labels/areas/train_labels_1920x1080"
test_mask_dir = r"/kaggle/input/fs2020-runway-dataset/labels/labels/areas/test_labels_1920x1080"
binary_mask_dir = r"/kaggle/working/binary_mask"
binary_test_mask_dir = r"/kaggle/working/binary_test_mask"

def convert2Binary(mask_dir, binary_mask_dir):
    os.makedirs(binary_mask_dir, exist_ok = True)
    for filename in os.listdir(mask_dir):
        if not filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            continue
        img_path = os.path.join(mask_dir, filename)
        img = cv2.imread(img_path)
    
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    
        binary = np.where(gray > 0, 255, 0).astype(np.uint8)
    
        cv2.imwrite(os.path.join(binary_mask_dir, filename), binary)

convert2Binary(test_mask_dir, binary_test_mask_dir)


In [None]:
convert2Binary(mask_dir, binary_mask_dir)

In [None]:
#TRransforms images as well masks
from torchvision import transforms

image_transform = transforms.Compose([
    transforms.Resize((288,512)),
    transforms.ColorJitter(brightness=[max(0, 1 - 0.2), 1 + 0.2],
                           contrast=[max(0, 1 - 0.2), 1 + 0.2],
                           saturation=[max(0, 1 - 0.2), 1 + 0.2],
                           hue=0.02),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),

])

mask_transform = transforms.Compose([
    transforms.Resize((288,512),interpolation=transforms.InterpolationMode.NEAREST),
    #transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
])

In [None]:
#View Transformed images
random.seed(42)

random_image_path = random.sample(image_path_feature_list,3)

for image_path in random_image_path:
    with Image.open(image_path) as img:
        fig,ax = plt.subplots(1,2, figsize=(15,15))

        ax[0].imshow(img)
        ax[0].set_title('Original Image')
        ax[0].axis('off')

        transformed_img = image_transform(img).permute(1, 2, 0)

        ax[1].imshow(transformed_img)
        ax[1].set_title('Transformed Image')
        ax[1].axis('off')


In [None]:
#View Transformed labels
random.seed(42)

random_mask_path = random.sample(image_mask_list,3)

for image_path in random_mask_path:
    with Image.open(image_path) as img:
        fig,ax = plt.subplots(1,2, figsize=(15,15))

        ax[0].imshow(img)
        ax[0].set_title('Original Image')
        ax[0].axis('off')

        transformed_img = mask_transform(img).permute(1, 2, 0)

        ax[1].imshow(transformed_img)
        ax[1].set_title('Transformed Image')
        ax[1].axis('off')


In [None]:
import os
from pathlib import Path
from torch.utils.data import Dataset
from PIL import Image

class RunwayDataset(Dataset):
    def __init__(self, train_dir,mask_dir, image_transform=None, mask_transform=None):
        self.train_dir = Path(train_dir)
        self.image_transform = image_transform
        self.mask_dir = Path(mask_dir)
        self.mask_transform = mask_transform


        self.image_paths = sorted(list(self.train_dir.glob("*.jpg")) + list(self.train_dir.glob("*.png")))


        self.mask_paths = sorted(list(self.mask_dir.glob("*.png")))


        if len(self.image_paths) != len(self.mask_paths):
            print(f"⚠️ Warning: Found {len(self.image_paths)} images but {len(self.mask_paths)} masks!")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):

        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        image = Image.open(img_path)

        mask = Image.open(mask_path).convert("L")


        if self.image_transform:
            image = self.image_transform(image)

        if self.mask_transform:
            mask = self.mask_transform(mask)


        return image, mask

In [None]:
train_dataset = RunwayDataset(
    train_dir = train_dir,
    mask_dir = binary_mask_dir,
    image_transform = image_transform,
    mask_transform= mask_transform
)

test_dataset = RunwayDataset(
    train_dir =test_dir,
    mask_dir = binary_test_mask_dir,
    image_transform = image_transform,
    mask_transform= mask_transform
)
train_dataset.__getitem__(1)

In [None]:
#Load Data
from torch.utils.data import DataLoader

train_data_load = DataLoader(
    train_dataset,
    batch_size=16,
    num_workers=4,
    shuffle = True,
    pin_memory=True,
    drop_last=True
)
test_data_load = DataLoader(
    test_dataset,
    batch_size=16,
    num_workers=4,
    shuffle = True
)

In [None]:
from torch import nn
#DoubleConvLayer
class DoubleConv2d(nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

        )
    def forward(self,x):
        return self.double_conv(x)


In [None]:
#Downgrade resolution
class Down(nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()
        self.max_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            DoubleConv2d(in_channels, out_channels)
        )
    def forward(self,x):
        return self.max_pool(x)

In [None]:
#Upsacle image and merge with results of same level downscaled images
from torch.nn import functional as F
class Up(nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
        self.conv = DoubleConv2d(in_channels, out_channels)

    def forward(self,x1,x2):
        x1 = self.up(x1)

        diffY = x2.size()[2]  - x1.size()[2]
        diffX = x2.size()[3]  - x1.size()[3]

        x1 = F.pad(x1,[diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat((x1,x2),dim=1)
        return self.conv(x)


In [None]:
#feature mapping
class Out(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(Out,self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self,x):
        return self.conv(x)

In [None]:
#Main model class --brain
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes


        self.inc = DoubleConv2d(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)


        self.down4 = Down(512, 1024)


        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)


        self.outc = Out(64, n_classes)

    def forward(self, x):

        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        # 2. Going Up (Concatenate with saved outputs)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        # 3. Final Prediction
        logits = self.outc(x)
        return logits

In [None]:
#test
modelu  = UNet(n_channels=3, n_classes=1).to(device)
modelu

In [None]:
# Add this code to calculate the real class imbalance
total_pixels = 0
runway_pixels = 0

for _, mask in train_data_load:
    total_pixels += mask.numel()
    runway_pixels += mask.sum().item()

background_pixels = total_pixels - runway_pixels
weight = background_pixels / runway_pixels

print(f"Calculated weight: {weight:.2f}")
pos_weight = torch.tensor([weight]).to(device)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

In [None]:
# class DiceLoss(nn.Module):
#     def __init__(self):
#         super().__init__()
    
#     def forward(self, pred, target):
#         pred = torch.sigmoid(pred)
#         smooth = 1.0
        
#         pred_flat = pred.view(-1)
#         target_flat = target.view(-1)
        
#         intersection = (pred_flat * target_flat).sum()
        
#         return 1 - ((2. * intersection + smooth) /
#                     (pred_flat.sum() + target_flat.sum() + smooth))

# Use it:
#loss_fn = DiceLoss()

In [None]:
#loss and optimizer function

optim = torch.optim.Adam(modelu.parameters(), lr=3e-5)

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-8):
        super().__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        preds = torch.sigmoid(preds)
        intersection = (preds * targets).sum()
        dice = 1 - (2 * intersection + self.smooth) / (preds.sum() + targets.sum() + self.smooth)
        return dice.mean()
        
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("SAVING CHECK POINT......")
    torch.save(state, filename)

def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
    os.makedirs(folder, exist_ok=True)
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device)
        with torch.no_grad():
            preds = (torch.sigmoid(model(x)) > 0.5).float()
        torchvision.utils.save_image(preds, f"{folder}/pred_{idx}.png")
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}/mask_{idx}.png")
    model.train()

def load_checkpoint(checkpoint, model):
    print("......LOADING CHECKPOINT")
    model.load_state_dict(checkpoint["state_dict"])

def save_sample_predictions(loader, model, folder="saved_images/", device="cuda", num_images_to_save=10, threshold=0.5):
    os.makedirs(folder, exist_ok=True)
    model.eval()

    x, y = next(iter(loader))
    x = x.to(device)
    y = y.to(device)

    with torch.no_grad():
        preds = (torch.sigmoid(model(x)) > threshold).float()
    
    num_to_save = min(num_images_to_save, x.shape[0])
    print(f"Saving {num_to_save} prediction comparisons using threshold={threshold:.2f}")
    
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(device)

def check_accuracy(loader, model, device="cuda"):
    dice = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            intersection = (preds * y).sum()
            dice += (2*intersection + 1e-8) / (preds.sum() + y.sum() + 1e-8)
    print("DICE SCORE:", dice/len(loader))
    model.train()

In [None]:
EPOCHS  = 100
from tqdm import tqdm
from torch.amp import GradScaler,autocast

scaler = GradScaler('cuda')

for epoch in range(EPOCHS):
    modelu.train()
    running_loss = 0.0

    loop = tqdm(train_data_load, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
    for input_img,mask in loop:
        mask = mask.to(device)
        input_img = input_img.to(device)

        mask = mask.squeeze(1).float()

        optim.zero_grad()

        with autocast('cuda'):
            predic = modelu(input_img)
            loss = loss_fn(predic.squeeze(1),mask)

        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()

        running_loss += loss.item()

   

    test_loss = 0.0

    modelu.eval()
    with torch.no_grad():
        for test_img,test_mask in test_data_load:
            test_mask = test_mask.to(device)
            test_img = test_img.to(device)

            test_mask = test_mask.squeeze(1).float()

            predic_test = modelu(test_img)

            tloss = loss_fn(predic_test.squeeze(1),test_mask)

            test_loss += tloss.item()

        avg_test_loss = test_loss/len(test_data_load)

        if epoch%10 == 0:
            print(f" | Train Loss: {running_loss/len(train_data_load):.4f} | Test Loss: {avg_test_loss:.4f} |")
            
        if ((epoch+1)%10 == 0) or ((epoch+1) == EPOCHS):
            print(f"\n--- Running Validation for Epoch {epoch+1} ---\n")
            
            check_accuracy(test_data_load, modelu, device)
            
            save_sample_predictions(test_data_load, modelu, folder="saved_images", device=device, threshold = 0.65)
            print("\n---------PREDICTIONS SAVED---------\n")
        if (epoch == 24):
            checkpoint = {
                "state_dict" : modelu.state_dict(),
                "optimizer" : optim.state_dict(),
            }
            save_checkpoint(checkpoint)


In [None]:
# Show ORIGINAL RGB image with predictions
from PIL import Image

modelu.eval()
with torch.no_grad():
    # Get one sample
    test_img, test_mask = next(iter(test_data_load))
    test_img = test_img.to(device)
    test_mask = test_mask.to(device)
    
    # Get prediction
    pred = modelu(test_img)
    pred = torch.sigmoid(pred[10].squeeze()).cpu()
    
    # Get ground truth mask
    mask = test_mask[10].cpu().squeeze()
    
    # Load and resize original RGB image
    img_path = test_dataset.image_paths[10]
    original_img = Image.open(img_path)
    original_img = original_img.resize((512, 288))  # Resize to match (width, height)
    
    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(original_img)
    axes[0].set_title('Original RGB Image')
    axes[0].axis('off')
    
    axes[1].imshow(mask, cmap='gray')
    axes[1].set_title('Ground Truth Mask')
    axes[1].axis('off')
    
    axes[2].imshow(pred, cmap='gray')
    axes[2].set_title('Predicted Mask')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Add this cell and run it BEFORE training:
import torch
import gc

# Clear all GPU memory
torch.cuda.empty_cache()
gc.collect()

# If model already exists, delete it
try:
    del modelu
    del optim
    del scaler
    torch.cuda.empty_cache()
except:
    pass

print(f"GPU Memory allocated: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB")
print(f"GPU Memory cached: {torch.cuda.memory_reserved(0)/1024**3:.2f} GB")