In [58]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms.transforms import Resize
from torchvision import transforms
from tqdm import tqdm

In [59]:
class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()
  
    self.conv1= nn.Conv2d(3,32,(3,3),(1,1),(1,1))
    self.conv2= nn.Conv2d(32,32,(3,3),(1,1),(1,1))
    self.conv3= nn.Conv2d(32,64,(3,3),(1,1),(1,1))
    self.conv4= nn.Conv2d(64,64,(3,3),(1,1),(1,1))

    self.fc1=nn.Linear(64*4*4,512)
    self.fc2=nn.Linear(512,10)

  def forward(self,x):
    x=F.relu(self.conv1(x))
    x=F.max_pool2d(x,kernel_size=(2,2))

    x=F.relu(self.conv2(x))
    x=F.max_pool2d(x,kernel_size=(2,2))

    x=F.relu(self.conv3(x))
    x=F.max_pool2d(x,kernel_size=(2,2))

    x=F.relu(self.conv4(x))
    

    x=torch.flatten(x,start_dim=1)

    x = self.fc1(x)
    x = torch.dropout(x, 0.4, train=True)

    x = self.fc2(x)
    output = torch.softmax(x, dim=1)

    return output


In [60]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model().to(device)

HYPER PARAMETER

In [61]:
batch_size = 64
epochs = 20
lr= 0.001

In [62]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [63]:
transform = transforms.Compose([
                              transforms.RandomRotation(10),
                              transforms.Resize((32,32)),
                              transforms.ToTensor(),
                              transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
dataset = torchvision.datasets.ImageFolder(root="/content/drive/MyDrive/MNIST_persian", transform = transform)
train_data = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

COMPILE

In [64]:
optimizer=torch.optim.Adam(model.parameters(),lr=lr)
loss_function=nn.CrossEntropyLoss()

def calc_acc(preds, labels):
  _, pred_max = torch.max(preds, 1)
  acc = torch.sum(pred_max == labels.data, dtype=torch.float64) / len(preds)
  return acc

In [65]:
model.train()

for epoch in range(epochs):
  train_loss = 0.0
  train_acc = 0.0
  
  for images, labels in tqdm(train_data):
    images, labels = images.to(device), labels.to(device)

    optimizer.zero_grad()
    preds = model(images)

    loss = loss_function(preds, labels)
    loss.backward()

    optimizer.step()

    train_loss += loss
    train_acc += calc_acc(preds, labels)

  total_loss = train_loss / len(train_data)
  total_acc = train_acc / len(train_data)
  print(f"Epoch: {epoch+1}, Loss: {total_loss}, Accuracy: {total_acc}")


100%|██████████| 19/19 [02:50<00:00,  8.95s/it]


Epoch: 1, Loss: 2.30310320854187, Accuracy: 0.09978070175438596


100%|██████████| 19/19 [00:03<00:00,  6.22it/s]


Epoch: 2, Loss: 2.3005213737487793, Accuracy: 0.10992324561403508


100%|██████████| 19/19 [00:03<00:00,  6.26it/s]


Epoch: 3, Loss: 2.1975960731506348, Accuracy: 0.2658991228070175


100%|██████████| 19/19 [00:03<00:00,  6.19it/s]


Epoch: 4, Loss: 1.9825963973999023, Accuracy: 0.48821271929824556


100%|██████████| 19/19 [00:03<00:00,  6.11it/s]


Epoch: 5, Loss: 1.887609601020813, Accuracy: 0.5910087719298245


100%|██████████| 19/19 [00:03<00:00,  6.09it/s]


Epoch: 6, Loss: 1.8563064336776733, Accuracy: 0.6030701754385965


100%|██████████| 19/19 [00:03<00:00,  6.22it/s]


Epoch: 7, Loss: 1.8318507671356201, Accuracy: 0.6389802631578947


100%|██████████| 19/19 [00:03<00:00,  6.20it/s]


Epoch: 8, Loss: 1.7606861591339111, Accuracy: 0.7044956140350876


100%|██████████| 19/19 [00:03<00:00,  6.16it/s]


Epoch: 9, Loss: 1.7499403953552246, Accuracy: 0.7124451754385965


100%|██████████| 19/19 [00:03<00:00,  6.15it/s]


Epoch: 10, Loss: 1.7116265296936035, Accuracy: 0.750548245614035


100%|██████████| 19/19 [00:03<00:00,  6.15it/s]


Epoch: 11, Loss: 1.68999183177948, Accuracy: 0.7763157894736842


100%|██████████| 19/19 [00:03<00:00,  6.22it/s]


Epoch: 12, Loss: 1.6816753149032593, Accuracy: 0.7839912280701754


100%|██████████| 19/19 [00:03<00:00,  6.26it/s]


Epoch: 13, Loss: 1.6706513166427612, Accuracy: 0.7971491228070176


100%|██████████| 19/19 [00:03<00:00,  6.20it/s]


Epoch: 14, Loss: 1.6638129949569702, Accuracy: 0.8061951754385964


100%|██████████| 19/19 [00:03<00:00,  6.25it/s]


Epoch: 15, Loss: 1.6379406452178955, Accuracy: 0.825657894736842


100%|██████████| 19/19 [00:03<00:00,  6.21it/s]


Epoch: 16, Loss: 1.6569170951843262, Accuracy: 0.8086622807017544


100%|██████████| 19/19 [00:03<00:00,  6.30it/s]


Epoch: 17, Loss: 1.6444634199142456, Accuracy: 0.8215460526315789


100%|██████████| 19/19 [00:03<00:00,  6.27it/s]


Epoch: 18, Loss: 1.6288739442825317, Accuracy: 0.835252192982456


100%|██████████| 19/19 [00:03<00:00,  6.16it/s]


Epoch: 19, Loss: 1.6335337162017822, Accuracy: 0.8333333333333334


100%|██████████| 19/19 [00:03<00:00,  6.20it/s]

Epoch: 20, Loss: 1.6164368391036987, Accuracy: 0.8453947368421052





In [66]:
torch.save(model.state_dict(), "mnist_persian.pth")