In [None]:
#SSL model: BYOL+ MIM + Contrastive Learning
#loss = 1 * byol_loss + 0.6 * mim_loss + 0.6 * contrastive_loss

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from byol_pytorch import BYOL
from torch.utils.data import DataLoader
import random
import numpy as np
import kornia.augmentation as K
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
from torchvision.datasets import DatasetFolder
from torchvision.datasets.folder import default_loader
import os
from torch.utils.data import Dataset
from PIL import Image
from tqdm import tqdm  # Import tqdm for progress bars

# Ensure Reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


class PlantDocDataset(Dataset):
    def __init__(self, folder, target_size=(384,384)):
        # We won't do heavy transforms here. We'll let Kornia handle it on GPU.
        valid_exts = {'.png', '.jpg', '.jpeg', '.bmp', '.gif'}
        files = os.listdir(folder)

        self.image_paths = []
        for f in files:
            _, ext = os.path.splitext(f)
            if ext.lower() in valid_exts:
                self.image_paths.append(os.path.join(folder, f))

        self.target_size = target_size
        print(f"Found {len(self.image_paths)} images in '{folder}'.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        try:
            # Open image as PIL, convert to RGB, then resize to target size and convert to Tensor
            image = Image.open(path).convert('RGB')
            image = transforms.Resize(self.target_size)(image)  # Resize all images to fixed size
            image = transforms.ToTensor()(image)  # shape [C,H,W] on CPU
            return image
        except Exception as e:
            print(f"Warning: Skipping corrupted image {path} - {e}")
            # Fallback: move to the next image in a circular manner
            return self.__getitem__((idx + 1) % len(self.image_paths))

plant_dataset_folder = "../PlantDoc-Dataset/unlabel"  # set your path
dataset = PlantDocDataset(folder=plant_dataset_folder, target_size=(384,384))

batch_size = 16

train_loader_ssl = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    drop_last=True,
    persistent_workers=True
)

print("Data Loaders Ready for SSL!")


class GPUAugment(nn.Module):
    def __init__(self, image_size=384):
        super().__init__()
        self.augs = nn.Sequential(
            K.RandomResizedCrop((image_size, image_size), scale=(0.6, 1.0)),  # Increase diversity
            K.RandomHorizontalFlip(p=0.5),
            K.RandomRotation(degrees=20.0),
            K.RandomSolarize(thresholds=0.5, p=0.2),  # NEW: Randomly darken images
            K.RandomGrayscale(p=0.2),  # NEW: Removes color, forces shape-based learning
            K.ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8),
            K.RandomGaussianBlur((3,3), sigma=(0.1,2.0), p=0.3)
        )
        # Separate normalization step
        self.normalize = K.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225])
        )

    def forward(self, x):
        x = self.augs(x)               # apply random augmentations
        x = self.normalize(x)          # then normalize
        return x

class MLPClassifier(nn.Module):
    def __init__(self, in_features, num_classes):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.classifier(x)

# Initialize ResNet101 backbone and wrap it in BYOL
from torchvision.models import resnet101, ResNet101_Weights
resnet = resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)

resnet.fc = MLPClassifier(in_features=2048, num_classes=27)


# Create a GPU-based augmentation function
gpu_augment = GPUAugment(image_size=384).to(device)

def batch_augment_fn(batch):
    """
    Expects a (B,C,H,W) Tensor on GPU, returns augmented + normalized Tensor on GPU.
    """
    return gpu_augment(batch)

learner = BYOL(
    net=resnet,
    image_size=384,  # Not used much if we do GPU transforms
    hidden_layer='avgpool',
    augment_fn=batch_augment_fn,
    augment_fn2=batch_augment_fn,
    projection_hidden_size = 2048,
    moving_average_decay=0.99,
    use_momentum=True
).to(device)


##################################
# Define Loss Functions
##################################
# Masked Image Modeling (MIM)
import kornia.losses
class MIMLoss(nn.Module):
    def __init__(self, min_mask=0.3, max_mask=0.6, alpha=0.5):
        super().__init__()
        self.min_mask = min_mask
        self.max_mask = max_mask
        self.alpha = alpha
        self.mse_loss = nn.MSELoss()
        self.ssim_loss = kornia.losses.SSIMLoss(window_size=11)

    def forward(self, x, target):
        batch_size, channels, height, width = x.shape
        mask_ratio = random.uniform(self.min_mask, self.max_mask)
        mask = torch.rand(batch_size, height, width, device=x.device) < mask_ratio
        mask = mask.unsqueeze(1).expand(-1, channels, -1, -1)

        x_masked = x.clone()
        x_masked[mask] = 0

        # Normalize after masking
        normalize = K.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        x_masked = normalize(x_masked)
        target = normalize(target)

        mse = self.mse_loss(x_masked, target)
        ssim = self.ssim_loss(x_masked, target)

        return self.alpha * mse + (1 - self.alpha) * (1 - ssim)


# Contrastive Learning Loss (InfoNCE)
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature):
        super().__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        batch_size = z_i.shape[0]
        z_i = F.normalize(z_i, dim=1)
        z_j = F.normalize(z_j, dim=1)

        logits = torch.mm(z_i, z_j.T) / self.temperature
        labels = torch.arange(batch_size).to(z_i.device)

        loss = F.cross_entropy(logits, labels)
        return loss

mim_loss_fn = MIMLoss(min_mask=0.3, max_mask=0.6).to(device)


contrastive_loss_fn = ContrastiveLoss(temperature=0.1).to(device)


optimizer = torch.optim.AdamW(
    learner.parameters(),
    lr=1e-6,  # Start with a very low LR for warmup
    weight_decay=1e-5
)

max_lr = 1e-4
min_lr = 1e-6

def warmup_lr_scheduler(optimizer, warmup_epochs=10, min_lr=1e-6, max_lr=1e-4):
    """
    Warm-up scheduler: Over 'warmup_epochs' epochs, LR will go from `min_lr` to `max_lr` linearly.
    Then you can switch to a different scheduler afterwards.
    """
    # We assume you set the optimizer's "base LR" to max_lr ahead of time.
    # The factor starts from (min_lr / max_lr) and goes to 1.0.
    start_factor = min_lr / max_lr

    def lr_lambda(epoch):
      progress = epoch / warmup_epochs  # This gives 0 when epoch=0 and 1.0 when epoch==warmup_epochs
      factor = start_factor + (1.0 - start_factor) * min(progress, 1.0)
      return factor

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


optimizer = torch.optim.AdamW(learner.parameters(), lr=max_lr, weight_decay=1e-5)

# Create warmup scheduler
warmup_epochs = 10
warmup_scheduler = warmup_lr_scheduler(
    optimizer,
    warmup_epochs=warmup_epochs,
    min_lr=min_lr,
    max_lr=max_lr
)



# Gradient Scaler for Mixed Precision
scaler = torch.cuda.amp.GradScaler()

# Checkpoint file path
checkpoint_path = "checkpoint.pth"

# Resume training if a checkpoint exists
start_epoch = 0
best_loss = float('inf')
if os.path.exists(checkpoint_path):
    print("Resuming training from checkpoint:", checkpoint_path)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    learner.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_loss = checkpoint['best_loss']
    avg_loss = checkpoint['avg_loss']  # Updated here
    print(f"Resumed at epoch {start_epoch} with avg_loss {avg_loss:.4f}")

# Training Loop
epochs = 90  # Train

for epoch in range(start_epoch, epochs):
    learner.train()
    total_loss = 0
    total_mim_loss = 0
    total_contrastive_loss = 0

    optimizer.zero_grad()

    # Wrap the training data loader with tqdm for a progress bar
    pbar = tqdm(enumerate(train_loader_ssl), total=len(train_loader_ssl), desc=f"Epoch [{epoch+1}/{epochs}]")
    for batch_idx, images in pbar:
        images = images.to(device)

        with torch.cuda.amp.autocast(dtype=torch.float32):  # Force ResNet to use float32
            # Forward Pass (BYOL)
            byol_loss = learner(images)
            # Forward Pass (MIM)
            mim_loss = mim_loss_fn(images, images)  # Target is the original image

            # Generate two augmented views for contrastive learning
            augmented_one = batch_augment_fn(images)
            augmented_two = batch_augment_fn(images)
            z_i, _ = learner.online_encoder(augmented_one, return_projection=True)
            z_j, _ = learner.online_encoder(augmented_two, return_projection=True)

            contrastive_loss = contrastive_loss_fn(z_i, z_j)
            # Total Loss (Weighted Sum)
            loss = 1 * byol_loss + 0.6 * mim_loss + 0.6 * contrastive_loss


        # Backpropagation (with Gradient Accumulation every 2 steps)
        scaler.scale(loss).backward()
        if (batch_idx + 1) % 2 == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        total_loss += loss.item()
        total_mim_loss += mim_loss.item()
        total_contrastive_loss += contrastive_loss.item()

        # Update tqdm progress bar with current loss values
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'BYOL': f"{byol_loss.item():.4f}",
            'MIM': f"{mim_loss.item():.4f}",
            'Contrastive': f"{contrastive_loss.item():.4f}"
        })

    avg_loss = total_loss / len(train_loader_ssl)
    avg_mim_loss = total_mim_loss / len(train_loader_ssl)
    avg_contrastive_loss = total_contrastive_loss / len(train_loader_ssl)

    print(f"Epoch [{epoch+1}/{epochs}] - Total Loss: {avg_loss:.4f} | MIM: {avg_mim_loss:.4f} | Contrastive: {avg_contrastive_loss:.4f}")

    # Update learning rate with warmup for first 10 epochs
    if epoch < 10:
        warmup_scheduler.step()
        print(f"Warmup Scheduler Updated LR: {optimizer.param_groups[0]['lr']:.6f}")

    # Save the best model if current loss is lower
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(resnet.state_dict(), "byol_mim_contrastive_best.pth")
        print(f"Saved Best Model at epoch {epoch+1}!")

    # Save additional .pth and checkpoints every 5 epochs
    if (epoch + 1) % 5 == 0:
        torch.save(resnet.state_dict(), f"byol_mim_contrastive_epoch{epoch+1}.pth")
        checkpoint_name = f"checkpoint_epoch{epoch+1}.pth"
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': learner.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'best_loss': best_loss,
            'avg_loss': avg_loss,

        }
        torch.save(checkpoint, checkpoint_name)

print("Self-Supervised Pretraining Complete!")
