# Vietnamese Handwritten Recognition with CRNN model

## Data Loader

In [17]:
from torch.utils.data import Dataset, DataLoader
from tensorflow.keras.preprocessing.sequence import pad_sequences
from torchvision.transforms import ToTensor, Resize, Compose
import torchvision.transforms.functional as F
import os
import numpy as np
from PIL import Image


train_folder_path = '/kaggle/input/handwritten-ocr/training_data/new_train' 
test_folder_path = '/kaggle/input/handwritten-ocr/public_test_data/new_public_test'
label_file_path = '/kaggle/input/handwriting/train_gt.txt'
root = '/kaggle/input/handwritten-ocr'


def encode_to_num(text, char_list):
    encoded_label = []
    for char in text:
        encoded_label.append(char_list.index(char)+1)
    return encoded_label

class OCRDataset(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.train = train
        self.transform = transform
        if train:
            dir = os.path.join(root, train_folder_path)
            paths = os.listdir(dir)
            image_files = [os.path.join(dir, path) for path in paths]
            label_file = label_file_path
        else:
            dir = os.path.join(root, test_folder_path)
            paths = os.listdir(dir)
            image_files = [os.path.join(dir, path) for path in paths]
        
        self.images_path = image_files
        if train:
            self.labels = []
            with open(label_file, encoding='utf-8') as f:
                self.labels = [line.split()[1] for line in f.readlines()]
            char_list= set()
            for label in self.labels:
                char_list.update(set(label))
            self.char_list = sorted(char_list)
            for i in range(len(self.labels)):
                self.labels[i] = encode_to_num(self.labels[i], self.char_list)

    def __len__(self):
        return len(self.images_path)
    def __getitem__(self, idx):      
        image_path = self.images_path[idx]
        image = Image.open(image_path).convert('L')
        if self.transform:
            image = self.transform(image)
        if self.train:
            label = self.labels[idx]
            max_seq_len = 32
            padded_label = np.squeeze(pad_sequences([label], maxlen=max_seq_len, padding='post', value = 0))
            return image, padded_label, len(label)
        else:
            return image
        


transform = Compose([
    Resize((64,128)),
    ToTensor(),
    ])

train_dataloader = DataLoader(
    dataset=OCRDataset(root = train_folder_path, train=True, transform=transform),
    batch_size=8,
    num_workers=4,
    drop_last=True,
    shuffle=True
)
test_dataloader = DataLoader(
    dataset=OCRDataset(root = test_folder_path, train=False, transform=transform),
    batch_size=8,
    num_workers=4,
    drop_last=True,
    shuffle=True
)


In [18]:

# if __name__ == '__main__':
ocr = OCRDataset(root = root, train=True, transform=transform)
image, label,_ = ocr.__getitem__(100)
print(image.shape)
print(label)


torch.Size([1, 64, 128])
[ 51 131   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0]


In [19]:

    for images, labels, nothing in train_dataloader:
        print(images)
        print(labels)
        print(nothing)
        break

tensor([[[[0.6667, 0.6667, 0.6627,  ..., 0.6627, 0.6627, 0.6627],
          [0.6627, 0.6627, 0.6588,  ..., 0.6667, 0.6667, 0.6667],
          [0.6627, 0.6627, 0.6588,  ..., 0.6627, 0.6667, 0.6627],
          ...,
          [0.6784, 0.6627, 0.6510,  ..., 0.6431, 0.6471, 0.6510],
          [0.6784, 0.6667, 0.6549,  ..., 0.6431, 0.6471, 0.6510],
          [0.6784, 0.6706, 0.6588,  ..., 0.6392, 0.6431, 0.6471]]],


        [[[0.9294, 0.9294, 0.9255,  ..., 0.9059, 0.8941, 0.8902],
          [0.9255, 0.9294, 0.9294,  ..., 0.9137, 0.8980, 0.8863],
          [0.9216, 0.9176, 0.9294,  ..., 0.9098, 0.8941, 0.8863],
          ...,
          [0.8627, 0.8863, 0.9020,  ..., 0.9216, 0.9176, 0.9137],
          [0.8549, 0.8745, 0.9098,  ..., 0.9137, 0.9216, 0.9137],
          [0.8706, 0.8824, 0.9059,  ..., 0.9137, 0.9216, 0.9294]]],


        [[[0.6157, 0.6118, 0.6078,  ..., 0.6157, 0.6118, 0.6118],
          [0.6157, 0.6196, 0.6275,  ..., 0.6157, 0.6118, 0.6118],
          [0.6235, 0.6235, 0.6196,  ..

In [20]:
print(train_dataloader.__len__()*8)
print(test_dataloader.__len__()*8)

103000
33000


## Build Model

In [21]:
import torch
import torch.nn as nn
# from torchsummary import summary

class CRNN(nn.Module):
    def __init__(self, time_steps, num_classes, drop_out_rate = 0.35):
        super().__init__()
        self.time_steps = time_steps
        #CNN
        self.conv1 = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding='same', bias=True),
        nn.BatchNorm2d(num_features=64),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=(2,2))
        )
        self.conv2 = nn.Sequential(
        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding='same', bias=True),
        nn.BatchNorm2d(num_features=128),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=(2,2))
        )
        self.conv3 = nn.Sequential(
        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding='same', bias=True),
        nn.BatchNorm2d(num_features=256),
        nn.ReLU(),
        )
        self.conv4 = nn.Sequential(
        nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding='same', bias=True),
        nn.Dropout(drop_out_rate),
        nn.BatchNorm2d(num_features=256),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=(1,2), stride=(1,2))
        )
        self.conv5 = nn.Sequential(
        nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding='same', bias=True),
        nn.BatchNorm2d(num_features=512),
        nn.ReLU(),
        )
        self.conv6 = nn.Sequential(
        nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding='same', bias=True),
        nn.Dropout(drop_out_rate),
        nn.BatchNorm2d(num_features=512),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=(1,2), stride=(1,2))     
        )   
        self.conv7 = nn.Sequential(
        nn.Conv2d(in_channels=512, out_channels=512, kernel_size=2, padding='same', bias=True),
        nn.Dropout(0.25),
        nn.BatchNorm2d(num_features=512),
        nn.ReLU()
        )

        self.fc1 = nn.Sequential(
        nn.Linear(4096, 512),
        nn.ReLU())

        #RNN
        self.rnn1 = nn.LSTM(input_size=512, hidden_size=256, bidirectional=True, batch_first=True)
        self.rnn2 = nn.LSTM(input_size=512, hidden_size=256, bidirectional=True, batch_first=True)
        #FC
        self.fc2 = nn.Linear(512, num_classes)
        #Softmax
        self.softmax = nn.LogSoftmax(dim=2)

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)        
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        
        #CNN to RNN
        x = x.reshape(x.shape[0], self.time_steps, -1)  # reshape (batch_size, seq_length, -1) # 16 = time_steps
        x = self.fc1(x)
       
        x = self.rnn1(x)[0]
        x = self.rnn2(x)[0]
        x = self.fc2(x)
        x = self.softmax(x)
        x = x.permute(1,0,2)
        return x

In [22]:
input_data = torch.rand(8, 1, 64, 128)

model = CRNN(time_steps = 16,num_classes=188).cuda()
if torch.cuda.is_available():
    input_data = input_data.cuda()
while True:
    result = model(input_data)
    print(result.shape)
    break

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 14.76 GiB total capacity; 140.29 MiB already allocated; 5.75 MiB free; 144.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

## Train model

In [None]:
!pip install mltu

In [None]:
import os
from datetime import datetime

from mltu.configs import BaseModelConfigs

class ModelConfigs(BaseModelConfigs):
    def __init__(self):
        super().__init__()
        self.trained_models = 'trained_model'
        self.root = 'data'
        self.height = 64
        self.width = 128
        self.max_label_len = 16
        self.epochs = 100        
        self.batch_size = 8
        self.learning_rate = 1e-3
        self.train_workers = 4
        self.logging = 'tensorboard'
        self.checkpoint = None

In [None]:
# pip install torchtools


In [None]:
def validation(model, device, valid_loader, loss_function):
    model.eval()
    for iter, (images, padded_labels, label_lenghts) in enumerate(val_dataloader):
            images = images.to(device)
            padded_labels = padded_labels.to(device)
            with torch.no_grad():
                predictions = model(images)  
                loss_value = criterion(predictions, padded_labels, output_lengths, label_lenghts)
    writer.add_scalar("Val/Loss", loss_value, epoch)
    checkpoint = {
        "epoch": epoch + 1,
        "best_loss" : best_loss,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict()
    }
    torch.save(checkpoint, "{}/last_crnn.pt".format(trained_models))

    if loss_value >= best_loss:
        torch.save(checkpoint, "{}/best_crnn.pt".format(trained_models))
        best_loss = loss_value
    print('Validate', loss_value)


In [None]:
pip install tensorboard

In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# import torch.utils.data as data
# from torchvision import datasets, transforms
# import os
# import torch
# import torch.nn as nn
# # from OCRDataset import OCRDataset
# from torchvision.transforms import ToTensor, Resize, Compose, RandomAffine, ColorJitter
# from torchvision.transforms import ToTensor, Resize, Compose
# from torch.utils.data import DataLoader
# from torch.utils.data import random_split
# # from model import CRNN
# import itertools
# import numpy as np
# from argparse import ArgumentParser
# # from config import ModelConfigs
# from tqdm.autonotebook import tqdm
# from torch.utils.tensorboard import SummaryWriter
# import shutil
# import warnings
# warnings.simplefilter("ignore")

# # augment data
# augment_transform= Compose([RandomAffine(
#                                             degrees=(-5, 5),
#                                             scale=(0.5, 1.05), 
#                                             shear=10),
#                                             ColorJitter(
#                                                         brightness=0.5, 
#                                                         contrast=0.5,
#                                                         saturation=0.5,
#                                                         hue=0.5)
#                            ])

# def validation(model, device, val_dataloader, loss_function):
#     model.eval()
#     loss_value = 0
#     for iter, (images, padded_labels, label_lenghts) in enumerate(val_dataloader):
#         images = images.to(device)
#         padded_labels = padded_labels.to(device)
#         with torch.no_grad():
#             predictions = model(images)  
#             loss_value = criterion(predictions, padded_labels, output_lengths, label_lenghts)
#     writer.add_scalar("Val/Loss", loss_value, epoch)
#     checkpoint = {
#         "epoch": epoch + 1,
#         "best_loss" : best_loss,
#         "model": model.state_dict(),
#         "optimizer": optimizer.state_dict()
#     }
#     torch.save(checkpoint, "{}/last_crnn.pt".format(trained_models))

#     if loss_value >= best_loss:
#         torch.save(checkpoint, "{}/best_crnn.pt".format(trained_models))
#         best_loss = loss_value
#     return loss_value

# def traindata(device, model, start_epoch, epochs, optimizer, loss_function , train_loader, valid_loader):
#     # Early stopping
#     last_loss = 100
#     patience = 2
#     triggertimes = 0
    
#     for epoch in range(start_epoch, epochs):
#         model.train()
#         progress_bar = tqdm(train_loader, colour="green")
#         for iter, (images, padded_labels, label_lenghts) in enumerate(train_loader):
#             images = augment_transform(images)
#             images = images.to(device)
#             padded_labels = padded_labels.to(device)
#             #forward
#             outputs = model(images)
#             loss_value = loss_function(outputs, padded_labels, output_lengths, label_lenghts)
#             if torch.isinf(loss_value):
#                 print(outputs)
#                 exit()
#             progress_bar.set_description("Epoch {}/{}. Iteration {}/{}. Loss{:3f}".format(epoch+1, num_epochs, iter+1, num_iters, loss_value))
#             writer.add_scalar("Train/Loss", loss_value, epoch*num_iters+iter)
#             #backward
#             optimizer.zero_grad()
#             loss_value.backward()  
#             optimizer.step()


#             # Show progress
#             if iter % 100 == 0 or iter == len(train_loader):
#                 print('[{}/{}, {}/{}] loss: {:.8}'.format(epoch, epochs, iter, len(train_loader), loss_value.item()))

#         # Early stopping
#         current_loss = validation(model, device, valid_loader, loss_function)
#         print('The Current Loss:', current_loss)

#         if current_loss > last_loss:
#             trigger_times += 1
#             print('Trigger Times:', trigger_times)

#             if trigger_times >= patience:
#                 print('Early stopping!\nStart to test process.')
#                 return model

#         else:
#             print('trigger times: 0')
#             trigger_times = 0

#         last_loss = current_loss

#     return model


# def words_from_labels(labels, char_list):
#     """
#     converts the list of encoded integer labels to word strings like eg. [12,10,29] returns CAT 
#     """
#     txt=[]
#     for ele in labels:
#         if ele == 0: # CTC blank space
#             txt.append("")
#         else:
#             #print(letters[ele])
#             txt.append(char_list[ele+1])
#     return "".join(txt)

# def decode_batch(test_func, word_batch): #take only a sequence once a time
#     """
#     Takes the Batch of Predictions and decodes the Predictions by Best Path Decoding and Returns the Output
#     """
#     out = test_func([word_batch])[0] #returns the predicted output matrix of the model
#     ret = []
#     for j in range(out.shape[0]):
#         out_best = list(np.argmax(out[j, :], 1))
#         out_best = [k for k, g in itertools.groupby(out_best)]
#         outstr = words_from_labels(out_best)
#         ret.append(outstr)
#     return ret

# # device
# if torch.cuda.is_available():
#         device = torch.device("cuda")
# else:
#     device = torch.device("cpu")


# configs = ModelConfigs()
# root = configs.root
# num_epochs = configs.epochs
# batch_size = configs.batch_size
# max_label_len = configs.max_label_len
# height = configs.height
# width = configs.width
# learning_rate = configs.learning_rate
# logging = configs.logging
# trained_models = configs.trained_models
# checkpoint = configs.checkpoint

# transform = Compose([
#         Resize((height,width)),
#         ToTensor(),
#          ])

# #split train/val dataset
# dataset = OCRDataset(root = train_folder_path, train=True, transform=transform)  # Replace with your dataset
# train_size = int(0.9 * len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# train_dataloader = DataLoader(
#     dataset=train_dataset,
#     batch_size=batch_size,
#     num_workers=4,
#     drop_last=True,
#     shuffle=True
# )

# val_dataloader = DataLoader(
#     dataset=val_dataset,
#     batch_size=batch_size,
#     num_workers=4,
#     drop_last=True,
#     shuffle=True
# )
# # if not os.path.isdir(logging):
# #     shutil.rmtree(logging)
# # if not os.path.isdir(trained_models):
# #     os.mkdir(trained_models)
# writer = SummaryWriter(logging)
# # Model architecture
# # class CRNN(nn.Module):

# char_list = dataset.char_list
# model = CRNN(time_steps=max_label_len, num_classes=len(char_list)+1).to(device)
# criterion = nn.CTCLoss(blank=0)
# output_lengths = torch.full(size=(batch_size,), fill_value=max_label_len, dtype=torch.long)
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# best_loss = 0
# if checkpoint:
#     checkpoint = torch.load(checkpoint)
#     start_epoch = checkpoint['epoch']
#     best_loss = checkpoint['best_loss']  
#     model.load_state_dict(checkpoint["model"])
#     optimizer.load_state_dict(checkpoint["optimizer"])
# else:
#     start_epoch = 0  
# num_iters = len(train_dataloader)

# traindata(device, model, start_epoch, num_epochs, optimizer, criterion , train_dataloader, val_dataloader)

# validation(model, device, val_dataloader, criterion)


# Add callbacks

## Save point

In [None]:
import torch

def save_checkpoint(model, optimizer, epoch, val_loss, checkpoint_path='checkpoint.pth'):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
    }
    torch.save(checkpoint, checkpoint_path)

# Usage during training loop:
# After evaluating validation loss
# save_checkpoint(model, optimizer, epoch, val_loss, 'checkpoint.pth')

## Early Stopping

In [None]:
class EarlyStopping:
    def __init__(self, patience=20, min_delta=1e-8, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

# # Usage during training loop:
# early_stopping = EarlyStopping(patience=20, min_delta=1e-8, restore_best_weights=True)
# if early_stopping(val_loss):
#     print("Early stopping triggered.")
#     if early_stopping.restore_best_weights:
#         # Restore the model to the best state
#         checkpoint = torch.load('checkpoint.pth')
#         model.load_state_dict(checkpoint['model_state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


## ReduceLR

In [None]:

# # Create a learning rate scheduler
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=10, verbose=True)

# # Usage during training loop:
# scheduler.step(val_loss)  # Adjust learning rate based on validation loss


In [None]:
import os
import torch
import torch.nn as nn
from torchvision.transforms import ToTensor, Resize, Compose, RandomAffine, ColorJitter
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import itertools
import numpy as np
from argparse import ArgumentParser
from tqdm.autonotebook import tqdm
from torch.utils.tensorboard import SummaryWriter
import shutil
import warnings
from torch.optim.lr_scheduler import ReduceLROnPlateau

warnings.simplefilter("ignore")

def words_from_labels(labels, char_list):
    """
    converts the list of encoded integer labels to word strings like eg. [12,10,29] returns CAT 
    """
    txt = []
    for ele in labels:
        if ele == 0:  # CTC blank space
            txt.append("")
        else:
            txt.append(char_list[ele + 1])
    return "".join(txt)

def decode_batch(test_func, word_batch):
    """
    Takes the Batch of Predictions and decodes the Predictions by Best Path Decoding and Returns the Output
    """
    out = test_func([word_batch])[0]
    ret = []
    for j in range(out.shape[0]):
        out_best = list(np.argmax(out[j, :], 1))
        out_best = [k for k, g in itertools.groupby(out_best)]
        outstr = words_from_labels(out_best)
        ret.append(outstr)
    return ret

if __name__ == '__main__':
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    configs = ModelConfigs()
    root = configs.root
#     num_epochs = configs.epochs
    num_epochs = 1

    batch_size = configs.batch_size
    max_label_len = configs.max_label_len
    height = configs.height
    width = configs.width
    learning_rate = configs.learning_rate
    logging = configs.logging
    trained_models = configs.trained_models
    checkpoint = configs.checkpoint

    transform = Compose([
        Resize((height, width)),
        ToTensor(),
    ])

    augment_transform = Compose([RandomAffine(
        degrees=(-5, 5),
        scale=(0.5, 1.05),
        shear=10),
        ColorJitter(
            brightness=0.5,
            contrast=0.5,
            saturation=0.5,
            hue=0.5)])

    dataset = OCRDataset(root=root, train=True, transform=transform)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_dataloader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        num_workers=4,
        drop_last=True,
        shuffle=True
    )

    val_dataloader = DataLoader(
        dataset=val_dataset,
        batch_size=batch_size,
        num_workers=4,
        drop_last=True,
        shuffle=True
    )

    if not os.path.isdir(logging):
        shutil.rmtree(logging)
    if not os.path.isdir(trained_models):
        os.mkdir(trained_models)
    writer = SummaryWriter(logging)

    char_list = dataset.char_list
    model = CRNN(time_steps=max_label_len, num_classes=len(char_list) + 1).to(device)
    criterion = nn.CTCLoss(blank=0)
    output_lengths = torch.full(size=(batch_size,), fill_value=max_label_len, dtype=torch.long)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Initialize ReduceLROnPlateau scheduler
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    # Early stopping 
    early_stopping = EarlyStopping(patience=20, min_delta=1e-8, restore_best_weights=True)
    
    best_loss = float('inf')
    early_stopping_counter = 0
    early_stopping_patience = 10  # Adjust as needed

    if checkpoint:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch']
        best_loss = checkpoint['best_loss']
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
    else:
        start_epoch = 0
    num_iters = len(train_dataloader)

    for epoch in range(start_epoch, num_epochs):
        model.train()
        progress_bar = tqdm(train_dataloader, colour="green")
        for iter, (images, padded_labels, label_lenghts) in enumerate(train_dataloader):
            images = augment_transform(images)
            images = images.to(device)
            padded_labels = padded_labels.to(device)
            outputs = model(images)
            loss_value = criterion(outputs, padded_labels, output_lengths, label_lenghts)
            if torch.isinf(loss_value):
                print(outputs)
                exit()
            progress_bar.set_description("Epoch {}/{}. Iteration {}/{}. Loss{:3f}".format(epoch + 1, num_epochs,
                                                                                         iter + 1, num_iters, loss_value))
            writer.add_scalar("Train/Loss", loss_value, epoch * num_iters + iter)
            optimizer.zero_grad()
            loss_value.backward()
            optimizer.step()

        model.eval()
        for iter, (images, padded_labels, label_lenghts) in enumerate(val_dataloader):
            images = images.to(device)
            padded_labels = padded_labels.to(device)
            with torch.no_grad():
                predictions = model(images)
                loss_value = criterion(predictions, padded_labels, output_lengths, label_lenghts)
        writer.add_scalar("Val/Loss", loss_value, epoch)
        
        # Update learning rate scheduler
        scheduler.step(loss_value)
        
        
        checkpoint = {
            "epoch": epoch + 1,
            "best_loss": best_loss,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict()
        }
#         torch.save(checkpoint, "{}/last_crnn.pt".format(trained_models))

#         if loss_value <= best_loss:
#             torch.save(checkpoint, "{}/best_crnn.pt".format(trained_models))
#             best_loss = loss_value
#         else:
#             early_stopping_counter += 1
#             if early_stopping_counter >= early_stopping_patience:
#                 print("Early stopping triggered.")
#                 break
                
        if early_stopping(val_loss):
            print("Early stopping triggered.")
        if early_stopping.restore_best_weights:
            checkpoint = torch.load('checkpoint.pth')
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            
        print('Validate', loss_value)


In [None]:
# model = traindata(device, model, start_epoch, num_epochs, optimizer, criterion , train_dataloader, val_dataloader)

# model

In [None]:
# validation(model, device, val_dataloader, criterion)
# 

In [None]:
# Exponentially weighted averages are a type of moving average that give more weight and significance to recent data points. This is in contrast to simple moving averages, which give equal weight to all data points within a specified period.