<a href="https://colab.research.google.com/github/zahra-zarrabi/Mnist_Persian/blob/main/mnist_persian.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from tqdm import tqdm
import torch.utils.data as data

In [None]:
def calc_acc(preds,labels):
  _,preds_max = torch.max(preds,1)
  acc = torch.sum(preds_max == labels.data, dtype=torch.float64) / len(preds)
  return acc

In [None]:
class Model(nn.Module):
  def __init__(self):
    super().__init__()
    self.cnn1 = nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,stride=1,padding=1)
    self.cnn2= nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,stride=1,padding=1) 

    self.fc1= nn.Linear(18496,128)
    self.fc2= nn.Linear(128,10)

  def forward(self, x):
    x = F.relu(self.cnn1(x))
    x = F.max_pool2d(x, kernel_size=(2,2))
    x = F.relu(self.cnn2(x))
    x = F.max_pool2d(x, kernel_size=(2,2))
    x = torch.flatten(x, start_dim=1)
    x= F.relu(self.fc1(x))
    x = torch.dropout(x,0.3,train=True)
    x = self.fc2(x)
    x = torch.softmax(x, dim=1)
    return x
  

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model().to(device)

In [None]:
batch_size=64
epochs=30
lr=0.001

In [None]:
transform = transforms.Compose([
                                # transforms.RandomRotation(10),
                                transforms.Resize((70,70)),
                                transforms.ToTensor()
                                # transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
])

dataset = torchvision.datasets.ImageFolder(root='/content/drive/MyDrive/MNIST_persian/MNIST_persian',transform=transform)
train_size=int(len(dataset)*0.8)
val_size=len(dataset)-train_size
train_data ,val_data = data.random_split(dataset,[train_size,val_size])

train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)

In [None]:
# compile
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_function = nn.CrossEntropyLoss()

In [None]:
def train(model, train_data_loader,epoch):
  model.train(True)
  train_loss=0.0
  train_acc=0.0
  for images,labels in tqdm(train_data_loader):
    images=images.to(device)
    labels=labels.to(device)
    optimizer.zero_grad()
    
    preds_train = model(images)

    loss_train=loss_function(preds_train,labels) # loss_train
    loss_train.backward()

    optimizer.step()

    train_loss += loss_train
    train_acc += calc_acc(preds_train,labels)
    
  total_loss = train_loss/len(train_data_loader)
  total_acc = train_acc/len(train_data_loader)
  print(f"loss_train:{total_loss},accuracy_train:{total_acc}")


In [None]:
def test(model, val_data_loader,epoch):
  model.eval()
  test_loss=0.0
  test_acc=0.0
  for images,labels in tqdm(val_data_loader):    
    images=images.to(device)
    labels=labels.to(device)
    
    preds_test = model(images)

    loss_test=loss_function(preds_test,labels) 

    test_loss += loss_test
    test_acc += calc_acc(preds_test,labels)

  total_loss = test_loss/len(val_data_loader)
  total_acc = test_acc/len(val_data_loader)
  print(f"loss_eval:{total_loss},accuracy_eval:{total_acc}")


In [None]:
for epoch in range(epochs):
  print(f'Epoch:{epoch}')
  train(model, train_data_loader, epoch)
  test(model, val_data_loader, epoch)  

Epoch:0


100%|██████████| 15/15 [10:20<00:00, 41.34s/it]


loss_train:2.3046417236328125,accuracy_train:0.10104166666666667


100%|██████████| 4/4 [02:32<00:00, 38.01s/it]


loss_eval:2.301353693008423,accuracy_eval:0.10416666666666667
Epoch:1


100%|██████████| 15/15 [00:07<00:00,  1.89it/s]


loss_train:2.3000383377075195,accuracy_train:0.13333333333333333


100%|██████████| 4/4 [00:01<00:00,  3.70it/s]


loss_eval:2.294072389602661,accuracy_eval:0.22916666666666669
Epoch:2


100%|██████████| 15/15 [00:07<00:00,  1.89it/s]


loss_train:2.275603771209717,accuracy_train:0.21770833333333334


100%|██████████| 4/4 [00:01<00:00,  3.79it/s]


loss_eval:2.2430338859558105,accuracy_eval:0.23177083333333331
Epoch:3


100%|██████████| 15/15 [00:07<00:00,  1.89it/s]


loss_train:2.1961429119110107,accuracy_train:0.271875


100%|██████████| 4/4 [00:01<00:00,  3.61it/s]


loss_eval:2.1481096744537354,accuracy_eval:0.3606770833333333
Epoch:4


100%|██████████| 15/15 [00:07<00:00,  1.90it/s]


loss_train:2.077285051345825,accuracy_train:0.434375


100%|██████████| 4/4 [00:01<00:00,  3.74it/s]


loss_eval:2.039973735809326,accuracy_eval:0.48046875
Epoch:5


100%|██████████| 15/15 [00:07<00:00,  1.91it/s]


loss_train:1.9597737789154053,accuracy_train:0.5479166666666667


100%|██████████| 4/4 [00:01<00:00,  3.81it/s]


loss_eval:1.9404597282409668,accuracy_eval:0.5729166666666666
Epoch:6


100%|██████████| 15/15 [00:07<00:00,  1.90it/s]


loss_train:1.881927728652954,accuracy_train:0.6145833333333334


100%|██████████| 4/4 [00:01<00:00,  3.81it/s]


loss_eval:1.84578537940979,accuracy_eval:0.6692708333333334
Epoch:7


100%|██████████| 15/15 [00:08<00:00,  1.67it/s]


loss_train:1.8189990520477295,accuracy_train:0.6770833333333334


100%|██████████| 4/4 [00:01<00:00,  2.83it/s]


loss_eval:1.838819146156311,accuracy_eval:0.6627604166666666
Epoch:8


100%|██████████| 15/15 [00:09<00:00,  1.56it/s]


loss_train:1.7969741821289062,accuracy_train:0.6822916666666666


100%|██████████| 4/4 [00:01<00:00,  3.18it/s]


loss_eval:1.7973158359527588,accuracy_eval:0.69921875
Epoch:9


100%|██████████| 15/15 [00:11<00:00,  1.32it/s]


loss_train:1.7729779481887817,accuracy_train:0.7166666666666667


100%|██████████| 4/4 [00:01<00:00,  2.87it/s]


loss_eval:1.8107765913009644,accuracy_eval:0.6627604166666666
Epoch:10


100%|██████████| 15/15 [00:09<00:00,  1.57it/s]


loss_train:1.7567237615585327,accuracy_train:0.7270833333333333


100%|██████████| 4/4 [00:01<00:00,  3.70it/s]


loss_eval:1.8110809326171875,accuracy_eval:0.671875
Epoch:11


100%|██████████| 15/15 [00:09<00:00,  1.52it/s]


loss_train:1.736576795578003,accuracy_train:0.740625


100%|██████████| 4/4 [00:01<00:00,  2.59it/s]


loss_eval:1.7710667848587036,accuracy_eval:0.6966145833333334
Epoch:12


100%|██████████| 15/15 [00:10<00:00,  1.43it/s]


loss_train:1.7341886758804321,accuracy_train:0.7479166666666667


100%|██████████| 4/4 [00:01<00:00,  2.39it/s]


loss_eval:1.7882518768310547,accuracy_eval:0.6979166666666666
Epoch:13


100%|██████████| 15/15 [00:09<00:00,  1.60it/s]


loss_train:1.7274798154830933,accuracy_train:0.74375


100%|██████████| 4/4 [00:02<00:00,  1.98it/s]


loss_eval:1.7768439054489136,accuracy_eval:0.7018229166666666
Epoch:14


100%|██████████| 15/15 [00:11<00:00,  1.33it/s]


loss_train:1.73273503780365,accuracy_train:0.7427083333333333


100%|██████████| 4/4 [00:01<00:00,  2.74it/s]


loss_eval:1.769434928894043,accuracy_eval:0.7057291666666666
Epoch:15


100%|██████████| 15/15 [00:11<00:00,  1.25it/s]


loss_train:1.710880994796753,accuracy_train:0.7645833333333333


100%|██████████| 4/4 [00:01<00:00,  3.76it/s]


loss_eval:1.7520548105239868,accuracy_eval:0.7265625
Epoch:16


100%|██████████| 15/15 [00:09<00:00,  1.58it/s]


loss_train:1.7043172121047974,accuracy_train:0.76875


100%|██████████| 4/4 [00:01<00:00,  3.25it/s]


loss_eval:1.7640527486801147,accuracy_eval:0.7200520833333334
Epoch:17


100%|██████████| 15/15 [00:10<00:00,  1.44it/s]


loss_train:1.7033119201660156,accuracy_train:0.7729166666666667


100%|██████████| 4/4 [00:01<00:00,  3.75it/s]


loss_eval:1.7646996974945068,accuracy_eval:0.703125
Epoch:18


100%|██████████| 15/15 [00:08<00:00,  1.72it/s]


loss_train:1.7024482488632202,accuracy_train:0.7729166666666667


100%|██████████| 4/4 [00:01<00:00,  3.14it/s]


loss_eval:1.761509895324707,accuracy_eval:0.7083333333333334
Epoch:19


100%|██████████| 15/15 [00:09<00:00,  1.64it/s]


loss_train:1.6988489627838135,accuracy_train:0.7739583333333333


100%|██████████| 4/4 [00:01<00:00,  2.52it/s]


loss_eval:1.75214421749115,accuracy_eval:0.7044270833333334
Epoch:20


100%|██████████| 15/15 [00:08<00:00,  1.85it/s]


loss_train:1.6944843530654907,accuracy_train:0.7791666666666667


100%|██████████| 4/4 [00:01<00:00,  3.56it/s]


loss_eval:1.7157655954360962,accuracy_eval:0.7747395833333334
Epoch:21


100%|██████████| 15/15 [00:08<00:00,  1.76it/s]


loss_train:1.6821684837341309,accuracy_train:0.790625


100%|██████████| 4/4 [00:01<00:00,  3.77it/s]


loss_eval:1.7310465574264526,accuracy_eval:0.7161458333333334
Epoch:22


100%|██████████| 15/15 [00:10<00:00,  1.39it/s]


loss_train:1.679027795791626,accuracy_train:0.7927083333333333


100%|██████████| 4/4 [00:01<00:00,  3.28it/s]


loss_eval:1.7175936698913574,accuracy_eval:0.7682291666666666
Epoch:23


100%|██████████| 15/15 [00:11<00:00,  1.30it/s]


loss_train:1.656499981880188,accuracy_train:0.8208333333333333


100%|██████████| 4/4 [00:01<00:00,  2.74it/s]


loss_eval:1.7177624702453613,accuracy_eval:0.7526041666666666
Epoch:24


100%|██████████| 15/15 [00:11<00:00,  1.30it/s]


loss_train:1.6534007787704468,accuracy_train:0.815625


100%|██████████| 4/4 [00:01<00:00,  2.08it/s]


loss_eval:1.7113016843795776,accuracy_eval:0.7591145833333334
Epoch:25


100%|██████████| 15/15 [00:09<00:00,  1.51it/s]


loss_train:1.6543910503387451,accuracy_train:0.8166666666666667


100%|██████████| 4/4 [00:01<00:00,  3.73it/s]


loss_eval:1.729932188987732,accuracy_eval:0.7317708333333334
Epoch:26


100%|██████████| 15/15 [00:08<00:00,  1.77it/s]


loss_train:1.6422721147537231,accuracy_train:0.8322916666666667


100%|██████████| 4/4 [00:01<00:00,  3.64it/s]


loss_eval:1.717791199684143,accuracy_eval:0.7604166666666666
Epoch:27


100%|██████████| 15/15 [00:09<00:00,  1.55it/s]


loss_train:1.6353294849395752,accuracy_train:0.8375


100%|██████████| 4/4 [00:01<00:00,  3.64it/s]


loss_eval:1.6797819137573242,accuracy_eval:0.7786458333333334
Epoch:28


100%|██████████| 15/15 [00:10<00:00,  1.39it/s]


loss_train:1.6094032526016235,accuracy_train:0.8697916666666666


100%|██████████| 4/4 [00:01<00:00,  3.37it/s]


loss_eval:1.6882514953613281,accuracy_eval:0.7981770833333334
Epoch:29


100%|██████████| 15/15 [00:10<00:00,  1.40it/s]


loss_train:1.6160534620285034,accuracy_train:0.8489583333333334


100%|██████████| 4/4 [00:01<00:00,  3.28it/s]

loss_eval:1.6980441808700562,accuracy_eval:0.7578125





In [None]:
torch.save(model.state_dict(), "/content/drive/MyDrive/MNIST_persian/model_mnist.pth")