<a href="https://colab.research.google.com/github/xByEMPE/BYOL_MODEL_VND/blob/main/vnd_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet18
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import Callback
from PIL import Image
import os
import matplotlib.pyplot as plt
from google.colab import drive

# Montar Google Drive
drive.mount('/content/drive', force_remount=True)

# ======================
# Dataset for SSL
# ======================
class CustomDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_paths = []

        # Recorre todas las subcarpetas en image_dir
        for root, _, files in os.walk(image_dir):
            for file in files:
                if file.endswith(('png', 'jpg', 'jpeg')):
                    self.image_paths.append(os.path.join(root, file))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

# ======================
# Transformations
# ======================
transform_ssl = T.Compose([
    T.Resize((224, 224)),  # Redimensiona todas las imágenes a 224x224
    T.RandomHorizontalFlip(),
    T.RandomResizedCrop(224, scale=(0.8, 1.0)),
    T.ToTensor(),  # Convierte la imagen a formato tensor (C, H, W)
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normaliza los valores RGB
])

# Rutas a las carpetas en Google Drive
train_image_dir = "/content/drive/My Drive/ssl_train/labels"  # Carpeta con subcarpetas de categorías
val_image_dir = "/content/drive/My Drive/ssl_train/no_labels"  # Carpeta de validación

# Datasets
train_dataset = CustomDataset(image_dir=train_image_dir, transform=transform_ssl)
val_dataset = CustomDataset(image_dir=val_image_dir, transform=transform_ssl)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# ======================
# BYOL Model
# ======================
class BYOL(pl.LightningModule):
    def __init__(self, lr=1e-3):
        super(BYOL, self).__init__()
        self.lr = lr

        # Online network
        self.online_network = resnet18(weights=None)
        self.online_network.fc = nn.Identity()  # Desactiva la última capa
        self.online_pool = nn.AdaptiveAvgPool2d((1, 1))  # Pooling global

        self.online_projector = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

        # Target network
        self.target_network = resnet18(weights=None)
        self.target_network.fc = nn.Identity()
        self.target_pool = nn.AdaptiveAvgPool2d((1, 1))  # Pooling global

        self.target_projector = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

        # Freeze target network
        for param in self.target_network.parameters():
            param.requires_grad = False

    def forward(self, x):
        features = self.online_network(x)  # Output de ResNet (batch_size, 512, H, W)
        pooled_features = self.online_pool(features).view(features.size(0), -1)  # (batch_size, 512)
        projections = self.online_projector(pooled_features)  # (batch_size, 128)
        return projections

    def training_step(self, batch, batch_idx):
        images = batch
        augmented_images_1 = images
        augmented_images_2 = images

        # Forward pass
        online_proj_1 = self.forward(augmented_images_1)
        online_proj_2 = self.forward(augmented_images_2)

        with torch.no_grad():
            target_proj_1 = self.target_projector(self.target_pool(self.target_network(augmented_images_1)).view(augmented_images_1.size(0), -1))
            target_proj_2 = self.target_projector(self.target_pool(self.target_network(augmented_images_2)).view(augmented_images_2.size(0), -1))

        # Compute loss (negative cosine similarity)
        loss = -torch.mean(
            nn.functional.cosine_similarity(online_proj_1, target_proj_2.detach(), dim=-1) +
            nn.functional.cosine_similarity(online_proj_2, target_proj_1.detach(), dim=-1)
        )

        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        images = batch
        augmented_images_1 = images
        augmented_images_2 = images

        # Forward pass
        online_proj_1 = self.forward(augmented_images_1)
        online_proj_2 = self.forward(augmented_images_2)

        with torch.no_grad():
            target_proj_1 = self.target_projector(self.target_pool(self.target_network(augmented_images_1)).view(augmented_images_1.size(0), -1))
            target_proj_2 = self.target_projector(self.target_pool(self.target_network(augmented_images_2)).view(augmented_images_2.size(0), -1))

        # Compute loss (negative cosine similarity)
        loss = -torch.mean(
            nn.functional.cosine_similarity(online_proj_1, target_proj_2.detach(), dim=-1) +
            nn.functional.cosine_similarity(online_proj_2, target_proj_1.detach(), dim=-1)
        )

        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

# ======================
# Metric Callback for Plotting
# ======================
class MetricLoggerCallback(Callback):
    def __init__(self):
        self.train_losses = []
        self.val_losses = []

    def on_train_epoch_end(self, trainer, pl_module):
        # Log training loss
        loss = trainer.callback_metrics.get("train_loss")
        if loss is not None:
            self.train_losses.append(loss.item())

    def on_validation_epoch_end(self, trainer, pl_module):
        # Log validation loss
        loss = trainer.callback_metrics.get("val_loss")
        if loss is not None:
            self.val_losses.append(loss.item())

    def plot_metrics(self):
        # Plot training and validation loss
        plt.figure(figsize=(10, 5))
        plt.plot(self.train_losses, label="Train Loss")
        plt.plot(self.val_losses, label="Validation Loss")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.title("Training and Validation Loss over Epochs")
        plt.legend()
        plt.show()

# ======================
# Early Stopping Callback
# ======================
early_stopping_callback = EarlyStopping(
    monitor="val_loss",  # Monitor the validation loss
    patience=8,          # Stop after 8 epochs without improvement
    mode="min"           # Stop when the monitored metric stops decreasing
)

# Initialize Metric Logger
metric_logger = MetricLoggerCallback()

# ======================
# Training
# ======================
model = BYOL()
trainer = Trainer(
    max_epochs=50,
    accelerator="gpu",  # Usa la GPU de Colab
    devices=1,          # Usa una GPU
    callbacks=[early_stopping_callback, metric_logger]
)
trainer.fit(model, train_loader, val_loader)

# ======================
# Plot Metrics
# ======================
metric_logger.plot_metrics()


Mounted at /content/drive


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name             | Type              | Params | Mode 
---------------------------------------------------------------
0 | online_network   | ResNet            | 11.2 M | train
1 | online_pool      | AdaptiveAvgPool2d | 0      | train
2 | online_projector | Sequential        | 164 K  | train
3 | target_network   | ResNet            | 11.2 M | train
4 | target_pool      | AdaptiveAvgPool2d | 0      | train
5 | target_projector | Sequential        | 164 K  | train
---------------------------------------------------------------
11.5 M    Trainable params
11.2 M    Non-trainable params
22.7 M    Total params
90.726   

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

ValueError: Input dimension should be at least 3