In [None]:
## "Multi-scale" EfficientNet-B0 Features (early + middle + last)layers ‚úÖ Temporal Transformer


import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import random
import copy
import matplotlib.pyplot as plt
from torchvision.models import efficientnet_b0


# ============================================================
# 1 Dataset Class
# ============================================================
class FishClipDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

        for cls in self.classes:
            cls_path = os.path.join(root_dir, cls)
            for clip_folder in os.listdir(cls_path):
                clip_path = os.path.join(cls_path, clip_folder)
                if os.path.isdir(clip_path):
                    self.samples.append((clip_path, self.class_to_idx[cls]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        clip_path, label = self.samples[idx]
        frames = sorted([f for f in os.listdir(clip_path) if f.endswith('.jpg')])

        imgs = []
        for frame in frames:
            img_path = os.path.join(clip_path, frame)
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            imgs.append(img)

        clip_tensor = torch.stack(imgs)  # (T, C, H, W)
        return clip_tensor, label, clip_path


# ============================================================
# 2 Multi-Scale EfficientNet + Transformer
# ============================================================
class MultiScaleEffiTrans(nn.Module):
    def __init__(self, num_classes=2, embed_dim=256, num_heads=4, num_layers=2):
        super().__init__()

        effi = torchvision.models.efficientnet_b0(weights='IMAGENET1K_V1')

        # ---- EfficientNet Stages ----
        self.stem    = effi.features[0]      # 32 channels
        self.stage1  = effi.features[1]      # output: 16 channels
        self.stage2  = effi.features[2]      # output: 24 channels
        self.stage3  = effi.features[3]      # output: 40 channels

        # ---- Global pooling for each scale ----
        self.pool = nn.AdaptiveAvgPool2d(1)

        # 16 + 24 + 40 = 80
        self.feature_proj = nn.Linear(80, embed_dim)

        # ---- Transformer ----
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=num_heads, 
            dim_feedforward=256,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.cls = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B, T, C, H, W = x.size()
        x = x.view(B*T, C, H, W)

        # ---- EfficientNet forward ----
        x = self.stem(x)

        f1 = self.stage1(x)            # 16ch
        f2 = self.stage2(f1)           # 24ch
        f3 = self.stage3(f2)           # 40ch

        # ---- Pool each level ----
        f1 = self.pool(f1).view(B*T, -1)
        f2 = self.pool(f2).view(B*T, -1)
        f3 = self.pool(f3).view(B*T, -1)

        multi = torch.cat([f1, f2, f3], dim=1)     # 80 dims
        multi = self.feature_proj(multi)           # 80 ‚Üí embed_dim

        # ---- Transformer: reshape to sequence ----
        multi = multi.view(B, T, -1)

        out = self.transformer(multi).mean(dim=1)
        return self.cls(out)

# ============================================================
# 3 Training Function + Learning Curves
# ============================================================
def train_multiscale_effi_trans(train_dir, val_dir, test_dir, save_path, device, init_batch_size=1):

    # -------- transforms ---------
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])

    train_dataset = FishClipDataset(train_dir, transform)
    val_dataset   = FishClipDataset(val_dir, transform)
    test_dataset  = FishClipDataset(test_dir, transform)

    batch_size = init_batch_size
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # -------- model --------
    model = MultiScaleEffiTrans().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scaler = torch.cuda.amp.GradScaler()

    # -------- early stop --------
    patience = 5
    best_val_acc = 0
    counter = 0
    best_model = copy.deepcopy(model.state_dict())

    num_epochs = 10

    # for learning curves
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []

    for epoch in range(num_epochs):

        # ---------- TRAIN ----------
        model.train()
        running_loss, correct, total = 0, 0, 0

        for clips, labels, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            clips, labels = clips.to(device), labels.to(device)

            optimizer.zero_grad()

            with torch.cuda.amp.autocast():
                outputs = model(clips)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        train_acc = 100 * correct / total
        train_loss = running_loss / len(train_loader)

        # ---------- VALID ----------
        model.eval()
        v_correct, v_total = 0, 0
        v_loss_total = 0

        with torch.no_grad():
            for clips, labels, _ in val_loader:
                clips, labels = clips.to(device), labels.to(device)

                with torch.cuda.amp.autocast():
                    outputs = model(clips)
                    loss = criterion(outputs, labels)

                v_loss_total += loss.item()
                _, pred = outputs.max(1)
                v_correct += (pred == labels).sum().item()
                v_total += labels.size(0)

        val_acc = 100 * v_correct / v_total
        val_loss = v_loss_total / len(val_loader)

        print(f"Epoch {epoch+1}: Train Acc={train_acc:.2f}, Val Acc={val_acc:.2f}")

        # save curves
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)

        # ---------- Save best ----------
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = copy.deepcopy(model.state_dict())
            torch.save(best_model, save_path)
            print(f"üíæ Saved Best Model (Val Acc {val_acc:.2f}%)")
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("‚èπ Early Stopping Triggered")
                break

    # load best
    model.load_state_dict(torch.load(save_path))
    print(f"Training Complete. Best Val Acc = {best_val_acc:.2f}%")

    # =====================================================
    # SAVE LEARNING CURVES (Acc and Loss)
    # =====================================================
    plt.figure()
    plt.plot(train_accs, label="Train Accuracy")
    plt.plot(val_accs,   label="Val Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.legend()
    plt.title("Accuracy Curve")
    plt.savefig("accuracy_curve.png")
    plt.close()

    plt.figure()
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses,   label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Loss Curve")
    plt.savefig("loss_curve.png")
    plt.close()

    print("üìä Learning curves saved: accuracy_curve.png, loss_curve.png")

    return model


# ============================================================
# 4 Usage
# ============================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dir = r"C:\Users\STF8586\OneDrive - University of Derby\Desktop\S.K\Effi+Trans Dataset\train"
val_dir   = r"C:\Users\STF8586\OneDrive - University of Derby\Desktop\S.K\Effi+Trans Dataset\val"
test_dir  = r"C:\Users\STF8586\OneDrive - University of Derby\Desktop\S.K\Effi+Trans Dataset\test"
save_path = r"C:\Users\STF8586\OneDrive - University of Derby\Desktop\MultiScaleEffiTrans_bestT6.pth"

model = train_multiscale_effi_trans(train_dir, val_dir, test_dir, save_path, device, init_batch_size=1)