In [None]:
import os
import torch
from torch import nn, optim
from transformers import ViTForImageClassification
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

train_dir = r"C:\Users\sowmiyasagadevan\OneDrive\Desktop\project\dataset\vit seg\train"
val_dir = r"C:\Users\sowmiyasagadevan\OneDrive\Desktop\project\dataset\vit seg\val"
save_dir = r"C:\Users\sowmiyasagadevan\OneDrive\Desktop\project\models\with seg"
os.makedirs(save_dir, exist_ok=True)

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),                       
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomResizedCrop((224, 224), scale=(0.8, 1.0)),
    transforms.ToTensor(),                              
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),                               
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class OilRefineryDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.classes = ['not_oil_refinery', 'oil_refinery']
        self.image_paths = []
        self.labels = []

        for label, class_name in enumerate(self.classes):
            class_dir = os.path.join(self.data_dir, class_name)
            for img_name in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir, img_name))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            img = self.transform(img)

        return img, label

train_dataset = OilRefineryDataset(train_dir, transform=train_transform)
val_dataset = OilRefineryDataset(val_dir, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
model_path = r"C:\Users\sowmiyasagadevan\OneDrive\Desktop\project\pretrained model\vit"
model = ViTForImageClassification.from_pretrained(model_path)

for param in model.vit.parameters():
    param.requires_grad = False
model.classifier = nn.Sequential(
    nn.Linear(model.config.hidden_size, 128),
    nn.BatchNorm1d(128),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(128, 2)
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.05)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)
criterion = nn.CrossEntropyLoss()

num_epochs = 10
best_val_acc = 0.0

for epoch in range(num_epochs):
    model.train()
    running_loss, correct_preds, total_preds = 0.0, 0, 0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(images)
        outputs = outputs.logits if hasattr(outputs, 'logits') else outputs

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct_preds += (preds == labels).sum().item()
        total_preds += labels.size(0)

    train_loss = running_loss / len(train_loader)
    train_accuracy = correct_preds / total_preds

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")

    model.eval()
    val_running_loss, val_correct_preds, val_total_preds = 0.0, 0, 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            outputs = outputs.logits if hasattr(outputs, 'logits') else outputs

            val_loss = criterion(outputs, labels)
            val_running_loss += val_loss.item()

            _, val_preds = torch.max(outputs, 1)
            val_correct_preds += (val_preds == labels).sum().item()
            val_total_preds += labels.size(0)

    val_loss = val_running_loss / len(val_loader)
    val_accuracy = val_correct_preds / val_total_preds

    print(f"Epoch {epoch+1}/{num_epochs}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

    scheduler.step(val_loss)

    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        torch.save(model.state_dict(), os.path.join(save_dir, f"best_vit_model.pth"))
        print(f"Best model saved with accuracy: {best_val_acc:.4f}")