In [9]:
import sys
sys.path.append('/Users/parth/Desktop/Flash/')

In [10]:
import torch 
import torch.nn as nn
import numpy as np
import pandas as pd 
import neptune

from sklearn import metrics
from sklearn.model_selection import train_test_split

import timm
import albumentations
from neptune.types import File
from torchret import Model 
import plotly.graph_objects as go

In [11]:
configs = {
    'lr' : 1e-4,
    'eta_min' : 1e-6,
    'T_0' : 20,
    'epochs' : 20,
    'step_scheduler_after' : 'epoch',

    'train_bs' : 256,
    'valid_bs' : 256,

    'num_workers' : 0,
    'pin_memory' : False,

    'model_name' : 'resnet10t',
    'pretrained' : True,
    'num_classes' : 10,
    'in_channels' : 1,
    'device' : 'mps',

    'model_path' : 'digit-recognizer.pt',
    'save_best_model' : 'on_eval_metric',
    'save_on_metric' : 'accuracy',
    'save_model_at_every_epoch' : False,

}

In [12]:
class DigitRecognizerDataset(torch.utils.data.Dataset):
    def __init__(self, df, augmentations):
        self.df = df
        self.targets = df.label.values
        self.df = self.df.drop(columns=["label"])
        self.augmentations = augmentations

        self.images = self.df.to_numpy(dtype=np.float32).reshape((-1, 28, 28))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        targets = self.targets[item]
        image = self.images[item]
        image = np.expand_dims(image, axis=0)

        return {
            "images": torch.tensor(image, dtype=torch.float),
            "targets": torch.tensor(targets, dtype=torch.long),
        }

In [13]:
train_augs = albumentations.Compose(
        [
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0,
                p=1.0,
            ),
        ],
        p=1.0,
    )

valid_augs = albumentations.Compose(
        [
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0,
                p=1.0,
            ),
        ],
        p=1.0,
    )

In [14]:
class DigitRecognizerModel(Model):
    def __init__(self):
        super().__init__()

        self.model = timm.create_model(
            model_name = configs['model_name'],
            pretrained=configs['pretrained'],
            in_chans=configs['in_channels'],
            num_classes=configs['num_classes'],
        )

        self.num_workers = configs['num_workers']
        self.pin_memory = configs['pin_memory']
        self.step_scheduler_after = configs['step_scheduler_after']

        self.model_path = configs['model_path']
        self.save_best_model = configs['save_best_model']
        self.save_on_metric = configs['save_on_metric']
        self.save_model_at_every_epoch = configs['save_model_at_every_epoch']

    def setup_logger(self):
        neptune_api = NEPTUNE_API_TOKEN
        self.run = neptune.init_run(
            project='fenilsavani62/Digit-recog',
            api_token=neptune_api,
            capture_stdout=True,       # Enable capture of stdout
            capture_stderr=True,       # Enable capture of stderr
            capture_traceback=True,    # Enable capture of traceback
            capture_hardware_metrics=True,  # Enable capture of hardware metrics
            source_files='*.ipynb'  # Capture notebook outputs
        )
        self.run['parameters'] = configs

    def valid_one_step_logs(self, batch_id, data, logits, loss, metrics):
        if batch_id % len(self.validloader) == 100:
            images = data['images']
            labels = data['targets']
            outputs = np.argmax(logits.cpu().detach().numpy(), axis=1)
            images = images.permute(0, 2, 3, 1).squeeze().cpu()
            for i in range(len(images)):
                description = f'true label : {labels[i]} prediction : {outputs[i]}'
                self.run["valid/prediction_example"].append(File.as_image(images[i]), description = description)

    def train_one_epoch_logs(self, loss, monitor):
        self.run['train/loss'].append(loss)
        self.run['train/monitors'].append(monitor)

    def valid_one_epoch_logs(self, loss, monitor):
        self.run['valid/loss'].append(loss)
        self.run['valid/monitors'].append(monitor)

    def monitor_metrics(self, outputs, targets):
        device = targets.device.type
        outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1)
        targets = targets.cpu().detach().numpy()
        acc = metrics.accuracy_score(targets, outputs)
        acc = torch.tensor(acc).float()
        f1_score = metrics.f1_score(targets, outputs, average = 'macro')
        f1_score = torch.tensor(f1_score)
        return {"accuracy": acc, 'f1_score' : f1_score}
    
    def monitor_loss(self, outputs, targets):
        loss = nn.CrossEntropyLoss()(outputs, targets)
        return loss

    def fetch_optimizer(self):
        opt = torch.optim.SGD(
            self.parameters(),
            lr=configs['lr'],
            momentum=0.9,
        )
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            opt, 
            T_0 = configs['T_0'],
            eta_min = configs['eta_min'],
            T_mult= 1
        )
        return opt, sch

    def forward(self, images, targets=None):
        x = self.model(images)
        if targets is not None:
            loss = self.monitor_loss(x, targets)
            metrics = self.monitor_metrics(x, targets)
            return x, loss, metrics
        return x, 0, {}

In [15]:
def main():
    df = pd.read_csv('train.csv')
    train, test = train_test_split(df, test_size=0.2)

    train_dataset = DigitRecognizerDataset(df = train, augmentations = train_augs)
    valid_dataset = DigitRecognizerDataset(df = test, augmentations = valid_augs)

    model = DigitRecognizerModel()
    model.fit(train_dataset, valid_dataset, train_bs = configs['train_bs'], valid_bs = configs['valid_bs'], device = configs['device'], epochs = configs['epochs'], logger=False)

if __name__ == "__main__":
    main()

100%|██████████| 132/132 [00:21<00:00,  6.12it/s, accuracy=0.229, current_lr=0.0001, epoch=1, f1_score=0.219, loss=2.261829, stage=train]
100%|██████████| 33/33 [00:01<00:00, 17.27it/s, accuracy=0.406, epoch=1, f1_score=0.393, loss=1.799884, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.4059130973888166 accuracy


100%|██████████| 132/132 [00:17<00:00,  7.42it/s, accuracy=0.532, current_lr=0.0001, epoch=2, f1_score=0.518, loss=1.507485, stage=train]
100%|██████████| 33/33 [00:01<00:00, 19.44it/s, accuracy=0.626, epoch=2, f1_score=0.614, loss=1.262709, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.6260744459701307 accuracy


100%|██████████| 132/132 [00:17<00:00,  7.49it/s, accuracy=0.693, current_lr=0.0001, epoch=3, f1_score=0.681, loss=1.084317, stage=train]
100%|██████████| 33/33 [00:01<00:00, 19.29it/s, accuracy=0.73, epoch=3, f1_score=0.72, loss=0.960774, stage=eval]  


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.7300772143132759 accuracy


100%|██████████| 132/132 [00:18<00:00,  7.16it/s, accuracy=0.772, current_lr=0.0001, epoch=4, f1_score=0.764, loss=0.833752, stage=train]
100%|██████████| 33/33 [00:01<00:00, 18.86it/s, accuracy=0.793, epoch=4, f1_score=0.786, loss=0.758647, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.7925681092522361 accuracy


100%|██████████| 132/132 [00:20<00:00,  6.38it/s, accuracy=0.82, current_lr=0.0001, epoch=5, f1_score=0.813, loss=0.669492, stage=train] 
100%|██████████| 33/33 [00:02<00:00, 15.65it/s, accuracy=0.828, epoch=5, f1_score=0.822, loss=0.622115, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.8282251610900416 accuracy


100%|██████████| 132/132 [00:18<00:00,  7.13it/s, accuracy=0.85, current_lr=0.0001, epoch=6, f1_score=0.845, loss=0.555259, stage=train] 
100%|██████████| 33/33 [00:01<00:00, 19.67it/s, accuracy=0.855, epoch=6, f1_score=0.85, loss=0.524166, stage=eval] 


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.8547858397165934 accuracy


100%|██████████| 132/132 [00:17<00:00,  7.56it/s, accuracy=0.874, current_lr=0.0001, epoch=7, f1_score=0.87, loss=0.470227, stage=train] 
100%|██████████| 33/33 [00:01<00:00, 20.44it/s, accuracy=0.876, epoch=7, f1_score=0.872, loss=0.450679, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.8760653409090909 accuracy


100%|██████████| 132/132 [00:17<00:00,  7.76it/s, accuracy=0.89, current_lr=0.0001, epoch=8, f1_score=0.886, loss=0.408348, stage=train] 
100%|██████████| 33/33 [00:01<00:00, 20.30it/s, accuracy=0.891, epoch=8, f1_score=0.888, loss=0.397884, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.8911804340102456 accuracy


100%|██████████| 132/132 [00:17<00:00,  7.69it/s, accuracy=0.904, current_lr=0.0001, epoch=9, f1_score=0.901, loss=0.357748, stage=train]
100%|██████████| 33/33 [00:01<00:00, 20.59it/s, accuracy=0.903, epoch=9, f1_score=0.9, loss=0.353898, stage=eval]  


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.9032725095748901 accuracy


100%|██████████| 132/132 [00:17<00:00,  7.60it/s, accuracy=0.914, current_lr=0.0001, epoch=10, f1_score=0.91, loss=0.316879, stage=train] 
100%|██████████| 33/33 [00:01<00:00, 19.63it/s, accuracy=0.914, epoch=10, f1_score=0.911, loss=0.318337, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.9141626610900416 accuracy


100%|██████████| 132/132 [00:18<00:00,  7.01it/s, accuracy=0.922, current_lr=0.0001, epoch=11, f1_score=0.919, loss=0.285896, stage=train]
100%|██████████| 33/33 [00:01<00:00, 18.56it/s, accuracy=0.921, epoch=11, f1_score=0.918, loss=0.291129, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.921082823565512 accuracy


100%|██████████| 132/132 [00:17<00:00,  7.46it/s, accuracy=0.928, current_lr=0.0001, epoch=12, f1_score=0.926, loss=0.260992, stage=train]
100%|██████████| 33/33 [00:01<00:00, 19.01it/s, accuracy=0.924, epoch=12, f1_score=0.923, loss=0.270440, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.9244791666666666 accuracy


100%|██████████| 132/132 [00:17<00:00,  7.55it/s, accuracy=0.934, current_lr=0.0001, epoch=13, f1_score=0.932, loss=0.237613, stage=train]
100%|██████████| 33/33 [00:01<00:00, 16.96it/s, accuracy=0.93, epoch=13, f1_score=0.929, loss=0.252875, stage=eval] 


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.9299879814639236 accuracy


100%|██████████| 132/132 [00:18<00:00,  6.95it/s, accuracy=0.939, current_lr=0.0001, epoch=14, f1_score=0.937, loss=0.217758, stage=train]
100%|██████████| 33/33 [00:01<00:00, 19.49it/s, accuracy=0.933, epoch=14, f1_score=0.932, loss=0.236808, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.933393429626118 accuracy


100%|██████████| 132/132 [00:17<00:00,  7.57it/s, accuracy=0.943, current_lr=0.0001, epoch=15, f1_score=0.941, loss=0.202696, stage=train]
100%|██████████| 33/33 [00:01<00:00, 18.82it/s, accuracy=0.937, epoch=15, f1_score=0.935, loss=0.223561, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.9369900931011547 accuracy


100%|██████████| 132/132 [00:17<00:00,  7.60it/s, accuracy=0.947, current_lr=0.0001, epoch=16, f1_score=0.946, loss=0.189272, stage=train]
100%|██████████| 33/33 [00:01<00:00, 17.59it/s, accuracy=0.939, epoch=16, f1_score=0.938, loss=0.210925, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.9392755681818182 accuracy


100%|██████████| 132/132 [00:17<00:00,  7.37it/s, accuracy=0.95, current_lr=0.0001, epoch=17, f1_score=0.948, loss=0.179096, stage=train] 
100%|██████████| 33/33 [00:01<00:00, 19.04it/s, accuracy=0.942, epoch=17, f1_score=0.941, loss=0.199391, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.9423805366862904 accuracy


100%|██████████| 132/132 [00:18<00:00,  7.27it/s, accuracy=0.954, current_lr=0.0001, epoch=18, f1_score=0.952, loss=0.167709, stage=train]
100%|██████████| 33/33 [00:01<00:00, 20.02it/s, accuracy=0.944, epoch=18, f1_score=0.942, loss=0.192539, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.9440013116056268 accuracy


100%|██████████| 132/132 [00:21<00:00,  6.17it/s, accuracy=0.956, current_lr=0.0001, epoch=19, f1_score=0.955, loss=0.157917, stage=train]
100%|██████████| 33/33 [00:02<00:00, 15.63it/s, accuracy=0.947, epoch=19, f1_score=0.945, loss=0.185573, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.9467875867178945 accuracy


100%|██████████| 132/132 [00:21<00:00,  6.22it/s, accuracy=0.958, current_lr=0.0001, epoch=20, f1_score=0.957, loss=0.148720, stage=train]
100%|██████████| 33/33 [00:01<00:00, 19.12it/s, accuracy=0.949, epoch=20, f1_score=0.947, loss=0.176911, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.9485358397165934 accuracy
