In [None]:
!pip install bpemb timm

Collecting bpemb
  Downloading bpemb-0.3.6-py3-none-any.whl.metadata (19 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->timm)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->timm)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->timm)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl

In [None]:
!unzip /content/drive/MyDrive/Pneumothorax.zip -d /content/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/Pneumothorax/masks/5912_train_0_.png  
  inflating: /content/Pneumothorax/masks/5913_train_0_.png  
  inflating: /content/Pneumothorax/masks/5914_train_0_.png  
  inflating: /content/Pneumothorax/masks/5915_train_0_.png  
  inflating: /content/Pneumothorax/masks/5916_train_0_.png  
  inflating: /content/Pneumothorax/masks/5917_train_0_.png  
  inflating: /content/Pneumothorax/masks/5918_train_0_.png  
  inflating: /content/Pneumothorax/masks/5919_train_0_.png  
  inflating: /content/Pneumothorax/masks/591_test_0_.png  
  inflating: /content/Pneumothorax/masks/591_train_1_.png  
  inflating: /content/Pneumothorax/masks/5920_train_0_.png  
  inflating: /content/Pneumothorax/masks/5921_train_1_.png  
  inflating: /content/Pneumothorax/masks/5922_train_0_.png  
  inflating: /content/Pneumothorax/masks/5923_train_0_.png  
  inflating: /content/Pneumothorax/masks/5924_train_0_.png  
  inflating: /content/P

In [None]:
import os
import random
import time
import datetime
import numpy as np
import albumentations as A
import cv2
from glob import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from text2embed import Text2Embed
from utils import seeding, create_dir, print_and_save, shuffling, epoch_time, calculate_metrics, mask_to_bbox
from metrics import DiceLoss, DiceBCELoss, MultiClassBCE, FocalLoss
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import timm
import torch.nn.functional as F
import warnings
import torch.cuda.amp as amp
import torch.nn.init as init

class conv2d(nn.Module):
    def __init__(self, in_c, out_c, kernel_size=3, padding=1, dilation=1, act=True):
        super().__init__()
        self.act = act

        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.act == True:
            x = self.relu(x)
        return x

class channel_attention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(channel_attention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x0 = x
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return x0 * self.sigmoid(out)


class spatial_attention(nn.Module):
    def __init__(self, kernel_size=7):
        super(spatial_attention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x0 = x
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return x0 * self.sigmoid(x)

class dilated_conv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)

        self.c1 = nn.Sequential(conv2d(in_c, out_c, kernel_size=1, padding=0), channel_attention(out_c))
        self.c2 = nn.Sequential(conv2d(in_c, out_c, kernel_size=(3, 3), padding=6, dilation=6), channel_attention(out_c))
        self.c3 = nn.Sequential(conv2d(in_c, out_c, kernel_size=(3, 3), padding=12, dilation=12), channel_attention(out_c))
        self.c4 = nn.Sequential(conv2d(in_c, out_c, kernel_size=(3, 3), padding=18, dilation=18), channel_attention(out_c))
        self.c5 = conv2d(out_c*4, out_c, kernel_size=3, padding=1, act=False)
        self.c6 = conv2d(in_c, out_c, kernel_size=1, padding=0, act=False)
        self.sa = spatial_attention()

    def forward(self, x):
        x1 = self.c1(x)
        x2 = self.c2(x)
        x3 = self.c3(x)
        x4 = self.c4(x)
        xc = torch.cat([x1, x2, x3, x4], axis=1)
        xc = self.c5(xc)
        xs = self.c6(x)
        x = self.relu(xc+xs)
        x = self.sa(x)
        return x

class label_attention(nn.Module):
    def __init__(self, in_c):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)

        """ Channel Attention """
        self.c1 = nn.Sequential(
            nn.Conv2d(in_c[1], in_c[0], kernel_size=1, padding=0, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_c[0], in_c[0], kernel_size=1, padding=0, bias=False)
        )

    def forward(self, feats, label):
        """ Channel Attention """
        b, c = label.shape
        label = label.reshape(b, c, 1, 1)
        ch_attn = self.c1(label)
        ch_map = torch.sigmoid(ch_attn)
        feats = feats * ch_map

        ch_attn = ch_attn.reshape(ch_attn.shape[0], ch_attn.shape[1])
        return ch_attn, feats

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c, scale=2):
        super().__init__()
        self.scale = scale
        self.relu = nn.ReLU(inplace=True)

        self.up = nn.Upsample(scale_factor=scale, mode="bilinear", align_corners=True)
        self.c1 = conv2d(in_c + out_c, out_c, kernel_size=1, padding=0)  # Adjust input channels
        self.c2 = conv2d(out_c, out_c, act=False)
        self.c3 = conv2d(out_c, out_c, act=False)
        self.c4 = conv2d(out_c, out_c, kernel_size=1, padding=0, act=False)
        self.ca = channel_attention(out_c)
        self.sa = spatial_attention()

    def forward(self, x, skip):
        x = self.up(x)

        # Ensure spatial dimensions match
        if x.shape[2:] != skip.shape[2:]:
            x = F.interpolate(x, size=skip.shape[2:], mode="bilinear", align_corners=True)

        x = torch.cat([x, skip], dim=1)  # Concatenate along the channel dimension
        x = self.c1(x)

        s1 = x
        x = self.c2(x)
        x = self.relu(x + s1)

        s2 = x
        x = self.c3(x)
        x = self.relu(x + s2 + s1)

        s3 = x
        x = self.c4(x)
        x = self.relu(x + s3 + s2 + s1)

        x = self.ca(x)
        x = self.sa(x)
        return x

class output_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.c1 = nn.Conv2d(in_c, out_c, kernel_size=1, padding=0)

    def forward(self, x):
        x = self.up(x)
        x = self.c1(x)
        return x

class text_classifier(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Sequential(
            nn.Linear(in_c, in_c//8, bias=False), nn.ReLU(),
            nn.Linear(in_c//8, out_c[0], bias=False)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(in_c, in_c//8, bias=False), nn.ReLU(),
            nn.Linear(in_c//8, out_c[1], bias=False)
        )

    def forward(self, feats):
        pool = self.avg_pool(feats).view(feats.shape[0], feats.shape[1])
        num_lesions = self.fc1(pool)
        lesion_sizes = self.fc2(pool)
        # print(f"num_lesions shape: {num_lesions.shape}")
        # print(f"lesion_sizes shape: {lesion_sizes.shape}")
        return num_lesions, lesion_sizes

class embedding_feature_fusion(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.fc = nn.Sequential(
            nn.Conv2d((in_c[0]+in_c[1])*in_c[2], out_c, 1, bias=False), nn.ReLU(),
            nn.Conv2d(out_c, out_c, 1, bias=False), nn.ReLU()
        )

    def forward(self, num_lesions, lesion_sizes, label):
        # print(f"num_lesions shape: {num_lesions.shape}")
        # print(f"lesion_sizes shape: {lesion_sizes.shape}")
        num_lesions_prob = torch.softmax(num_lesions, axis=1)
        lesion_sizes_prob = torch.softmax(lesion_sizes, axis=1)
        # print(f"num_lesions_classes", num_lesions_prob.shape)
        # print(f"lesions_size_classes", lesion_sizes_prob.shape)
        prob = torch.cat([num_lesions_prob, lesion_sizes_prob], axis=1)
        # print(f"prob shape: {prob.shape}")
        prob = prob.view(prob.shape[0], prob.shape[1], 1)
        if label.shape[1] != prob.shape[1]:
          raise ValueError(
              f"Shape mismatch: label channels ({label.shape[1]}) != prob channels ({prob.shape[1]})"
          )
        # print(f"x shape before fc: {x.shape}")
        x = label * prob
        x = x.view(x.shape[0], -1, 1, 1)
        if x.shape[1] != 1800:
            # x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))  # Adjust spatial dimensions
            device = x.device  # Get the device of the input tensor
            conv_layer = torch.nn.Conv2d(x.shape[1], 1800, kernel_size=1).to(device)
            x = conv_layer(x)

        # print(f"x shape after fc: {x.shape}")
        x = self.fc(x)
        x = x.view(x.shape[0], -1)
        return x

class multiscale_feature_aggregation(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)

        self.up_2x2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.up_4x4 = nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True)
        self.up_8x8 = nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)  # Add upsampling for layer4 features

        self.c11 = conv2d(in_c[0], out_c, kernel_size=1, padding=0)
        self.c12 = conv2d(in_c[1], out_c, kernel_size=1, padding=0)
        self.c13 = conv2d(in_c[2], out_c, kernel_size=1, padding=0)
        self.c14 = conv2d(in_c[3], out_c, kernel_size=1, padding=0)  # Add for layer4 features
        self.c15 = conv2d(out_c * 4, out_c, kernel_size=1, padding=0)  # Update to handle 4 inputs

        self.c2 = conv2d(out_c, out_c, act=False)
        self.c3 = conv2d(out_c, out_c, act=False)

    def forward(self, x1, x2, x3, x4):  # Add x4 for layer4 features
        # Upsample all feature maps to the same spatial resolution (128x128)
        x1 = self.up_8x8(x1)  # Upsample x1 (from layer4) by 8x
        x2 = self.up_4x4(x2)  # Upsample x2 (from layer3) by 4x
        x3 = self.up_2x2(x3)  # Upsample x3 (from layer2) by 2x
        x4 = F.interpolate(x4, size=(128, 128), mode="bilinear", align_corners=True)  # Upsample x4 (from layer1) to 128x128

        x1 = self.c11(x1)
        x2 = self.c12(x2)
        x3 = self.c13(x3)
        x4 = self.c14(x4)  # Process x4

        x = torch.cat([x1, x2, x3, x4], dim=1)  # Concatenate all 4 inputs
        x = self.c15(x)

        s1 = x
        x = self.c2(x)
        x = self.relu(x + s1)

        s2 = x
        x = self.c3(x)
        x = self.relu(x + s2 + s1)

        return x

class ProgressiveDenoisingAttention1(nn.Module):
    def __init__(self, channels, num_iterations=3):
        super(ProgressiveDenoisingAttention1, self).__init__()
        self.num_iterations = num_iterations
        self.channels = channels

        # Convolution layers for refining attention
        self.conv_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(channels),
                nn.ReLU(inplace=True)
            ) for _ in range(num_iterations)
        ])

        # Final attention map generation
        self.final_conv = nn.Conv2d(channels, 1, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Initial attention map
        attention_map = x

        # Progressive denoising
        for conv in self.conv_layers:
            attention_map = conv(attention_map)

        # Final attention map
        attention_map = self.final_conv(attention_map)
        attention_map = self.sigmoid(attention_map)

        # Apply attention to input features
        return x * attention_map

class UNetDenoise(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(64, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

class TGAlesionSeg(nn.Module):
    def __init__(self):
        super().__init__()

        """ Backbone: SE-ResNeXt50 """
        backbone = timm.create_model('seresnext50_32x4d', pretrained=True)

        # Extract the layers from the backbone
        self.layer0 = nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.act1,
            backbone.maxpool
        )
        self.layer1 = backbone.layer1
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4  # Add layer4

        self.text_classifier = text_classifier(2048, [3, 3])  # Update input channels to 2048 for layer4
        self.label_fc = embedding_feature_fusion([3, 3, 300], 128)

        """ Dilated Conv """
        self.s1 = dilated_conv(64, 128)  # Input channels: 64 (from layer0)
        self.s2 = dilated_conv(256, 128)  # Input channels: 256 (from layer1)
        self.s3 = dilated_conv(512, 128)  # Input channels: 512 (from layer2)
        self.s4 = dilated_conv(1024, 128)  # Input channels: 1024 (from layer3)
        self.s5 = dilated_conv(2048, 128)  # Input channels: 2048 (from layer4)

        """ Progressive Denoising Attention """
        self.pda1 = ProgressiveDenoisingAttention1(128, num_iterations=3)
        self.pda2 = ProgressiveDenoisingAttention1(128, num_iterations=3)
        self.pda3 = ProgressiveDenoisingAttention1(128, num_iterations=3)
        self.pda4 = ProgressiveDenoisingAttention1(128, num_iterations=3)
        self.pda5 = ProgressiveDenoisingAttention1(128, num_iterations=3)  # Add PDA for layer4 features

        """ U-Net Denoising Modules """
        self.denoise1 = UNetDenoise(in_channels=128, out_channels=128)
        self.denoise2 = UNetDenoise(in_channels=128, out_channels=128)
        self.denoise3 = UNetDenoise(in_channels=128, out_channels=128)
        self.denoise4 = UNetDenoise(in_channels=128, out_channels=128)  # Add denoising for layer4 features

        """ Decoder """
        self.d1 = decoder_block(128, 128, scale=2)
        self.a1 = label_attention([128, 128])

        self.d2 = decoder_block(128, 128, scale=2)
        self.a2 = label_attention([128, 128])

        self.d3 = decoder_block(128, 128, scale=2)
        self.a3 = label_attention([128, 128])

        self.d4 = decoder_block(128, 128, scale=2)  # Add decoder block for layer4 features
        self.a4 = label_attention([128, 128])

        self.ag = multiscale_feature_aggregation([128, 128, 128, 128], 128)  # Update to include layer4 features

        self.y1 = output_block(128, 1)

    def forward(self, image, label):
        """ Backbone: SE-ResNeXt50 """
        x0 = image
        x1 = self.layer0(x0)    ## [-1, 64, h/2, w/2]
        x2 = self.layer1(x1)    ## [-1, 256, h/4, w/4]
        x3 = self.layer2(x2)    ## [-1, 512, h/8, w/8]
        x4 = self.layer3(x3)    ## [-1, 1024, h/16, w/16]
        x5 = self.layer4(x4)    ## [-1, 2048, h/32, w/32]  # Add layer4 output

        num_lesions, lesion_sizes = self.text_classifier(x5)  # Use x5 (layer4 features)
        f0 = self.label_fc(num_lesions, lesion_sizes, label)

        """ Dilated Conv + PDA """
        s1 = self.pda1(self.s1(x1))
        s2 = self.pda2(self.s2(x2))
        s3 = self.pda3(self.s3(x3))
        s4 = self.pda4(self.s4(x4))
        s5 = self.pda5(self.s5(x5))  # Add PDA for layer4 features

        """ Decoder + Denoising """
        # Stage 1
        d1 = self.d1(s5, s4)  # Use s5 (layer4 features) as input to the first decoder block
        f1, a1 = self.a1(d1, f0)
        a1_denoised = self.denoise1(a1)

        # Stage 2
        d2 = self.d2(a1_denoised, s3)
        f = f0 + f1
        f2, a2 = self.a2(d2, f)
        a2_denoised = self.denoise2(a2)

        # Stage 3
        d3 = self.d3(a2_denoised, s2)
        f = f0 + f1 + f2
        f3, a3 = self.a3(d3, f)
        a3_denoised = self.denoise3(a3)

        # Stage 4
        d4 = self.d4(a3_denoised, s1)
        f = f0 + f1 + f2 + f3
        f4, a4 = self.a4(d4, f)
        a4_denoised = self.denoise4(a4)

        ag = self.ag(a1_denoised, a2_denoised, a3_denoised, a4_denoised)  # Include layer4 features
        y1 = self.y1(ag)

        return y1, num_lesions, lesion_sizes
def prepare_input(res):
    x1 = torch.FloatTensor(1, 3, 256, 256).cuda()
    x2 = torch.FloatTensor(1, 5, 300).cuda()
    return dict(x = [x1, x2])

### --------------------------------------------------------------------------------------------------------------------------------------------------------

def load_names(path, file_path):
    f = open(file_path, "r")
    data = f.read().split("\n")[:-1]
    images = [os.path.join(path,"images", name) + ".png" for name in data]
    masks = [os.path.join(path,"masks", name) + ".png" for name in data]
    return images, masks

def label_dictionary():
    label_dict = {}
    label_dict["lesion"] = ["zero", "one", "multiple", "small", "medium", "large"]
    return label_dict

def load_data(path):
    train_names_path = f"{path}/train.txt"
    valid_names_path = f"{path}/val.txt"

    train_x, train_y = load_names(path, train_names_path)
    valid_x, valid_y = load_names(path, valid_names_path)

    label_dict = label_dictionary()
    print(label_dict)
    train_label = len(train_x) * [label_dict["lesion"]]
    valid_label = len(valid_x) * [label_dict["lesion"]]

    return (train_x, train_y, train_label), (valid_x, valid_y, valid_label)

class DATASET(Dataset):
    def __init__(self, images_path, labels_path, masks_path, size, transform=None):
        super().__init__()
        self.images_path = images_path
        self.labels_path = labels_path
        self.masks_path = masks_path
        self.size = size  # Ensure size is stored as an instance attribute
        self.transform = transform
        self.n_samples = len(images_path)

        self.embed = Text2Embed()

    def visualize_bbox_on_mask(self, image, mask, bboxes, title="Mask with BBoxes"):
        """Visualize the mask and bounding boxes overlaid on the mask."""
        if len(image.shape) != 3 or image.shape[2] != 3:
            raise ValueError(f"Expected a color image with 3 channels, but got shape {image.shape}")

        fig, ax = plt.subplots(1, 2, figsize=(12, 6))

        # Show original image
        ax[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        ax[0].set_title("Original Image")
        ax[0].axis("off")

        # Show mask with bounding boxes
        mask_with_bbox = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
        for bbox in bboxes:
            x1, y1, x2, y2 = bbox
            mask_with_bbox = cv2.rectangle(mask_with_bbox, (x1, y1), (x2, y2), (255, 0, 0), 2)

        ax[1].imshow(cv2.cvtColor(mask_with_bbox, cv2.COLOR_BGR2RGB))
        ax[1].set_title(title)
        ax[1].axis("off")

        plt.tight_layout()
        plt.show()


    def mask_to_text(self, mask, image=None):
        bboxes = mask_to_bbox(mask)
        lesion_sizes = 0
        num_lesions = 0

        # Calculate sizes and number of lesions
        for bbox in bboxes:
            x1, y1, x2, y2 = bbox
            h = (y2 - y1)
            w = (x2 - x1)
            area = (h * w) / (mask.shape[0] * mask.shape[1])

            if area < 0.007:
                lesion_sizes = 0
            elif 0.007 <= area < 0.04:
                lesion_sizes = 1
            elif area >= 0.04:
                lesion_sizes = 2

        if(len(bboxes) == 0):
          num_lesions = 0
        elif(len(bboxes) == 1):
          num_lesions = 1
        elif(len(bboxes) >= 2):
          num_lesions = 2

        # visualization
        if image is None:
            self.visualize_bbox_on_mask(image, mask, bboxes, title="Mask with Bounding Boxes")

        return np.array(num_lesions), np.array(lesion_sizes)


    def __getitem__(self, index):
      image = cv2.imread(self.images_path[index], cv2.IMREAD_COLOR)
      mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)

      if image is None or mask is None:
          raise FileNotFoundError(f"Could not read image or mask at index {index}. Check the file paths.")

      if self.transform is not None:
          augmentations = self.transform(image=image, mask=mask)
          image = augmentations["image"]
          mask = augmentations["mask"]

      image = cv2.resize(image, self.size)
      if len(image.shape) != 3 or image.shape[2] != 3:
          raise ValueError(f"Expected a color image with 3 channels after resizing, but got shape {image.shape}")

      image = np.transpose(image, (2, 0, 1))
      image = image / 255.0

      mask = cv2.resize(mask, self.size)
      mask_copy = mask
      mask = np.expand_dims(mask, axis=0)
      mask = mask / 255.0

      visual_image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)
      num_lesions, lesion_sizes = self.mask_to_text(mask_copy, visual_image)

      label = []
      words = self.labels_path[index]
      for word in words:
          word_embed = self.embed.to_embed(word)[0]
          label.append(word_embed)
      label = np.array(label)
      return (image, label), (mask, num_lesions, lesion_sizes)

    def __len__(self):
        return self.n_samples


def train(model, loader, optimizer, loss_fn, device):
    model.train()
    epoch_loss = 0.0
    epoch_jac = 0.0
    epoch_f1 = 0.0
    epoch_recall = 0.0
    epoch_precision = 0.0
    # scaler = amp.GradScaler()  # Define scaler before training

    for i, ((x, l), (y1, y2, y3)) in enumerate(loader):
        x = x.to(device, dtype=torch.float32)
        l = l.to(device, dtype=torch.float32)
        y1 = y1.to(device, dtype=torch.float32)
        y2 = y2.to(device, dtype=torch.long)
        y3 = y3.to(device, dtype=torch.long)


        optimizer.zero_grad()

        p1, p2, p3 = model(x, l)
        p2 = torch.softmax(p2, dim=1)
        p3 = torch.softmax(p3, dim=1)

        loss1 = loss_fn[0](p1, y1)
        loss2 = loss_fn[1](p2, y2)
        loss3 = loss_fn[2](p3, y3)

        loss = loss1 + loss2 + loss3

        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        batch_jac = []
        batch_f1 = []
        batch_recall = []
        batch_precision = []

        for yt, yp in zip(y1, p1):
            score = calculate_metrics(yt, yp)
            batch_jac.append(score[0])
            batch_f1.append(score[1])
            batch_recall.append(score[2])
            batch_precision.append(score[3])

        epoch_jac += np.mean(batch_jac)
        epoch_f1 += np.mean(batch_f1)
        epoch_recall += np.mean(batch_recall)
        epoch_precision += np.mean(batch_precision)

    epoch_loss = epoch_loss / len(loader)
    epoch_jac = epoch_jac / len(loader)
    epoch_f1 = epoch_f1 / len(loader)
    epoch_recall = epoch_recall / len(loader)
    epoch_precision = epoch_precision / len(loader)

    return epoch_loss, [epoch_jac, epoch_f1, epoch_recall, epoch_precision]

def evaluate(model, loader, loss_fn, device):
    model.eval()

    epoch_loss = 0.0
    epoch_jac = 0.0
    epoch_f1 = 0.0
    epoch_recall = 0.0
    epoch_precision = 0.0

    with torch.no_grad():
        for i, ((x, l), (y1, y2, y3)) in enumerate(loader):
            x = x.to(device, dtype=torch.float32)
            l = l.to(device, dtype=torch.float32)
            y1 = y1.to(device, dtype=torch.float32)
            y2 = y2.to(device, dtype=torch.long)
            y3 = y3.to(device, dtype=torch.long)

            p1, p2, p3 = model(x, l)
            p2 = torch.softmax(p2, dim=1)
            p3 = torch.softmax(p3, dim=1)

            loss1 = loss_fn[0](p1, y1)
            loss2 = loss_fn[1](p2, y2)
            loss3 = loss_fn[2](p3, y3)

            loss = loss1 + loss2 + loss3

            epoch_loss += loss.item()

            batch_jac = []
            batch_f1 = []
            batch_recall = []
            batch_precision = []

            for yt, yp in zip(y1, p1):
                score = calculate_metrics(yt, yp)
                batch_jac.append(score[0])
                batch_f1.append(score[1])
                batch_recall.append(score[2])
                batch_precision.append(score[3])

            epoch_jac += np.mean(batch_jac)
            epoch_f1 += np.mean(batch_f1)
            epoch_recall += np.mean(batch_recall)
            epoch_precision += np.mean(batch_precision)

    epoch_loss = epoch_loss / len(loader)
    epoch_jac = epoch_jac / len(loader)
    epoch_f1 = epoch_f1 / len(loader)
    epoch_recall = epoch_recall / len(loader)
    epoch_precision = epoch_precision / len(loader)

    return epoch_loss, [epoch_jac, epoch_f1, epoch_recall, epoch_precision]

if __name__ == "__main__":
    """ Seeding """
    seeding(42)

    """ Directories """
    create_dir("files")

    """ Training logfile """
    train_log_path = "files/train_log.txt"
    if os.path.exists(train_log_path):
        print("Log file exists")
    else:
        train_log = open("files/train_log.txt", "w")
        train_log.write("\n")
        train_log.close()

    """ Record Date & Time """
    datetime_object = str(datetime.datetime.now())
    print_and_save(train_log_path, datetime_object)
    print("")

    """ Hyperparameters """
    image_size = 256
    size = (image_size, image_size)
    batch_size = 16
    num_epochs = 70
    lr = 1e-5
    early_stopping_patience = 70
    checkpoint_path = "/content/checkpoint1 (15) (2).pth"
    path = "/content/Pneumothorax"

    data_str = f"Image Size: {size}\nBatch Size: {batch_size}\nLR: {lr}\nEpochs: {num_epochs}\n"
    data_str += f"Early Stopping Patience: {early_stopping_patience}\n"
    print_and_save(train_log_path, data_str)

    """ Data augmentation: Transforms """
    transform = A.Compose([
        A.Rotate(limit=15, p=0.5),  # Moderate rotations
        A.HorizontalFlip(p=0.5),    # Horizontal flipping
        A.VerticalFlip(p=0.2),      # Vertical flipping
        A.RandomBrightnessContrast(p=0.3),  # Adjust brightness and contrast
        A.CLAHE(p=0.3),  # Contrast Limited Adaptive Histogram Equalization
        A.ElasticTransform(alpha=1, sigma=50, p=0.3),  # Elastic transformations
        A.CoarseDropout(num_holes_range=(2, 5), hole_height_range=(5, 20), hole_width_range=(5, 20), fill="inpaint_ns", p=0.5)  # Coarse dropout
    ])

    """ Dataset """
    (train_x, train_y, train_label), (valid_x, valid_y, valid_label) = load_data(path)
    train_x, train_y, train_label = shuffling(train_x, train_y, train_label)
    data_str = f"Dataset Size:\nTrain: {len(train_x)} - Valid: {len(valid_x)}\n"
    print_and_save(train_log_path, data_str)

    """ Dataset and loader """
    train_dataset = DATASET(train_x, train_label, train_y, (image_size, image_size), transform=transform)
    valid_dataset = DATASET(valid_x, valid_label, valid_y, (image_size, image_size), transform=None)

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4
    )

    valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4
    )

    """ Model """
    device = torch.device('cuda')
    model = TGAlesionSeg()
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
    # loss_fn = [DiceBCELoss(), nn.CrossEntropyLoss(), nn.CrossEntropyLoss()]
    # Using Focal Loss for num_lesions & lesion_sizes
    loss_fn = [DiceBCELoss(), FocalLoss(alpha=0.25, gamma=2), FocalLoss(alpha=0.25, gamma=2)]

    loss_name = "BCE Dice Loss"
    data_str = f"Optimizer: AdamW\nLoss: {loss_name}\n"
    print_and_save(train_log_path, data_str)


    try:
        model.load_state_dict(torch.load(checkpoint_path))
        print(f"Checkpoint loaded successfully from {checkpoint_path}. Resuming training.")
    except Exception as e:
        print(f"Error loading checkpoint: {e}")

    """ Training the model """
    best_valid_metrics = 0.835  ## update the last metric you have accordingly
    early_stopping_count = 0

    # Logging
    data_str = f"Resuming training from checkpoint.\nOptimizer: Adam\nLoss: {loss_name}\n"
    print_and_save(train_log_path, data_str)

    # Resume training
    for epoch in range(num_epochs):  # Update resume_epoch appropriately
        start_time = time.time()

        train_loss, train_metrics = train(model, train_loader, optimizer, loss_fn, device)
        valid_loss, valid_metrics = evaluate(model, valid_loader, loss_fn, device)
        scheduler.step(valid_loss)

        if valid_metrics[1] > best_valid_metrics:
            data_str = f"Valid F1 improved from {best_valid_metrics:2.4f} to {valid_metrics[1]:2.4f}. Saving checkpoint: {checkpoint_path}"
            print_and_save(train_log_path, data_str)

            best_valid_metrics = valid_metrics[1]
            torch.save(model.state_dict(), checkpoint_path)
            early_stopping_count = 0

        elif valid_metrics[1] < best_valid_metrics:
            early_stopping_count += 1

        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        data_str = f"Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n"
        data_str += f"\tTrain Loss: {train_loss:.4f} - Jaccard: {train_metrics[0]:.4f} - F1: {train_metrics[1]:.4f} - Recall: {train_metrics[2]:.4f} - Precision: {train_metrics[3]:.4f}\n"
        data_str += f"\t Val. Loss: {valid_loss:.4f} - Jaccard: {valid_metrics[0]:.4f} - F1: {valid_metrics[1]:.4f} - Recall: {valid_metrics[2]:.4f} - Precision: {valid_metrics[3]:.4f}\n"
        print_and_save(train_log_path, data_str)

        if early_stopping_count == early_stopping_patience:
            data_str = f"Early stopping: validation loss stops improving from last {early_stopping_patience} continuously.\n"
            print_and_save(train_log_path, data_str)
            break


Log file exists
2025-03-06 23:02:06.232472

Image Size: (256, 256)
Batch Size: 16
LR: 1e-05
Epochs: 70
Early Stopping Patience: 70

{'polyp': ['zero', 'one', 'multiple', 'small', 'medium', 'large']}
Dataset Size:
Train: 9636 - Valid: 2409

Optimizer: AdamW
Loss: BCE Dice Loss



  model.load_state_dict(torch.load(checkpoint_path))


Checkpoint loaded successfully from /content/checkpoint1 (15) (2).pth. Resuming training.
Resuming training from checkpoint.
Optimizer: Adam
Loss: BCE Dice Loss

Epoch: 01 | Epoch Time: 8m 50s
	Train Loss: 0.4883 - Jaccard: 0.8447 - F1: 0.8664 - Recall: 0.8889 - Precision: 0.9324
	 Val. Loss: 0.6643 - Jaccard: 0.8131 - F1: 0.8317 - Recall: 0.8672 - Precision: 0.9194

Epoch: 02 | Epoch Time: 8m 51s
	Train Loss: 0.5006 - Jaccard: 0.8404 - F1: 0.8625 - Recall: 0.8890 - Precision: 0.9268
	 Val. Loss: 0.6414 - Jaccard: 0.8127 - F1: 0.8323 - Recall: 0.8754 - Precision: 0.9108

Epoch: 03 | Epoch Time: 8m 51s
	Train Loss: 0.4855 - Jaccard: 0.8445 - F1: 0.8669 - Recall: 0.8918 - Precision: 0.9279
	 Val. Loss: 0.6360 - Jaccard: 0.8099 - F1: 0.8298 - Recall: 0.8769 - Precision: 0.9037

Epoch: 04 | Epoch Time: 8m 51s
	Train Loss: 0.4897 - Jaccard: 0.8448 - F1: 0.8662 - Recall: 0.8871 - Precision: 0.9350
	 Val. Loss: 0.6498 - Jaccard: 0.8112 - F1: 0.8303 - Recall: 0.8741 - Precision: 0.9106

Epoch:

KeyboardInterrupt: 

In [None]:
# Log file exists
# 2025-03-04 18:49:08.021780

# Image Size: (256, 256)
# Batch Size: 16
# LR: 0.0001
# Epochs: 70
# Early Stopping Patience: 50

# {'lesion': ['zero', 'one', 'multiple', 'small', 'medium', 'large']}
# Dataset Size:
# Train: 9636 - Valid: 2409

# Optimizer: AdamW
# Loss: BCE Dice Loss

# Valid F1 improved from 0.0000 to 0.7648. Saving checkpoint: files/checkpoint1.pth
# Epoch: 01 | Epoch Time: 8m 50s
# 	Train Loss: 1.0457 - Jaccard: 0.6437 - F1: 0.6527 - Recall: 0.8131 - Precision: 0.7769
# 	 Val. Loss: 0.8658 - Jaccard: 0.7490 - F1: 0.7648 - Recall: 0.8488 - Precision: 0.8517

# Epoch: 02 | Epoch Time: 8m 48s
# 	Train Loss: 0.8626 - Jaccard: 0.7035 - F1: 0.7201 - Recall: 0.8435 - Precision: 0.8104
# 	 Val. Loss: 0.8967 - Jaccard: 0.6419 - F1: 0.6620 - Recall: 0.8943 - Precision: 0.6917

# Valid F1 improved from 0.7648 to 0.8106. Saving checkpoint: files/checkpoint1.pth
# Epoch: 03 | Epoch Time: 8m 49s
# 	Train Loss: 0.8022 - Jaccard: 0.7359 - F1: 0.7533 - Recall: 0.8493 - Precision: 0.8432
# 	 Val. Loss: 0.8014 - Jaccard: 0.7962 - F1: 0.8106 - Recall: 0.8408 - Precision: 0.9285

# Epoch: 04 | Epoch Time: 8m 48s
# 	Train Loss: 0.7560 - Jaccard: 0.7466 - F1: 0.7654 - Recall: 0.8551 - Precision: 0.8493
# 	 Val. Loss: 0.7356 - Jaccard: 0.7814 - F1: 0.7991 - Recall: 0.8614 - Precision: 0.8833

# Epoch: 05 | Epoch Time: 8m 49s
# 	Train Loss: 0.7527 - Jaccard: 0.7381 - F1: 0.7572 - Recall: 0.8558 - Precision: 0.8392
# 	 Val. Loss: 0.7579 - Jaccard: 0.7933 - F1: 0.8092 - Recall: 0.8484 - Precision: 0.9162

# Epoch: 06 | Epoch Time: 8m 49s
# 	Train Loss: 0.7210 - Jaccard: 0.7614 - F1: 0.7803 - Recall: 0.8575 - Precision: 0.8654
# 	 Val. Loss: 0.7591 - Jaccard: 0.6885 - F1: 0.7118 - Recall: 0.8953 - Precision: 0.7447

# Valid F1 improved from 0.8106 to 0.8183. Saving checkpoint: files/checkpoint1.pth
# Epoch: 07 | Epoch Time: 8m 49s
# 	Train Loss: 0.7110 - Jaccard: 0.7655 - F1: 0.7850 - Recall: 0.8612 - Precision: 0.8664
# 	 Val. Loss: 0.7249 - Jaccard: 0.8005 - F1: 0.8183 - Recall: 0.8634 - Precision: 0.9092

# Epoch: 08 | Epoch Time: 8m 49s
# 	Train Loss: 0.6799 - Jaccard: 0.7759 - F1: 0.7956 - Recall: 0.8641 - Precision: 0.8752
# 	 Val. Loss: 0.7066 - Jaccard: 0.7976 - F1: 0.8161 - Recall: 0.8668 - Precision: 0.9013

# Epoch: 09 | Epoch Time: 8m 48s
# 	Train Loss: 0.6758 - Jaccard: 0.7760 - F1: 0.7960 - Recall: 0.8651 - Precision: 0.8759
# 	 Val. Loss: 0.6908 - Jaccard: 0.7829 - F1: 0.8029 - Recall: 0.8730 - Precision: 0.8731

# Epoch: 10 | Epoch Time: 8m 49s
# 	Train Loss: 0.6796 - Jaccard: 0.7723 - F1: 0.7918 - Recall: 0.8624 - Precision: 0.8743
# 	 Val. Loss: 0.6963 - Jaccard: 0.7930 - F1: 0.8129 - Recall: 0.8746 - Precision: 0.8827

# Valid F1 improved from 0.8183 to 0.8188. Saving checkpoint: files/checkpoint1.pth
# Epoch: 11 | Epoch Time: 8m 49s
# 	Train Loss: 0.6454 - Jaccard: 0.7916 - F1: 0.8113 - Recall: 0.8659 - Precision: 0.8934
# 	 Val. Loss: 0.6828 - Jaccard: 0.7994 - F1: 0.8188 - Recall: 0.8757 - Precision: 0.8919

# Epoch: 12 | Epoch Time: 8m 48s
# 	Train Loss: 0.6580 - Jaccard: 0.7888 - F1: 0.8091 - Recall: 0.8681 - Precision: 0.8870
# 	 Val. Loss: 0.7341 - Jaccard: 0.7880 - F1: 0.8062 - Recall: 0.8611 - Precision: 0.8969

# Epoch: 13 | Epoch Time: 8m 48s
# 	Train Loss: 0.6451 - Jaccard: 0.7789 - F1: 0.7995 - Recall: 0.8682 - Precision: 0.8743
# 	 Val. Loss: 0.6762 - Jaccard: 0.7473 - F1: 0.7695 - Recall: 0.8825 - Precision: 0.8219

# Epoch: 14 | Epoch Time: 8m 48s
# 	Train Loss: 0.6304 - Jaccard: 0.7969 - F1: 0.8176 - Recall: 0.8702 - Precision: 0.8933
# 	 Val. Loss: 0.6900 - Jaccard: 0.7759 - F1: 0.7950 - Recall: 0.8756 - Precision: 0.8702

# Valid F1 improved from 0.8188 to 0.8219. Saving checkpoint: files/checkpoint1.pth
# Epoch: 15 | Epoch Time: 8m 49s
# 	Train Loss: 0.6258 - Jaccard: 0.7975 - F1: 0.8176 - Recall: 0.8696 - Precision: 0.8983
# 	 Val. Loss: 0.7451 - Jaccard: 0.8059 - F1: 0.8219 - Recall: 0.8578 - Precision: 0.9235

# Epoch: 16 | Epoch Time: 8m 49s
# 	Train Loss: 0.6616 - Jaccard: 0.7919 - F1: 0.8110 - Recall: 0.8649 - Precision: 0.8955
# 	 Val. Loss: 0.6994 - Jaccard: 0.7774 - F1: 0.7980 - Recall: 0.8811 - Precision: 0.8547

# Epoch: 17 | Epoch Time: 8m 48s
# 	Train Loss: 0.6376 - Jaccard: 0.7915 - F1: 0.8118 - Recall: 0.8712 - Precision: 0.8862
# 	 Val. Loss: 0.6703 - Jaccard: 0.7834 - F1: 0.8047 - Recall: 0.8835 - Precision: 0.8598

# Epoch: 18 | Epoch Time: 8m 48s
# 	Train Loss: 0.6265 - Jaccard: 0.7900 - F1: 0.8105 - Recall: 0.8713 - Precision: 0.8874
# 	 Val. Loss: 0.7101 - Jaccard: 0.7988 - F1: 0.8144 - Recall: 0.8555 - Precision: 0.9216

# Epoch: 19 | Epoch Time: 8m 49s
# 	Train Loss: 0.6350 - Jaccard: 0.7847 - F1: 0.8043 - Recall: 0.8667 - Precision: 0.8889
# 	 Val. Loss: 0.6936 - Jaccard: 0.7995 - F1: 0.8174 - Recall: 0.8668 - Precision: 0.9093

# Epoch: 20 | Epoch Time: 8m 48s
# 	Train Loss: 0.6111 - Jaccard: 0.7983 - F1: 0.8189 - Recall: 0.8735 - Precision: 0.8945
# 	 Val. Loss: 0.6642 - Jaccard: 0.7867 - F1: 0.8073 - Recall: 0.8742 - Precision: 0.8769

# Epoch: 21 | Epoch Time: 8m 48s
# 	Train Loss: 0.6099 - Jaccard: 0.8012 - F1: 0.8222 - Recall: 0.8744 - Precision: 0.8942
# 	 Val. Loss: 0.6786 - Jaccard: 0.7988 - F1: 0.8177 - Recall: 0.8637 - Precision: 0.9048

# Epoch: 22 | Epoch Time: 8m 48s
# 	Train Loss: 0.5996 - Jaccard: 0.8032 - F1: 0.8242 - Recall: 0.8754 - Precision: 0.8985
# 	 Val. Loss: 0.6903 - Jaccard: 0.7842 - F1: 0.8046 - Recall: 0.8741 - Precision: 0.8770

# Epoch: 23 | Epoch Time: 8m 48s
# 	Train Loss: 0.5972 - Jaccard: 0.8019 - F1: 0.8229 - Recall: 0.8756 - Precision: 0.8969
# 	 Val. Loss: 0.6759 - Jaccard: 0.7568 - F1: 0.7799 - Recall: 0.8945 - Precision: 0.8229

# Valid F1 improved from 0.8219 to 0.8235. Saving checkpoint: files/checkpoint1.pth
# Epoch: 24 | Epoch Time: 8m 49s
# 	Train Loss: 0.6012 - Jaccard: 0.8063 - F1: 0.8272 - Recall: 0.8761 - Precision: 0.9001
# 	 Val. Loss: 0.6708 - Jaccard: 0.8043 - F1: 0.8235 - Recall: 0.8779 - Precision: 0.8972

# Epoch: 25 | Epoch Time: 8m 48s
# 	Train Loss: 0.5761 - Jaccard: 0.8119 - F1: 0.8330 - Recall: 0.8766 - Precision: 0.9065
# 	 Val. Loss: 0.6732 - Jaccard: 0.7831 - F1: 0.8042 - Recall: 0.8758 - Precision: 0.8683

# Epoch: 26 | Epoch Time: 8m 49s
# 	Train Loss: 0.5771 - Jaccard: 0.8098 - F1: 0.8314 - Recall: 0.8797 - Precision: 0.9004
# 	 Val. Loss: 0.6822 - Jaccard: 0.7955 - F1: 0.8150 - Recall: 0.8675 - Precision: 0.8963

# Valid F1 improved from 0.8235 to 0.8279. Saving checkpoint: files/checkpoint1.pth
# Epoch: 27 | Epoch Time: 8m 49s
# 	Train Loss: 0.5567 - Jaccard: 0.8155 - F1: 0.8368 - Recall: 0.8794 - Precision: 0.9085
# 	 Val. Loss: 0.6613 - Jaccard: 0.8089 - F1: 0.8279 - Recall: 0.8676 - Precision: 0.9143

# Epoch: 28 | Epoch Time: 8m 48s
# 	Train Loss: 0.5464 - Jaccard: 0.8221 - F1: 0.8434 - Recall: 0.8801 - Precision: 0.9133
# 	 Val. Loss: 0.6515 - Jaccard: 0.8060 - F1: 0.8256 - Recall: 0.8719 - Precision: 0.9063

# Valid F1 improved from 0.8279 to 0.8292. Saving checkpoint: files/checkpoint1.pth
# Epoch: 29 | Epoch Time: 8m 49s
# 	Train Loss: 0.5354 - Jaccard: 0.8200 - F1: 0.8423 - Recall: 0.8853 - Precision: 0.9083
# 	 Val. Loss: 0.6520 - Jaccard: 0.8105 - F1: 0.8292 - Recall: 0.8691 - Precision: 0.9169

# Valid F1 improved from 0.8292 to 0.8298. Saving checkpoint: files/checkpoint1.pth
# Epoch: 30 | Epoch Time: 8m 49s
# 	Train Loss: 0.5301 - Jaccard: 0.8289 - F1: 0.8504 - Recall: 0.8814 - Precision: 0.9215
# 	 Val. Loss: 0.6510 - Jaccard: 0.8107 - F1: 0.8298 - Recall: 0.8688 - Precision: 0.9173

# Epoch: 31 | Epoch Time: 8m 48s
# 	Train Loss: 0.5266 - Jaccard: 0.8303 - F1: 0.8527 - Recall: 0.8878 - Precision: 0.9152
# 	 Val. Loss: 0.6489 - Jaccard: 0.8028 - F1: 0.8226 - Recall: 0.8731 - Precision: 0.9017

# Valid F1 improved from 0.8298 to 0.8300. Saving checkpoint: files/checkpoint1.pth
# Epoch: 32 | Epoch Time: 8m 49s
# 	Train Loss: 0.5084 - Jaccard: 0.8337 - F1: 0.8558 - Recall: 0.8859 - Precision: 0.9216
# 	 Val. Loss: 0.6672 - Jaccard: 0.8124 - F1: 0.8300 - Recall: 0.8643 - Precision: 0.9258

# Epoch: 33 | Epoch Time: 8m 48s
# 	Train Loss: 0.5126 - Jaccard: 0.8367 - F1: 0.8584 - Recall: 0.8852 - Precision: 0.9263
# 	 Val. Loss: 0.6753 - Jaccard: 0.8128 - F1: 0.8296 - Recall: 0.8601 - Precision: 0.9302

# Valid F1 improved from 0.8300 to 0.8301. Saving checkpoint: files/checkpoint1.pth
# Epoch: 34 | Epoch Time: 8m 49s
# 	Train Loss: 0.5211 - Jaccard: 0.8371 - F1: 0.8586 - Recall: 0.8849 - Precision: 0.9270
# 	 Val. Loss: 0.6675 - Jaccard: 0.8123 - F1: 0.8301 - Recall: 0.8661 - Precision: 0.9219

# Valid F1 improved from 0.8301 to 0.8303. Saving checkpoint: files/checkpoint1.pth
# Epoch: 35 | Epoch Time: 8m 49s
# 	Train Loss: 0.5252 - Jaccard: 0.8376 - F1: 0.8594 - Recall: 0.8847 - Precision: 0.9278
# 	 Val. Loss: 0.6453 - Jaccard: 0.8112 - F1: 0.8303 - Recall: 0.8759 - Precision: 0.9095

# Valid F1 improved from 0.8303 to 0.8318. Saving checkpoint: files/checkpoint1.pth
# Epoch: 36 | Epoch Time: 8m 49s
# 	Train Loss: 0.5114 - Jaccard: 0.8401 - F1: 0.8617 - Recall: 0.8868 - Precision: 0.9286
# 	 Val. Loss: 0.6584 - Jaccard: 0.8136 - F1: 0.8318 - Recall: 0.8696 - Precision: 0.9200

# Valid F1 improved from 0.8318 to 0.8322. Saving checkpoint: files/checkpoint1.pth
# Epoch: 37 | Epoch Time: 8m 49s
# 	Train Loss: 0.5149 - Jaccard: 0.8392 - F1: 0.8617 - Recall: 0.8908 - Precision: 0.9228
# 	 Val. Loss: 0.6599 - Jaccard: 0.8146 - F1: 0.8322 - Recall: 0.8634 - Precision: 0.9286

# Epoch: 38 | Epoch Time: 8m 49s
# 	Train Loss: 0.5036 - Jaccard: 0.8395 - F1: 0.8616 - Recall: 0.8885 - Precision: 0.9240
# 	 Val. Loss: 0.6449 - Jaccard: 0.8082 - F1: 0.8266 - Recall: 0.8731 - Precision: 0.9096

# Epoch: 39 | Epoch Time: 8m 48s
# 	Train Loss: 0.5044 - Jaccard: 0.8391 - F1: 0.8613 - Recall: 0.8885 - Precision: 0.9253
# 	 Val. Loss: 0.6503 - Jaccard: 0.8147 - F1: 0.8320 - Recall: 0.8644 - Precision: 0.9271

# Epoch: 40 | Epoch Time: 8m 48s
# 	Train Loss: 0.4962 - Jaccard: 0.8407 - F1: 0.8631 - Recall: 0.8914 - Precision: 0.9225
# 	 Val. Loss: 0.6431 - Jaccard: 0.8072 - F1: 0.8265 - Recall: 0.8739 - Precision: 0.9058

# Epoch: 41 | Epoch Time: 8m 48s
# 	Train Loss: 0.5069 - Jaccard: 0.8386 - F1: 0.8614 - Recall: 0.8907 - Precision: 0.9200
# 	 Val. Loss: 0.6388 - Jaccard: 0.8054 - F1: 0.8250 - Recall: 0.8771 - Precision: 0.9020

# Valid F1 improved from 0.8322 to 0.8325. Saving checkpoint: files/checkpoint1.pth
# Epoch: 42 | Epoch Time: 8m 49s
# 	Train Loss: 0.4853 - Jaccard: 0.8426 - F1: 0.8654 - Recall: 0.8906 - Precision: 0.9263
# 	 Val. Loss: 0.6655 - Jaccard: 0.8150 - F1: 0.8325 - Recall: 0.8608 - Precision: 0.9302

# Epoch: 43 | Epoch Time: 8m 48s
# 	Train Loss: 0.5040 - Jaccard: 0.8423 - F1: 0.8648 - Recall: 0.8898 - Precision: 0.9261
# 	 Val. Loss: 0.6525 - Jaccard: 0.8128 - F1: 0.8315 - Recall: 0.8687 - Precision: 0.9199

# Epoch: 44 | Epoch Time: 8m 49s
# 	Train Loss: 0.5020 - Jaccard: 0.8405 - F1: 0.8631 - Recall: 0.8922 - Precision: 0.9236
# 	 Val. Loss: 0.6392 - Jaccard: 0.7978 - F1: 0.8185 - Recall: 0.8792 - Precision: 0.8887

# Epoch: 45 | Epoch Time: 8m 48s
# 	Train Loss: 0.4916 - Jaccard: 0.8409 - F1: 0.8636 - Recall: 0.8908 - Precision: 0.9254
# 	 Val. Loss: 0.6414 - Jaccard: 0.8091 - F1: 0.8286 - Recall: 0.8746 - Precision: 0.9082

# Valid F1 improved from 0.8325 to 0.8398. Saving checkpoint: files/checkpoint1.pth
# Epoch: 46 | Epoch Time: 8m 49s
# 	Train Loss: 0.4863 - Jaccard: 0.8450 - F1: 0.8682 - Recall: 0.8929 - Precision: 0.9261
# 	 Val. Loss: 0.6526 - Jaccard: 0.8194 - F1: 0.8398 - Recall: 0.8656 - Precision: 0.9290

# Epoch: 47 | Epoch Time: 8m 48s
# 	Train Loss: 0.4857 - Jaccard: 0.8435 - F1: 0.8665 - Recall: 0.8927 - Precision: 0.9252
# 	 Val. Loss: 0.6395 - Jaccard: 0.8071 - F1: 0.8266 - Recall: 0.8773 - Precision: 0.9030

# Epoch: 48 | Epoch Time: 8m 48s
# 	Train Loss: 0.4904 - Jaccard: 0.8447 - F1: 0.8675 - Recall: 0.8935 - Precision: 0.9256
# 	 Val. Loss: 0.6395 - Jaccard: 0.8133 - F1: 0.8325 - Recall: 0.8741 - Precision: 0.9128

# Epoch: 49 | Epoch Time: 8m 48s
# 	Train Loss: 0.4750 - Jaccard: 0.8469 - F1: 0.8695 - Recall: 0.8931 - Precision: 0.9280
# 	 Val. Loss: 0.6381 - Jaccard: 0.8102 - F1: 0.8295 - Recall: 0.8761 - Precision: 0.9080

# Epoch: 50 | Epoch Time: 8m 48s
# 	Train Loss: 0.4780 - Jaccard: 0.8453 - F1: 0.8677 - Recall: 0.8920 - Precision: 0.9285
# 	 Val. Loss: 0.6537 - Jaccard: 0.8179 - F1: 0.8355 - Recall: 0.8653 - Precision: 0.9297

# Epoch: 51 | Epoch Time: 8m 49s
# 	Train Loss: 0.4754 - Jaccard: 0.8491 - F1: 0.8715 - Recall: 0.8909 - Precision: 0.9340
# 	 Val. Loss: 0.6527 - Jaccard: 0.8158 - F1: 0.8339 - Recall: 0.8669 - Precision: 0.9245

# Epoch: 52 | Epoch Time: 8m 48s
# 	Train Loss: 0.4751 - Jaccard: 0.8504 - F1: 0.8724 - Recall: 0.8896 - Precision: 0.9362
# 	 Val. Loss: 0.6480 - Jaccard: 0.8144 - F1: 0.8329 - Recall: 0.8696 - Precision: 0.9201

# Epoch: 53 | Epoch Time: 8m 48s
# 	Train Loss: 0.4858 - Jaccard: 0.8466 - F1: 0.8685 - Recall: 0.8899 - Precision: 0.9336
# 	 Val. Loss: 0.6540 - Jaccard: 0.8149 - F1: 0.8327 - Recall: 0.8673 - Precision: 0.9234


In [None]:
import os
import random
import time
import datetime
import numpy as np
import albumentations as A
import cv2
from glob import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from text2embed import Text2Embed
from utils import seeding, create_dir, print_and_save, shuffling, epoch_time, calculate_metrics, mask_to_bbox
from metrics import DiceLoss, DiceBCELoss, MultiClassBCE, FocalLoss, CascadingLoss
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import timm
# from resnet import seresnext50_32x4d
import torch.nn.functional as F
# from torch.amp import autocast, GradScaler
import warnings
import torch.cuda.amp as amp
import torch.nn.init as init

class conv2d(nn.Module):
    def __init__(self, in_c, out_c, kernel_size=3, padding=1, dilation=1, act=True):
        super().__init__()
        self.act = act

        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.act == True:
            x = self.relu(x)
        return x

class channel_attention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(channel_attention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x0 = x
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return x0 * self.sigmoid(out)


class spatial_attention(nn.Module):
    def __init__(self, kernel_size=7):
        super(spatial_attention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x0 = x
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return x0 * self.sigmoid(x)

class dilated_conv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)

        self.c1 = nn.Sequential(conv2d(in_c, out_c, kernel_size=1, padding=0), channel_attention(out_c))
        self.c2 = nn.Sequential(conv2d(in_c, out_c, kernel_size=(3, 3), padding=6, dilation=6), channel_attention(out_c))
        self.c3 = nn.Sequential(conv2d(in_c, out_c, kernel_size=(3, 3), padding=12, dilation=12), channel_attention(out_c))
        self.c4 = nn.Sequential(conv2d(in_c, out_c, kernel_size=(3, 3), padding=18, dilation=18), channel_attention(out_c))
        self.c5 = conv2d(out_c*4, out_c, kernel_size=3, padding=1, act=False)
        self.c6 = conv2d(in_c, out_c, kernel_size=1, padding=0, act=False)
        self.sa = spatial_attention()

    def forward(self, x):
        x1 = self.c1(x)
        x2 = self.c2(x)
        x3 = self.c3(x)
        x4 = self.c4(x)
        xc = torch.cat([x1, x2, x3, x4], axis=1)
        xc = self.c5(xc)
        xs = self.c6(x)
        x = self.relu(xc+xs)
        x = self.sa(x)
        return x

class label_attention(nn.Module):
    def __init__(self, in_c):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)

        """ Channel Attention """
        self.c1 = nn.Sequential(
            nn.Conv2d(in_c[1], in_c[0], kernel_size=1, padding=0, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_c[0], in_c[0], kernel_size=1, padding=0, bias=False)
        )

    def forward(self, feats, label):
        """ Channel Attention """
        b, c = label.shape
        label = label.reshape(b, c, 1, 1)
        ch_attn = self.c1(label)
        ch_map = torch.sigmoid(ch_attn)
        feats = feats * ch_map

        ch_attn = ch_attn.reshape(ch_attn.shape[0], ch_attn.shape[1])
        return ch_attn, feats

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c, scale=2):
        super().__init__()
        self.scale = scale
        self.relu = nn.ReLU(inplace=True)

        self.up = nn.Upsample(scale_factor=scale, mode="bilinear", align_corners=True)
        self.c1 = conv2d(in_c + out_c, out_c, kernel_size=1, padding=0)  # Adjust input channels
        self.c2 = conv2d(out_c, out_c, act=False)
        self.c3 = conv2d(out_c, out_c, act=False)
        self.c4 = conv2d(out_c, out_c, kernel_size=1, padding=0, act=False)
        self.ca = channel_attention(out_c)
        self.sa = spatial_attention()

    def forward(self, x, skip):
        x = self.up(x)

        # Ensure spatial dimensions match
        if x.shape[2:] != skip.shape[2:]:
            x = F.interpolate(x, size=skip.shape[2:], mode="bilinear", align_corners=True)

        x = torch.cat([x, skip], dim=1)  # Concatenate along the channel dimension
        x = self.c1(x)

        s1 = x
        x = self.c2(x)
        x = self.relu(x + s1)

        s2 = x
        x = self.c3(x)
        x = self.relu(x + s2 + s1)

        s3 = x
        x = self.c4(x)
        x = self.relu(x + s3 + s2 + s1)

        x = self.ca(x)
        x = self.sa(x)
        return x

class output_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.c1 = nn.Conv2d(in_c, out_c, kernel_size=1, padding=0)

    def forward(self, x):
        x = self.up(x)
        x = self.c1(x)
        return x

class text_classifier(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Sequential(
            nn.Linear(in_c, in_c//8, bias=False), nn.ReLU(),
            nn.Linear(in_c//8, out_c[0], bias=False)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(in_c, in_c//8, bias=False), nn.ReLU(),
            nn.Linear(in_c//8, out_c[1], bias=False)
        )

    def forward(self, feats):
        pool = self.avg_pool(feats).view(feats.shape[0], feats.shape[1])
        num_lesions = self.fc1(pool)
        lesion_sizes = self.fc2(pool)
        # print(f"num_lesions shape: {num_lesions.shape}")
        # print(f"lesion_sizes shape: {lesion_sizes.shape}")
        return num_lesions, lesion_sizes

class embedding_feature_fusion(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.fc = nn.Sequential(
            nn.Conv2d((in_c[0]+in_c[1])*in_c[2], out_c, 1, bias=False), nn.ReLU(),
            nn.Conv2d(out_c, out_c, 1, bias=False), nn.ReLU()
        )

    def forward(self, num_lesions, lesion_sizes, label):
        # print(f"num_lesions shape: {num_lesions.shape}")
        # print(f"lesion_sizes shape: {lesion_sizes.shape}")
        num_lesions_prob = torch.softmax(num_lesions, axis=1)
        lesion_sizes_prob = torch.softmax(lesion_sizes, axis=1)
        # print(f"num_lesions_classes", num_lesions_prob.shape)
        # print(f"lesions_size_classes", lesion_sizes_prob.shape)
        prob = torch.cat([num_lesions_prob, lesion_sizes_prob], axis=1)
        # print(f"prob shape: {prob.shape}")
        prob = prob.view(prob.shape[0], prob.shape[1], 1)
        if label.shape[1] != prob.shape[1]:
          raise ValueError(
              f"Shape mismatch: label channels ({label.shape[1]}) != prob channels ({prob.shape[1]})"
          )
        # print(f"x shape before fc: {x.shape}")
        x = label * prob
        x = x.view(x.shape[0], -1, 1, 1)
        if x.shape[1] != 1800:
            # x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))  # Adjust spatial dimensions
            device = x.device  # Get the device of the input tensor
            conv_layer = torch.nn.Conv2d(x.shape[1], 1800, kernel_size=1).to(device)
            x = conv_layer(x)

        # print(f"x shape after fc: {x.shape}")
        x = self.fc(x)
        x = x.view(x.shape[0], -1)
        return x

class multiscale_feature_aggregation(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)

        self.up_2x2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.up_4x4 = nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True)
        self.up_8x8 = nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)  # Add upsampling for layer4 features

        self.c11 = conv2d(in_c[0], out_c, kernel_size=1, padding=0)
        self.c12 = conv2d(in_c[1], out_c, kernel_size=1, padding=0)
        self.c13 = conv2d(in_c[2], out_c, kernel_size=1, padding=0)
        self.c14 = conv2d(in_c[3], out_c, kernel_size=1, padding=0)  # Add for layer4 features
        self.c15 = conv2d(out_c * 4, out_c, kernel_size=1, padding=0)  # Update to handle 4 inputs

        self.c2 = conv2d(out_c, out_c, act=False)
        self.c3 = conv2d(out_c, out_c, act=False)

    def forward(self, x1, x2, x3, x4):  # Add x4 for layer4 features
        # Upsample all feature maps to the same spatial resolution (128x128)
        x1 = self.up_8x8(x1)  # Upsample x1 (from layer4) by 8x
        x2 = self.up_4x4(x2)  # Upsample x2 (from layer3) by 4x
        x3 = self.up_2x2(x3)  # Upsample x3 (from layer2) by 2x
        x4 = F.interpolate(x4, size=(128, 128), mode="bilinear", align_corners=True)  # Upsample x4 (from layer1) to 128x128

        x1 = self.c11(x1)
        x2 = self.c12(x2)
        x3 = self.c13(x3)
        x4 = self.c14(x4)  # Process x4

        x = torch.cat([x1, x2, x3, x4], dim=1)  # Concatenate all 4 inputs
        x = self.c15(x)

        s1 = x
        x = self.c2(x)
        x = self.relu(x + s1)

        s2 = x
        x = self.c3(x)
        x = self.relu(x + s2 + s1)

        return x

class ProgressiveDenoisingAttention1(nn.Module):
    def __init__(self, channels, num_iterations=3):
        super(ProgressiveDenoisingAttention1, self).__init__()
        self.num_iterations = num_iterations
        self.channels = channels

        # Convolution layers for refining attention
        self.conv_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(channels),
                nn.ReLU(inplace=True)
            ) for _ in range(num_iterations)
        ])

        # Final attention map generation
        self.final_conv = nn.Conv2d(channels, 1, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Initial attention map
        attention_map = x

        # Progressive denoising
        for conv in self.conv_layers:
            attention_map = conv(attention_map)

        # Final attention map
        attention_map = self.final_conv(attention_map)
        attention_map = self.sigmoid(attention_map)

        # Apply attention to input features
        return x * attention_map

class UNetDenoise(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(64, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

class EnhancedDilatedConv(nn.Module):
    def __init__(self, in_c, out_c, neg_weight=0.3):
        super().__init__()
        self.neg_weight = neg_weight
        self.main_conv = dilated_conv(in_c, out_c)
        self.negative_conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=6, dilation=6),
            channel_attention(out_c),
            spatial_attention()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        pos_features = self.main_conv(x)
        neg_features = self.negative_conv(1 - x)
        return self.relu(pos_features + self.neg_weight * neg_features)


class TGAlesionSeg(nn.Module):
    def __init__(self):
        super().__init__()

        """ Backbone: SE-ResNeXt50 """
        backbone = timm.create_model('seresnext50_32x4d', pretrained=True)

        # Extract the layers from the backbone
        self.layer0 = nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.act1,
            backbone.maxpool
        )
        self.layer1 = backbone.layer1
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4  # Add layer4

        self.text_classifier = text_classifier(2048, [3, 3])  # Update input channels to 2048 for layer4
        self.label_fc = embedding_feature_fusion([3, 3, 300], 128)

        """ Dilated Conv """
        self.s1 = EnhancedDilatedConv(64, 128)
        self.s2 = EnhancedDilatedConv(256, 128)
        self.s3 = EnhancedDilatedConv(512, 128)
        self.s4 = EnhancedDilatedConv(1024, 128)
        self.s5 = EnhancedDilatedConv(2048, 128)


        """ Progressive Denoising Attention """
        self.pda1 = ProgressiveDenoisingAttention1(128, num_iterations=3)
        self.pda2 = ProgressiveDenoisingAttention1(128, num_iterations=3)
        self.pda3 = ProgressiveDenoisingAttention1(128, num_iterations=3)
        self.pda4 = ProgressiveDenoisingAttention1(128, num_iterations=3)
        self.pda5 = ProgressiveDenoisingAttention1(128, num_iterations=3)  # Add PDA for layer4 features

        """ U-Net Denoising Modules """
        self.denoise1 = UNetDenoise(in_channels=128, out_channels=128)
        self.denoise2 = UNetDenoise(in_channels=128, out_channels=128)
        self.denoise3 = UNetDenoise(in_channels=128, out_channels=128)
        self.denoise4 = UNetDenoise(in_channels=128, out_channels=128)  # Add denoising for layer4 features

        """ Decoder """
        self.d1 = decoder_block(128, 128, scale=2)
        self.a1 = label_attention([128, 128])

        self.d2 = decoder_block(128, 128, scale=2)
        self.a2 = label_attention([128, 128])

        self.d3 = decoder_block(128, 128, scale=2)
        self.a3 = label_attention([128, 128])

        self.d4 = decoder_block(128, 128, scale=2)  # Add decoder block for layer4 features
        self.a4 = label_attention([128, 128])

        self.ag = multiscale_feature_aggregation([128, 128, 128, 128], 128)  # Update to include layer4 features

        self.y1 = output_block(128, 1)

    def forward(self, image, label):
        """ Backbone: SE-ResNeXt50 """
        x0 = image
        x1 = self.layer0(x0)    ## [-1, 64, h/2, w/2]
        x2 = self.layer1(x1)    ## [-1, 256, h/4, w/4]
        x3 = self.layer2(x2)    ## [-1, 512, h/8, w/8]
        x4 = self.layer3(x3)    ## [-1, 1024, h/16, w/16]
        x5 = self.layer4(x4)    ## [-1, 2048, h/32, w/32]  # Add layer4 output

        num_lesions, lesion_sizes = self.text_classifier(x5)  # Use x5 (layer4 features)
        f0 = self.label_fc(num_lesions, lesion_sizes, label)

        """ Dilated Conv + PDA """
        s1 = self.pda1(self.s1(x1))
        s2 = self.pda2(self.s2(x2))
        s3 = self.pda3(self.s3(x3))
        s4 = self.pda4(self.s4(x4))
        s5 = self.pda5(self.s5(x5))  # Add PDA for layer4 features

        """ Decoder + Denoising """
        # Stage 1
        d1 = self.d1(s5, s4)  # Use s5 (layer4 features) as input to the first decoder block
        f1, a1 = self.a1(d1, f0)
        a1_denoised = self.denoise1(a1)

        # Stage 2
        d2 = self.d2(a1_denoised, s3)
        f = f0 + f1
        f2, a2 = self.a2(d2, f)
        a2_denoised = self.denoise2(a2)

        # Stage 3
        d3 = self.d3(a2_denoised, s2)
        f = f0 + f1 + f2
        f3, a3 = self.a3(d3, f)
        a3_denoised = self.denoise3(a3)

        # Stage 4
        d4 = self.d4(a3_denoised, s1)
        f = f0 + f1 + f2 + f3
        f4, a4 = self.a4(d4, f)
        a4_denoised = self.denoise4(a4)

        ag = self.ag(a1_denoised, a2_denoised, a3_denoised, a4_denoised)  # Include layer4 features
        y1 = self.y1(ag)

        return y1, num_lesions, lesion_sizes

def prepare_input(res):
    x1 = torch.FloatTensor(1, 3, 256, 256).cuda()
    x2 = torch.FloatTensor(1, 5, 300).cuda()
    return dict(x = [x1, x2])

### --------------------------------------------------------------------------------------------------------------------------------------------------------

def load_names(path, file_path):
    f = open(file_path, "r")
    data = f.read().split("\n")[:-1]
    images = [os.path.join(path,"images", name) + ".png" for name in data]
    masks = [os.path.join(path,"masks", name) + ".png" for name in data]
    return images, masks

def label_dictionary():
    label_dict = {}
    label_dict["lesion"] = ["zero", "one", "multiple", "small", "medium", "large"]
    return label_dict

def load_data(path):
    train_names_path = f"{path}/train.txt"
    valid_names_path = f"{path}/val.txt"

    train_x, train_y = load_names(path, train_names_path)
    valid_x, valid_y = load_names(path, valid_names_path)

    label_dict = label_dictionary()
    print(label_dict)
    train_label = len(train_x) * [label_dict["lesion"]]
    valid_label = len(valid_x) * [label_dict["lesion"]]

    return (train_x, train_y, train_label), (valid_x, valid_y, valid_label)

class DATASET(Dataset):
    def __init__(self, images_path, labels_path, masks_path, size, transform=None):
        super().__init__()
        self.images_path = images_path
        self.labels_path = labels_path
        self.masks_path = masks_path
        self.size = size  # Ensure size is stored as an instance attribute
        self.transform = transform
        self.n_samples = len(images_path)

        self.embed = Text2Embed()

    def visualize_bbox_on_mask(self, image, mask, bboxes, title="Mask with BBoxes"):
        """Visualize the mask and bounding boxes overlaid on the mask."""
        if len(image.shape) != 3 or image.shape[2] != 3:
            raise ValueError(f"Expected a color image with 3 channels, but got shape {image.shape}")

        fig, ax = plt.subplots(1, 2, figsize=(12, 6))

        # Show original image
        ax[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        ax[0].set_title("Original Image")
        ax[0].axis("off")

        # Show mask with bounding boxes
        mask_with_bbox = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
        for bbox in bboxes:
            x1, y1, x2, y2 = bbox
            mask_with_bbox = cv2.rectangle(mask_with_bbox, (x1, y1), (x2, y2), (255, 0, 0), 2)

        ax[1].imshow(cv2.cvtColor(mask_with_bbox, cv2.COLOR_BGR2RGB))
        ax[1].set_title(title)
        ax[1].axis("off")

        plt.tight_layout()
        plt.show()


    def mask_to_text(self, mask, image=None):
        bboxes = mask_to_bbox(mask)
        lesion_sizes = 0
        num_lesions = 0

        # Calculate sizes and number of lesions
        for bbox in bboxes:
            x1, y1, x2, y2 = bbox
            h = (y2 - y1)
            w = (x2 - x1)
            area = (h * w) / (mask.shape[0] * mask.shape[1])

            if area < 0.007:
                lesion_sizes = 0
            elif 0.007 <= area < 0.04:
                lesion_sizes = 1
            elif area >= 0.04:
                lesion_sizes = 2

        if(len(bboxes) == 0):
          num_lesions = 0
        elif(len(bboxes) == 1):
          num_lesions = 1
        elif(len(bboxes) >= 2):
          num_lesions = 2

        # visualization
        if image is None:
            self.visualize_bbox_on_mask(image, mask, bboxes, title="Mask with Bounding Boxes")

        return np.array(num_lesions), np.array(lesion_sizes)


    def __getitem__(self, index):
      image = cv2.imread(self.images_path[index], cv2.IMREAD_COLOR)
      mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)

      if image is None or mask is None:
          raise FileNotFoundError(f"Could not read image or mask at index {index}. Check the file paths.")

      if self.transform is not None:
          augmentations = self.transform(image=image, mask=mask)
          image = augmentations["image"]
          mask = augmentations["mask"]

      image = cv2.resize(image, self.size)
      if len(image.shape) != 3 or image.shape[2] != 3:
          raise ValueError(f"Expected a color image with 3 channels after resizing, but got shape {image.shape}")

      image = np.transpose(image, (2, 0, 1))
      image = image / 255.0

      mask = cv2.resize(mask, self.size)
      mask_copy = mask
      mask = np.expand_dims(mask, axis=0)
      mask = mask / 255.0

      visual_image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)
      num_lesions, lesion_sizes = self.mask_to_text(mask_copy, visual_image)

      # print(f"image = {self.images_path[index]}\n")
      # print(f"mask = {self.masks_path[index]}\n"print(np.unique(labels))  # Confirm number and range of classes.
      label = []
      words = self.labels_path[index]
      for word in words:
          word_embed = self.embed.to_embed(word)[0]
          label.append(word_embed)
      label = np.array(label)
      # print(f"label shape: {label.shape}")
      # print(np.unique(label))  # Confirm number and range of classes.

      return (image, label), (mask, num_lesions, lesion_sizes)

    def __len__(self):
        return self.n_samples


def train(model, loader, optimizer, loss_fn, device):
    model.train()
    epoch_loss = 0.0
    epoch_jac = 0.0
    epoch_f1 = 0.0
    epoch_recall = 0.0
    epoch_precision = 0.0
    # scaler = amp.GradScaler()  # Define scaler before training

    for i, ((x, l), (y1, y2, y3)) in enumerate(loader):
        x = x.to(device, dtype=torch.float32)
        l = l.to(device, dtype=torch.float32)
        y1 = y1.to(device, dtype=torch.float32)
        y2 = y2.to(device, dtype=torch.long)
        y3 = y3.to(device, dtype=torch.long)

        optimizer.zero_grad()

        p1, p2, p3 = model(x, l)
        p2 = torch.softmax(p2, dim=1)
        p3 = torch.softmax(p3, dim=1)

        loss1 = loss_fn[0](p1, y1, p2, y2)  # Segmentation loss
        loss2 = loss_fn[1](p2, y2)
        loss3 = loss_fn[2](p3, y3)

        loss = loss1 + loss2 + loss3

        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        batch_jac = []
        batch_f1 = []
        batch_recall = []
        batch_precision = []

        for yt, yp in zip(y1, p1):
            score = calculate_metrics(yt, yp)
            batch_jac.append(score[0])
            batch_f1.append(score[1])
            batch_recall.append(score[2])
            batch_precision.append(score[3])

        epoch_jac += np.mean(batch_jac)
        epoch_f1 += np.mean(batch_f1)
        epoch_recall += np.mean(batch_recall)
        epoch_precision += np.mean(batch_precision)

    epoch_loss = epoch_loss / len(loader)
    epoch_jac = epoch_jac / len(loader)
    epoch_f1 = epoch_f1 / len(loader)
    epoch_recall = epoch_recall / len(loader)
    epoch_precision = epoch_precision / len(loader)

    return epoch_loss, [epoch_jac, epoch_f1, epoch_recall, epoch_precision]

def evaluate(model, loader, loss_fn, device):
    model.eval()

    epoch_loss = 0.0
    epoch_jac = 0.0
    epoch_f1 = 0.0
    epoch_recall = 0.0
    epoch_precision = 0.0

    with torch.no_grad():
        for i, ((x, l), (y1, y2, y3)) in enumerate(loader):
            x = x.to(device, dtype=torch.float32)
            l = l.to(device, dtype=torch.float32)
            y1 = y1.to(device, dtype=torch.float32)
            y2 = y2.to(device, dtype=torch.long)
            y3 = y3.to(device, dtype=torch.long)

            p1, p2, p3 = model(x, l)
            p2 = torch.softmax(p2, dim=1)
            p3 = torch.softmax(p3, dim=1)

            loss1 = loss_fn[0](p1, y1)
            loss2 = loss_fn[1](p2, y2)
            loss3 = loss_fn[2](p3, y3)

            loss = loss1 + loss2 + loss3

            epoch_loss += loss.item()

            batch_jac = []
            batch_f1 = []
            batch_recall = []
            batch_precision = []

            for yt, yp in zip(y1, p1):
                score = calculate_metrics(yt, yp)
                batch_jac.append(score[0])
                batch_f1.append(score[1])
                batch_recall.append(score[2])
                batch_precision.append(score[3])

            epoch_jac += np.mean(batch_jac)
            epoch_f1 += np.mean(batch_f1)
            epoch_recall += np.mean(batch_recall)
            epoch_precision += np.mean(batch_precision)

    epoch_loss = epoch_loss / len(loader)
    epoch_jac = epoch_jac / len(loader)
    epoch_f1 = epoch_f1 / len(loader)
    epoch_recall = epoch_recall / len(loader)
    epoch_precision = epoch_precision / len(loader)

    return epoch_loss, [epoch_jac, epoch_f1, epoch_recall, epoch_precision]


if __name__ == "__main__":
    """ Seeding """
    seeding(42)

    """ Directories """
    create_dir("files")

    """ Training logfile """
    train_log_path = "files/train_log.txt"
    if os.path.exists(train_log_path):
        print("Log file exists")
    else:
        train_log = open("files/train_log.txt", "w")
        train_log.write("\n")
        train_log.close()

    """ Record Date & Time """
    datetime_object = str(datetime.datetime.now())
    print_and_save(train_log_path, datetime_object)
    print("")

    """ Hyperparameters """
    image_size = 256
    size = (image_size, image_size)
    batch_size = 16
    num_epochs = 70
    lr = 1e-5
    early_stopping_patience = 50
    checkpoint_path = "files/checkpoint1.pth"
    path = "/content/Pneumothorax"

    data_str = f"Image Size: {size}\nBatch Size: {batch_size}\nLR: {lr}\nEpochs: {num_epochs}\n"
    data_str += f"Early Stopping Patience: {early_stopping_patience}\n"
    print_and_save(train_log_path, data_str)

    """ Data augmentation: Transforms """
    transform = A.Compose([
        A.Rotate(limit=15, p=0.5),  # Moderate rotations
        A.HorizontalFlip(p=0.5),    # Horizontal flipping
        A.VerticalFlip(p=0.2),      # Vertical flipping
        A.RandomBrightnessContrast(p=0.5),  # Adjust brightness and contrast
        A.CLAHE(p=0.6),  # Contrast Limited Adaptive Histogram Equalization
        A.ElasticTransform(alpha=1, sigma=50, p=0.3),  # Elastic transformations
        A.CoarseDropout(num_holes_range=(2, 5), hole_height_range=(5, 20), hole_width_range=(5, 20), fill="inpaint_ns", p=0.5)  # Coarse dropout
    ])

    """ Dataset """
    (train_x, train_y, train_label), (valid_x, valid_y, valid_label) = load_data(path)
    train_x, train_y, train_label = shuffling(train_x, train_y, train_label)
    data_str = f"Dataset Size:\nTrain: {len(train_x)} - Valid: {len(valid_x)}\n"
    print_and_save(train_log_path, data_str)

    """ Dataset and loader """
    train_dataset = DATASET(train_x, train_label, train_y, (image_size, image_size), transform=transform)
    valid_dataset = DATASET(valid_x, valid_label, valid_y, (image_size, image_size), transform=None)

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4
    )

    valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4
    )

    """ Model """
    device = torch.device('cuda')
    model = TGAlesionSeg()
    # model.apply(initialize_weights)
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
    scheduler = torch.optim.lr_scheduler.CyclicLR(
        optimizer,
        base_lr=1e-5,
        max_lr=1e-4,
        step_size_up=len(train_loader)*2,  # 2 epochs up, 2 epochs down
        mode='triangular',
        cycle_momentum=False
    )


    # Using Focal Loss for num_lesions & lesion_sizes
    loss_fn = [
        CascadingLoss(beta_dice=8, beta_focal=7),  # Primary segmentation loss
        FocalLoss(alpha=0.25, gamma=2),  # Num lesions classification
        FocalLoss(alpha=0.25, gamma=2)   # lesion size classification
    ]

    loss_name = "BCE Dice Loss"
    data_str = f"Optimizer: AdamW\nLoss: {loss_name}\n"
    print_and_save(train_log_path, data_str)

    """ Training the model """
    best_valid_metrics = 0.0
    early_stopping_count = 0

    for epoch in range(num_epochs):
        start_time = time.time()

        train_loss, train_metrics = train(model, train_loader, optimizer, loss_fn, device)
        valid_loss, valid_metrics = evaluate(model, valid_loader, loss_fn, device)
        scheduler.step(valid_loss)

        if valid_metrics[1] > best_valid_metrics:
            data_str = f"Valid F1 improved from {best_valid_metrics:2.4f} to {valid_metrics[1]:2.4f}. Saving checkpoint: {checkpoint_path}"
            print_and_save(train_log_path, data_str)

            best_valid_metrics = valid_metrics[1]
            torch.save(model.state_dict(), checkpoint_path)
            early_stopping_count = 0

        elif valid_metrics[1] < best_valid_metrics:
            early_stopping_count += 1

        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        data_str = f"Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n"
        data_str += f"\tTrain Loss: {train_loss:.4f} - Jaccard: {train_metrics[0]:.4f} - F1: {train_metrics[1]:.4f} - Recall: {train_metrics[2]:.4f} - Precision: {train_metrics[3]:.4f}\n"
        data_str += f"\t Val. Loss: {valid_loss:.4f} - Jaccard: {valid_metrics[0]:.4f} - F1: {valid_metrics[1]:.4f} - Recall: {valid_metrics[2]:.4f} - Precision: {valid_metrics[3]:.4f}\n"
        print_and_save(train_log_path, data_str)

        if early_stopping_count == early_stopping_patience:
            data_str = f"Early stopping: validation loss stops improving from last {early_stopping_patience} continously.\n"
            print_and_save(train_log_path, data_str)
            break


2025-03-10 22:13:27.681866

Image Size: (256, 256)
Batch Size: 16
LR: 1e-05
Epochs: 70
Early Stopping Patience: 50

{'polyp': ['zero', 'one', 'multiple', 'small', 'medium', 'large']}
Dataset Size:
Train: 9636 - Valid: 2409

downloading https://nlp.h-its.org/bpemb/en/en.wiki.bpe.vs100000.model


100%|██████████| 1987533/1987533 [00:00<00:00, 2039043.56B/s]


downloading https://nlp.h-its.org/bpemb/en/en.wiki.bpe.vs100000.d300.w2v.bin.tar.gz


100%|██████████| 112159933/112159933 [00:07<00:00, 15093736.52B/s]
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/111M [00:00<?, ?B/s]

Optimizer: AdamW
Loss: BCE Dice Loss



In [None]:
import os
import random
import time
import datetime
import numpy as np
import albumentations as A
import cv2
from glob import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from text2embed import Text2Embed
from utils import seeding, create_dir, print_and_save, shuffling, epoch_time, calculate_metrics, mask_to_bbox
from metrics import DiceLoss, DiceBCELoss, MultiClassBCE, FocalLoss
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from resnet import resnet50
import torch.nn.functional as F
# from torch.amp import autocast, GradScaler
import warnings
import torch.cuda.amp as amp
import torch.nn.init as init

class conv2d(nn.Module):
    def __init__(self, in_c, out_c, kernel_size=3, padding=1, dilation=1, act=True):
        super().__init__()
        self.act = act

        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.act == True:
            x = self.relu(x)
        return x

class channel_attention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(channel_attention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x0 = x
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return x0 * self.sigmoid(out)


class spatial_attention(nn.Module):
    def __init__(self, kernel_size=7):
        super(spatial_attention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x0 = x
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return x0 * self.sigmoid(x)

class dilated_conv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)

        self.c1 = nn.Sequential(conv2d(in_c, out_c, kernel_size=1, padding=0), channel_attention(out_c))
        self.c2 = nn.Sequential(conv2d(in_c, out_c, kernel_size=(3, 3), padding=6, dilation=6), channel_attention(out_c))
        self.c3 = nn.Sequential(conv2d(in_c, out_c, kernel_size=(3, 3), padding=12, dilation=12), channel_attention(out_c))
        self.c4 = nn.Sequential(conv2d(in_c, out_c, kernel_size=(3, 3), padding=18, dilation=18), channel_attention(out_c))
        self.c5 = conv2d(out_c*4, out_c, kernel_size=3, padding=1, act=False)
        self.c6 = conv2d(in_c, out_c, kernel_size=1, padding=0, act=False)
        self.sa = spatial_attention()

    def forward(self, x):
        x1 = self.c1(x)
        x2 = self.c2(x)
        x3 = self.c3(x)
        x4 = self.c4(x)
        xc = torch.cat([x1, x2, x3, x4], axis=1)
        xc = self.c5(xc)
        xs = self.c6(x)
        x = self.relu(xc+xs)
        x = self.sa(x)
        return x

class label_attention(nn.Module):
    def __init__(self, in_c):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)

        """ Channel Attention """
        self.c1 = nn.Sequential(
            nn.Conv2d(in_c[1], in_c[0], kernel_size=1, padding=0, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_c[0], in_c[0], kernel_size=1, padding=0, bias=False)
        )

    def forward(self, feats, label):
        """ Channel Attention """
        b, c = label.shape
        label = label.reshape(b, c, 1, 1)
        ch_attn = self.c1(label)
        ch_map = torch.sigmoid(ch_attn)
        feats = feats * ch_map

        ch_attn = ch_attn.reshape(ch_attn.shape[0], ch_attn.shape[1])
        return ch_attn, feats

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c, scale=2):
        super().__init__()
        self.scale = scale
        self.relu = nn.ReLU(inplace=True)

        self.up = nn.Upsample(scale_factor=scale, mode="bilinear", align_corners=True)
        self.c1 = conv2d(in_c+out_c, out_c, kernel_size=1, padding=0)
        self.c2 = conv2d(out_c, out_c, act=False)
        self.c3 = conv2d(out_c, out_c, act=False)
        self.c4 = conv2d(out_c, out_c, kernel_size=1, padding=0, act=False)
        self.ca = channel_attention(out_c)
        self.sa = spatial_attention()

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], axis=1)
        x = self.c1(x)

        s1 = x
        x = self.c2(x)
        x = self.relu(x+s1)

        s2 = x
        x = self.c3(x)
        x = self.relu(x+s2+s1)

        s3 = x
        x = self.c4(x)
        x = self.relu(x+s3+s2+s1)

        x = self.ca(x)
        x = self.sa(x)
        return x

class output_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.c1 = nn.Conv2d(in_c, out_c, kernel_size=1, padding=0)

    def forward(self, x):
        x = self.up(x)
        x = self.c1(x)
        return x

class text_classifier(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Sequential(
            nn.Linear(in_c, in_c//8, bias=False), nn.ReLU(),
            nn.Linear(in_c//8, out_c[0], bias=False)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(in_c, in_c//8, bias=False), nn.ReLU(),
            nn.Linear(in_c//8, out_c[1], bias=False)
        )

    def forward(self, feats):
        pool = self.avg_pool(feats).view(feats.shape[0], feats.shape[1])
        num_lesions = self.fc1(pool)
        lesion_sizes = self.fc2(pool)
        # print(f"num_lesions shape: {num_lesions.shape}")
        # print(f"lesion_sizes shape: {lesion_sizes.shape}")
        return num_lesions, lesion_sizes

class embedding_feature_fusion(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.fc = nn.Sequential(
            nn.Conv2d((in_c[0]+in_c[1])*in_c[2], out_c, 1, bias=False), nn.ReLU(),
            nn.Conv2d(out_c, out_c, 1, bias=False), nn.ReLU()
        )

    def forward(self, num_lesions, lesion_sizes, label):
        # print(f"num_lesions shape: {num_lesions.shape}")
        # print(f"lesion_sizes shape: {lesion_sizes.shape}")
        num_lesions_prob = torch.softmax(num_lesions, axis=1)
        lesion_sizes_prob = torch.softmax(lesion_sizes, axis=1)
        # print(f"num_lesions_classes", num_lesions_prob.shape)
        # print(f"lesions_size_classes", lesion_sizes_prob.shape)
        prob = torch.cat([num_lesions_prob, lesion_sizes_prob], axis=1)
        # print(f"prob shape: {prob.shape}")
        prob = prob.view(prob.shape[0], prob.shape[1], 1)
        if label.shape[1] != prob.shape[1]:
          raise ValueError(
              f"Shape mismatch: label channels ({label.shape[1]}) != prob channels ({prob.shape[1]})"
          )
        # print(f"x shape before fc: {x.shape}")
        x = label * prob
        x = x.view(x.shape[0], -1, 1, 1)
        if x.shape[1] != 1800:
            # x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))  # Adjust spatial dimensions
            device = x.device  # Get the device of the input tensor
            conv_layer = torch.nn.Conv2d(x.shape[1], 1800, kernel_size=1).to(device)
            x = conv_layer(x)

        # print(f"x shape after fc: {x.shape}")
        x = self.fc(x)
        x = x.view(x.shape[0], -1)
        return x

class multiscale_feature_aggregation(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)

        self.up_2x2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.up_4x4 = nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True)

        self.c11 = conv2d(in_c[0], out_c, kernel_size=1, padding=0)
        self.c12 = conv2d(in_c[1], out_c, kernel_size=1, padding=0)
        self.c13 = conv2d(in_c[2], out_c, kernel_size=1, padding=0)
        self.c14 = conv2d(out_c*3, out_c, kernel_size=1, padding=0)

        self.c2 = conv2d(out_c, out_c, act=False)
        self.c3 = conv2d(out_c, out_c, act=False)

    def forward(self, x1, x2, x3):
        x1 = self.up_4x4(x1)
        x2 = self.up_2x2(x2)

        x1 = self.c11(x1)
        x2 = self.c12(x2)
        x3 = self.c13(x3)
        x = torch.cat([x1, x2, x3], axis=1)
        x = self.c14(x)

        s1 = x
        x = self.c2(x)
        x = self.relu(x+s1)

        s2 = x
        x = self.c3(x)
        x = self.relu(x+s2+s1)

        return x

class ProgressiveDenoisingAttention1(nn.Module):
    def __init__(self, channels, num_iterations=3):
        super(ProgressiveDenoisingAttention1, self).__init__()
        self.num_iterations = num_iterations
        self.channels = channels

        # Convolution layers for refining attention
        self.conv_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(channels),
                nn.ReLU(inplace=True)
            ) for _ in range(num_iterations)
        ])

        # Final attention map generation
        self.final_conv = nn.Conv2d(channels, 1, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Initial attention map
        attention_map = x

        # Progressive denoising
        for conv in self.conv_layers:
            attention_map = conv(attention_map)

        # Final attention map
        attention_map = self.final_conv(attention_map)
        attention_map = self.sigmoid(attention_map)

        # Apply attention to input features
        return x * attention_map

class UNetDenoise(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(64, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

class TGAlesionSeg(nn.Module):
    def __init__(self):
        super().__init__()

        """ Backbone: ResNet50 """
        backbone = resnet50()
        self.layer0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu)
        self.layer1 = nn.Sequential(backbone.maxpool, backbone.layer1)
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3

        self.text_classifier = text_classifier(1024, [3, 3])
        self.label_fc = embedding_feature_fusion([3, 3, 300], 128)

        """ Dilated Conv """
        self.s1 = dilated_conv(64, 128)
        self.s2 = dilated_conv(256, 128)
        self.s3 = dilated_conv(512, 128)
        self.s4 = dilated_conv(1024, 128)

        # """ Progressive Denoising Attention """
        self.pda1 = ProgressiveDenoisingAttention1(128, num_iterations=3)
        self.pda2 = ProgressiveDenoisingAttention1(128, num_iterations=3)
        self.pda3 = ProgressiveDenoisingAttention1(128, num_iterations=3)
        self.pda4 = ProgressiveDenoisingAttention1(128, num_iterations=3)

        """ U-Net Denoising Modules """
        self.denoise1 = UNetDenoise(in_channels=128, out_channels=128)
        self.denoise2 = UNetDenoise(in_channels=128, out_channels=128)
        self.denoise3 = UNetDenoise(in_channels=128, out_channels=128)

        """ Decoder """
        self.d1 = decoder_block(128, 128, scale=2)
        self.a1 = label_attention([128, 128])

        self.d2 = decoder_block(128, 128, scale=2)
        self.a2 = label_attention([128, 128])

        self.d3 = decoder_block(128, 128, scale=2)
        self.a3 = label_attention([128, 128])

        self.ag = multiscale_feature_aggregation([128, 128, 128], 128)

        self.y1 = output_block(128, 1)

    def forward(self, image, label):
        """ Backbone: ResNet50 """
        x0 = image
        x1 = self.layer0(x0)    ## [-1, 64, h/2, w/2]
        x2 = self.layer1(x1)    ## [-1, 256, h/4, w/4]
        x3 = self.layer2(x2)    ## [-1, 512, h/8, w/8]
        x4 = self.layer3(x3)    ## [-1, 1024, h/16, w/16]
        # print(x1.shape, x2.shape, x3.shape, x4.shape, x5.shape)

        num_lesions, lesion_sizes = self.text_classifier(x4)
        f0 = self.label_fc(num_lesions, lesion_sizes, label)

        # """ Dilated Conv + PDA """
        s1 = self.pda1(self.s1(x1))
        s2 = self.pda2(self.s2(x2))
        s3 = self.pda3(self.s3(x3))
        s4 = self.pda4(self.s4(x4))

        """ Decoder + Denoising """
        # Stage 1
        d1 = self.d1(s4, s3)
        f1, a1 = self.a1(d1, f0)
        a1_denoised = self.denoise1(a1)  # Apply PDA for a1

        # Stage 2
        d2 = self.d2(a1_denoised, s2)
        f = f0 + f1
        f2, a2 = self.a2(d2, f)
        a2_denoised = self.denoise2(a2)  # Apply PDA for a2

        # Stage 3
        d3 = self.d3(a2_denoised, s1)
        f = f0 + f1 + f2
        f3, a3 = self.a3(d3, f)
        a3_denoised = self.denoise3(a3)  # Apply PDA for a3

        ag = self.ag(a1_denoised, a2_denoised, a3_denoised)
        y1 = self.y1(ag)


        return y1, num_lesions, lesion_sizes

def prepare_input(res):
    x1 = torch.FloatTensor(1, 3, 256, 256).cuda()
    x2 = torch.FloatTensor(1, 5, 300).cuda()
    return dict(x = [x1, x2])

### --------------------------------------------------------------------------------------------------------------------------------------------------------

def load_names(path, file_path):
    f = open(file_path, "r")
    data = f.read().split("\n")[:-1]
    images = [os.path.join(path,"images", name) + ".png" for name in data]
    masks = [os.path.join(path,"masks", name) + ".png" for name in data]
    return images, masks

def label_dictionary():
    label_dict = {}
    label_dict["lesion"] = ["zero", "one", "multiple", "small", "medium", "large"]
    return label_dict

def load_data(path):
    train_names_path = f"{path}/train.txt"
    valid_names_path = f"{path}/val.txt"

    train_x, train_y = load_names(path, train_names_path)
    valid_x, valid_y = load_names(path, valid_names_path)

    label_dict = label_dictionary()
    print(label_dict)
    train_label = len(train_x) * [label_dict["lesion"]]
    valid_label = len(valid_x) * [label_dict["lesion"]]

    return (train_x, train_y, train_label), (valid_x, valid_y, valid_label)

class DATASET(Dataset):
    def __init__(self, images_path, labels_path, masks_path, size, transform=None):
        super().__init__()
        self.images_path = images_path
        self.labels_path = labels_path
        self.masks_path = masks_path
        self.size = size  # Ensure size is stored as an instance attribute
        self.transform = transform
        self.n_samples = len(images_path)

        self.embed = Text2Embed()

    def visualize_bbox_on_mask(self, image, mask, bboxes, title="Mask with BBoxes"):
        """Visualize the mask and bounding boxes overlaid on the mask."""
        if len(image.shape) != 3 or image.shape[2] != 3:
            raise ValueError(f"Expected a color image with 3 channels, but got shape {image.shape}")

        fig, ax = plt.subplots(1, 2, figsize=(12, 6))

        # Show original image
        ax[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        ax[0].set_title("Original Image")
        ax[0].axis("off")

        # Show mask with bounding boxes
        mask_with_bbox = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
        for bbox in bboxes:
            x1, y1, x2, y2 = bbox
            mask_with_bbox = cv2.rectangle(mask_with_bbox, (x1, y1), (x2, y2), (255, 0, 0), 2)

        ax[1].imshow(cv2.cvtColor(mask_with_bbox, cv2.COLOR_BGR2RGB))
        ax[1].set_title(title)
        ax[1].axis("off")

        plt.tight_layout()
        plt.show()


    def mask_to_text(self, mask, image=None):
        bboxes = mask_to_bbox(mask)
        lesion_sizes = 0
        num_lesions = 0

        # Calculate sizes and number of lesions
        for bbox in bboxes:
            x1, y1, x2, y2 = bbox
            h = (y2 - y1)
            w = (x2 - x1)
            area = (h * w) / (mask.shape[0] * mask.shape[1])

            if area < 0.007:
                lesion_sizes = 0
            elif 0.007 <= area < 0.04:
                lesion_sizes = 1
            elif area >= 0.04:
                lesion_sizes = 2

        if(len(bboxes) == 0):
          num_lesions = 0
        elif(len(bboxes) == 1):
          num_lesions = 1
        elif(len(bboxes) >= 2):
          num_lesions = 2

        # visualization
        if image is None:
            self.visualize_bbox_on_mask(image, mask, bboxes, title="Mask with Bounding Boxes")

        return np.array(num_lesions), np.array(lesion_sizes)


    def __getitem__(self, index):
      image = cv2.imread(self.images_path[index], cv2.IMREAD_COLOR)
      mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)

      if image is None or mask is None:
          raise FileNotFoundError(f"Could not read image or mask at index {index}. Check the file paths.")

      if self.transform is not None:
          augmentations = self.transform(image=image, mask=mask)
          image = augmentations["image"]
          mask = augmentations["mask"]

      image = cv2.resize(image, self.size)
      if len(image.shape) != 3 or image.shape[2] != 3:
          raise ValueError(f"Expected a color image with 3 channels after resizing, but got shape {image.shape}")

      image = np.transpose(image, (2, 0, 1))
      image = image / 255.0

      mask = cv2.resize(mask, self.size)
      mask_copy = mask
      mask = np.expand_dims(mask, axis=0)
      mask = mask / 255.0

      visual_image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)
      num_lesions, lesion_sizes = self.mask_to_text(mask_copy, visual_image)

      # print(f"image = {self.images_path[index]}\n")
      # print(f"mask = {self.masks_path[index]}\n"print(np.unique(labels))  # Confirm number and range of classes.
      label = []
      words = self.labels_path[index]
      for word in words:
          word_embed = self.embed.to_embed(word)[0]
          label.append(word_embed)
      label = np.array(label)
      # print(f"label shape: {label.shape}")
      # print(np.unique(label))  # Confirm number and range of classes.

      return (image, label), (mask, num_lesions, lesion_sizes)

    def __len__(self):
        return self.n_samples


def train(model, loader, optimizer, loss_fn, device):
    model.train()
    epoch_loss = 0.0
    epoch_jac = 0.0
    epoch_f1 = 0.0
    epoch_recall = 0.0
    epoch_precision = 0.0
    # scaler = amp.GradScaler()  # Define scaler before training

    for i, ((x, l), (y1, y2, y3)) in enumerate(loader):
        x = x.to(device, dtype=torch.float32)
        l = l.to(device, dtype=torch.float32)
        y1 = y1.to(device, dtype=torch.float32)
        y2 = y2.to(device, dtype=torch.long)
        y3 = y3.to(device, dtype=torch.long)

        optimizer.zero_grad()
     
        p1, p2, p3 = model(x, l)
        p2 = torch.softmax(p2, dim=1)
        p3 = torch.softmax(p3, dim=1)

        loss1 = loss_fn[0](p1, y1)
        loss2 = loss_fn[1](p2, y2)
        loss3 = loss_fn[2](p3, y3)

        loss = loss1 + loss2 + loss3

        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        batch_jac = []
        batch_f1 = []
        batch_recall = []
        batch_precision = []

        for yt, yp in zip(y1, p1):
            score = calculate_metrics(yt, yp)
            batch_jac.append(score[0])
            batch_f1.append(score[1])
            batch_recall.append(score[2])
            batch_precision.append(score[3])

        epoch_jac += np.mean(batch_jac)
        epoch_f1 += np.mean(batch_f1)
        epoch_recall += np.mean(batch_recall)
        epoch_precision += np.mean(batch_precision)

    epoch_loss = epoch_loss / len(loader)
    epoch_jac = epoch_jac / len(loader)
    epoch_f1 = epoch_f1 / len(loader)
    epoch_recall = epoch_recall / len(loader)
    epoch_precision = epoch_precision / len(loader)

    return epoch_loss, [epoch_jac, epoch_f1, epoch_recall, epoch_precision]

def evaluate(model, loader, loss_fn, device):
    model.eval()

    epoch_loss = 0.0
    epoch_jac = 0.0
    epoch_f1 = 0.0
    epoch_recall = 0.0
    epoch_precision = 0.0

    with torch.no_grad():
        for i, ((x, l), (y1, y2, y3)) in enumerate(loader):
            x = x.to(device, dtype=torch.float32)
            l = l.to(device, dtype=torch.float32)
            y1 = y1.to(device, dtype=torch.float32)
            y2 = y2.to(device, dtype=torch.long)
            y3 = y3.to(device, dtype=torch.long)

            p1, p2, p3 = model(x, l)
            p2 = torch.softmax(p2, dim=1)
            p3 = torch.softmax(p3, dim=1)

            loss1 = loss_fn[0](p1, y1)
            loss2 = loss_fn[1](p2, y2)
            loss3 = loss_fn[2](p3, y3)

            loss = loss1 + loss2 + loss3

            epoch_loss += loss.item()

            batch_jac = []
            batch_f1 = []
            batch_recall = []
            batch_precision = []

            for yt, yp in zip(y1, p1):
                score = calculate_metrics(yt, yp)
                batch_jac.append(score[0])
                batch_f1.append(score[1])
                batch_recall.append(score[2])
                batch_precision.append(score[3])

            epoch_jac += np.mean(batch_jac)
            epoch_f1 += np.mean(batch_f1)
            epoch_recall += np.mean(batch_recall)
            epoch_precision += np.mean(batch_precision)

    epoch_loss = epoch_loss / len(loader)
    epoch_jac = epoch_jac / len(loader)
    epoch_f1 = epoch_f1 / len(loader)
    epoch_recall = epoch_recall / len(loader)
    epoch_precision = epoch_precision / len(loader)

    return epoch_loss, [epoch_jac, epoch_f1, epoch_recall, epoch_precision]

if __name__ == "__main__":
    """ Seeding """
    seeding(42)

    """ Directories """
    create_dir("files")

    """ Training logfile """
    train_log_path = "files/train_log.txt"
    if os.path.exists(train_log_path):
        print("Log file exists")
    else:
        train_log = open("files/train_log.txt", "w")
        train_log.write("\n")
        train_log.close()

    """ Record Date & Time """
    datetime_object = str(datetime.datetime.now())
    print_and_save(train_log_path, datetime_object)
    print("")

    """ Hyperparameters """
    image_size = 256
    size = (image_size, image_size)
    batch_size = 16
    num_epochs = 70
    lr = 1e-4
    early_stopping_patience = 70
    checkpoint_path = "/content/checkpoint1 (7).pth"
    path = "/content/Pneumothorax"

    data_str = f"Image Size: {size}\nBatch Size: {batch_size}\nLR: {lr}\nEpochs: {num_epochs}\n"
    data_str += f"Early Stopping Patience: {early_stopping_patience}\n"
    print_and_save(train_log_path, data_str)

    """ Data augmentation: Transforms """
    transform = A.Compose([
        A.Rotate(limit=15, p=0.5),  # Moderate rotations
        A.HorizontalFlip(p=0.5),    # Horizontal flipping
        A.VerticalFlip(p=0.2),      # Vertical flipping
        A.RandomBrightnessContrast(p=0.3),  # Adjust brightness and contrast
        A.CLAHE(p=0.3),  # Contrast Limited Adaptive Histogram Equalization
        A.ElasticTransform(alpha=1, sigma=50, p=0.3),  # Elastic transformations
        A.CoarseDropout(num_holes_range=(2, 5), hole_height_range=(5, 20), hole_width_range=(5, 20), fill="inpaint_ns", p=0.5)  # Coarse dropout
    ])

    """ Dataset """
    (train_x, train_y, train_label), (valid_x, valid_y, valid_label) = load_data(path)
    train_x, train_y, train_label = shuffling(train_x, train_y, train_label)
    data_str = f"Dataset Size:\nTrain: {len(train_x)} - Valid: {len(valid_x)}\n"
    print_and_save(train_log_path, data_str)

    """ Dataset and loader """
    train_dataset = DATASET(train_x, train_label, train_y, (image_size, image_size), transform=transform)
    valid_dataset = DATASET(valid_x, valid_label, valid_y, (image_size, image_size), transform=None)

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4
    )

    valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4
    )

    """ Model """
    device = torch.device('cuda')
    model = TGAlesionSeg()
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
    # loss_fn = [DiceBCELoss(), nn.CrossEntropyLoss(), nn.CrossEntropyLoss()]
    # Using Focal Loss for num_lesions & lesion_sizes
    loss_fn = [DiceBCELoss(), FocalLoss(alpha=0.25, gamma=2), FocalLoss(alpha=0.25, gamma=2)]

    loss_name = "BCE Dice Loss"
    data_str = f"Optimizer: AdamW\nLoss: {loss_name}\n"
    print_and_save(train_log_path, data_str)


    try:
        model.load_state_dict(torch.load(checkpoint_path))
        print(f"Checkpoint loaded successfully from {checkpoint_path}. Resuming training.")
    except Exception as e:
        print(f"Error loading checkpoint: {e}")

    """ Training the model """
    best_valid_metrics = 0.8155  ## update the last metric you have accordingly
    early_stopping_count = 0

    # Logging
    data_str = f"Resuming training from checkpoint.\nOptimizer: Adam\nLoss: {loss_name}\n"
    print_and_save(train_log_path, data_str)

    # Resume training
    for epoch in range(num_epochs):  # Update resume_epoch appropriately
        start_time = time.time()

        train_loss, train_metrics = train(model, train_loader, optimizer, loss_fn, device)
        valid_loss, valid_metrics = evaluate(model, valid_loader, loss_fn, device)
        scheduler.step(valid_loss)

        if valid_metrics[1] > best_valid_metrics:
            data_str = f"Valid F1 improved from {best_valid_metrics:2.4f} to {valid_metrics[1]:2.4f}. Saving checkpoint: {checkpoint_path}"
            print_and_save(train_log_path, data_str)

            best_valid_metrics = valid_metrics[1]
            torch.save(model.state_dict(), checkpoint_path)
            early_stopping_count = 0

        elif valid_metrics[1] < best_valid_metrics:
            early_stopping_count += 1

        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        data_str = f"Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n"
        data_str += f"\tTrain Loss: {train_loss:.4f} - Jaccard: {train_metrics[0]:.4f} - F1: {train_metrics[1]:.4f} - Recall: {train_metrics[2]:.4f} - Precision: {train_metrics[3]:.4f}\n"
        data_str += f"\t Val. Loss: {valid_loss:.4f} - Jaccard: {valid_metrics[0]:.4f} - F1: {valid_metrics[1]:.4f} - Recall: {valid_metrics[2]:.4f} - Precision: {valid_metrics[3]:.4f}\n"
        print_and_save(train_log_path, data_str)

        if early_stopping_count == early_stopping_patience:
            data_str = f"Early stopping: validation loss stops improving from last {early_stopping_patience} continuously.\n"
            print_and_save(train_log_path, data_str)
            break



2025-02-16 09:55:12.272669

Image Size: (256, 256)
Batch Size: 16
LR: 0.0001
Epochs: 70
Early Stopping Patience: 70

{'polyp': ['zero', 'one', 'multiple', 'small', 'medium', 'large']}
Dataset Size:
Train: 9636 - Valid: 2409

downloading https://nlp.h-its.org/bpemb/en/en.wiki.bpe.vs100000.model


100%|██████████| 1987533/1987533 [00:00<00:00, 2180635.77B/s]


downloading https://nlp.h-its.org/bpemb/en/en.wiki.bpe.vs100000.d300.w2v.bin.tar.gz


100%|██████████| 112159933/112159933 [00:05<00:00, 19427886.86B/s]
Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 222MB/s]


Optimizer: AdamW
Loss: BCE Dice Loss

Checkpoint loaded successfully from /content/checkpoint1 (7).pth. Resuming training.
Resuming training from checkpoint.
Optimizer: Adam
Loss: BCE Dice Loss



  model.load_state_dict(torch.load(checkpoint_path))


Epoch: 01 | Epoch Time: 11m 57s
	Train Loss: 0.7596 - Jaccard: 0.7367 - F1: 0.7549 - Recall: 0.8552 - Precision: 0.8426
	 Val. Loss: 0.8292 - Jaccard: 0.6834 - F1: 0.7029 - Recall: 0.8864 - Precision: 0.7540

Epoch: 02 | Epoch Time: 11m 58s
	Train Loss: 0.7639 - Jaccard: 0.7392 - F1: 0.7571 - Recall: 0.8532 - Precision: 0.8499
	 Val. Loss: 0.7556 - Jaccard: 0.7696 - F1: 0.7888 - Recall: 0.8697 - Precision: 0.8588

Epoch: 03 | Epoch Time: 11m 58s
	Train Loss: 0.7489 - Jaccard: 0.7445 - F1: 0.7619 - Recall: 0.8531 - Precision: 0.8552
	 Val. Loss: 0.7887 - Jaccard: 0.8000 - F1: 0.8138 - Recall: 0.8439 - Precision: 0.9316

Epoch: 04 | Epoch Time: 11m 59s
	Train Loss: 0.7352 - Jaccard: 0.7510 - F1: 0.7687 - Recall: 0.8542 - Precision: 0.8632
	 Val. Loss: 0.7296 - Jaccard: 0.7860 - F1: 0.8030 - Recall: 0.8540 - Precision: 0.9019

Epoch: 05 | Epoch Time: 11m 58s
	Train Loss: 0.7357 - Jaccard: 0.7361 - F1: 0.7560 - Recall: 0.8619 - Precision: 0.8350
	 Val. Loss: 0.7604 - Jaccard: 0.7993 - F1: 

KeyboardInterrupt: 

Epoch: 01 | Epoch Time: 11m 57s
	Train Loss: 0.7596 - Jaccard: 0.7367 - F1: 0.7549 - Recall: 0.8552 - Precision: 0.8426
	 Val. Loss: 0.8292 - Jaccard: 0.6834 - F1: 0.7029 - Recall: 0.8864 - Precision: 0.7540

Epoch: 02 | Epoch Time: 11m 58s
	Train Loss: 0.7639 - Jaccard: 0.7392 - F1: 0.7571 - Recall: 0.8532 - Precision: 0.8499
	 Val. Loss: 0.7556 - Jaccard: 0.7696 - F1: 0.7888 - Recall: 0.8697 - Precision: 0.8588

Epoch: 03 | Epoch Time: 11m 58s
	Train Loss: 0.7489 - Jaccard: 0.7445 - F1: 0.7619 - Recall: 0.8531 - Precision: 0.8552
	 Val. Loss: 0.7887 - Jaccard: 0.8000 - F1: 0.8138 - Recall: 0.8439 - Precision: 0.9316

Epoch: 04 | Epoch Time: 11m 59s
	Train Loss: 0.7352 - Jaccard: 0.7510 - F1: 0.7687 - Recall: 0.8542 - Precision: 0.8632
	 Val. Loss: 0.7296 - Jaccard: 0.7860 - F1: 0.8030 - Recall: 0.8540 - Precision: 0.9019

Epoch: 05 | Epoch Time: 11m 58s
	Train Loss: 0.7357 - Jaccard: 0.7361 - F1: 0.7560 - Recall: 0.8619 - Precision: 0.8350
	 Val. Loss: 0.7604 - Jaccard: 0.7993 - F1: 0.8150 - Recall: 0.8555 - Precision: 0.9164

Epoch: 06 | Epoch Time: 11m 58s
	Train Loss: 0.7427 - Jaccard: 0.7451 - F1: 0.7641 - Recall: 0.8600 - Precision: 0.8486
	 Val. Loss: 0.7418 - Jaccard: 0.7638 - F1: 0.7833 - Recall: 0.8784 - Precision: 0.8507

Epoch: 07 | Epoch Time: 11m 58s
	Train Loss: 0.7451 - Jaccard: 0.7387 - F1: 0.7574 - Recall: 0.8587 - Precision: 0.8419
	 Val. Loss: 0.7234 - Jaccard: 0.7682 - F1: 0.7876 - Recall: 0.8666 - Precision: 0.8617

Valid F1 improved from 0.8155 to 0.8168. Saving checkpoint: /content/checkpoint1 (7).pth
Epoch: 08 | Epoch Time: 11m 59s
	Train Loss: 0.7201 - Jaccard: 0.7560 - F1: 0.7751 - Recall: 0.8618 - Precision: 0.8573
	 Val. Loss: 0.7459 - Jaccard: 0.8020 - F1: 0.8168 - Recall: 0.8469 - Precision: 0.9322

Epoch: 09 | Epoch Time: 11m 58s
	Train Loss: 0.7056 - Jaccard: 0.7575 - F1: 0.7770 - Recall: 0.8628 - Precision: 0.8606
	 Val. Loss: 0.8386 - Jaccard: 0.6737 - F1: 0.6956 - Recall: 0.9047 - Precision: 0.7221

Epoch: 10 | Epoch Time: 11m 58s
	Train Loss: 0.7027 - Jaccard: 0.7562 - F1: 0.7758 - Recall: 0.8635 - Precision: 0.8582
	 Val. Loss: 0.7278 - Jaccard: 0.7347 - F1: 0.7545 - Recall: 0.8749 - Precision: 0.8208

Epoch: 11 | Epoch Time: 11m 58s
	Train Loss: 0.7231 - Jaccard: 0.7418 - F1: 0.7608 - Recall: 0.8583 - Precision: 0.8470
	 Val. Loss: 0.7282 - Jaccard: 0.7772 - F1: 0.7953 - Recall: 0.8649 - Precision: 0.8781

Epoch: 12 | Epoch Time: 11m 58s
	Train Loss: 0.6636 - Jaccard: 0.7668 - F1: 0.7879 - Recall: 0.8701 - Precision: 0.8594
	 Val. Loss: 0.6867 - Jaccard: 0.7840 - F1: 0.8035 - Recall: 0.8737 - Precision: 0.8745

Epoch: 13 | Epoch Time: 11m 58s
	Train Loss: 0.6449 - Jaccard: 0.7786 - F1: 0.7993 - Recall: 0.8695 - Precision: 0.8749
	 Val. Loss: 0.6815 - Jaccard: 0.7919 - F1: 0.8112 - Recall: 0.8720 - Precision: 0.8836

Valid F1 improved from 0.8168 to 0.8197. Saving checkpoint: /content/checkpoint1 (7).pth
Epoch: 14 | Epoch Time: 11m 59s
	Train Loss: 0.6567 - Jaccard: 0.7781 - F1: 0.7982 - Recall: 0.8669 - Precision: 0.8799
	 Val. Loss: 0.6793 - Jaccard: 0.8012 - F1: 0.8197 - Recall: 0.8677 - Precision: 0.9022

Epoch: 15 | Epoch Time: 11m 58s
	Train Loss: 0.6480 - Jaccard: 0.7818 - F1: 0.8025 - Recall: 0.8712 - Precision: 0.8778
	 Val. Loss: 0.6684 - Jaccard: 0.7851 - F1: 0.8056 - Recall: 0.8799 - Precision: 0.8659

Epoch: 16 | Epoch Time: 11m 58s
	Train Loss: 0.6392 - Jaccard: 0.7868 - F1: 0.8072 - Recall: 0.8693 - Precision: 0.8859
	 Val. Loss: 0.6711 - Jaccard: 0.7985 - F1: 0.8172 - Recall: 0.8692 - Precision: 0.8970

Epoch: 17 | Epoch Time: 11m 58s
	Train Loss: 0.6234 - Jaccard: 0.7890 - F1: 0.8103 - Recall: 0.8729 - Precision: 0.8849
	 Val. Loss: 0.6672 - Jaccard: 0.7937 - F1: 0.8131 - Recall: 0.8762 - Precision: 0.8820

Epoch: 18 | Epoch Time: 11m 58s
	Train Loss: 0.6276 - Jaccard: 0.7906 - F1: 0.8111 - Recall: 0.8698 - Precision: 0.8902
	 Val. Loss: 0.6651 - Jaccard: 0.7792 - F1: 0.8003 - Recall: 0.8851 - Precision: 0.8538

Valid F1 improved from 0.8197 to 0.8200. Saving checkpoint: /content/checkpoint1 (7).pth
Epoch: 19 | Epoch Time: 11m 59s
	Train Loss: 0.6273 - Jaccard: 0.7884 - F1: 0.8098 - Recall: 0.8758 - Precision: 0.8787
	 Val. Loss: 0.6686 - Jaccard: 0.8013 - F1: 0.8200 - Recall: 0.8680 - Precision: 0.8974

Epoch: 20 | Epoch Time: 11m 58s
	Train Loss: 0.6258 - Jaccard: 0.7885 - F1: 0.8096 - Recall: 0.8743 - Precision: 0.8821
	 Val. Loss: 0.6560 - Jaccard: 0.7842 - F1: 0.8052 - Recall: 0.8818 - Precision: 0.8617

Epoch: 21 | Epoch Time: 11m 58s
	Train Loss: 0.6222 - Jaccard: 0.7946 - F1: 0.8145 - Recall: 0.8703 - Precision: 0.8952
	 Val. Loss: 0.6625 - Jaccard: 0.7962 - F1: 0.8156 - Recall: 0.8722 - Precision: 0.8886

Epoch: 22 | Epoch Time: 11m 58s
	Train Loss: 0.6135 - Jaccard: 0.7986 - F1: 0.8196 - Recall: 0.8728 - Precision: 0.8960
	 Val. Loss: 0.6623 - Jaccard: 0.7816 - F1: 0.8024 - Recall: 0.8829 - Precision: 0.8582

Epoch: 23 | Epoch Time: 11m 58s
	Train Loss: 0.6107 - Jaccard: 0.7975 - F1: 0.8184 - Recall: 0.8742 - Precision: 0.8937
	 Val. Loss: 0.6608 - Jaccard: 0.7928 - F1: 0.8128 - Recall: 0.8787 - Precision: 0.8786

Epoch: 24 | Epoch Time: 11m 58s
	Train Loss: 0.6089 - Jaccard: 0.7957 - F1: 0.8171 - Recall: 0.8759 - Precision: 0.8870
	 Val. Loss: 0.6594 - Jaccard: 0.7821 - F1: 0.8028 - Recall: 0.8803 - Precision: 0.8610

Epoch: 25 | Epoch Time: 11m 58s
	Train Loss: 0.6073 - Jaccard: 0.7960 - F1: 0.8177 - Recall: 0.8765 - Precision: 0.8886
	 Val. Loss: 0.6585 - Jaccard: 0.7903 - F1: 0.8105 - Recall: 0.8765 - Precision: 0.8734

Epoch: 26 | Epoch Time: 11m 58s
	Train Loss: 0.6083 - Jaccard: 0.8015 - F1: 0.8227 - Recall: 0.8751 - Precision: 0.8957
	 Val. Loss: 0.6572 - Jaccard: 0.7961 - F1: 0.8158 - Recall: 0.8751 - Precision: 0.8831

Epoch: 27 | Epoch Time: 11m 58s
	Train Loss: 0.6245 - Jaccard: 0.7954 - F1: 0.8167 - Recall: 0.8750 - Precision: 0.8880
	 Val. Loss: 0.6542 - Jaccard: 0.7875 - F1: 0.8081 - Recall: 0.8803 - Precision: 0.8671

Epoch: 28 | Epoch Time: 11m 58s
	Train Loss: 0.6092 - Jaccard: 0.8009 - F1: 0.8221 - Recall: 0.8759 - Precision: 0.8945
	 Val. Loss: 0.6557 - Jaccard: 0.7968 - F1: 0.8162 - Recall: 0.8744 - Precision: 0.8862

Epoch: 29 | Epoch Time: 11m 58s
	Train Loss: 0.6092 - Jaccard: 0.7978 - F1: 0.8193 - Recall: 0.8769 - Precision: 0.8920
	 Val. Loss: 0.6532 - Jaccard: 0.7948 - F1: 0.8148 - Recall: 0.8775 - Precision: 0.8792

Epoch: 30 | Epoch Time: 11m 58s
	Train Loss: 0.6016 - Jaccard: 0.8010 - F1: 0.8223 - Recall: 0.8763 - Precision: 0.8933
	 Val. Loss: 0.6532 - Jaccard: 0.7975 - F1: 0.8171 - Recall: 0.8752 - Precision: 0.8847

Epoch: 31 | Epoch Time: 11m 58s
	Train Loss: 0.6045 - Jaccard: 0.8013 - F1: 0.8225 - Recall: 0.8754 - Precision: 0.8948
	 Val. Loss: 0.6552 - Jaccard: 0.7970 - F1: 0.8164 - Recall: 0.8740 - Precision: 0.8869

Epoch: 32 | Epoch Time: 11m 59s
	Train Loss: 0.6026 - Jaccard: 0.8025 - F1: 0.8233 - Recall: 0.8738 - Precision: 0.8984
	 Val. Loss: 0.6561 - Jaccard: 0.7950 - F1: 0.8147 - Recall: 0.8753 - Precision: 0.8828