In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

# === CONFIG ===
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TRAIN_DIR = r"C:\Users\shpigel-lab\Desktop\tiles\train"
MODEL_PATH = r"C:\Users\shpigel-lab\Documents\DL project\models\autoencoder.pth"
BATCH_SIZE = 32
EPOCHS = 50
LEARNING_RATE = 1e-3
IMAGE_SIZE = 128

# === TRANSFORM ===
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
])

# === DATASET CLASS ===
class PatchDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.paths = [os.path.join(folder, fname) for fname in os.listdir(folder) if fname.endswith(".png")]
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_path = self.paths[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img

# === MODEL ===
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# === LOAD DATA ===
train_dataset = PatchDataset(TRAIN_DIR, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)

# === TRAINING ===
model = Autoencoder().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()
    
best_loss = float('inf')
epochs_without_improvement = 0
NUM_EPOCHS = 50  # or any number you choose as the maximum
patience = 3  # stop after 3 bad epochs

for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    running_loss = 0.0
    for imgs in tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS}"):
        imgs = imgs.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, imgs)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch}, Loss: {epoch_loss:.5f}")
    ...


    # === EARLY STOPPING LOGIC ===
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        epochs_without_improvement = 0
        # optionally save model checkpoint here
        torch.save(model.state_dict(), MODEL_PATH)
        print("✅ New best model saved.")
    else:
        epochs_without_improvement += 1
        print(f"⚠️ No improvement for {epochs_without_improvement} epoch(s).")

    if epochs_without_improvement >= patience:
        print(f"🛑 Early stopping triggered after {epoch} epochs.")
        break


# === SAVE MODEL ===
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
torch.save(model.state_dict(), MODEL_PATH)
print(f"✅ Model saved to: {MODEL_PATH}")
