In [1]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import time
from tqdm import *

In [2]:
# Define the dataset class
class NumberDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.data.iloc[idx, 0])
        image = Image.open(img_name).convert('L')  # Convert to grayscale
        label = self.data.iloc[idx, 1]
        
        if self.transform:
            image = self.transform(image)
        
        # Convert label to tensor of individual digits
        label = torch.tensor([int(d) for d in str(label).zfill(6)], dtype=torch.long)
        
        return image, label

In [3]:
# Define the larger CNN model
class LargerNumberCNN(nn.Module):
    def __init__(self):
        super(LargerNumberCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 512, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(512 * 26 * 8, 256)
        self.fc2 = nn.Linear(256, 10 * 6)  # 10 classes for each of the 6 digits

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 512 * 26 * 8)  # Flatten the tensor
        x = self.dropout(torch.relu(self.fc1(x)))
        x = self.fc2(x)
        return x.view(-1, 6, 10)  # Reshape to (batch_size, 6, 10)

In [4]:

# Set up data transformations
transform = transforms.Compose([
    transforms.Resize((104, 32)),
    transforms.ToTensor(),
])

# Create dataset
dataset = NumberDataset(csv_file=os.path.join('generate_data', 'images','data.csv'), img_dir=os.path.join('generate_data', 'images'), transform=transform)

# Split dataset into train and test sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [5]:
# Initialize the model, loss function, and optimizer
model = LargerNumberCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training and evaluation loop
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

cuda


LargerNumberCNN(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(64, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=106496, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=60, bias=True)
)

In [6]:

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        # 計算組合損失
        loss = 0
        for i in range(6):
            digit_loss = criterion(outputs[:, i, :], labels[:, i])
            loss += digit_loss

        # 添加正則化項
        l2_lambda = 0.001
        l2_reg = torch.tensor(0.).to(device)
        for param in model.parameters():
            l2_reg += torch.norm(param)
        loss += l2_lambda * l2_reg

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Evaluation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 2)
            total += labels.size(0) * labels.size(1)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {running_loss/len(train_loader):.4f}, Test Accuracy: {accuracy:.2f}%")

print("Training completed!")

100%|██████████| 625/625 [02:10<00:00,  4.79it/s]


Epoch 1/100, Train Loss: 13.8768, Test Accuracy: 10.07%


100%|██████████| 625/625 [02:07<00:00,  4.90it/s]


Epoch 2/100, Train Loss: 13.8158, Test Accuracy: 9.94%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 3/100, Train Loss: 13.8158, Test Accuracy: 9.99%


100%|██████████| 625/625 [02:07<00:00,  4.90it/s]


Epoch 4/100, Train Loss: 13.8158, Test Accuracy: 10.02%


100%|██████████| 625/625 [02:08<00:00,  4.85it/s]


Epoch 5/100, Train Loss: 13.8158, Test Accuracy: 9.98%


100%|██████████| 625/625 [02:07<00:00,  4.89it/s]


Epoch 6/100, Train Loss: 13.8158, Test Accuracy: 9.93%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 7/100, Train Loss: 13.8158, Test Accuracy: 9.90%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 8/100, Train Loss: 13.8158, Test Accuracy: 9.99%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 9/100, Train Loss: 13.8158, Test Accuracy: 9.98%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 10/100, Train Loss: 13.8158, Test Accuracy: 9.98%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 11/100, Train Loss: 13.8158, Test Accuracy: 9.97%


100%|██████████| 625/625 [02:07<00:00,  4.92it/s]


Epoch 12/100, Train Loss: 13.8158, Test Accuracy: 10.04%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 13/100, Train Loss: 13.8158, Test Accuracy: 10.00%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 14/100, Train Loss: 13.8158, Test Accuracy: 10.00%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 15/100, Train Loss: 13.8158, Test Accuracy: 9.90%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 16/100, Train Loss: 13.8158, Test Accuracy: 10.04%


100%|██████████| 625/625 [02:07<00:00,  4.88it/s]


Epoch 17/100, Train Loss: 13.8158, Test Accuracy: 10.07%


100%|██████████| 625/625 [02:06<00:00,  4.94it/s]


Epoch 18/100, Train Loss: 13.8158, Test Accuracy: 9.97%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 19/100, Train Loss: 13.8158, Test Accuracy: 10.06%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 20/100, Train Loss: 13.8158, Test Accuracy: 10.00%


100%|██████████| 625/625 [02:07<00:00,  4.89it/s]


Epoch 21/100, Train Loss: 13.8158, Test Accuracy: 9.96%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 22/100, Train Loss: 13.8158, Test Accuracy: 9.97%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 23/100, Train Loss: 13.8158, Test Accuracy: 9.92%


100%|██████████| 625/625 [02:07<00:00,  4.90it/s]


Epoch 24/100, Train Loss: 13.8158, Test Accuracy: 10.05%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 25/100, Train Loss: 13.8158, Test Accuracy: 10.06%


100%|██████████| 625/625 [02:07<00:00,  4.90it/s]


Epoch 26/100, Train Loss: 13.8158, Test Accuracy: 9.88%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 27/100, Train Loss: 13.8158, Test Accuracy: 10.06%


100%|██████████| 625/625 [02:07<00:00,  4.90it/s]


Epoch 28/100, Train Loss: 13.8158, Test Accuracy: 10.00%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 29/100, Train Loss: 13.8158, Test Accuracy: 9.98%


100%|██████████| 625/625 [02:07<00:00,  4.90it/s]


Epoch 30/100, Train Loss: 13.8158, Test Accuracy: 9.97%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 31/100, Train Loss: 13.8158, Test Accuracy: 10.10%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 32/100, Train Loss: 13.8158, Test Accuracy: 9.94%


100%|██████████| 625/625 [02:08<00:00,  4.87it/s]


Epoch 33/100, Train Loss: 13.8158, Test Accuracy: 9.90%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 34/100, Train Loss: 13.8158, Test Accuracy: 9.90%


100%|██████████| 625/625 [02:08<00:00,  4.87it/s]


Epoch 35/100, Train Loss: 13.8158, Test Accuracy: 9.91%


100%|██████████| 625/625 [02:07<00:00,  4.91it/s]


Epoch 36/100, Train Loss: 13.8158, Test Accuracy: 10.04%


100%|██████████| 625/625 [02:11<00:00,  4.75it/s]


Epoch 37/100, Train Loss: 13.8157, Test Accuracy: 10.01%


100%|██████████| 625/625 [02:09<00:00,  4.84it/s]


Epoch 38/100, Train Loss: 13.8158, Test Accuracy: 9.91%


100%|██████████| 625/625 [02:08<00:00,  4.87it/s]


Epoch 39/100, Train Loss: 13.8157, Test Accuracy: 9.94%


100%|██████████| 625/625 [02:07<00:00,  4.92it/s]


Epoch 40/100, Train Loss: 13.8158, Test Accuracy: 9.97%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 41/100, Train Loss: 13.8158, Test Accuracy: 10.02%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 42/100, Train Loss: 13.8158, Test Accuracy: 9.96%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 43/100, Train Loss: 13.8158, Test Accuracy: 10.08%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 44/100, Train Loss: 13.8158, Test Accuracy: 10.02%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 45/100, Train Loss: 13.8158, Test Accuracy: 9.89%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 46/100, Train Loss: 13.8158, Test Accuracy: 10.01%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 47/100, Train Loss: 13.8158, Test Accuracy: 10.04%


100%|██████████| 625/625 [02:07<00:00,  4.90it/s]


Epoch 48/100, Train Loss: 13.8158, Test Accuracy: 10.02%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 49/100, Train Loss: 13.8157, Test Accuracy: 9.96%


100%|██████████| 625/625 [02:07<00:00,  4.90it/s]


Epoch 50/100, Train Loss: 13.8158, Test Accuracy: 9.93%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 51/100, Train Loss: 13.8158, Test Accuracy: 9.94%


100%|██████████| 625/625 [02:07<00:00,  4.90it/s]


Epoch 52/100, Train Loss: 13.8158, Test Accuracy: 9.99%


100%|██████████| 625/625 [02:08<00:00,  4.85it/s]


Epoch 53/100, Train Loss: 13.8158, Test Accuracy: 10.00%


100%|██████████| 625/625 [02:07<00:00,  4.90it/s]


Epoch 54/100, Train Loss: 13.8158, Test Accuracy: 9.89%


100%|██████████| 625/625 [02:08<00:00,  4.87it/s]


Epoch 55/100, Train Loss: 13.8158, Test Accuracy: 9.98%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 56/100, Train Loss: 13.8157, Test Accuracy: 10.04%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 57/100, Train Loss: 13.8158, Test Accuracy: 10.03%


100%|██████████| 625/625 [02:07<00:00,  4.90it/s]


Epoch 58/100, Train Loss: 13.8158, Test Accuracy: 10.08%


100%|██████████| 625/625 [02:08<00:00,  4.85it/s]


Epoch 59/100, Train Loss: 13.8158, Test Accuracy: 10.05%


100%|██████████| 625/625 [02:07<00:00,  4.92it/s]


Epoch 60/100, Train Loss: 13.8158, Test Accuracy: 10.02%


100%|██████████| 625/625 [02:08<00:00,  4.87it/s]


Epoch 61/100, Train Loss: 13.8158, Test Accuracy: 10.00%


100%|██████████| 625/625 [02:07<00:00,  4.92it/s]


Epoch 62/100, Train Loss: 13.8157, Test Accuracy: 9.98%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 63/100, Train Loss: 13.8158, Test Accuracy: 9.94%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 64/100, Train Loss: 13.8158, Test Accuracy: 9.99%


100%|██████████| 625/625 [02:07<00:00,  4.89it/s]


Epoch 65/100, Train Loss: 13.8158, Test Accuracy: 9.99%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 66/100, Train Loss: 13.8158, Test Accuracy: 10.11%


100%|██████████| 625/625 [02:08<00:00,  4.87it/s]


Epoch 67/100, Train Loss: 13.8158, Test Accuracy: 10.03%


100%|██████████| 625/625 [02:06<00:00,  4.92it/s]


Epoch 68/100, Train Loss: 13.8158, Test Accuracy: 9.96%


100%|██████████| 625/625 [02:08<00:00,  4.87it/s]


Epoch 69/100, Train Loss: 13.8158, Test Accuracy: 10.02%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 70/100, Train Loss: 13.8158, Test Accuracy: 9.99%


100%|██████████| 625/625 [02:08<00:00,  4.87it/s]


Epoch 71/100, Train Loss: 13.8158, Test Accuracy: 10.00%


100%|██████████| 625/625 [02:07<00:00,  4.91it/s]


Epoch 72/100, Train Loss: 13.8158, Test Accuracy: 9.97%


100%|██████████| 625/625 [02:09<00:00,  4.84it/s]


Epoch 73/100, Train Loss: 13.8158, Test Accuracy: 10.01%


100%|██████████| 625/625 [02:07<00:00,  4.91it/s]


Epoch 74/100, Train Loss: 13.8158, Test Accuracy: 9.86%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 75/100, Train Loss: 13.8158, Test Accuracy: 9.93%


100%|██████████| 625/625 [02:07<00:00,  4.90it/s]


Epoch 76/100, Train Loss: 13.8158, Test Accuracy: 10.03%


100%|██████████| 625/625 [02:08<00:00,  4.85it/s]


Epoch 77/100, Train Loss: 13.8158, Test Accuracy: 10.00%


100%|██████████| 625/625 [02:07<00:00,  4.90it/s]


Epoch 78/100, Train Loss: 13.8158, Test Accuracy: 9.97%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 79/100, Train Loss: 13.8158, Test Accuracy: 10.00%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 80/100, Train Loss: 13.8158, Test Accuracy: 9.90%


100%|██████████| 625/625 [02:08<00:00,  4.87it/s]


Epoch 81/100, Train Loss: 13.8158, Test Accuracy: 9.92%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 82/100, Train Loss: 13.8158, Test Accuracy: 9.97%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 83/100, Train Loss: 13.8158, Test Accuracy: 9.94%


100%|██████████| 625/625 [02:06<00:00,  4.94it/s]


Epoch 84/100, Train Loss: 13.8158, Test Accuracy: 10.08%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 85/100, Train Loss: 13.8158, Test Accuracy: 10.05%


100%|██████████| 625/625 [02:07<00:00,  4.88it/s]


Epoch 86/100, Train Loss: 13.8158, Test Accuracy: 9.96%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 87/100, Train Loss: 13.8158, Test Accuracy: 9.97%


100%|██████████| 625/625 [02:07<00:00,  4.92it/s]


Epoch 88/100, Train Loss: 13.8158, Test Accuracy: 9.99%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 89/100, Train Loss: 13.8158, Test Accuracy: 10.06%


100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 90/100, Train Loss: 13.8158, Test Accuracy: 9.97%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 91/100, Train Loss: 13.8158, Test Accuracy: 9.96%


100%|██████████| 625/625 [02:06<00:00,  4.92it/s]


Epoch 92/100, Train Loss: 13.8158, Test Accuracy: 10.05%


100%|██████████| 625/625 [02:08<00:00,  4.88it/s]


Epoch 93/100, Train Loss: 13.8158, Test Accuracy: 10.06%


100%|██████████| 625/625 [02:07<00:00,  4.92it/s]


Epoch 94/100, Train Loss: 13.8158, Test Accuracy: 9.99%


100%|██████████| 625/625 [02:09<00:00,  4.84it/s]


Epoch 95/100, Train Loss: 13.8158, Test Accuracy: 10.05%


100%|██████████| 625/625 [02:07<00:00,  4.91it/s]


Epoch 96/100, Train Loss: 13.8158, Test Accuracy: 10.03%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 97/100, Train Loss: 13.8157, Test Accuracy: 9.98%


100%|██████████| 625/625 [02:07<00:00,  4.91it/s]


Epoch 98/100, Train Loss: 13.8158, Test Accuracy: 10.01%


100%|██████████| 625/625 [02:08<00:00,  4.86it/s]


Epoch 99/100, Train Loss: 13.8158, Test Accuracy: 9.99%


100%|██████████| 625/625 [02:07<00:00,  4.91it/s]


Epoch 100/100, Train Loss: 13.8158, Test Accuracy: 9.87%
Training completed!


In [7]:

# Save the trained model
torch.save(model.state_dict(), os.path.join("model", f"larger_model_{time.time()}.pth"))
print(f"Model saved as 'larger_model_{time.time()}.pth'")

Model saved as 'larger_model_1726217934.0204105.pth'
