In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pandas as pd
from torchvision import transforms
from faster_vit import *
import os
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm 
# from custom_dataset import MoleculeDataset  # Assuming you have MoleculeDataset defined in custom_dataset.py

# Constants
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
EPOCHS = 200
SAVED_MODEL_NAME = "./model.pth"
CSV_FILE = "C:/Users/anshg/Python_shit/ML/ViTST/dataset/one_hot_encoded_odors.csv"
IMAGE_DIR = "C:/Users/anshg/Python_shit/ML/ViTST/dataset/Molecule_images"

# Image Preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class MoleculeDataset(Dataset):
    def __init__(self, features, image_dir, transform=None):
        self.features = features
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        image_name = f"{idx}.jpg"  # Assuming image names match indices
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        feature = torch.tensor(self.features[idx], dtype=torch.float32)
        return image, feature

def train():
    # Load the multi-class classification model (109 output classes)
    model = faster_vit_0_224()  

    # Load features from the CSV file
    df = pd.read_csv(CSV_FILE)
    features = df.values  # Convert the DataFrame to a numpy array

    # Create MoleculeDataset instances for train and test
    dataset = MoleculeDataset(features, IMAGE_DIR, transform)
    
    # Split into training and test sets (e.g., 80/20 split)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

    # Optimizer and Loss function
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    loss_fn = nn.CrossEntropyLoss()

    # Transfer model and loss function to GPU if available
    model = model.to(DEVICE)
    loss_fn = loss_fn.to(DEVICE)

    # Training loop
    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        correct_train = 0

        for data, target in train_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pred = output.argmax(dim=1)
            correct_train += pred.eq(target.argmax(dim=1)).sum().item()

        # Print training stats
        print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {running_loss/len(train_loader):.4f} - Accuracy: {correct_train/len(train_loader.dataset):.4f}")

        # Validation loop
        model.eval()
        correct_test = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(DEVICE), target.to(DEVICE)
                output = model(data)
                pred = output.argmax(dim=1)
                correct_test += pred.eq(target.argmax(dim=1)).sum().item()

        # Print validation stats
        print(f"Validation Accuracy: {correct_test/len(test_loader.dataset):.4f}")

        # Save the model after each epoch
        torch.save(model.state_dict(), SAVED_MODEL_NAME)

if __name__ == "__main__":
    train()


  from .autonotebook import tqdm as notebook_tqdm
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Epoch 1/200 - Loss: 14.1185 - Accuracy: 0.0626
Validation Accuracy: 0.0345
Epoch 2/200 - Loss: 13.5010 - Accuracy: 0.0704
Validation Accuracy: 0.0740
Epoch 3/200 - Loss: 13.2144 - Accuracy: 0.0873
Validation Accuracy: 0.0674
Epoch 4/200 - Loss: 12.8603 - Accuracy: 0.0861
Validation Accuracy: 0.0789
Epoch 5/200 - Loss: 12.5723 - Accuracy: 0.1021
Validation Accuracy: 0.0625
Epoch 6/200 - Loss: 12.3182 - Accuracy: 0.1100
Validation Accuracy: 0.0970
Epoch 7/200 - Loss: 12.0178 - Accuracy: 0.1194
Validation Accuracy: 0.0822
Epoch 8/200 - Loss: 11.8429 - Accuracy: 0.1186
Validation Accuracy: 0.0970
Epoch 9/200 - Loss: 11.5968 - Accuracy: 0.1236
Validation Accuracy: 0.1151
Epoch 10/200 - Loss: 11.3421 - Accuracy: 0.1306
Validation Accuracy: 0.1250
Epoch 11/200 - Loss: 11.0254 - Accuracy: 0.1528
Validation Accuracy: 0.1053
Epoch 12/200 - Loss: 10.8100 - Accuracy: 0.1643
Validation Accuracy: 0.1151
Epoch 13/200 - Loss: 10.5124 - Accuracy: 0.1693
Validation Accuracy: 0.1234
Epoch 14/200 - Loss: 

KeyboardInterrupt: 