In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

In [2]:


# Define transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to 3-channel
    transforms.Resize((224, 224)),  # Resize to match pre-trained models
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset
folder_path = "docs"
dataset = datasets.ImageFolder(root=folder_path, transform=transform)

# Split into train (70%), validation (15%), and test (15%)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Get class names
class_names = dataset.classes
print("Classes:", class_names)

# Count the number of images per class
label_counts = {class_name: 0 for class_name in class_names}
for _, label in dataset.samples:
    label_counts[class_names[label]] += 1

# Print label distribution
print("\nTotal number of images per class:")
for class_name, count in label_counts.items():
    print(f"{class_name}: {count} images")


Classes: ['email', 'invoice', 'questionnaire']

Total number of images per class:
email: 291 images
invoice: 302 images
questionnaire: 311 images


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from tqdm import tqdm

# Define the model using ResNet18
class DocumentClassifier(nn.Module):
    def __init__(self, num_classes):
        super(DocumentClassifier, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
    
    def forward(self, x):
        return self.resnet(x)


In [5]:

# Instantiate model
num_classes = len(class_names)
model = DocumentClassifier(num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-4)

# Learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    with tqdm(train_loader, desc=f'Epoch [{epoch+1}/{num_epochs}]', unit='batch') as pbar:
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix(loss=running_loss/len(train_loader))
    
    scheduler.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], loss: {running_loss/len(train_loader):.4f}")

print("Training complete")


Epoch [1/10]: 100%|██████████| 40/40 [01:20<00:00,  2.01s/batch, loss=0.499]


Epoch [1/10], loss: 0.4988


Epoch [2/10]: 100%|██████████| 40/40 [02:05<00:00,  3.13s/batch, loss=0.145] 


Epoch [2/10], loss: 0.1455


Epoch [3/10]: 100%|██████████| 40/40 [01:31<00:00,  2.28s/batch, loss=0.0958]


Epoch [3/10], loss: 0.0958


Epoch [4/10]: 100%|██████████| 40/40 [01:55<00:00,  2.89s/batch, loss=0.0577]


Epoch [4/10], loss: 0.0577


Epoch [5/10]: 100%|██████████| 40/40 [01:52<00:00,  2.81s/batch, loss=0.0284] 


Epoch [5/10], loss: 0.0284


Epoch [6/10]: 100%|██████████| 40/40 [01:55<00:00,  2.88s/batch, loss=0.0205] 


Epoch [6/10], loss: 0.0205


Epoch [7/10]: 100%|██████████| 40/40 [02:57<00:00,  4.44s/batch, loss=0.0271] 


Epoch [7/10], loss: 0.0271


Epoch [8/10]: 100%|██████████| 40/40 [03:42<00:00,  5.57s/batch, loss=0.0401] 


Epoch [8/10], loss: 0.0401


Epoch [9/10]: 100%|██████████| 40/40 [02:03<00:00,  3.09s/batch, loss=0.0138] 


Epoch [9/10], loss: 0.0138


Epoch [10/10]: 100%|██████████| 40/40 [01:21<00:00,  2.04s/batch, loss=0.0123] 

Epoch [10/10], loss: 0.0123
Training complete





In [8]:
torch.save(model, './checkpoints/model.pth')

In [9]:
loaded_model = torch.load('./checkpoints/model.pth', weights_only=False)


In [10]:
from torchmetrics.classification import Accuracy

loaded_model.eval()
accuracy = Accuracy(task = "multiclass", num_classes = num_classes).to(device)
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = loaded_model(images)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total  += labels.size(0)

print(f"Test Accuracy : {100 * correct / total:.2f}%")

Test Accuracy : 87.59%


In [None]:
# import cv2
# import pytesseract

# pytesseract.pytesseract.tesseract_cmd = r"C:/Users/rohith.nardela/AppData/Local/Programs/Tesseract-OCR/tesseract.exe"  # Update path if needed

# def extract_invoice_data(image_path):
#     img = cv2.imread(image_path)
#     gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
#     text = pytesseract.image_to_string(gray)

#     # Extract fields using regex
#     import re
#     invoice_number = re.findall(r'Invoice\s*#\s*(\d+)', text)
#     date = re.findall(r'Date:\s*([\d-]+)', text)
#     amount = re.findall(r'Total\s*Amount:\s*\$?([\d,]+.\d{2})', text)

#     return {
#         "invoice_number": invoice_number[0] if invoice_number else None,
#         "date": date[0] if date else None,
#         "amount": amount[0] if amount else None,
#         "raw_text": text
#     }


In [7]:
# Run the app.py using uvicorn