In [1]:
import sys
sys.path.append('/Users/parth/Desktop/Flash/')

In [1]:
import torch 
import torch.nn as nn
import numpy as np
import pandas as pd 
import neptune

from sklearn import metrics
from sklearn.model_selection import train_test_split

import timm
import albumentations
from neptune.types import File
from torchret import Model 
import plotly.graph_objects as go

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
configs = {
    'lr' : 1e-4,
    'eta_min' : 1e-6,
    'T_0' : 20,
    'epochs' : 20,
    'step_scheduler_after' : 'epoch',

    'train_bs' : 256,
    'valid_bs' : 256,

    'num_workers' : 0,
    'pin_memory' : False,

    'model_name' : 'resnet10t',
    'pretrained' : True,
    'num_classes' : 10,
    'in_channels' : 1,
    'device' : 'mps',

    'model_path' : 'digit-recognizer.pt',
    'save_best_model' : 'on_eval_metric',
    'save_on_metric' : 'accuracy',
    'save_model_at_every_epoch' : False,

}

In [3]:
class DigitRecognizerDataset(torch.utils.data.Dataset):
    def __init__(self, df, augmentations):
        self.df = df
        self.targets = df.label.values
        self.df = self.df.drop(columns=["label"])
        self.augmentations = augmentations

        self.images = self.df.to_numpy(dtype=np.float32).reshape((-1, 28, 28))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        targets = self.targets[item]
        image = self.images[item]
        image = np.expand_dims(image, axis=0)

        return {
            "images": torch.tensor(image, dtype=torch.float),
            "targets": torch.tensor(targets, dtype=torch.long),
        }

In [4]:
train_augs = albumentations.Compose(
        [
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0,
                p=1.0,
            ),
        ],
        p=1.0,
    )

valid_augs = albumentations.Compose(
        [
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0,
                p=1.0,
            ),
        ],
        p=1.0,
    )

In [5]:
class DigitRecognizerModel(Model):
    def __init__(self):
        super().__init__()

        self.model = timm.create_model(
            model_name = configs['model_name'],
            pretrained=configs['pretrained'],
            in_chans=configs['in_channels'],
            num_classes=configs['num_classes'],
        )

        self.num_workers = configs['num_workers']
        self.pin_memory = configs['pin_memory']
        self.step_scheduler_after = configs['step_scheduler_after']

        self.model_path = configs['model_path']
        self.save_best_model = configs['save_best_model']
        self.save_on_metric = configs['save_on_metric']
        self.save_model_at_every_epoch = configs['save_model_at_every_epoch']

    def setup_logger(self):
        neptune_api = NEPTUNE_API_TOKEN
        self.run = neptune.init_run(
            project='abc/Digit-recog',
            api_token=neptune_api,
            capture_stdout=True,       # Enable capture of stdout
            capture_stderr=True,       # Enable capture of stderr
            capture_traceback=True,    # Enable capture of traceback
            capture_hardware_metrics=True,  # Enable capture of hardware metrics
            source_files='*.ipynb'  # Capture notebook outputs
        )
        self.run['parameters'] = configs

    def valid_one_step_logs(self, batch_id, data, logits, loss, metrics):
        if batch_id % len(self.validloader) == 100:
            images = data['images']
            labels = data['targets']
            outputs = np.argmax(logits.cpu().detach().numpy(), axis=1)
            images = images.permute(0, 2, 3, 1).squeeze().cpu()
            for i in range(len(images)):
                description = f'true label : {labels[i]} prediction : {outputs[i]}'
                self.run["valid/prediction_example"].append(File.as_image(images[i]), description = description)

    def train_one_epoch_logs(self, loss, monitor):
        self.run['train/loss'].append(loss)
        self.run['train/monitors'].append(monitor)

    def valid_one_epoch_logs(self, loss, monitor):
        self.run['valid/loss'].append(loss)
        self.run['valid/monitors'].append(monitor)

    def monitor_metrics(self, outputs, targets):
        device = targets.device.type
        outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1)
        targets = targets.cpu().detach().numpy()
        acc = metrics.accuracy_score(targets, outputs)
        acc = torch.tensor(acc).float()
        f1_score = metrics.f1_score(targets, outputs, average = 'macro')
        f1_score = torch.tensor(f1_score)
        return {"accuracy": acc, 'f1_score' : f1_score}
    
    def monitor_loss(self, outputs, targets):
        loss = nn.CrossEntropyLoss()(outputs, targets)
        return loss

    def fetch_optimizer(self):
        opt = torch.optim.SGD(
            self.parameters(),
            lr=configs['lr'],
            momentum=0.9,
        )
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            opt, 
            T_0 = configs['T_0'],
            eta_min = configs['eta_min'],
            T_mult= 1
        )
        return opt, sch

    def forward(self, images, targets=None):
        x = self.model(images)
        if targets is not None:
            loss = self.monitor_loss(x, targets)
            metrics = self.monitor_metrics(x, targets)
            return x, loss, metrics
        return x, 0, {}

In [6]:
def main():
    df = pd.read_csv('train.csv')
    train, test = train_test_split(df, test_size=0.2)

    train_dataset = DigitRecognizerDataset(df = train, augmentations = train_augs)
    valid_dataset = DigitRecognizerDataset(df = test, augmentations = valid_augs)

    model = DigitRecognizerModel()
    model.fit(train_dataset, valid_dataset, train_bs = configs['train_bs'], valid_bs = configs['valid_bs'], device = configs['device'], epochs = configs['epochs'], logger=False)

if __name__ == "__main__":
    main()

100%|██████████| 132/132 [00:22<00:00,  5.78it/s, accuracy=0.256, current_lr=0.0001, epoch=1, f1_score=0.246, loss=2.180834, stage=train]
100%|██████████| 33/33 [00:01<00:00, 17.35it/s, accuracy=0.439, epoch=1, f1_score=0.421, loss=1.711278, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.4385107078335502 accuracy


100%|██████████| 132/132 [00:23<00:00,  5.61it/s, accuracy=0.555, current_lr=9.94e-5, epoch=2, f1_score=0.541, loss=1.432079, stage=train]
100%|██████████| 33/33 [00:02<00:00, 11.64it/s, accuracy=0.639, epoch=2, f1_score=0.625, loss=1.194170, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.6387037436167399 accuracy


100%|██████████| 132/132 [00:21<00:00,  6.19it/s, accuracy=0.701, current_lr=9.76e-5, epoch=3, f1_score=0.692, loss=1.039973, stage=train]
100%|██████████| 33/33 [00:01<00:00, 16.95it/s, accuracy=0.749, epoch=3, f1_score=0.74, loss=0.910744, stage=eval] 


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.7489983981305902 accuracy


100%|██████████| 132/132 [00:17<00:00,  7.35it/s, accuracy=0.779, current_lr=9.46e-5, epoch=4, f1_score=0.773, loss=0.799934, stage=train]
100%|██████████| 33/33 [00:01<00:00, 20.02it/s, accuracy=0.807, epoch=4, f1_score=0.801, loss=0.726126, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.8068090767571421 accuracy


100%|██████████| 132/132 [00:20<00:00,  6.51it/s, accuracy=0.827, current_lr=9.05e-5, epoch=5, f1_score=0.822, loss=0.647008, stage=train]
100%|██████████| 33/33 [00:01<00:00, 17.24it/s, accuracy=0.84, epoch=5, f1_score=0.835, loss=0.603217, stage=eval] 


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.8400531754349218 accuracy


100%|██████████| 132/132 [00:20<00:00,  6.30it/s, accuracy=0.853, current_lr=8.55e-5, epoch=6, f1_score=0.848, loss=0.544903, stage=train]
100%|██████████| 33/33 [00:01<00:00, 19.57it/s, accuracy=0.862, epoch=6, f1_score=0.858, loss=0.517981, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.8623980193427114 accuracy


100%|██████████| 132/132 [00:17<00:00,  7.44it/s, accuracy=0.873, current_lr=7.96e-5, epoch=7, f1_score=0.869, loss=0.468792, stage=train]
100%|██████████| 33/33 [00:03<00:00, 10.33it/s, accuracy=0.878, epoch=7, f1_score=0.875, loss=0.454529, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.8780776515151515 accuracy


100%|██████████| 132/132 [00:19<00:00,  6.92it/s, accuracy=0.888, current_lr=7.3e-5, epoch=8, f1_score=0.883, loss=0.417407, stage=train]
100%|██████████| 33/33 [00:01<00:00, 17.93it/s, accuracy=0.891, epoch=8, f1_score=0.888, loss=0.409891, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.8908162153128422 accuracy


100%|██████████| 132/132 [00:19<00:00,  6.81it/s, accuracy=0.898, current_lr=6.58e-5, epoch=9, f1_score=0.895, loss=0.377299, stage=train]
100%|██████████| 33/33 [00:01<00:00, 18.52it/s, accuracy=0.897, epoch=9, f1_score=0.893, loss=0.379101, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.8965344555450209 accuracy


100%|██████████| 132/132 [00:18<00:00,  7.14it/s, accuracy=0.903, current_lr=5.82e-5, epoch=10, f1_score=0.9, loss=0.350856, stage=train]  
100%|██████████| 33/33 [00:01<00:00, 18.28it/s, accuracy=0.902, epoch=10, f1_score=0.898, loss=0.355237, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.902216273726839 accuracy


100%|██████████| 132/132 [00:19<00:00,  6.82it/s, accuracy=0.91, current_lr=5.05e-5, epoch=11, f1_score=0.906, loss=0.328866, stage=train] 
100%|██████████| 33/33 [00:02<00:00, 16.28it/s, accuracy=0.906, epoch=11, f1_score=0.903, loss=0.336978, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.9060861016764785 accuracy


100%|██████████| 132/132 [00:20<00:00,  6.39it/s, accuracy=0.913, current_lr=4.28e-5, epoch=12, f1_score=0.911, loss=0.310399, stage=train]
100%|██████████| 33/33 [00:02<00:00, 11.49it/s, accuracy=0.909, epoch=12, f1_score=0.907, loss=0.323446, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.9093458616372311 accuracy


100%|██████████| 132/132 [00:20<00:00,  6.42it/s, accuracy=0.918, current_lr=3.52e-5, epoch=13, f1_score=0.916, loss=0.298190, stage=train]
100%|██████████| 33/33 [00:01<00:00, 17.77it/s, accuracy=0.912, epoch=13, f1_score=0.91, loss=0.313337, stage=eval] 


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.9116222316568549 accuracy


100%|██████████| 132/132 [00:19<00:00,  6.78it/s, accuracy=0.921, current_lr=2.8e-5, epoch=14, f1_score=0.918, loss=0.286649, stage=train]
100%|██████████| 33/33 [00:01<00:00, 18.64it/s, accuracy=0.914, epoch=14, f1_score=0.913, loss=0.304640, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.9144813526760448 accuracy


100%|██████████| 132/132 [00:20<00:00,  6.58it/s, accuracy=0.923, current_lr=2.14e-5, epoch=15, f1_score=0.92, loss=0.280121, stage=train] 
100%|██████████| 33/33 [00:01<00:00, 18.34it/s, accuracy=0.914, epoch=15, f1_score=0.911, loss=0.301539, stage=eval]
100%|██████████| 132/132 [00:17<00:00,  7.43it/s, accuracy=0.923, current_lr=1.55e-5, epoch=16, f1_score=0.921, loss=0.275765, stage=train]
100%|██████████| 33/33 [00:01<00:00, 18.37it/s, accuracy=0.917, epoch=16, f1_score=0.915, loss=0.294742, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.917103729464791 accuracy


100%|██████████| 132/132 [00:17<00:00,  7.46it/s, accuracy=0.924, current_lr=1.05e-5, epoch=17, f1_score=0.922, loss=0.273015, stage=train]
100%|██████████| 33/33 [00:01<00:00, 18.89it/s, accuracy=0.917, epoch=17, f1_score=0.915, loss=0.291873, stage=eval]
100%|██████████| 132/132 [00:17<00:00,  7.36it/s, accuracy=0.925, current_lr=6.4e-6, epoch=18, f1_score=0.922, loss=0.269056, stage=train]
100%|██████████| 33/33 [00:01<00:00, 17.93it/s, accuracy=0.916, epoch=18, f1_score=0.914, loss=0.294466, stage=eval]
100%|██████████| 132/132 [00:18<00:00,  7.17it/s, accuracy=0.925, current_lr=3.42e-6, epoch=19, f1_score=0.923, loss=0.268326, stage=train]
100%|██████████| 33/33 [00:01<00:00, 17.34it/s, accuracy=0.917, epoch=19, f1_score=0.916, loss=0.293591, stage=eval]
100%|██████████| 132/132 [00:18<00:00,  7.24it/s, accuracy=0.926, current_lr=1.61e-6, epoch=20, f1_score=0.923, loss=0.266349, stage=train]
100%|██████████| 33/33 [00:01<00:00, 18.90it/s, accuracy=0.917, epoch=20, f1_score=0.914, 