<a href="https://colab.research.google.com/github/xByEMPE/BYOL_MODEL_VND/blob/main/vnd_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet18
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import Callback
from PIL import Image
import os
import matplotlib.pyplot as plt

# ======================
# Dataset for SSL
# ======================
class CustomDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_paths = [os.path.join(image_dir, img) for img in os.listdir(image_dir) if img.endswith(('png', 'jpg', 'jpeg'))]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

# ======================
# Transformations
# ======================
transform_ssl = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.RandomResizedCrop(224, scale=(0.8, 1.0)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset_ssl = CustomDataset(image_dir="cambiar ruta a imagenes con etiquetas", transform=transform_ssl)
data_loader_ssl = DataLoader(dataset_ssl, batch_size=32, shuffle=True)

# ======================
# BYOL Model
# ======================
class BYOL(pl.LightningModule):
    def __init__(self, lr=1e-3):
        super(BYOL, self).__init__()
        self.lr = lr

        # Online network
        self.online_network = resnet18(pretrained=False)
        self.online_projector = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

        # Target network
        self.target_network = resnet18(pretrained=False)
        self.target_projector = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

        # Freeze target network
        for param in self.target_network.parameters():
            param.requires_grad = False

    def forward(self, x):
        features = self.online_network(x)
        projections = self.online_projector(features)
        return projections

    def training_step(self, batch, batch_idx):
        images = batch
        augmented_images_1 = images
        augmented_images_2 = images

        # Forward pass
        online_proj_1 = self.forward(augmented_images_1)
        online_proj_2 = self.forward(augmented_images_2)

        with torch.no_grad():
            target_proj_1 = self.target_projector(self.target_network(augmented_images_1))
            target_proj_2 = self.target_projector(self.target_network(augmented_images_2))

        # Compute loss (negative cosine similarity)
        loss = -torch.mean(
            nn.functional.cosine_similarity(online_proj_1, target_proj_2.detach(), dim=-1) +
            nn.functional.cosine_similarity(online_proj_2, target_proj_1.detach(), dim=-1)
        )

        self.log('Perdida de entrenamiento', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

# ======================
# Metric Callback for Plotting
# ======================
class MetricLoggerCallback(Callback):
    def __init__(self):
        self.train_losses = []

    def on_train_epoch_end(self, trainer, pl_module):
        # Log training loss
        loss = trainer.callback_metrics.get("perdida_entrenamiento")
        if loss is not None:
            self.train_losses.append(loss.item())

    def plot_metrics(self):
        # Plot training loss
        plt.figure(figsize=(10, 5))
        plt.plot(self.train_losses, label="perdida de entrenamiento")
        plt.xlabel("Epocas")
        plt.ylabel("Perdida")
        plt.title("Perdida de entrenamiento por epocas")
        plt.legend()
        plt.show()

# ======================
# Early Stopping Callback
# ======================
early_stopping_callback = EarlyStopping(
    monitor="perdida_entrenamiento",  # Monitor the training loss
    patience=8,            # Stop after 8 epochs without improvement
    mode="min"             # Stop when the monitored metric stops decreasing
)

# Initialize Metric Logger
metric_logger = MetricLoggerCallback()

# ======================
# Training
# ======================
model = BYOL()
trainer = Trainer(max_epochs=50, gpus=1, callbacks=[early_stopping_callback, metric_logger])
trainer.fit(model, data_loader_ssl)

# ======================
# Plot Metrics
# ======================
metric_logger.plot_metrics()




FileNotFoundError: [Errno 2] No such file or directory: 'path_to_images'