In [None]:
import os
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoFeatureExtractor, AutoModel
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === ViT feature extractor setup ===
extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
vit_model = AutoModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)

def extract_vit_features(image):
    image = image.convert("RGB")
    inputs = extractor(images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        features = vit_model(**inputs).last_hidden_state.mean(dim=1)
    return features.cpu()

# === Loop through dataset ===
def load_dataset_and_features(dataset_path):
    features = []
    labels = []
    for label_str, label_val in {"real": 0, "ai": 1}.items():
        folder = os.path.join(dataset_path, label_str)
        for filename in tqdm(os.listdir(folder), desc=f"Loading {label_str}"):
            try:
                img_path = os.path.join(folder, filename)
                image = Image.open(img_path).convert("RGB")
                feat = extract_vit_features(image)
                features.append(feat)
                labels.append(label_val)
            except Exception as e:
                print(f"⚠️ Error with {filename}: {e}")
    return torch.cat(features), torch.tensor(labels)

X_train, y_train = load_dataset_and_features("dataset")

# === Define your model ===
class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(768, 384),
            nn.BatchNorm1d(384),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(384, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.net(x)


# === Train ===
model = FullModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.BCEWithLogitsLoss()

for epoch in range(5):
    model.train()
    for i in range(0, len(X_train), 16):
        xb = X_train[i:i+16].to(device)
        yb = y_train[i:i+16].float().to(device)

        optimizer.zero_grad()
        preds = model(xb).squeeze()
        loss = loss_fn(preds, yb)
        loss.backward()
        optimizer.step()
    print(f"✅ Epoch {epoch+1} complete")

torch.save(model.state_dict(), "full_model.pt")
print("✅ Trained model saved as full_model.pt")
