<a href="https://colab.research.google.com/github/shahbazfareedchishti/FlaskWebApp/blob/main/modeltraining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Importing Libraries**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
import pandas as pd
from PIL import Image
import os
import numpy as np

In [None]:
os.makedirs(OUT_DIR, exist_ok=True)
CLASSES = ["KaiYuan", "Speedboat", "UUV", "Unknown"]
VAL_RATIO = 0.15
TEST_RATIO = 0.15
random.seed(42)

# **Data Preprocessing**

In [None]:
def list_mels(root, class_name):
    folder = os.path.join(root, class_name)
    files = []
    for f in os.listdir(folder):
        if f.endswith(".npy"):  # assuming mel spectrograms saved as numpy arrays
            path = os.path.join(folder, f)
            if os.path.getsize(path) > 0:  # skip empty
                files.append(path)
    return files

def sanity_check(files, cls):
    bad = []
    for f in files:
        try:
            mel = np.load(f)
            if np.isnan(mel).any() or mel.size == 0:
                bad.append(f)
        except Exception:
            bad.append(f)
    if bad:
        print(f"Removed {len(bad)} corrupt {cls} files")
    return [f for f in files if f not in bad]

def write_manifest(rows, filename):
    path = os.path.join(OUT_DIR, filename)
    with open(path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["path", "label"])
        writer.writerows(rows)
    print(f"Saved {len(rows)} â†’ {filename}")

In [None]:
all_data = []
for cls in CLASSES:
    files = list_mels(ROOT, cls)
    files = sanity_check(files, cls)
    print(f"{cls}: {len(files)} valid mels")
    all_data.append((cls, files))

In [None]:
min_count = min(len(f) for _, f in all_data)
print(f"Balancing to {min_count} samples per class")

balanced = []
for cls, files in all_data:
    selected = random.sample(files, min_count)
    balanced += [(f, cls) for f in selected]

random.shuffle(balanced)

# **Splitting**

In [None]:
X = [f for f, _ in balanced]
y = [c for _, c in balanced]

X_train, X_tmp, y_train, y_tmp = train_test_split(X, y, test_size=(VAL_RATIO + TEST_RATIO), stratify=y, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_tmp, y_tmp, test_size=TEST_RATIO / (VAL_RATIO + TEST_RATIO), stratify=y_tmp, random_state=42)

print("Train:", len(X_train), "Val:", len(X_val), "Test:", len(X_test))

In [None]:
def overlap_check(a, b):
    return len(set(a) & set(b))

assert overlap_check(X_train, X_val) == 0
assert overlap_check(X_train, X_test) == 0
assert overlap_check(X_val, X_test) == 0
print("Leakage check: OK (no shared files).")

In [None]:
train_rows = list(zip(X_train, y_train))
val_rows = list(zip(X_val, y_val))
test_rows = list(zip(X_test, y_test))

write_manifest(train_rows, "train.csv")
write_manifest(val_rows, "val.csv")
write_manifest(test_rows, "test.csv")

print("Dataset manifests ready at:", OUT_DIR)

# **Converting to Mel**

In [None]:
BATCH_SIZE = 32
EPOCHS = 25
LR = 1e-4
NUM_CLASSES = 4  # adjust if different
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class MelDataset(Dataset):
    def __init__(self, csv_path, transform=None):
        self.df = pd.read_csv(csv_path)
        self.transform = transform
        self.label_to_idx = {label: idx for idx, label in enumerate(sorted(self.df['label'].unique()))}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        x = np.load(self.df.iloc[idx, 0])
        if x.ndim == 2:
            x = np.expand_dims(x, 0)
        else:
            x = np.transpose(x, (2, 0, 1))
        x = torch.tensor(x, dtype=torch.float32)
        if self.transform:
            x = self.transform(x)

        label_str = self.df.iloc[idx, 1]
        label = self.label_to_idx[label_str]
        return x, label

In [None]:
train_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3),
])

val_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3),
])

# **Rechecking**

In [None]:
train_ds = MelDataset("/content/drive/MyDrive/mel_spectrograms_splits/train.csv")
val_ds = MelDataset("/content/drive/MyDrive/mel_spectrograms_splits/val.csv")
test_ds = MelDataset("/content/drive/MyDrive/mel_spectrograms_splits/test.csv")


train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
for imgs, labels in train_dl:
    imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
    if imgs.shape[1] == 1:
        imgs = imgs.repeat(1, 3, 1, 1)  # replicate to RGB

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

# **Model Training**

In [None]:
EPOCHS = 25
best_val_acc = 0.0

for epoch in range(EPOCHS):
    model.train()
    total_loss, correct = 0.0, 0

    for imgs, labels in train_dl:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        if imgs.shape[1] == 1:
            imgs = imgs.repeat(1, 3, 1, 1)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (out.argmax(1) == labels).sum().item()

    train_acc = correct / len(train_ds)
    avg_loss = total_loss / len(train_dl)

    model.eval()
    val_correct, val_loss = 0, 0.0
    with torch.no_grad():
        for imgs, labels in val_dl:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            if imgs.shape[1] == 1:
                imgs = imgs.repeat(1, 3, 1, 1)
            out = model(imgs)
            val_loss += criterion(out, labels).item()
            val_correct += (out.argmax(1) == labels).sum().item()

    val_acc = val_correct / len(val_ds)
    val_loss /= len(val_dl)

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Acc: {train_acc:.3f} | Val Acc: {val_acc:.3f} | Train Loss: {avg_loss:.3f} | Val Loss: {val_loss:.3f}")